├── .gitignore ├── LRP.py ├── Main.py ├── NN.py ├── README.md ├── crossant.py ├── data.py ├── requirements.txt ├── settings_tt.py └── training.py /.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | data 3 | __pycache__ 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /LRP.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torch as tc 3 | import pandas as pd 4 | import os 5 | import numpy as np 6 | from tqdm import tqdm 7 | from NN import Simple_Model 8 | 9 | 10 | 11 | def reverse_feature_expansion(frame): 12 | 13 | rev_features = frame[frame['therapy_diagnostics'].str.endswith('_rev')] 14 | rev_features['therapy_diagnostics_antero'] = rev_features['therapy_diagnostics'].str.split('_rev').str[0].copy() 15 | rev_features = rev_features.rename(columns = {'LRP': 'LRP_rev', 'input_score':'input_score_rev'}) 16 | antero_features = frame[~frame['therapy_diagnostics'].str.endswith('_rev')] 17 | antero_features = antero_features.rename(columns = {'LRP': 'LRP_antero', 'therapy_diagnostics': 'therapy_diagnostics_antero', 'input_score':'input_score_antero'}) 18 | 19 | frame_unexpanded = antero_features.merge(rev_features[['therapy_diagnostics_antero', 'LRP_rev', 'sample_name', 'input_score_rev']], 20 | how='left', on = ['therapy_diagnostics_antero', 'sample_name']) 21 | 22 | frame_unexpanded['LRP'] = frame_unexpanded[['LRP_antero', 'LRP_rev']].sum(axis=1, skipna=True) 23 | return frame_unexpanded 24 | 25 | 26 | def np2pd(nparray, sample_names, feature_names): 27 | 28 | frame = pd.DataFrame(nparray[:,:], index = np.array(sample_names),columns = feature_names) 29 | frame['sample_name'] = sample_names 30 | long_frame = pd.melt(frame, id_vars = 'sample_name', var_name = 'therapy_diagnostics', value_name = 'LRP') 31 | 32 | return long_frame.reset_index(drop=True) 33 | 34 | 35 | def calculate_LRP_simple(model, data_collection, setting, PATH = './results/LRP/', fold=None): 36 | if not os.path.exists(PATH): 37 | os.makedirs(PATH) 38 | device = tc.device(setting['LRP_device']) 39 | 40 | # this is really test+val set 41 | patdata_test, surv_test = data_collection.get_test_set() 42 | model.eval().to(device) 43 | 44 | patdata_test, surv_test = patdata_test.to(device), surv_test.to(device) 45 | 46 | R = model(patdata_test) 47 | 48 | sample_names = data_collection.get_test_names() 49 | predictions = pd.DataFrame({'sample_name': sample_names, 'prediction': R.clone().detach().cpu().numpy().squeeze()}) 50 | 51 | input_relevance = model.relprop(R).cpu().detach().numpy() 52 | print(R.shape) 53 | LRP_scores_long = np2pd(input_relevance, sample_names, data_collection.f_feature_names) 54 | input_long = data_collection.get_input_long() 55 | 56 | LRP_scores_long_and_inputs = LRP_scores_long.merge(input_long, on = ['sample_name', 'therapy_diagnostics'], how = 'left') 57 | 58 | LRP_scores_long_and_inputs_unexpanded = reverse_feature_expansion(LRP_scores_long_and_inputs) 59 | 60 | risk_prediction = pd.DataFrame({'sample_name': sample_names, 61 | 'risk_prediction_all': R.clone().detach().cpu().numpy().squeeze()}) 62 | LRP_scores_long_and_inputs_unexpanded = LRP_scores_long_and_inputs_unexpanded.merge(risk_prediction, how = 'left') 63 | 64 | LRP_scores_long_and_inputs_unexpanded.to_csv(PATH + 'LRP_'+ str(model.classname) + '_scores_input_' + str(fold) + '.csv') 65 | 66 | -------------------------------------------------------------------------------- /Main.py: -------------------------------------------------------------------------------- 1 | import torch as tc 2 | from data import combine_data 3 | from settings_tt import setting 4 | from training import train_test 5 | from NN import Simple_Model 6 | import os 7 | import pandas as pd 8 | from LRP import calculate_LRP_simple 9 | from crossant import crossvalidate 10 | 11 | 12 | def main(): 13 | if not os.path.exists('./results/data/'): 14 | os.makedirs('./results/data/') 15 | data_coll = combine_data(setting, current_test_split = 0, splits = 5) 16 | 17 | ################ 18 | #construct model 19 | ################ 20 | 21 | model = Simple_Model(data_coll, setting) 22 | print(model) 23 | 24 | ################# 25 | #crossvalidate 26 | ################# 27 | crossvalidate(model, data_coll, setting, train_on_all=True, train_on_parts=True) 28 | 29 | 30 | 31 | if __name__ == '__main__': 32 | print('lets go') 33 | main() 34 | 35 | -------------------------------------------------------------------------------- /NN.py: -------------------------------------------------------------------------------- 1 | import torch as tc 2 | import torch.nn as nn 3 | import copy 4 | import torchtuples as tt 5 | 6 | 7 | class LRP_Linear(nn.Module): 8 | def __init__(self, inp, outp, gamma=0.01, eps=1e-5): 9 | super(LRP_Linear, self).__init__() 10 | self.A_dict = {} 11 | self.linear = nn.Linear(inp, outp) 12 | nn.init.xavier_uniform_(self.linear.weight, gain=nn.init.calculate_gain('relu')) 13 | self.gamma = tc.tensor(gamma) 14 | self.eps = tc.tensor(eps) 15 | self.rho = None 16 | self.iteration = None 17 | 18 | def forward(self, x): 19 | 20 | if not self.training: 21 | self.A_dict[self.iteration] = x.clone() 22 | return self.linear(x) 23 | 24 | def relprop(self, R): 25 | device = next(self.parameters()).device 26 | 27 | A = self.A_dict[self.iteration].clone() 28 | A, self.eps = A.to(device), self.eps.to(device) 29 | 30 | Ap = A.clamp(min=0).detach().data.requires_grad_(True) 31 | Am = A.clamp(max=0).detach().data.requires_grad_(True) 32 | 33 | 34 | zpp = self.newlayer(1).forward(Ap) 35 | zmm = self.newlayer(-1, no_bias=True).forward(Am) 36 | 37 | zmp = self.newlayer(1, no_bias=True).forward(Am) 38 | zpm = self.newlayer(-1).forward(Ap) 39 | 40 | with tc.no_grad(): 41 | Y = self.forward(A).data 42 | 43 | sp = ((Y > 0).float() * R / (zpp + zmm + self.eps * ((zpp + zmm == 0).float() + tc.sign(zpp + zmm)))).data 44 | sm = ((Y < 0).float() * R / (zmp + zpm + self.eps * ((zmp + zpm == 0).float() + tc.sign(zmp + zpm)))).data 45 | 46 | (zpp * sp).sum().backward() 47 | cpp = Ap.grad 48 | Ap.grad = None 49 | Ap.requires_grad_(True) 50 | 51 | (zpm * sm).sum().backward() 52 | cpm = Ap.grad 53 | Ap.grad = None 54 | Ap.requires_grad_(True) 55 | 56 | (zmp * sm).sum().backward() 57 | cmp = Am.grad 58 | Am.grad = None 59 | Am.requires_grad_(True) 60 | 61 | (zmm * sp).sum().backward() 62 | cmm = Am.grad 63 | Am.grad = None 64 | Am.requires_grad_(True) 65 | 66 | 67 | R_1 = (Ap * cpp).data 68 | R_2 = (Ap * cpm).data 69 | R_3 = (Am * cmp).data 70 | R_4 = (Am * cmm).data 71 | 72 | 73 | return R_1 + R_2 + R_3 + R_4 74 | 75 | def newlayer(self, sign, no_bias=False): 76 | 77 | if sign == 1: 78 | rho = lambda p: p + self.gamma * p.clamp(min=0) # Replace 1e-9 by zero 79 | else: 80 | rho = lambda p: p + self.gamma * p.clamp(max=0) # same here 81 | 82 | layer_new = copy.deepcopy(self.linear) 83 | 84 | try: 85 | layer_new.weight = nn.Parameter(rho(self.linear.weight)) 86 | except AttributeError: 87 | pass 88 | 89 | try: 90 | layer_new.bias = nn.Parameter(self.linear.bias * 0 if no_bias else rho(self.linear.bias)) 91 | except AttributeError: 92 | pass 93 | 94 | return layer_new 95 | 96 | 97 | class LRP_ReLU(nn.Module): 98 | def __init__(self): 99 | super(LRP_ReLU, self).__init__() 100 | self.relu = nn.ReLU() 101 | 102 | def forward(self, x): 103 | return self.relu(x) 104 | 105 | def relprop(self, R): 106 | return R 107 | 108 | 109 | class LRP_DropOut(nn.Module): 110 | def __init__(self, p): 111 | super(LRP_DropOut, self).__init__() 112 | self.dropout = nn.Dropout(p) 113 | 114 | def forward(self, x): 115 | return self.dropout(x) 116 | 117 | def relprop(self, R): 118 | return R 119 | 120 | class LRP_cat(nn.Module): 121 | def __init__(self): 122 | super(LRP_cat, self).__init__() 123 | 124 | def forward(self, list_of_tensors): 125 | self.sizes = [tensor.shape[1] for tensor in list_of_tensors] 126 | return tc.cat(list_of_tensors, axis=1) 127 | 128 | def relprop(self, R): 129 | splitted_R = tc.split(R,self.sizes,dim=1) 130 | return splitted_R 131 | 132 | 133 | class Simple_Model(nn.Module): 134 | classname = 'Simple Model' 135 | def __init__(self, data_coll, setting): 136 | super(Simple_Model, self).__init__() 137 | 138 | self.inp, self.hidden, outp = data_coll.f_nfeatures, int(setting['factor_hidden_nodes']*data_coll.f_nfeatures), 1 139 | self.hidden_depth = setting['hidden_depth_simple'] 140 | self.layers = nn.Sequential(LRP_DropOut(p = setting['input_dropout']), LRP_Linear(self.inp, self.hidden, gamma=0.01), LRP_ReLU()) 141 | for i in range(self.hidden_depth): 142 | self.layers.add_module('dropout', LRP_DropOut(p = setting['dropout'])) 143 | self.layers.add_module('LRP_Linear' + str(i + 1), LRP_Linear(self.hidden, self.hidden, gamma=0.01)) 144 | self.layers.add_module('LRP_ReLU' + str(i + 1), LRP_ReLU()) 145 | self.layers.add_module('dropout', LRP_DropOut(p = setting['dropout'])) 146 | self.layers.add_module('LRP_Linear_last', LRP_Linear(self.hidden, outp, gamma=0.01)) 147 | 148 | 149 | count_parameters(self) 150 | def forward(self, x): 151 | return self.layers.forward(x) 152 | 153 | def relprop(self, R): 154 | assert not self.training, 'relprop does not work during training time' 155 | for module in self.layers[::-1]: 156 | R = module.relprop(R) 157 | return R 158 | 159 | 160 | 161 | 162 | def count_parameters(model): 163 | nparams = sum(p.numel() for p in model.parameters() if p.requires_grad) 164 | print('{} contains {} trainable parameters'.format(model.classname,nparams)) 165 | 166 | 167 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 2 | Code base for [**Decoding pan-cancer treatment outcomes using multimodal real-world data and explainable AI**](https://www.nature.com/articles/s43018-024-00891-1) 3 | 4 | Install requirements can be found in requirements.txt. 5 | 6 | To start the the analysis, run the Main file 7 | 8 | 9 | ``` 10 | python Main.py 11 | ``` 12 | 13 | Data will prepared in data.py and loaded. Here, this is a artificial dataset in which the outcome depends on a single variable. 14 | 15 | The *crossvalidate* function will split data into five folds and start 16 | 17 | - the neural network training on the pan cancer data, 18 | - the neural network training on the single entity cancer data 19 | - LRP computation based on the trained model from the pan cancer training 20 | 21 | Two new folders (*results* and *LRP*) are generated and contain the training and LRP results, respectively. 22 | 23 | 24 | 25 | 26 | To change settings with respect to model architecture or training regimen, see settings_tt.py. By default, code is run on the gpu, this can be changed using the device variable in settings. 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /crossant.py: -------------------------------------------------------------------------------- 1 | from training import train_test, train_test_individual_cancers, save_predictions 2 | from LRP import calculate_LRP_simple 3 | import pandas as pd 4 | import numpy as np 5 | import copy 6 | import torch as tc 7 | 8 | 9 | def crossvalidate(model, data_coll, setting, train_on_all=True, train_on_parts=True): 10 | tc.save(model.state_dict(), './results/raw_params.pt') 11 | 12 | for i in range(data_coll.splits): 13 | data_coll.change_test_set(i) 14 | 15 | if train_on_all: 16 | model = train_test(model, data_coll, setting, fold=i) 17 | save_predictions(model, data_coll, setting, fold=i) 18 | 19 | if model.classname == 'Simple Model': 20 | calculate_LRP_simple(model, data_coll, setting, fold=i) 21 | 22 | 23 | if train_on_parts: 24 | train_test_individual_cancers(model, data_coll, setting, fold=i) 25 | 26 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch as tc 2 | from torch.utils.data import Dataset 3 | import pandas as pd 4 | import numpy as np 5 | from sklearn.preprocessing import StandardScaler 6 | from settings_tt import place 7 | import random 8 | from collections import Counter 9 | 10 | np.random.seed(0) 11 | class DataSimulator: 12 | def __init__(self): 13 | pass 14 | 15 | def simulate_sample(self, i): 16 | tc.manual_seed(i) 17 | therapy = tc.bernoulli(tc.ones(3)*0.5) 18 | diagnostics = tc.randn(10) 19 | 20 | sum1 = therapy[0] * diagnostics[0] + 2 * therapy[0] * diagnostics[1] 21 | sum2 = therapy[1] * diagnostics[2] if (diagnostics[2]> 0) else 0 22 | sum3 = therapy[1] * diagnostics[2] if (diagnostics[2]< -1.0) else 0 23 | sum4 = therapy[2] * diagnostics[3] * diagnostics[4] 24 | 25 | survival = sum1+sum2+sum3+sum4 + 0.01*tc.randn(1) 26 | return therapy, diagnostics, survival 27 | 28 | def simulate_data(self,n): 29 | therapy, diagnostics, survival = zip(*[self.simulate_sample(i) for i in range(n)]) 30 | return tc.stack(therapy), tc.stack(diagnostics), tc.stack(survival) 31 | 32 | 33 | def combine_data(setting, current_test_split, splits): 34 | 35 | if place == 'M': 36 | nsamples = 10000 37 | synthetic_data = pd.DataFrame(np.random.rand(nsamples,6)*2) 38 | 39 | synthetic_data.columns = ['somefeature1_' +str(col) for col in synthetic_data.columns[:5]] + ['somefeature2_' + str(col) for col in synthetic_data.columns[5:]] 40 | synthetic_data=synthetic_data.sort_index(ascending=True) 41 | 42 | synthetic_data_rev = 1-synthetic_data.copy() 43 | synthetic_data_rev.columns = [str(col)+'_rev' for col in synthetic_data_rev.columns] 44 | full_data = pd.concat((synthetic_data, synthetic_data_rev), axis=1) 45 | full_data = synthetic_data.copy() 46 | random_numbers = np.random.rand(full_data.shape[0]) 47 | full_data['Cancer_C0'] = (random_numbers<0.5) * 1.0 48 | full_data['Cancer_C1'] = (random_numbers>=0.5) * 1.0 49 | 50 | survival_data= pd.DataFrame({'duration': full_data.iloc[:,0]}) 51 | survival_data['event'] = ((tc.rand_like(tc.tensor(survival_data['duration']))>0.2)*1.0).numpy() 52 | survival_data = survival_data.rename(columns = {setting['duration_name']: 'duration', setting['event_name']: 'event'}) 53 | 54 | ####################### 55 | 56 | cancer_type = pd.DataFrame(full_data[[str(col) for col in full_data.columns if str(col).startswith('Cancer_')]].idxmax(axis=1), 57 | columns = ['cancer_type'], index = full_data.index) 58 | 59 | data_collection = Data_Collection(full_data, survival_data, cancer_type, current_test_split=current_test_split, splits=splits) 60 | 61 | return data_collection 62 | 63 | 64 | class Data_Collection(Dataset): 65 | def __init__(self, full_data, survival_data, cancer_type, current_test_split, splits): 66 | self.splits = splits 67 | full_data = full_data.reindex(sorted(full_data.columns), axis=1) 68 | 69 | self.full_data = full_data 70 | self.survival_data = survival_data 71 | 72 | self.f_tensor = tc.tensor(full_data.to_numpy()).float() 73 | self.f_sample_names = np.array(full_data.index) 74 | self.f_feature_names = full_data.columns 75 | 76 | self.s_tensor = tc.tensor(survival_data.to_numpy()).float() 77 | self.s_sample_names = survival_data.index 78 | self.s_feature_names = survival_data.columns 79 | 80 | self.cancer_type = cancer_type 81 | print('cancer types:', self.cancer_type['cancer_type']) 82 | self.unique_cancer_types = self.cancer_type['cancer_type'].unique() 83 | 84 | self.nsamples = self.f_tensor.shape[0] 85 | self.f_nfeatures = self.f_tensor.shape[1] 86 | self.s_nfeatures = self.s_tensor.shape[1] 87 | 88 | tc.manual_seed(0) 89 | self.random_sequence = tc.randperm(self.nsamples) 90 | #self.test_splits = np.array_split(self.random_sequence, splits) 91 | 92 | #generate train and test ids for all splits 93 | self.test_splits = self.generate_stratified_test_sets(self.random_sequence, cancer_type['cancer_type'], splits) 94 | 95 | #select current training and test ids 96 | self.change_test_set(current_test_split) 97 | 98 | self.unique_features_dict = self.make_dict(full_data) 99 | 100 | assert self.nsamples == self.s_tensor.shape[0], 'the input data have different sample lengths' 101 | assert all(self.f_sample_names == self.s_sample_names), 'samples (full, survival) have different names' 102 | 103 | 104 | def get_train_set(self): 105 | return self.f_tensor[self.training_ids,:], self.s_tensor[self.training_ids,:] 106 | 107 | def get_test_set(self): 108 | return self.f_tensor[self.test_ids,:], self.s_tensor[self.test_ids,:] 109 | 110 | 111 | def generate_stratified_test_sets(self, sequence, group_var, splits): 112 | unique_groups = group_var.unique() 113 | permuted_group_vars = np.array(group_var)[sequence] 114 | 115 | lists_of_group_ids = [sequence[permuted_group_vars==g] for g in unique_groups] 116 | lists_of_split_data = [np.array_split(group_seq, splits) for group_seq in lists_of_group_ids] 117 | 118 | random.seed(0) 119 | [random.shuffle(seq) for seq in lists_of_split_data] 120 | 121 | 122 | test_splits = [] 123 | for i in range(splits): 124 | one_split = tc.cat([group_seq[i] for group_seq in lists_of_split_data]) 125 | 126 | #permute one_split for fair test - val separation during training 127 | one_split_new = one_split[tc.randperm(one_split.shape[0])] 128 | 129 | 130 | test_splits.append(one_split_new) 131 | 132 | for test_ids in test_splits: 133 | print(Counter(list(np.array(group_var)[test_ids]))) 134 | 135 | return test_splits 136 | 137 | 138 | 139 | def change_test_set(self, new_test_split): 140 | self.current_test_split = new_test_split 141 | #self.test_ids = self.random_sequence[self.test_splits[new_test_split]] #doppelt gemoppelt? 142 | self.test_ids = self.test_splits[new_test_split] #besser? 143 | 144 | self.training_ids = np.setdiff1d(self.random_sequence, self.test_ids, assume_unique=True) 145 | 146 | self.train_len = self.training_ids.shape[0] 147 | self.test_len = self.test_ids.shape[0] 148 | 149 | 150 | def get_test_names(self): 151 | return self.f_sample_names[self.test_ids] 152 | 153 | def get_train_names(self): 154 | return self.f_sample_names[self.training_ids] 155 | 156 | 157 | def get_cancer_types_test(self): 158 | return np.array(self.cancer_type)[self.test_ids].squeeze() 159 | 160 | def get_cancer_types_train(self): 161 | return np.array(self.cancer_type)[self.training_ids].squeeze() 162 | 163 | def get_input_long(self): 164 | full_data = self.full_data 165 | full_data['sample_name'] = full_data.index 166 | return pd.melt(full_data, id_vars = 'sample_name', var_name = 'therapy_diagnostics', value_name = 'input_score') 167 | 168 | 169 | def make_dict(self, dat): 170 | featureclass = [col.split('_')[0] for col in dat.columns] 171 | counts = np.unique(featureclass, return_counts=True) 172 | count_dict = {counts[0][i]:counts[1][i] for i,_ in enumerate(counts[0])} 173 | return count_dict 174 | 175 | 176 | 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Brotli==1.1.0 2 | certifi==2024.2.2 3 | charset-normalizer==3.3.2 4 | contourpy==1.2.0 5 | cycler==0.12.1 6 | Cython==3.0.9 7 | feather-format==0.4.1 8 | filelock==3.13.1 9 | fonttools==4.50.0 10 | fsspec==2024.3.0 11 | h5py==3.10.0 12 | idna==3.6 13 | inflate64==1.0.0 14 | Jinja2==3.1.3 15 | joblib==1.3.2 16 | kiwisolver==1.4.5 17 | llvmlite==0.42.0 18 | MarkupSafe==2.1.5 19 | matplotlib==3.8.3 20 | mpmath==1.3.0 21 | multivolumefile==0.2.3 22 | networkx==3.2.1 23 | numba==0.59.0 24 | numpy==1.26.4 25 | nvidia-cublas-cu12==12.1.3.1 26 | nvidia-cuda-cupti-cu12==12.1.105 27 | nvidia-cuda-nvrtc-cu12==12.1.105 28 | nvidia-cuda-runtime-cu12==12.1.105 29 | nvidia-cudnn-cu12==8.9.2.26 30 | nvidia-cufft-cu12==11.0.2.54 31 | nvidia-curand-cu12==10.3.2.106 32 | nvidia-cusolver-cu12==11.4.5.107 33 | nvidia-cusparse-cu12==12.1.0.106 34 | nvidia-nccl-cu12==2.19.3 35 | nvidia-nvjitlink-cu12==12.4.99 36 | nvidia-nvtx-cu12==12.1.105 37 | packaging==24.0 38 | pandas==1.1.5 39 | pillow==10.2.0 40 | psutil==5.9.8 41 | py7zr==0.21.0 42 | pyarrow==15.0.1 43 | pybcj==1.0.2 44 | pycox==0.2.3 45 | pycryptodomex==3.20.0 46 | pyparsing==3.1.2 47 | pyppmd==1.1.0 48 | python-dateutil==2.9.0.post0 49 | pytz==2024.1 50 | pyzstd==0.15.9 51 | requests==2.31.0 52 | scikit-learn==1.4.1.post1 53 | scipy==1.12.0 54 | six==1.16.0 55 | sklearn==0.0 56 | sympy==1.12 57 | texttable==1.7.0 58 | threadpoolctl==3.3.0 59 | torch==2.2.1 60 | torchtuples==0.2.2 61 | tqdm==4.66.2 62 | triton==2.2.0 63 | typing_extensions==4.10.0 64 | urllib3==2.2.1 65 | -------------------------------------------------------------------------------- /settings_tt.py: -------------------------------------------------------------------------------- 1 | place = 'M' 2 | 3 | 4 | 5 | if place == 'M': 6 | setting = {#data params 7 | 'DATAPATH': '../some_data/', 8 | 'therapy_file': 'therapy.csv', 9 | 'diagnostics_file': 'diagnostics.csv', 10 | 'survival_file': 'survival.csv', 11 | 'duration_name': '0', 12 | 'event_name': 'event', 13 | 14 | #training params 15 | 16 | 'hidden_depth_simple': 0, 17 | 18 | 'factor_hidden_nodes': 10, # determines the width ofhidden layers -> width = factor_hidden_nodes * input_width 19 | 20 | 'training_device': 'cuda:0', 21 | 22 | 'training_batch_size': 1024, 23 | 'reduce_lr_epochs': [50,50,50], 24 | 'lr': 1e-4, 25 | 'dropout': 0.5, 26 | 'input_dropout': 0.5, 27 | 'weight_decay': 0, 28 | 29 | #LRP params 30 | 'nepochs': 16, 31 | 'batch_size': 8, 32 | 33 | 'LRP_device': 'cuda:0', 34 | 'LRP_batch_size': 1000} 35 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from torch.optim import SGD, Adam 3 | import torch.nn as nn 4 | import torch as tc 5 | from torch.nn.utils import clip_grad_norm_ 6 | from tqdm import tqdm 7 | import torchtuples as tt 8 | import numpy as np 9 | from pycox.models import CoxPH 10 | from pycox.evaluation import EvalSurv 11 | import os 12 | import pandas as pd 13 | from NN import Simple_Model 14 | from data import combine_data 15 | import gc 16 | 17 | def train_test(model, data_collection, setting, fold, PATH = './results/training/'): 18 | if not os.path.exists(PATH): 19 | os.makedirs(PATH) 20 | 21 | #assert that data_collection is in the correct fold 22 | assert fold == data_collection.current_test_split, 'fold not right' 23 | 24 | device = tc.device(setting['training_device']) 25 | 26 | #reset model to the same state for each training fold 27 | model.load_state_dict(tc.load('./results/raw_params.pt')) 28 | model.train().to(device) 29 | 30 | #get train and test/val sets 31 | patdata_train, surv_train = data_collection.get_train_set() 32 | patdata_test_val, surv_test_val = data_collection.get_test_set() 33 | 34 | 35 | tv_length = patdata_test_val.shape[0] 36 | 37 | patdata_test, surv_test = patdata_test_val[:tv_length//2,:], surv_test_val[:tv_length//2,:] 38 | patdata_val, surv_val = patdata_test_val[tv_length//2:,:], surv_test_val[tv_length//2:,:] 39 | 40 | #################################################################################################### 41 | #save test and validation names for later quality control: e.g. each sample is only in one test/val set 42 | samplenames_test_val = data_collection.get_test_names() 43 | test_names = samplenames_test_val[:tv_length//2] 44 | val_names = samplenames_test_val[tv_length//2:] 45 | 46 | test_name_frame= pd.DataFrame({'sample_name': test_names, 'type': 'test', 'fold': fold}) 47 | val_name_frame= pd.DataFrame({'sample_name': val_names, 'type': 'val', 'fold': fold}) 48 | train_name_frame= pd.DataFrame({'sample_name': data_collection.get_train_names(), 'type': 'train', 'fold': fold}) 49 | name_frame = pd.concat([test_name_frame, val_name_frame, train_name_frame],axis=0) 50 | name_frame.to_csv('./results/training/name_frame.csv',mode='w' if fold==0 else 'a', header=fold==0) 51 | ###################################################################################################### 52 | 53 | print('train size:', patdata_train.shape,'test_size:', patdata_test.shape) 54 | 55 | #prepare test and val data for pycox workflow 56 | test_data = (patdata_test, (surv_test[:,0], surv_test[:,1])) 57 | val_data = (patdata_val, (surv_val[:,0], surv_val[:,1])) 58 | 59 | coxph_model = CoxPH(model, tt.optim.Adam(setting['lr'], weight_decay=setting['weight_decay']), device = device) 60 | 61 | effective_epochs = [] 62 | for exp, training_epochs in enumerate(setting['reduce_lr_epochs']): 63 | callbacks = [tt.callbacks.EarlyStopping(),tt.callbacks.ClipGradNorm(model, max_norm=1.0)] 64 | print('train for {} epochs with lr {}'.format(training_epochs, setting['lr']*10**(-exp))) 65 | coxph_model.optimizer.set_lr(setting['lr']*10**(-exp)) 66 | log = coxph_model.fit(patdata_train, (surv_train[:,0], surv_train[:,1]), batch_size=setting['training_batch_size'], epochs=training_epochs, callbacks=callbacks, 67 | val_data=val_data, val_batch_size = setting['training_batch_size'], verbose=1) 68 | effective_epochs.append(log.epoch) 69 | 70 | _ = coxph_model.compute_baseline_hazards() 71 | surv_pred = coxph_model.predict_surv_df(test_data[0]) 72 | 73 | #compute concordance and integrated brier scores 74 | ev = EvalSurv(surv_pred, np.array(test_data[1][0]).squeeze(), np.array(test_data[1][1]).squeeze(), censor_surv='km') 75 | concordance = ev.concordance_td() 76 | time_grid = np.linspace(np.array(test_data[1][0]).squeeze().min(), np.array(test_data[1][0]).squeeze().max(), 100) 77 | integrated_brier_score = ev.integrated_brier_score(time_grid) 78 | integrated_nbll = ev.integrated_nbll(time_grid) 79 | 80 | print('concordance:', concordance) 81 | print('integrated_brier_score:', integrated_brier_score) 82 | print('integrated_nbll:', integrated_nbll) 83 | 84 | 85 | concordance_scores = pd.DataFrame({'fold': [fold], 'conc_score': concordance, 'brier_score': integrated_brier_score, 'ncancers': test_data[0].shape[0]}) 86 | concordance_scores.to_csv('./results/training/conc_scores.csv',mode='w' if fold==0 else 'a', header=fold==0) 87 | 88 | strat_conc_scores = conc_score_per_cancer(coxph_model, data_collection) 89 | strat_conc_scores['fold'] = fold 90 | strat_conc_scores.to_csv('./results/training/stratified_conc_scores.csv',mode='w' if fold==0 else 'a', header=fold==0) 91 | 92 | return model 93 | 94 | 95 | def conc_score_per_cancer(trained_model, data_collection): 96 | patdata_test_val, surv_test_val = data_collection.get_test_set() 97 | tv_length = patdata_test_val.shape[0] 98 | 99 | patdata_test, surv_test = patdata_test_val[:tv_length//2,:], surv_test_val[:tv_length//2,:] 100 | patdata_val, surv_val = patdata_test_val[tv_length//2:,:], surv_test_val[tv_length//2:,:] #not needed 101 | 102 | cancer_types_test_val = data_collection.get_cancer_types_test() 103 | cancer_types_test = cancer_types_test_val[:tv_length//2] 104 | cancer_types_val = cancer_types_test_val[tv_length//2:] #not needed 105 | 106 | unique_cancer_types_test = data_collection.unique_cancer_types 107 | results = [] 108 | for cancer_type in unique_cancer_types_test: 109 | current_ids = cancer_types_test==cancer_type 110 | test_data_stratified = (patdata_test[current_ids,:], (surv_test[current_ids,0], surv_test[current_ids,1])) 111 | 112 | 113 | if (test_data_stratified[0].shape[0]<10) | (np.array(test_data_stratified[1][1]).sum() <= 5): 114 | continue 115 | 116 | surv_pred = trained_model.predict_surv_df(test_data_stratified[0]) 117 | 118 | ev_strat = EvalSurv(surv_pred, np.array(test_data_stratified[1][0]).squeeze(), np.array(test_data_stratified[1][1]).squeeze(), censor_surv='km') 119 | 120 | concordance_strat = ev_strat.concordance_td() 121 | 122 | time_grid = np.linspace(np.array(test_data_stratified[1][0]).squeeze().min(), np.array(test_data_stratified[1][0]).squeeze().max(), 100) 123 | integrated_brier_score_strat = ev_strat.integrated_brier_score(time_grid) 124 | 125 | results.append(pd.DataFrame({'cancer_type': [cancer_type], 'concordance': concordance_strat,"integrated_brier":integrated_brier_score_strat, 126 | 'ncancers_test':test_data_stratified[0].shape[0]})) 127 | 128 | gc.collect() 129 | 130 | stratified_concordance_scores = pd.concat(results, axis=0) 131 | 132 | return stratified_concordance_scores 133 | 134 | 135 | def train_test_individual_cancers(model, data_collection, setting, fold, PATH = './results/training/'): 136 | if not os.path.exists(PATH): 137 | os.makedirs(PATH) 138 | 139 | device = tc.device(setting['training_device']) 140 | model.train().to(device) 141 | 142 | patdata_train, surv_train = data_collection.get_train_set() 143 | patdata_test_val, surv_test_val = data_collection.get_test_set() 144 | 145 | tv_length = patdata_test_val.shape[0] 146 | 147 | patdata_test, surv_test = patdata_test_val[:tv_length//2,:], surv_test_val[:tv_length//2,:] 148 | patdata_val, surv_val = patdata_test_val[tv_length//2:,:], surv_test_val[tv_length//2:,:] 149 | 150 | 151 | cancer_types_test_val = data_collection.get_cancer_types_test() 152 | print(cancer_types_test_val) 153 | cancer_types_test = cancer_types_test_val[:tv_length//2] 154 | cancer_types_val = cancer_types_test_val[tv_length//2:] 155 | 156 | cancer_types_train = data_collection.get_cancer_types_train() 157 | unique_cancer_types_test = data_collection.unique_cancer_types 158 | results = [] 159 | 160 | for cancer_type in unique_cancer_types_test: 161 | print(cancer_type) 162 | 163 | model.load_state_dict(tc.load('./results/raw_params.pt')) 164 | 165 | current_test_ids = cancer_types_test==cancer_type 166 | current_train_ids = cancer_types_train==cancer_type 167 | current_val_ids = cancer_types_val==cancer_type 168 | 169 | test_data_now = (patdata_test[current_test_ids,:], (surv_test[current_test_ids,0], surv_test[current_test_ids,1])) 170 | val_data_now = (patdata_val[current_val_ids,:], (surv_val[current_val_ids,0], surv_val[current_val_ids,1])) 171 | 172 | patdata_train_now = patdata_train[current_train_ids,:] 173 | surv_train_now = surv_train[current_train_ids,:] 174 | 175 | print('train_samples:', patdata_train_now.shape[0], 'val_samples:', val_data_now[0].shape[0], 'test_samples:', test_data_now[0].shape[0], 176 | 'train_ratio:', patdata_train_now.shape[0]/(val_data_now[0].shape[0] + test_data_now[0].shape[0] + patdata_train_now.shape[0])) 177 | 178 | if (test_data_now[0].shape[0]<10) | (np.array(test_data_now[1][1]).sum() <= 5) | (val_data_now[0].shape[0]<10) | (np.array(val_data_now[1][1]).sum() <= 5): 179 | continue 180 | 181 | coxph_model = CoxPH(model, tt.optim.Adam(setting['lr'], weight_decay=setting['weight_decay']), device = device) 182 | #coxph_model = CoxPH(model, tt.optim.SGD(setting['lr']*0.01, momentum=0.9), device = device) 183 | 184 | callbacks = [tt.callbacks.EarlyStopping(), tt.callbacks.ClipGradNorm(model, max_norm=1.0)] 185 | 186 | effective_epochs = [] 187 | for exp, training_epochs in enumerate(setting['reduce_lr_epochs']): 188 | callbacks = [tt.callbacks.EarlyStopping(), tt.callbacks.ClipGradNorm(model, max_norm=1.0)] 189 | print('train for {} epochs with lr {}'.format(training_epochs, setting['lr']*10**(-exp))) 190 | coxph_model.optimizer.set_lr(setting['lr']*10**(-exp)) 191 | log = coxph_model.fit(patdata_train[current_train_ids,:], (surv_train[current_train_ids,0], surv_train[current_train_ids,1]), 192 | batch_size=setting['training_batch_size'], epochs=training_epochs, callbacks=callbacks, 193 | val_data=val_data_now, val_batch_size = setting['training_batch_size'], verbose=1) 194 | effective_epochs.append(log.epoch) 195 | 196 | _ = coxph_model.compute_baseline_hazards() 197 | surv_pred = coxph_model.predict_surv_df(test_data_now[0]) 198 | 199 | ev = EvalSurv(surv_pred, np.array(test_data_now[1][0]).squeeze(), np.array(test_data_now[1][1]).squeeze(), censor_surv='km') 200 | concordance = ev.concordance_td() 201 | time_grid = np.linspace(np.array(test_data_now[1][0]).squeeze().min(), np.array(test_data_now[1][0]).squeeze().max(), 100) 202 | integrated_brier_score = ev.integrated_brier_score(time_grid) 203 | integrated_nbll = ev.integrated_nbll(time_grid) 204 | 205 | print('concordance:', concordance) 206 | print('integated_brier_score:', integrated_brier_score) 207 | print('integrated_nbll:', integrated_nbll) 208 | 209 | results.append(pd.DataFrame({'cancer_type': [cancer_type], 'concordance': concordance, 'integrated_brier':integrated_brier_score, 'ncancers_test':test_data_now[0].shape[0]})) 210 | 211 | del coxph_model 212 | del log 213 | del callbacks 214 | del ev 215 | tc.cuda.empty_cache() 216 | 217 | gc.collect() 218 | 219 | strat_conc_scores = pd.concat(results,axis=0) 220 | strat_conc_scores['fold'] = fold 221 | strat_conc_scores.to_csv('./results/training/stratified_conc_scores_individual.csv',mode='w' if fold==0 else 'a', header=fold==0) 222 | 223 | 224 | def save_predictions(model, data_collection, setting, fold, PATH = './results/training/'): 225 | device = tc.device(setting['training_device']) 226 | model.eval().to(device) 227 | 228 | patdata_test_val, surv_test_val = data_collection.get_test_set() 229 | patdata_test_val, surv_test_val = patdata_test_val.to(device), surv_test_val.to(device) 230 | sample_names = data_collection.get_test_names() 231 | 232 | pred = model.forward(patdata_test_val).cpu().detach().numpy().squeeze() 233 | 234 | #maybe implement different risks here 235 | 236 | 237 | df = pd.DataFrame({'sample_name': sample_names, 'risk_prediction_all': pred}) 238 | 239 | df.to_csv(PATH + 'risk_predictions.csv',mode='w' if fold==0 else 'a', header=fold==0) 240 | 241 | 242 | 243 | --------------------------------------------------------------------------------