├── WGCNA_gbmlgg.R ├── example_data ├── input_adjacency_matrix │ └── split1_adjacency_matrix.csv └── input_features_labels │ ├── split1_test_320d_features_labels.csv │ └── split1_train_320d_features_labels.csv ├── figs └── pipeline.png ├── gradients_to_feature_importance.py ├── model_GAT.py ├── model_GAT_v4.py ├── options.py ├── pretrained_models └── split1_grad_720d_all_50epochs.pt ├── readme.md ├── test_cv.py ├── test_model.py ├── train_model.py ├── train_test_new.py └── utils.py /WGCNA_gbmlgg.R: -------------------------------------------------------------------------------- 1 | ### For each split of the features, perform WGCNA and generate adjacency matrix. 2 | 3 | data_folder = "./data/RNAseq_graph/RNAseq" 4 | output_folder = "./data/RNAseq_graph/wgcna_output" 5 | 6 | cancer = "GBMLGG" 7 | data_file = paste(data_folder, cancer, "split15_train_320d_features_labels.csv", sep='/') 8 | dir.create(paste(output_folder, cancer, sep='/')) 9 | # WGCNA parameters 10 | wgcna_power = 6 11 | wgcna_minModuleSize = 10 12 | wgcna_mergeCutHeight = 0.25 13 | data = read.csv(data_file, header=F) # each row is a patient 14 | geneExp = as.matrix(data[2:dim(data)[1], 83:322]) 15 | 16 | # gene as columns for WGCNA 17 | # geneExp = t(geneExp) 18 | dim(geneExp) 19 | 20 | ## imputate the NA by zero values. 21 | geneExp[is.na(geneExp)]<-0 22 | 23 | library(WGCNA) 24 | adjacency = adjacency(geneExp, power = wgcna_power) 25 | write.csv(adjacency,file=paste(output_folder, cancer, "split15_adjacency_matrix.csv", sep='/'),quote=F,row.names = F) 26 | 27 | -------------------------------------------------------------------------------- /figs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/MLA-GNN/6f950e05768a7d03ed28c60fe578ca78ee68a4df/figs/pipeline.png -------------------------------------------------------------------------------- /gradients_to_feature_importance.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script converts the feature gradients to the importance score of each node in each layer. 3 | Input: gradients, 3 layers * 240 nodes = 720 features in total. gradients.shape = [n, 720]. 4 | Output: the feature importance scores of each layer, 240 rows (nodes) * 4columns (classes). 5 | """ 6 | 7 | import numpy as np 8 | from sklearn.preprocessing import normalize 9 | import pandas as pd 10 | 11 | 12 | def gradients_each_class(gradients, labels): 13 | """ 14 | according to the labels, separate the gradients for each class. 15 | return: all_class_importance: [3, 240, 3], the first dimension is 3 classes(class0, class1, class2, overall), 16 | the second dimension is 240 nodes, the last dimension is 3 GAT layers. 17 | """ 18 | class0_index = np.argwhere(labels == 0) 19 | class0_gradients = gradients[class0_index[:, 0], :] 20 | class0_importance = np.mean(class0_gradients, axis=0) 21 | class0_importance = np.reshape(class0_importance, (3, 240)).T 22 | class0_importance = np.expand_dims(normalize(class0_importance, axis=0, norm='max'), axis=0) 23 | 24 | class1_index = np.argwhere(labels == 1) 25 | class1_gradients = gradients[class1_index[:, 0], :] 26 | class1_importance = np.mean(class1_gradients, axis=0) 27 | class1_importance = np.reshape(class1_importance, (3, 240)).T 28 | class1_importance = np.expand_dims(normalize(class1_importance, axis=0, norm='max'), axis=0) 29 | 30 | class2_index = np.argwhere(labels == 2) 31 | class2_gradients = gradients[class2_index[:, 0], :] 32 | class2_importance = np.mean(class2_gradients, axis=0) 33 | class2_importance = np.reshape(class2_importance, (3, 240)).T 34 | class2_importance = np.expand_dims(normalize(class2_importance, axis=0, norm='max'), axis=0) 35 | 36 | # print(class0_gradients.shape, class1_gradients.shape, class2_gradients.shape) 37 | # print(class0_importance.shape, class1_importance.shape, class2_importance.shape) 38 | 39 | overall_importance = (class0_importance+class1_importance+class2_importance)/3.0 40 | all_class_importance = np.concatenate(( 41 | class0_importance, class1_importance, class2_importance, overall_importance), axis=0) # [4, 240, 3] 42 | # print(all_class_importance.shape) 43 | 44 | return all_class_importance 45 | 46 | 47 | 48 | 49 | def all_class_feature_importance(): 50 | """ 51 | load the feature gradients, and then separate the gradients of the samples from each class, 52 | and calculate the average gradients as the feature importance of each class. 53 | """ 54 | 55 | feature_importance_all = np.zeros((4, 240, 3)) 56 | 57 | for k in range(1, 16): 58 | gradients_labels = np.array(pd.read_csv( 59 | './results/grad/feature_gradients/split' + str(k) + '_720d.csv')).astype(float)[:, 1:] 60 | gradients = gradients_labels[:, :720] 61 | labels = gradients_labels[:, -1] 62 | 63 | all_class_importance = gradients_each_class(gradients, labels)# [4, 240, 3] 64 | ### average the feature importance from three GAT layers. 65 | # all_class_importance = np.mean(all_class_importance, axis=1, keepdims=True) 66 | feature_importance_all += all_class_importance # [4, 240, 3] 67 | 68 | avg_layer_importance = np.mean(feature_importance_all, axis=-1, keepdims=True) 69 | feature_importance_all = np.concatenate((feature_importance_all, avg_layer_importance), axis=-1) # [4,240,4] 70 | print(feature_importance_all.shape) 71 | 72 | header = np.array(['class0', 'class1', 'class2', 'all_classes']).reshape(1, 4) 73 | layers = ['layer1', 'layer2', 'layer3', 'overall'] 74 | 75 | for layer in range(4): 76 | layer_importance = feature_importance_all[:, :, layer].T # [240, 4] 77 | layer_importance = np.concatenate((header, layer_importance), axis=0) 78 | # print(layer_importance) 79 | pd.DataFrame(layer_importance).to_csv( 80 | "./results/grad/feature_gradients/" + layers[layer] + "_feature_importance.csv", header=0, index=0) 81 | 82 | 83 | all_class_feature_importance() 84 | -------------------------------------------------------------------------------- /model_GAT.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the implementation of Graph Attention Network. 3 | The code is inspired by "https://github.com/Diego999/pyGAT" 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.optim import Adam 9 | from torch_geometric.nn import global_mean_pool as gap 10 | 11 | from torch.nn import init, Parameter 12 | import torch.optim.lr_scheduler as lr_scheduler 13 | 14 | from utils import * 15 | 16 | 17 | class GAT(nn.Module): 18 | def __init__(self, opt, input_dim, omic_dim, label_dim, dropout, alpha): 19 | 20 | super(GAT, self).__init__() 21 | self.dropout = dropout 22 | self.act = define_act_layer(act_type=opt.act_type) 23 | 24 | self.nhids = [8, 16, 12] 25 | self.nheads = [4, 3, 4] 26 | self.fc_dim = [64, 48, 32] 27 | 28 | self.attentions1 = [GraphAttentionLayer( 29 | input_dim, self.nhids[0], dropout=dropout, alpha=alpha, concat=True) for _ in range(self.nheads[0])] 30 | for i, attention1 in enumerate(self.attentions1): 31 | self.add_module('attention1_{}'.format(i), attention1) 32 | 33 | self.attentions2 = [GraphAttentionLayer( 34 | self.nhids[0]*self.nheads[0], self.nhids[1], dropout=dropout, alpha=alpha, concat=True) for _ in range(self.nheads[1])] 35 | for i, attention2 in enumerate(self.attentions2): 36 | self.add_module('attention2_{}'.format(i), attention2) 37 | 38 | self.attentions3 = [GraphAttentionLayer( 39 | self.nhids[1]*self.nheads[1], self.nhids[2], dropout=dropout, alpha=alpha, concat=True) for _ in range(self.nheads[2])] 40 | for i, attention3 in enumerate(self.attentions3): 41 | self.add_module('attention3_{}'.format(i), attention3) 42 | 43 | self.dropout_layer = nn.Dropout(p=self.dropout) 44 | 45 | # lin_input_dim = self.nhids[0]*self.nheads[0] + self.nhids[1]*self.nheads[1] + self.nhids[2]*self.nheads[2] 46 | lin_input_dim = opt.lin_input_dim 47 | 48 | # self.lin1 = torch.nn.Linear(lin_input_dim, lin_dim1) 49 | # self.lin2 = torch.nn.Linear(lin_dim1, label_dim) 50 | 51 | self.pool1 = torch.nn.Linear(self.nhids[0]*self.nheads[0], 1) 52 | self.pool2 = torch.nn.Linear(self.nhids[1]*self.nheads[1], 1) 53 | self.pool3 = torch.nn.Linear(self.nhids[2] * self.nheads[2], 1) 54 | 55 | fc1 = nn.Sequential( 56 | nn.Linear(lin_input_dim, self.fc_dim[0]), 57 | nn.ELU(), 58 | nn.AlphaDropout(p=self.dropout, inplace=False)) 59 | 60 | fc2 = nn.Sequential( 61 | nn.Linear(self.fc_dim[0], self.fc_dim[1]), 62 | nn.ELU(), 63 | nn.AlphaDropout(p=self.dropout, inplace=False)) 64 | 65 | fc3 = nn.Sequential( 66 | nn.Linear(self.fc_dim[1], self.fc_dim[2]), 67 | nn.ELU(), 68 | nn.AlphaDropout(p=self.dropout, inplace=False)) 69 | 70 | fc4 = nn.Sequential( 71 | nn.Linear(self.fc_dim[2], omic_dim), 72 | nn.ELU(), 73 | nn.AlphaDropout(p=self.dropout, inplace=False)) 74 | 75 | self.encoder = nn.Sequential(fc1, fc2, fc3, fc4) 76 | self.classifier = nn.Sequential(nn.Linear(omic_dim, label_dim)) 77 | 78 | 79 | self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False) 80 | self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False) 81 | 82 | 83 | def forward(self, x, adj, grad_labels, opt): 84 | 85 | # print("input shape:", x.shape) 86 | batch = torch.linspace(0, x.size(0) - 1, x.size(0), dtype=torch.long) 87 | batch = batch.unsqueeze(1).repeat(1, x.size(1)).view(-1).cuda() 88 | 89 | if opt.cnv_dim == 80: 90 | cnv_feature = torch.mean(x[:, :80, :], dim=-1) 91 | x = x[:, 80:, :] 92 | x0 = torch.mean(x, dim=-1) 93 | # print("x0:", x0.shape) 94 | 95 | x = self.dropout_layer(x) 96 | x = torch.cat([att(x, adj) for att in self.attentions1], dim=-1) # [bs, N, nhid1*nhead1] 97 | 98 | x1 = self.pool1(x).squeeze(-1) 99 | # print("x1:", x1.shape) 100 | 101 | x = self.dropout_layer(x) 102 | x = torch.cat([att(x, adj) for att in self.attentions2], dim=-1) # [bs, N, nhid2*nhead2] 103 | 104 | x2 = self.pool2(x).squeeze(-1) 105 | # print("x2:", x2) 106 | 107 | 108 | if opt.lin_input_dim == 800 or opt.lin_input_dim == 720: 109 | x = torch.cat([x0, x1, x2], dim=1) 110 | elif opt.lin_input_dim == 320 or opt.lin_input_dim == 240: 111 | if opt.which_layer == 'layer1': 112 | x = x0 113 | elif opt.which_layer == 'layer2': 114 | x = x1 115 | elif opt.which_layer == 'layer3': 116 | x = x2 117 | 118 | if opt.cnv_dim == 80: 119 | x = torch.cat([cnv_feature, x], dim=1) 120 | 121 | GAT_features = x 122 | 123 | # print("feature shape:", x.shape) 124 | 125 | features = self.encoder(x) 126 | out = self.classifier(features) 127 | 128 | fc_features = features 129 | 130 | if self.act is not None: 131 | out = self.act(out) 132 | 133 | if isinstance(self.act, nn.Sigmoid): 134 | out = out * self.output_range + self.output_shift 135 | 136 | if opt.task == "grad": 137 | one_hot_labels = torch.zeros(grad_labels.shape[0], 3).cuda().scatter(1, grad_labels.reshape(-1, 1), 1) 138 | y_c = torch.sum(one_hot_labels*out) 139 | elif opt.task == "surv": 140 | y_c = torch.sum(out) 141 | # print(out, y_c) 142 | GAT_features.grad = None 143 | GAT_features.retain_grad() 144 | y_c.backward(retain_graph=True) 145 | gradients = np.maximum(GAT_features.grad.detach().cpu().numpy(), 0)# (batch_size, 720) 146 | feature_importance = np.mean(gradients, 0) 147 | 148 | return GAT_features, fc_features, out, gradients, feature_importance 149 | 150 | 151 | 152 | class GraphAttentionLayer(nn.Module): 153 | 154 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 155 | super(GraphAttentionLayer, self).__init__() 156 | self.dropout = dropout 157 | self.in_features = in_features 158 | self.out_features = out_features 159 | self.alpha = alpha 160 | self.concat = concat 161 | 162 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 163 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 164 | self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) 165 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 166 | 167 | self.leakyrelu = nn.LeakyReLU(self.alpha) 168 | self.dropout_layer = nn.Dropout(p=self.dropout) 169 | 170 | def forward(self, input, adj): 171 | """ 172 | input: mini-batch input. size: [batch_size, num_nodes, node_feature_dim] 173 | adj: adjacency matrix. size: [num_nodes, num_nodes]. need to be expanded to batch_adj later. 174 | """ 175 | h = torch.matmul(input, self.W)# [bs, N, F] 176 | bs, N, _ = h.size() 177 | 178 | a_input = torch.cat([h.repeat(1, 1, N).view(bs, N * N, -1), h.repeat(1, N, 1)], dim=-1).view(bs, N, -1, 2 * self.out_features) 179 | # print("h size:", a_input.shape) 180 | 181 | e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3)) 182 | 183 | batch_adj = torch.unsqueeze(adj, 0).repeat(bs, 1, 1) 184 | # print("batch adj size:", batch_adj.shape) 185 | 186 | zero_vec = -9e15*torch.ones_like(e) 187 | attention = torch.where(batch_adj > 0, e, zero_vec) 188 | attention = self.dropout_layer(F.softmax(attention, dim=-1)) # [bs, N, N] 189 | # print("attention shape:", attention.shape) 190 | h_prime = torch.bmm(attention, h)# [bs, N, F] 191 | # print("h_prime:", h_prime.shape) 192 | 193 | if self.concat: 194 | return F.elu(h_prime) 195 | else: 196 | return h_prime 197 | 198 | def __repr__(self): 199 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 200 | 201 | 202 | def define_optimizer(opt, model): 203 | optimizer = None 204 | if opt.optimizer_type == 'adabound': 205 | optimizer = adabound.AdaBound(model.parameters(), lr=opt.lr, final_lr=opt.final_lr) 206 | elif opt.optimizer_type == 'adam': 207 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.999), weight_decay=opt.weight_decay) 208 | elif opt.optimizer_type == 'adagrad': 209 | optimizer = torch.optim.Adagrad(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay, initial_accumulator_value=0.1) 210 | else: 211 | raise NotImplementedError('initialization method [%s] is not implemented' % opt.optimizer) 212 | return optimizer 213 | 214 | 215 | def define_reg(model): 216 | 217 | for W in model.parameters(): 218 | loss_reg = torch.abs(W).sum() 219 | 220 | return loss_reg 221 | 222 | 223 | def define_scheduler(opt, optimizer): 224 | if opt.lr_policy == 'linear': 225 | def lambda_rule(epoch): 226 | lr_l = 1.0 - max(0, epoch + 1) / float(opt.num_epochs + 1) 227 | return lr_l 228 | 229 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 230 | elif opt.lr_policy == 'exp': 231 | scheduler = lr_scheduler.ExponentialLR(optimizer, 0.1, last_epoch=-1) 232 | elif opt.lr_policy == 'step': 233 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 234 | elif opt.lr_policy == 'plateau': 235 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 236 | elif opt.lr_policy == 'cosine': 237 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 238 | else: 239 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 240 | return scheduler 241 | 242 | -------------------------------------------------------------------------------- /model_GAT_v4.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import GATConv, SAGPooling 6 | from torch.nn import LayerNorm, Parameter 7 | import torch.optim.lr_scheduler as lr_scheduler 8 | from torch_geometric.utils import to_dense_batch, to_dense_adj 9 | from utils import * 10 | 11 | 12 | class GAT(torch.nn.Module): 13 | def __init__(self, opt): 14 | super(GAT, self).__init__() 15 | self.fc_dropout = opt.fc_dropout 16 | self.GAT_dropout = opt.GAT_dropout 17 | self.act = define_act_layer(act_type=opt.act_type) 18 | 19 | self.nhids = [8, 16, 12] 20 | self.nheads = [4, 3, 4] 21 | self.fc_dim = [64, 48, 32] 22 | 23 | self.conv1 = GATConv(opt.input_dim, self.nhids[0], heads=self.nheads[0], 24 | dropout=self.GAT_dropout) 25 | self.conv2 = GATConv(self.nhids[0]*self.nheads[0], self.nhids[1], heads=self.nheads[1], 26 | dropout=self.GAT_dropout) 27 | self.conv3 = GATConv(self.nhids[1]*self.nheads[1], self.nhids[2], heads=self.nheads[2], 28 | dropout=self.GAT_dropout) 29 | 30 | self.pool1 = torch.nn.Linear(self.nhids[0]*self.nheads[0], 1) 31 | self.pool2 = torch.nn.Linear(self.nhids[1]*self.nheads[1], 1) 32 | self.pool3 = torch.nn.Linear(self.nhids[2]*self.nheads[2], 1) 33 | 34 | self.layer_norm0 = LayerNorm(opt.num_nodes) 35 | self.layer_norm1 = LayerNorm(opt.num_nodes) 36 | self.layer_norm2 = LayerNorm(opt.num_nodes) 37 | self.layer_norm3 = LayerNorm(opt.num_nodes) 38 | 39 | fc1 = nn.Sequential( 40 | nn.Linear(opt.lin_input_dim, self.fc_dim[0]), 41 | nn.ELU(), 42 | nn.AlphaDropout(p=self.fc_dropout, inplace=True)) 43 | 44 | fc2 = nn.Sequential( 45 | nn.Linear(self.fc_dim[0], self.fc_dim[1]), 46 | nn.ELU(), 47 | nn.AlphaDropout(p=self.fc_dropout, inplace=False)) 48 | 49 | fc3 = nn.Sequential( 50 | nn.Linear(self.fc_dim[1], self.fc_dim[2]), 51 | nn.ELU(), 52 | nn.AlphaDropout(p=self.fc_dropout, inplace=False)) 53 | 54 | fc4 = nn.Sequential( 55 | nn.Linear(self.fc_dim[2], opt.omic_dim), 56 | nn.ELU(), 57 | nn.AlphaDropout(p=self.fc_dropout, inplace=False)) 58 | 59 | self.encoder = nn.Sequential(fc1, fc2, fc3, fc4) 60 | self.classifier = nn.Sequential(nn.Linear(opt.omic_dim, opt.label_dim)) 61 | 62 | self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False) 63 | self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False) 64 | 65 | 66 | def forward(self, x, adj, grad_labels, batch, opt): 67 | 68 | ### layer1 69 | x = x.requires_grad_() 70 | x0 = to_dense_batch(torch.mean(x, dim=-1), batch=batch)[0] #[bs, nodes] 71 | 72 | ### layer2 73 | x = F.dropout(x, p=0.2, training=self.training) 74 | x = F.elu(self.conv1(x, adj)) #[bs*nodes, nhids[0]*nheads[0]] 75 | 76 | x1 = to_dense_batch(self.pool1(x).squeeze(-1), batch=batch)[0] #[bs, nodes] 77 | 78 | x = F.dropout(x, p=0.2, training=self.training) 79 | x = F.elu(self.conv2(x, adj)) # [bs*nodes, nhids[0]*nheads[0]] 80 | 81 | x2 = to_dense_batch(self.pool2(x).squeeze(-1), batch=batch)[0] # [bs, nodes] 82 | 83 | 84 | if opt.layer_norm == "True": 85 | x0 = self.layer_norm0(x0) 86 | x1 = self.layer_norm1(x1) 87 | x2 = self.layer_norm0(x2) 88 | 89 | if opt.which_layer == 'all': 90 | x = torch.cat([x0, x1, x2], dim=1) 91 | 92 | elif opt.which_layer == 'layer1': 93 | x = x0 94 | elif opt.which_layer == 'layer2': 95 | x = x1 96 | elif opt.which_layer == 'layer3': 97 | x = x2 98 | 99 | GAT_features = x 100 | 101 | features = self.encoder(x) 102 | out = self.classifier(features) 103 | 104 | fc_features = features 105 | 106 | if self.act is not None: 107 | out = self.act(out) 108 | 109 | if isinstance(self.act, nn.Sigmoid): 110 | out = out * self.output_range + self.output_shift 111 | 112 | 113 | return GAT_features, fc_features, out 114 | 115 | 116 | def define_optimizer(opt, model): 117 | optimizer = None 118 | if opt.optimizer_type == 'adabound': 119 | optimizer = adabound.AdaBound(model.parameters(), lr=opt.lr, final_lr=opt.final_lr) 120 | elif opt.optimizer_type == 'adam': 121 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.999), weight_decay=opt.weight_decay) 122 | elif opt.optimizer_type == 'adagrad': 123 | optimizer = torch.optim.Adagrad(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay, initial_accumulator_value=0.1) 124 | else: 125 | raise NotImplementedError('initialization method [%s] is not implemented' % opt.optimizer) 126 | return optimizer 127 | 128 | 129 | def define_reg(model): 130 | 131 | for W in model.parameters(): 132 | loss_reg = torch.abs(W).sum() 133 | 134 | return loss_reg 135 | 136 | 137 | def define_scheduler(opt, optimizer): 138 | if opt.lr_policy == 'linear': 139 | def lambda_rule(epoch): 140 | lr_l = 1.0 - max(0, epoch + 1) / float(opt.num_epochs + 1) 141 | return lr_l 142 | 143 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 144 | elif opt.lr_policy == 'exp': 145 | scheduler = lr_scheduler.ExponentialLR(optimizer, 0.1, last_epoch=-1) 146 | elif opt.lr_policy == 'step': 147 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 148 | elif opt.lr_policy == 'plateau': 149 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 150 | elif opt.lr_policy == 'cosine': 151 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 152 | else: 153 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 154 | return scheduler -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | 6 | 7 | ### Parser 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') 13 | parser.add_argument('--batch_size', type=int, default=8, help='Number of batches to train/test for. Default: 64') 14 | parser.add_argument('--num_epochs', type=int, default=50, help='Number of epochs for training') 15 | parser.add_argument('--model_dir', type=str, default='./pretrained_models', help='models are saved here') 16 | parser.add_argument('--results_dir', type=str, default='./results', help='models are saved here') 17 | 18 | parser.add_argument('--lambda_cox', type=float, default=1) 19 | parser.add_argument('--lambda_reg', type=float, default=3e-4) 20 | parser.add_argument('--lambda_nll', type=float, default=1) 21 | 22 | parser.add_argument('--task', type=str, default='grad', help='surv | grad') 23 | parser.add_argument('--label_dim', type=int, default=3, help='size of output, grad task: label_dim=2, surv task: label_dim=1') 24 | parser.add_argument('--input_dim', type=int, default=1, help="input_size for omic vector") 25 | parser.add_argument('--lin_input_dim', type=int, default=720, help="the feature extracted by GAT layers") 26 | parser.add_argument('--which_layer', type=str, default='all', help='layer1 | layer2 | layer3, which GAT layer as the input of fc layers.') 27 | parser.add_argument('--cnv_dim', type=int, default=0, help="if use CNV as input, dim=80, if do not use CNV, dim=0") 28 | parser.add_argument('--omic_dim', type=int, default=32, help="dimension of the linear layer") 29 | parser.add_argument('--act_type', type=str, default="none", help='activation function') 30 | 31 | 32 | parser.add_argument('--optimizer_type', type=str, default='adam') 33 | parser.add_argument('--lr_policy', default='linear', type=str, help='5e-4 for Adam | 1e-3 for AdaBound') 34 | parser.add_argument('--lr', default=0.002, type=float, help='5e-4 for Adam | 1e-3 for AdaBound') 35 | parser.add_argument('--final_lr', default=0.1, type=float, help='Used for AdaBound') 36 | parser.add_argument('--weight_decay', default=5e-4, type=float, help='Used for Adam. L2 Regularization on weights.') 37 | parser.add_argument('--dropout', default=0.2, type=float, help='Dropout rate') 38 | parser.add_argument('--adj_thresh', default=0.08, type=float, help='Threshold convert the similarity matrix to adjacency matrix') 39 | parser.add_argument('--alpha', default=0.2, type=float, help='Used in the leaky relu') 40 | parser.add_argument('--patience', default=0.005, type=float) 41 | parser.add_argument('--gpu_ids', type=str, default='3,4,5', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 42 | 43 | opt = parser.parse_known_args()[0] 44 | print_options(parser, opt) 45 | 46 | return opt 47 | 48 | 49 | def print_options(parser, opt): 50 | """Print and save options 51 | 52 | It will print both current options and default values(if different). 53 | It will save options into a text file / [checkpoints_dir] / opt.txt 54 | """ 55 | message = '' 56 | message += '----------------- Options ---------------\n' 57 | for k, v in sorted(vars(opt).items()): 58 | comment = '' 59 | default = parser.get_default(k) 60 | if v != default: 61 | comment = '\t[default: %s]' % str(default) 62 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 63 | message += '----------------- End -------------------' 64 | print(message) 65 | 66 | # save to the disk 67 | mkdirs(opt.model_dir) 68 | file_name = os.path.join(opt.model_dir, '{}_opt.txt'.format('train')) 69 | with open(file_name, 'wt') as opt_file: 70 | opt_file.write(message) 71 | opt_file.write('\n') 72 | 73 | 74 | def mkdirs(paths): 75 | """create empty directories if they don't exist 76 | 77 | Parameters: 78 | paths (str list) -- a list of directory paths 79 | """ 80 | if isinstance(paths, list) and not isinstance(paths, str): 81 | for path in paths: 82 | mkdir(path) 83 | else: 84 | mkdir(paths) 85 | 86 | 87 | def mkdir(path): 88 | """create a single empty directory if it didn't exist 89 | 90 | Parameters: 91 | path (str) -- a single directory path 92 | """ 93 | if not os.path.exists(path): 94 | os.makedirs(path) 95 | -------------------------------------------------------------------------------- /pretrained_models/split1_grad_720d_all_50epochs.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAILabHealthcare/MLA-GNN/6f950e05768a7d03ed28c60fe578ca78ee68a4df/pretrained_models/split1_grad_720d_all_50epochs.pt -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # MLA-GNN 2 | 3 | 4 | 5 | This repository is an official PyTorch implementation of the paper 6 | **"Multi-Level Attention Graph Neural Network Based on Co-expression Gene Modules for Disease Diagnosis and Prognosis" 7 | submitted to **Bioinformatics 2021**. 8 | 9 | ![](./figs/pipeline.png) 10 | 11 | ## Installation 12 | ### Dependencies 13 | ``` 14 | Python 3.6 15 | PyTorch >= 1.5.0 16 | torch_geometric 17 | numpy 18 | pandas 19 | scipy 20 | sklearn 21 | opencv 22 | random 23 | ``` 24 | 25 | ## Data Description 26 | For the glioma dataset, 15-fold cross validation is conducted to evaluate the model performance. For each train-test split, we compute the adjacency matrix by performing the WGCNA algorithm on the training set. The **input features and adjacency matrix** are saved at: ./example_data/... 27 | 28 | The data structure is: 29 | ```bash 30 | ./example_data 31 | 32 | ├── input_features_labels 33 | ├── split1_train_320d_features_labels.csv 34 | ├── split1_test_320d_features_labels.csv 35 | ├── ... 36 | ├── ... 37 | ├── split15_train_320d_features_labels.csv 38 | ├── split15_test_320d_features_labels.csv 39 | 40 | ├── input_adjacency_matrix 41 | ├── split1_adjacency_matrix.csv 42 | ├── split2_adjacency_matrix.csv 43 | ├── ... 44 | ├── ... 45 | ├── split15_adjacency_matrix.csv 46 | ``` 47 | 48 | + For each train or test set, the RNAseq features and labels (for survival prediction and histological grading) are contained in the "xxx_xxx_320d_features_labels.csv" in the following format: 49 | 50 | 0 | 1 | 2 | 3 | ... | 321 | 322 | 323 51 | --- | --- | --- | --- | --- | --- | --- | --- 52 | TCGA-06-0141 |-0.751162972 | -1.72656962 | 0.876216622 | ... | 1 | 313 | 2 53 | TCGA-06-0187 |-0.751162972 | -1.72656962 | 2.305385481 | ... | 1 | 828 | 2 54 | ... | ... | ... | ... | ... | ... | ... 55 | TCGA-S9-A7R3 | -0.751162972 | 0.57918313 | -0.55295223 | ... | 0 | 3013 | 0 56 | 57 | ** Each row represents the features and labels of one patient. 58 | 59 | ** Since we used the preprocessed data from the [Pathomic fusion paper](https://ieeexplore.ieee.org/abstract/document/9186053), the features are 320d, containing 80d CNV features (1-80 columns) and 240d RNAseq features (81-320 columns). In our work, we only used the RNAseq data, so we extracted the 81-320 columns from this file as input features. 60 | 61 | ** The last three columns represent the labels. The 321-th column indicates whether a patient is censored, the 322-th column denotes the survival time, while the last column shows the ground-truth class for the histological grading task (0, 1, 2 denotes grade II, III, IV, respectively). 62 | 63 | + For each train-test split, the adjacency matrix is computed based on the training set and saved in the following format: 64 | 65 | V83 | V84 | V85 | ... | V321 | V322 66 | --- | --- | --- | --- | --- | --- 67 | 1 | 0.001676229 | 3.07E-06 | ... | 5.30E-07 | 5.89E-09 68 | 0.001676229 | 1 | 1.93E-07 | ... | 3.33E-10 | 1.69E-07 69 | 3.07E-06 | 1.93E-07 | 1 | ... | 1.98E-09 | 0.000125699 70 | ... | ... | ... | ... | ... | ... | ... 71 | 5.89E-09 | 1.69E-07 | 0.000125699 | ... | 1.04E-06 | 1 72 | 73 | ** The shape of the adjacency matrix is 240*240. 74 | 75 | ** Each element E_i, j represents the correlation between the i-th and j-th genes. 76 | 77 | 78 | ## Output files 79 | 80 | + After training the MLA-GNN model, the model will be save at: "./pretrained_models/". 81 | 82 | + The **pretrained files** are named as '/split' + str(k) + '_' + opt.task + '_' + str(opt.lin_input_dim) + 'd_' + opt.which_layer + '_' +str(opt.num_epochs) +'epochs.pt'. As an example, the model for the grading task trained with all three level features is saved as "split1_grad_720d_all_50epochs.pt". 83 | 84 | + After training, the model will also output the **feature importance** computed by the FGS mechanism, showing the contribution of each gene to the prediction task. 85 | 86 | + The **model prediction** is saved in the .pkl format, including the following information: risk_pred_all, survtime_all, censor_all, probs_all, gt_all. 87 | 88 | 89 | ## Usage 90 | 91 | ### Step1: gene co-expression network computation. 92 | ```shell script 93 | run WGCNA_gbmlgg.R 94 | ``` 95 | ** The adjacency matrix computed by the WGCNA algorithm should be saved in the folder "./example_data". 96 | 97 | 98 | ### Step2: Model inference. 99 | ```shell script 100 | ### The users should parse different arguments to change the experiment settings and evaluate different models. 101 | python3 test_cv.py 102 | ``` 103 | 104 | ### Step3: Model interpretation. 105 | ```shell script 106 | ### You can compute the feature importance with the FGS mechanism with the following script. 107 | python3 gradients_to_feature_importance.py 108 | ``` 109 | 110 | ### Scripts 111 | ```bash 112 | test_cv.py: Load the well-trained model from the folder “/pretrained_models/…” and test the performance on the testing set of the 15 splits. 113 | 114 | test_model.py: the definitions for "test". 115 | 116 | model_GAT.py: the definitions for the network optimizer and the GAT network, which can be selected as 720d model(GAT_features = layer1+layer2+layer3) or 240d (either use the layer1, or layer2, or layer3 as the GAT features) model. 117 | 118 | model_GAT_v4.py: Optimized implementation of the GAT layer, we employed the “GATconv” function encapsulated in the pytorch Geometric package, which is optimized to save computational cost. 119 | 120 | utils.py: contains data_loader and other functions (cindex, cox_loss, …). 121 | 122 | options.py: Contains all the options for the argparser. 123 | 124 | WGCNA_gbmlgg.R: compute the adjacency matrix using WGCNA method. 125 | 126 | gradients_to_feature_importance.py: combine the gradients produced by different splits, and obtain the feature importance according to the proposed FGS mechanism. 127 | ``` 128 | 129 | ## Disclaimer 130 | 131 | This tool is for research purpose and not approved for clinical use. 132 | 133 | This is not an official Tencent product. 134 | 135 | ## Copyright 136 | 137 | This tool is developed in Tencent AI Lab. 138 | 139 | The copyright holder for this project is Tencent AI Lab. 140 | 141 | All rights reserved. -------------------------------------------------------------------------------- /test_cv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import pandas as pd 5 | import random 6 | import pickle 7 | 8 | import torch 9 | 10 | # Env 11 | from utils import * 12 | from model_GAT import * 13 | from options import parse_args 14 | from test_model import test 15 | from model_GAT import * 16 | 17 | 18 | ### 1. Initializes parser and device 19 | opt = parse_args() 20 | device = torch.device('cuda:0') 21 | num_splits = 15 22 | results = [] 23 | 24 | ### 2. Sets-Up Main Loop 25 | for k in range(1, num_splits+1): 26 | print("*******************************************") 27 | print("************** SPLIT (%d/%d) **************" % (k, num_splits)) 28 | print("*******************************************") 29 | 30 | tr_features, tr_labels, te_features, te_labels, adj_matrix = load_csv_data(k, opt) 31 | load_path = opt.model_dir + '/split' + str(k) + '_' + opt.task + '_' + str( 32 | opt.lin_input_dim) + 'd_all_' + str(opt.num_epochs) + 'epochs.pt' 33 | model_ckpt = torch.load(load_path, map_location=device) 34 | 35 | #### Loading Env 36 | model_state_dict = model_ckpt['model_state_dict'] 37 | # hasattr(target, attr) 用于判断对象中是否含有某个属性,有则返回true. 38 | if hasattr(model_state_dict, '_metadata'): 39 | del model_state_dict._metadata 40 | 41 | model = GAT(opt=opt, input_dim=opt.input_dim, omic_dim=opt.omic_dim, label_dim=opt.label_dim, 42 | dropout=opt.dropout, alpha=opt.alpha).cuda() 43 | 44 | ### multiple GPU 45 | # model = torch.nn.DataParallel(model) 46 | # torch.backends.cudnn.benchmark = True 47 | 48 | if isinstance(model, torch.nn.DataParallel): model = model.module 49 | 50 | print('Loading the model from %s' % load_path) 51 | model.load_state_dict(model_state_dict) 52 | 53 | 54 | ### 3.2 Test the model. 55 | loss_test, cindex_test, pvalue_test, surv_acc_test, grad_acc_test, pred_test, te_features, te_fc_features = test( 56 | opt, model, te_features, te_labels, adj_matrix) 57 | GAT_te_features_labels = np.concatenate((te_features, te_fc_features, te_labels), axis=1) 58 | 59 | # print("model preds:", list(np.argmax(pred_test[3], axis=1))) 60 | # print("ground truth:", pred_test[4]) 61 | # print(te_labels[:, 2]) 62 | 63 | pd.DataFrame(GAT_te_features_labels).to_csv( 64 | "./results/"+opt.task+"/GAT_features_"+str(opt.lin_input_dim)+"d_model/split"+str(k)+"_"+ opt.which_layer+"_GAT_te_features.csv") 65 | 66 | if opt.task == 'surv': 67 | print("[Final] Apply model to testing set: C-Index: %.10f, P-Value: %.10e" % (cindex_test, pvalue_test)) 68 | logging.info("[Final] Apply model to testing set: cC-Index: %.10f, P-Value: %.10e" % (cindex_test, pvalue_test)) 69 | results.append(cindex_test) 70 | elif opt.task == 'grad': 71 | print("[Final] Apply model to testing set: Loss: %.10f, Acc: %.4f" % (loss_test, grad_acc_test)) 72 | logging.info("[Final] Apply model to testing set: Loss: %.10f, Acc: %.4f" % (loss_test, grad_acc_test)) 73 | results.append(grad_acc_test) 74 | 75 | test_preds_labels = np.concatenate((pred_test[3], np.expand_dims(pred_test[4], axis=1)), axis=1) 76 | print(test_preds_labels.shape) 77 | pd.DataFrame(test_preds_labels, columns=["class1", "class2", "class3", "pred_class"]).to_csv( 78 | "./results/" + opt.task + "/preds/split" + str(k) + "_" + opt.which_layer + "_test_preds_labels.csv") 79 | # pickle.dump(pred_test, open(os.path.join(opt.results_dir, opt.task, 80 | # 'preds/split%d_pred_test_%dd_%s_%depochs.pkl' % (k, opt.lin_input_dim, opt.which_layer, opt.num_epochs)), 'wb')) 81 | 82 | print('Split Results:', results) 83 | print("Average:", np.array(results).mean()) 84 | -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | import argparse 7 | import torch.utils.data as Data 8 | from sklearn.model_selection import StratifiedKFold 9 | 10 | from utils import * 11 | from model_GAT import * 12 | 13 | 14 | 15 | def test(opt, model, te_features, te_labels, adj_matrix): 16 | 17 | model.eval() 18 | 19 | test_dataset = Data.TensorDataset(te_features, te_labels) 20 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False) 21 | 22 | risk_pred_all, censor_all, survtime_all = np.array([]), np.array([]), np.array([]) 23 | probs_all, gt_all = None, np.array([]) 24 | loss_test, grad_acc_test = 0, 0 25 | 26 | for batch_idx, (batch_features, batch_labels) in enumerate(test_loader): 27 | censor = batch_labels[:, 0] 28 | survtime = batch_labels[:, 1] 29 | grade = batch_labels[:, 2] 30 | censor_batch_labels = censor.cuda() if "surv" in opt.task else censor 31 | surv_batch_labels = survtime 32 | # print(surv_batch_labels) 33 | grad_batch_labels = grade.cuda() if "grad" in opt.task else grade 34 | te_features, te_fc_features, te_preds, gradients, feature_importance = model( 35 | batch_features.cuda(), adj_matrix.cuda(), grad_batch_labels, opt) 36 | 37 | # print("surv_batch_labels:", surv_batch_labels) 38 | # print("te_preds:", te_preds) 39 | 40 | if batch_idx == 0: 41 | features_all = te_features.detach().cpu().numpy() 42 | fc_features_all = te_fc_features.detach().cpu().numpy() 43 | else: 44 | features_all = np.concatenate((features_all, te_features.detach().cpu().numpy()), axis=0) 45 | fc_features_all = np.concatenate((fc_features_all, te_fc_features.detach().cpu().numpy()), axis=0) 46 | # print(features_all.shape, te_features.shape) 47 | 48 | loss_cox = CoxLoss(surv_batch_labels, censor_batch_labels, te_preds) if opt.task == "surv" else 0 49 | loss_reg = define_reg(model) 50 | loss_func = nn.CrossEntropyLoss() 51 | grad_loss = loss_func(te_preds, grad_batch_labels) if opt.task == "grad" else 0 52 | loss = opt.lambda_cox * loss_cox + opt.lambda_nll * grad_loss + opt.lambda_reg * loss_reg 53 | loss_test += loss.data.item() 54 | 55 | gt_all = np.concatenate((gt_all, grad_batch_labels.detach().cpu().numpy().reshape(-1))) # Logging Information 56 | 57 | if opt.task == "surv": 58 | risk_pred_all = np.concatenate((risk_pred_all, te_preds.detach().cpu().numpy().reshape(-1))) # Logging Information 59 | censor_all = np.concatenate((censor_all, censor_batch_labels.detach().cpu().numpy().reshape(-1))) # Logging Information 60 | survtime_all = np.concatenate((survtime_all, surv_batch_labels.detach().cpu().numpy().reshape(-1))) # Logging Information 61 | 62 | elif opt.task == "grad": 63 | pred = te_preds.argmax(dim=1, keepdim=True) 64 | grad_acc_test += pred.eq(grad_batch_labels.view_as(pred)).sum().item() 65 | probs_np = te_preds.detach().cpu().numpy() 66 | probs_all = probs_np if probs_all is None else np.concatenate((probs_all, probs_np), axis=0) # Logging Information 67 | 68 | # print(survtime_all) 69 | ################################################### 70 | # ==== Measuring Test Loss, C-Index, P-Value ==== # 71 | ################################################### 72 | loss_test /= len(test_loader.dataset) 73 | cindex_test = CIndex_lifeline(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None 74 | pvalue_test = cox_log_rank(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None 75 | surv_acc_test = accuracy_cox(risk_pred_all, censor_all) if opt.task == 'surv' else None 76 | grad_acc_test = grad_acc_test / len(test_loader.dataset) if opt.task == 'grad' else None 77 | pred_test = [risk_pred_all, survtime_all, censor_all, probs_all, gt_all] 78 | 79 | return loss_test, cindex_test, pvalue_test, surv_acc_test, grad_acc_test, pred_test, features_all, fc_features_all -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import numpy as np 5 | import random 6 | import pickle 7 | 8 | import torch 9 | import pandas as pd 10 | 11 | # Env 12 | from utils import * 13 | from options import parse_args 14 | from train_test_new import train, test 15 | 16 | from load_CRCSC_data import load_features_labels, CRC_Dataset, construct_graph 17 | 18 | ### 1. Initializes parser and device 19 | opt = parse_args() 20 | 21 | log_path = "./logs/" 22 | snapshot_path = opt.model_dir + opt.exp + "/" 23 | 24 | 25 | if __name__ == "__main__": 26 | 27 | logging.basicConfig(filename=log_path + opt.exp +".txt", level=logging.INFO, 28 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 29 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 30 | logging.info(str(opt)) 31 | 32 | tr_features, norm_tr_features, te_features, tr_labels, te_labels, keep_idx, all_test_cohorts, \ 33 | tr_sample_ids, te_sample_ids = load_features_labels(retain_dim=opt.num_nodes) 34 | adj_matrix, edge_index = construct_graph(opt, norm_tr_features, keep_idx) 35 | print("train data:", tr_features.shape, tr_labels.shape) 36 | print("test data:", te_features.shape, te_labels.shape) 37 | 38 | train_dataset = CRC_Dataset(feature=tr_features, label=tr_labels, edge=edge_index) 39 | test_dataset = CRC_Dataset(feature=te_features, label=te_labels, edge=edge_index) 40 | 41 | print("====================dataset loaded====================") 42 | 43 | ### 2.1 Train Model 44 | # model, optimizer, metric_logger = train(opt, train_dataset, test_dataset) 45 | model, optimizer, metric_logger, class_edge_weights, class_node_importance, \ 46 | overall_edge_weights, overall_node_importance = train(opt, train_dataset, test_dataset) 47 | 48 | ### 2.2 Test Model 49 | loss_train, grad_acc_train, pred_train, tr_features, tr_fc_features = test(opt, model, train_dataset) 50 | loss_test, grad_acc_test, pred_test, te_features, te_fc_features = test(opt, model, test_dataset) 51 | # test_probs = np.exp(pred_test)/np.sum(np.exp(pred_test), axis=1) 52 | # print("test GAT features:", te_features.shape) 53 | # print("test fc features:", te_fc_features.shape) 54 | # print("test preds:", test_probs) 55 | all_metrics = compute_cohort_metrics(pred_test[0], np.uint(pred_test[1]), all_test_cohorts) 56 | print(all_metrics) 57 | test_results = {'sample_id': te_sample_ids, 'GNN_pred': np.argmax(pred_test[0], axis=1), 58 | 'CMS_network': np.uint(pred_test[1])} 59 | # pd.DataFrame(test_results).to_csv('./results/GNN_HumanNet_preds.csv') 60 | # pd.DataFrame(test_results).to_csv('./results/GNN_sim_graph_preds.csv') 61 | # pd.DataFrame(test_results).to_csv('./GNN_sim_top_genes_predictions/top' + str(opt.num_nodes) + '_test.csv') 62 | # 63 | # train_results = {'GNN_pred': np.argmax(pred_test[0], axis=1), 'label': np.uint(pred_test[1])} 64 | # pd.DataFrame(train_results).to_csv('./GNN_sim_top_genes_predictions/top' + str(opt.num_nodes) + '_train.csv') 65 | 66 | # edge_weights_file = opt.results_dir + "GNN_sim_graph_edge_weights_wo_elu/" 67 | # feat_importance_file = opt.results_dir + "GNN_sim_graph_feature_importance_wo_elu/" 68 | # # print(edge_weights_file, feat_importance_file) 69 | # 70 | # # edge_weights_file = opt.results_dir + "GNN_HumanNet_edge_weights/" 71 | # # feat_importance_file = opt.results_dir + "GNN_HumanNet_feature_importance/" 72 | # 73 | # for i in range(opt.label_dim): 74 | # # print(class_edge_weights[i]) 75 | # pd.DataFrame(class_edge_weights[i]).to_csv(edge_weights_file + "class" + str(i) + ".csv") 76 | # pd.DataFrame(class_node_importance[i]).to_csv(feat_importance_file + "class" + str(i) + ".csv") 77 | # 78 | # pd.DataFrame(overall_edge_weights).to_csv(edge_weights_file + "overall.csv") 79 | # pd.DataFrame(overall_node_importance).to_csv(feat_importance_file + "overall.csv") 80 | 81 | print("[Final] Apply model to training set: Loss: %.10f, Acc: %.4f" % (loss_train, grad_acc_train)) 82 | print("[Final] Apply model to testing set: Loss: %.10f, Acc: %.4f" % (loss_test, grad_acc_test)) 83 | logging.info("[Final] Apply model to testing set: Loss: %.10f, Acc: %.4f" % (loss_test, grad_acc_test)) 84 | 85 | 86 | ### 2.3 Saves Model 87 | model_state_dict = model.state_dict() 88 | save_path = opt.model_dir + opt.exp + '.pt' 89 | print("Saving model at:", save_path) 90 | 91 | torch.save({ 92 | 'opt': opt, 93 | 'epoch': opt.num_epochs, 94 | 'model_state_dict': model_state_dict, 95 | 'optimizer_state_dict': optimizer.state_dict(), 96 | 'metrics': metric_logger}, 97 | save_path) 98 | 99 | 100 | # pickle.dump(pred_train, open(os.path.join(opt.results_dir, 101 | # 'preds/pred_train_%s_%depochs.pkl' % (opt.which_layer, opt.num_epochs)), 'wb')) 102 | # pickle.dump(pred_test, open(os.path.join(opt.results_dir, 103 | # 'preds/pred_test_%s_%depochs.pkl' % (opt.which_layer, opt.num_epochs)), 'wb')) 104 | -------------------------------------------------------------------------------- /train_test_new.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | import argparse 7 | import logging 8 | from torch_geometric.utils import to_dense_batch 9 | from torch_geometric.data import DataLoader 10 | from sklearn.model_selection import StratifiedKFold 11 | 12 | import torch.nn.functional as F 13 | 14 | from utils import * 15 | from model_GAT_v4 import * 16 | 17 | 18 | def train(opt, train_dataset, test_dataset): 19 | 20 | cudnn.deterministic = True 21 | torch.cuda.manual_seed_all(2020) 22 | torch.manual_seed(2020) 23 | random.seed(2020) 24 | 25 | model = MLA_GNN(opt).cuda() 26 | for name, parameters in model.named_parameters(): 27 | print(name, ':', parameters.size()) 28 | optimizer = define_optimizer(opt, model) 29 | scheduler = define_scheduler(opt, optimizer) 30 | 31 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, 32 | shuffle=True, num_workers=8) 33 | metric_logger = {'train': {'loss': [], 'grad_acc': []}, 'test': {'loss': [], 'grad_acc': []}} 34 | 35 | # print("============ finish dataloader ===========") 36 | 37 | 38 | for epoch in range(opt.num_epochs): 39 | model.train() 40 | loss_epoch, grad_acc_epoch = 0, 0 41 | 42 | class_edge_weights = torch.zeros(opt.label_dim, opt.num_nodes, opt.num_nodes).cuda() 43 | class_node_importance = torch.zeros(opt.label_dim, opt.num_nodes).cuda() 44 | overall_edge_weights = torch.zeros(opt.num_nodes, opt.num_nodes).cuda() 45 | overall_node_importance = torch.zeros(opt.num_nodes).cuda() 46 | # print(class_edge_weights.shape, class_node_importance.shape) 47 | # print(overall_edge_weights.shape, overall_node_importance.shape) 48 | 49 | for batch_idx, data in enumerate(train_loader): 50 | batch_idx += 1 51 | tr_features, tr_fc_features, tr_preds, edge_weights, feature_importance = model( 52 | data.x.cuda(), data.edge_index.cuda(), data.y.cuda(), data.batch.cuda(), opt) 53 | # print(data.y) 54 | sample_weight = torch.tensor(cal_sample_weight( 55 | data.y, opt.label_dim, use_sample_weight=True)).cuda() 56 | # print("===========", data.y, sample_weight) 57 | 58 | """ 59 | Compute the edge weights and node importance for each class. 60 | """ 61 | overall_edge_weights += torch.mean(edge_weights, 0) 62 | overall_node_importance += torch.mean(feature_importance, 0) 63 | 64 | for i in range(opt.label_dim): 65 | index = np.nonzero(data.y == i).view(1, -1)[0] 66 | # print(torch.index_select(edge_weights, 0, index[0].cuda()).shape) 67 | if index.shape[0] > 0: 68 | # print(index.shape[0]) 69 | class_edge_weights[i] += torch.mean( 70 | torch.index_select(edge_weights, 0, index.cuda()), 0) 71 | class_node_importance[i] += torch.mean( 72 | torch.index_select(feature_importance, 0, index.cuda()), 0) 73 | # print(class_node_importance[i]) 74 | 75 | loss_reg = define_reg(model) 76 | loss_func = nn.CrossEntropyLoss(reduction='none') 77 | grad_loss = torch.mean(torch.mul(loss_func(tr_preds, data.y.cuda()), sample_weight)) 78 | # print(grad_loss) 79 | # one_hot_labels = one_hot_tensor(batch_labels.cuda(), opt.label_dim) 80 | # grad_loss = torch.mean(torch.mul(tr_preds - one_hot_labels, tr_preds - one_hot_labels)) 81 | loss = opt.lambda_nll * grad_loss + opt.lambda_reg * loss_reg 82 | loss_epoch += loss.data.item() 83 | 84 | optimizer.zero_grad() 85 | # tr_features.retain_grad() 86 | loss.backward(retain_graph=True) 87 | optimizer.step() 88 | 89 | pred = tr_preds.argmax(dim=1, keepdim=True) 90 | grad_acc_epoch += pred.eq(data.y.cuda().view_as(pred)).sum().item() 91 | 92 | scheduler.step() 93 | 94 | class_edge_weights = (class_edge_weights/batch_idx).detach().cpu().numpy() 95 | class_node_importance = (class_node_importance/batch_idx).detach().cpu().numpy() 96 | overall_edge_weights = (overall_edge_weights/batch_idx).detach().cpu().numpy() 97 | overall_node_importance = (overall_node_importance/batch_idx).detach().cpu().numpy() 98 | 99 | # print(gradients_all.shape, importance_all.shape) 100 | # print(class_edge_weights, class_node_importance) 101 | 102 | loss_epoch /= len(train_loader.dataset) 103 | grad_acc_epoch = grad_acc_epoch / len(train_loader.dataset) 104 | loss_test, grad_acc_test, pred_test, features_test, _ = test(opt, model, test_dataset) 105 | # loss_test, grad_acc_test, pred_test, features_test, _ = test(opt, model, train_dataset) 106 | 107 | metric_logger['train']['loss'].append(loss_epoch) 108 | metric_logger['train']['grad_acc'].append(grad_acc_epoch) 109 | 110 | metric_logger['test']['loss'].append(loss_test) 111 | metric_logger['test']['grad_acc'].append(grad_acc_test) 112 | 113 | # pickle.dump(pred_test, open(os.path.join( 114 | # opt.results_dir, 'split%d_%d_pred_test.pkl' % (k, epoch)), 'wb')) 115 | 116 | 117 | logging.info('\nEpoch {:02d}/{:02d}, [{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}'.format( 118 | epoch + 1, opt.num_epochs, 'Train', loss_epoch, 'Accuracy', grad_acc_epoch)) 119 | logging.info('\nEpoch {:02d}/{:02d}, [{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}\n'.format( 120 | epoch + 1, opt.num_epochs, 'Test', loss_test, 'Accuracy', grad_acc_test)) 121 | 122 | # print("=========gradients_all:", gradients_all) 123 | 124 | return model, optimizer, metric_logger, class_edge_weights, class_node_importance, \ 125 | overall_edge_weights, overall_node_importance 126 | 127 | 128 | def test(opt, model, test_dataset): 129 | 130 | model.eval() 131 | test_loader = DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False) 132 | 133 | probs_all, gt_all = None, np.array([]) 134 | loss_test, grad_acc_test = 0, 0 135 | 136 | for batch_idx, (data) in enumerate(test_loader): 137 | 138 | te_features, te_fc_features, te_preds, _, _ = model( 139 | data.x.cuda(), data.edge_index.cuda(), data.y.cuda(), data.batch.cuda(), opt) 140 | 141 | # print(te_preds, te_fc_features) 142 | # print(data.y, torch.argmax(te_preds, dim=1)) 143 | # print(data.y, te_preds) 144 | # print("te_preds:", te_preds) 145 | 146 | if batch_idx == 0: 147 | features_all = te_features.detach().cpu().numpy() 148 | fc_features_all = te_fc_features.detach().cpu().numpy() 149 | else: 150 | features_all = np.concatenate((features_all, te_features.detach().cpu().numpy()), axis=0) 151 | fc_features_all = np.concatenate((fc_features_all, te_fc_features.detach().cpu().numpy()), axis=0) 152 | # print(features_all.shape, te_features.shape) 153 | 154 | loss_reg = define_reg(model) 155 | loss_func = nn.CrossEntropyLoss() 156 | grad_loss = loss_func(te_preds, data.y.cuda()) 157 | # one_hot_labels = one_hot_tensor(batch_labels.cuda(), opt.label_dim) 158 | # grad_loss = torch.mean(torch.mul(te_preds - one_hot_labels, te_preds - one_hot_labels)) 159 | loss = opt.lambda_nll * grad_loss + opt.lambda_reg * loss_reg 160 | loss_test += loss.data.item() 161 | 162 | gt_all = np.concatenate((gt_all, data.y.reshape(-1))) # Logging Information 163 | 164 | pred = te_preds.argmax(dim=1, keepdim=True) 165 | grad_acc_test += pred.eq(data.y.cuda().view_as(pred)).sum().item() 166 | probs_np = te_preds.detach().cpu().numpy() 167 | probs_all = probs_np if probs_all is None else np.concatenate((probs_all, probs_np), axis=0) # Logging Information 168 | 169 | # print("total batch:", batch_idx) 170 | 171 | # print(survtime_all) 172 | ################################################### 173 | # ==== Measuring Test Loss, C-Index, P-Value ==== # 174 | ################################################### 175 | loss_test /= len(test_loader.dataset) 176 | grad_acc_test = grad_acc_test / len(test_loader.dataset) 177 | pred_test = [probs_all, gt_all] 178 | # print(probs_all.shape, gt_all.shape) 179 | 180 | return loss_test, grad_acc_test, pred_test, features_all, fc_features_all 181 | 182 | 183 | 184 | 185 | def external_test(opt, model, test_dataset): 186 | 187 | model.eval() 188 | test_loader = DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False) 189 | probs_all = None 190 | 191 | for batch_idx, (data) in enumerate(test_loader): 192 | 193 | te_features, te_fc_features, te_preds, _, _ = model( 194 | data.x.cuda(), data.edge_index.cuda(), data.y.cuda(), data.batch.cuda(), opt) 195 | # print("features:", te_features) 196 | probs_np = F.softmax(te_preds, 1).detach().cpu().numpy() 197 | probs_all = probs_np if probs_all is None else np.concatenate((probs_all, probs_np), axis=0) 198 | 199 | # print(probs_all) 200 | 201 | return probs_all 202 | 203 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import pandas as pd 6 | import torch.utils.data as Data 7 | from torch.utils.data.dataset import Dataset 8 | 9 | # import lifelines 10 | # from lifelines.utils import concordance_index 11 | # from lifelines.statistics import logrank_test 12 | 13 | from sklearn.metrics import auc, f1_score, roc_curve, precision_score, recall_score, cohen_kappa_score 14 | from sklearn.preprocessing import LabelBinarizer 15 | 16 | ################ 17 | # Data Utils 18 | ################ 19 | 20 | def load_csv_data(k, opt): 21 | folder_path = './example_data/input_features_labels/split' 22 | print("Loading data from:", folder_path+str(k)) 23 | train_data_path = folder_path+str(k)+'_train_320d_features_labels.csv' 24 | train_data = np.array(pd.read_csv(train_data_path, header=None))[1:, 2:].astype(float) 25 | 26 | tr_features = torch.FloatTensor(train_data[:, :320].reshape(-1, 320, 1)).requires_grad_() 27 | tr_labels = torch.LongTensor(train_data[:, 320:]) 28 | print("Training features and labels:", tr_features.shape, tr_labels.shape) 29 | 30 | test_data_path = folder_path+str(k)+'_test_320d_features_labels.csv' 31 | test_data = np.array(pd.read_csv(test_data_path, header=None))[1:, 2:].astype(float) 32 | 33 | te_features = torch.FloatTensor(test_data[:, :320].reshape(-1, 320, 1)).requires_grad_() 34 | te_labels = torch.LongTensor(test_data[:, 320:]) 35 | print("Testing features and labels:", te_features.shape, te_labels.shape) 36 | 37 | similarity_matrix = np.array(pd.read_csv( 38 | './example_data/input_adjacency_matrix/split'+str(k)+'_adjacency_matrix.csv')).astype(float) 39 | adj_matrix = torch.LongTensor(np.where(similarity_matrix > opt.adj_thresh, 1, 0)) 40 | print("Adjacency matrix:", adj_matrix.shape) 41 | print("Number of edges:", adj_matrix.sum()) 42 | 43 | if opt.task == "grad": 44 | tr_idx = tr_labels[:, 2] >= 0 45 | tr_labels = tr_labels[tr_idx] 46 | tr_features = tr_features[tr_idx] 47 | print("Training features and grade labels after deleting NA labels:", tr_features.shape, tr_labels.shape) 48 | 49 | te_idx = te_labels[:, 2] >= 0 50 | te_labels = te_labels[te_idx] 51 | te_features = te_features[te_idx] 52 | print("Testing features and grade labels after deleting NA labels:", te_features.shape, te_labels.shape) 53 | 54 | return tr_features, tr_labels, te_features, te_labels, adj_matrix 55 | 56 | 57 | ################ 58 | # Grading Utils 59 | ################ 60 | def accuracy(output, labels): 61 | preds = output.max(1)[1].type_as(labels) 62 | correct = preds.eq(labels).double() 63 | correct = correct.sum() 64 | return correct / len(labels) 65 | 66 | 67 | def print_model(model, optimizer): 68 | print(model) 69 | print("Model's state_dict:") 70 | # Print model's state_dict 71 | for param_tensor in model.state_dict(): 72 | print(param_tensor,"\t", model.state_dict()[param_tensor].size()) 73 | print("optimizer's state_dict:") 74 | # Print optimizer's state_dict 75 | for var_name in optimizer.state_dict(): 76 | print(var_name,"\t", optimizer.state_dict()[var_name]) 77 | 78 | 79 | def init_max_weights(module): 80 | for m in module.modules(): 81 | if type(m) == nn.Linear: 82 | stdv = 1. / math.sqrt(m.weight.size(1)) 83 | m.weight.data.normal_(0, stdv) 84 | m.bias.data.zero_() 85 | 86 | 87 | 88 | def compute_ROC_AUC(test_pred, gt_labels): 89 | 90 | enc = LabelBinarizer() 91 | enc.fit(gt_labels) 92 | labels_oh = enc.transform(gt_labels) ## convert to one_hot grade labels. 93 | # print(gt_labels, labels_oh, test_pred.shape) 94 | fpr, tpr, thresh = roc_curve(labels_oh.ravel(), test_pred.ravel()) 95 | aucroc = auc(fpr, tpr) 96 | 97 | return aucroc 98 | 99 | def compute_metrics(test_pred, gt_labels): 100 | 101 | enc = LabelBinarizer() 102 | enc.fit(gt_labels) 103 | labels_oh = enc.transform(gt_labels) ## convert to one_hot grade labels. 104 | 105 | # print(gt_labels, labels_oh, test_pred.shape) 106 | # print(labels_oh, test_pred) 107 | idx = np.argmax(test_pred, axis=1) 108 | # print(gt_labels, idx) 109 | labels_and_pred = np.concatenate((gt_labels, idx)) 110 | test_pred = enc.fit(labels_and_pred).transform(labels_and_pred)[gt_labels.shape[0]:, :] 111 | # print(test_pred) 112 | macro_f1_score = f1_score(labels_oh, test_pred, average='macro') 113 | # micro_f1_score = f1_score(labels_oh, test_pred, average='micro') #equal to accuracy. 114 | precision = precision_score(labels_oh, test_pred, average='macro') 115 | recall = recall_score(labels_oh, test_pred, average='macro') 116 | # kappa = cohen_kappa_score(labels_oh, test_pred) 117 | 118 | return macro_f1_score, precision, recall 119 | 120 | 121 | ################ 122 | # Survival Utils 123 | ################ 124 | def CoxLoss(survtime, censor, hazard_pred): 125 | # This calculation credit to Travers Ching https://github.com/traversc/cox-nnet 126 | # Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data 127 | current_batch_len = len(survtime) 128 | R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int) 129 | # print("R mat shape:", R_mat.shape) 130 | for i in range(current_batch_len): 131 | for j in range(current_batch_len): 132 | R_mat[i, j] = survtime[j] >= survtime[i] 133 | 134 | R_mat = torch.FloatTensor(R_mat).cuda() 135 | theta = hazard_pred.reshape(-1) 136 | exp_theta = torch.exp(theta) 137 | # print("censor and theta shape:", censor.shape, theta.shape) 138 | loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * censor) 139 | return loss_cox 140 | 141 | 142 | 143 | def accuracy_cox(hazardsdata, labels): 144 | # This accuracy is based on estimated survival events against true survival events 145 | median = np.median(hazardsdata) 146 | hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int) 147 | hazards_dichotomize[hazardsdata > median] = 1 148 | correct = np.sum(hazards_dichotomize == labels) 149 | return correct / len(labels) 150 | 151 | 152 | def cox_log_rank(hazardsdata, labels, survtime_all): 153 | median = np.median(hazardsdata) 154 | hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int) 155 | hazards_dichotomize[hazardsdata > median] = 1 156 | idx = hazards_dichotomize == 0 157 | T1 = survtime_all[idx] 158 | T2 = survtime_all[~idx] 159 | E1 = labels[idx] 160 | E2 = labels[~idx] 161 | results = logrank_test(T1, T2, event_observed_A=E1, event_observed_B=E2) 162 | pvalue_pred = results.p_value 163 | return(pvalue_pred) 164 | 165 | 166 | def CIndex(hazards, labels, survtime_all): 167 | concord = 0. 168 | total = 0. 169 | N_test = labels.shape[0] 170 | for i in range(N_test): 171 | if labels[i] == 1: 172 | for j in range(N_test): 173 | if survtime_all[j] > survtime_all[i]: 174 | total += 1 175 | if hazards[j] < hazards[i]: concord += 1 176 | elif hazards[j] < hazards[i]: concord += 0.5 177 | 178 | return(concord/total) 179 | 180 | 181 | def CIndex_lifeline(hazards, labels, survtime_all): 182 | return(concordance_index(survtime_all, -hazards, labels)) 183 | 184 | 185 | 186 | ################ 187 | # Layer Utils 188 | ################ 189 | def define_act_layer(act_type='Tanh'): 190 | if act_type == 'Tanh': 191 | act_layer = nn.Tanh() 192 | elif act_type == 'ReLU': 193 | act_layer = nn.ReLU() 194 | elif act_type == 'Sigmoid': 195 | act_layer = nn.Sigmoid() 196 | elif act_type == 'LSM': 197 | act_layer = nn.LogSoftmax(dim=1) 198 | elif act_type == "none": 199 | act_layer = None 200 | else: 201 | raise NotImplementedError('activation layer [%s] is not found' % act_type) 202 | return act_layer 203 | 204 | 205 | --------------------------------------------------------------------------------