├── utils ├── __init__.py └── data.py ├── model ├── __init__.py ├── MAGNN_dti.py └── base_MAGNN_dti.py ├── README.md ├── run_MHAN_dti.py ├── train_preprocess.py └── test_preprocess.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.MAGNN_nc import MAGNN_nc 2 | from model.MAGNN_nc_mb import MAGNN_nc_mb 3 | from model.MAGNN_lp import MAGNN_lp 4 | from model.MAGNN_dti import MAGNN_lp 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MHGNN-DTI 2 | 3 | This is for "Metapath aggregated heterogeneous graph neural network for drug-target interaction prediction", including the source code and the datasets. 4 | 5 | 1. The dataset can be downloaded from baiduyun (link:https://pan.baidu.com/s/1R1lpNFzVNlywy4T_001gjw code:4jkh). 6 | 2. run train_processor.py to generate metapaths for the training set 7 | 3. run test_processor.py to generate metapaths for the test set 8 | 4. run run_MHAN_dti.py to obtain the DTI prediction results. 9 | -------------------------------------------------------------------------------- /model/MAGNN_dti.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.nn import functional as F 5 | from model.base_MAGNN_dti import MAGNN_ctr_ntype_specific 6 | 7 | def adj_normalize(adj): 8 | rowsum = adj.sum(1) 9 | r_inv = torch.pow(rowsum, -0.5).flatten() 10 | r_inv[torch.isinf(r_inv)] = 0. 11 | r_mat_inv = torch.diag(r_inv) 12 | adj_ = r_mat_inv * adj * r_mat_inv 13 | return adj_ 14 | 15 | def MinMax_scalar(x): 16 | min = x.min(1).values 17 | min = min.view(-1, 1).repeat(1, x.shape[1]) 18 | max = x.max(1).values 19 | max = max.view(-1, 1).repeat(1, x.shape[1]) 20 | scalar = (x - min) / (max - min) 21 | return scalar 22 | 23 | def normalize(x): 24 | rowsum = x.sum(1) 25 | rowsum = rowsum.view(-1, 1).repeat(1, x.shape[1]) 26 | x_norm = x / rowsum 27 | return x_norm 28 | 29 | 30 | # for link prediction task 31 | class MAGNN_lp_layer(nn.Module): 32 | def __init__(self, 33 | num_metapaths_list, 34 | num_edge_type, 35 | etypes_lists, 36 | in_dim, 37 | out_dim, 38 | num_heads, 39 | attn_vec_dim, 40 | rnn_type='gru', 41 | attn_drop=0.5, 42 | attn_switch=False, 43 | args=None): 44 | super(MAGNN_lp_layer, self).__init__() 45 | self.in_dim = in_dim 46 | self.out_dim = out_dim 47 | self.num_heads = num_heads 48 | 49 | # etype-specific parameters 50 | r_vec = None 51 | if rnn_type == 'TransE0': 52 | r_vec = nn.Parameter(torch.empty(size=(num_edge_type // 2, in_dim))) 53 | elif rnn_type == 'TransE1': 54 | r_vec = nn.Parameter(torch.empty(size=(num_edge_type, in_dim))) 55 | elif rnn_type == 'RotatE0': 56 | r_vec = nn.Parameter(torch.empty(size=(num_edge_type // 2, in_dim // 2, 2))) 57 | elif rnn_type == 'RotatE1': 58 | r_vec = nn.Parameter(torch.empty(size=(num_edge_type, in_dim // 2, 2))) 59 | if r_vec is not None: 60 | nn.init.xavier_normal_(r_vec.data, gain=1.414) 61 | 62 | # ctr_ntype-specific layers 63 | self.user_layer = MAGNN_ctr_ntype_specific(num_metapaths_list[0], 64 | etypes_lists[0], 65 | in_dim, 66 | num_heads, 67 | attn_vec_dim, 68 | rnn_type, 69 | r_vec, 70 | attn_drop, 71 | use_minibatch=True, 72 | attn_switch=attn_switch, 73 | args=args) 74 | self.item_layer = MAGNN_ctr_ntype_specific(num_metapaths_list[1], 75 | etypes_lists[1], 76 | in_dim, 77 | num_heads, 78 | attn_vec_dim, 79 | rnn_type, 80 | r_vec, 81 | attn_drop, 82 | use_minibatch=True, 83 | attn_switch=attn_switch, 84 | args=args) 85 | 86 | # note that the acutal input dimension should consider the number of heads 87 | # as multiple head outputs are concatenated together 88 | # self.fc_user = nn.Linear(in_dim * num_heads * num_metapaths_list[0], out_dim * num_heads, bias=True) 89 | # self.fc_item = nn.Linear(in_dim * num_heads * num_metapaths_list[1], out_dim * num_heads, bias=True) 90 | # nn.init.xavier_normal_(self.fc_user.weight, gain=1.414) 91 | # nn.init.xavier_normal_(self.fc_item.weight, gain=1.414) 92 | 93 | def forward(self, inputs): 94 | g_lists, features, type_mask, edge_metapath_indices_lists, target_idx_lists = inputs 95 | 96 | # ctr_ntype-specific layers 97 | h_user = self.user_layer( 98 | (g_lists[0], features, type_mask, edge_metapath_indices_lists[0], target_idx_lists[0])) 99 | h_item = self.item_layer( 100 | (g_lists[1], features, type_mask, edge_metapath_indices_lists[1], target_idx_lists[1])) 101 | 102 | return [h_user, h_item] 103 | 104 | # logits_user = self.fc_user(h_user) 105 | # logits_item = self.fc_item(h_item) 106 | # return [logits_user, logits_item], [h_user, h_item] 107 | 108 | class GCN_layer(nn.Module): 109 | def __init__(self, dim, mp_ls=None): 110 | super(GCN_layer, self).__init__() 111 | self.gcn1 = nn.Parameter(torch.zeros([dim, 128]), requires_grad=True) 112 | self.gcn2 = nn.Parameter(torch.zeros([128, 2]), requires_grad=True) 113 | nn.init.xavier_normal_(self.gcn1, gain=1.414) 114 | nn.init.xavier_normal_(self.gcn2, gain=1.414) 115 | 116 | def forward(self, x, adj): 117 | # x = MinMax_scalar(x) 118 | # x = normalize(x) 119 | adj = F.softmax(torch.matmul(x, x.T), dim=-1) 120 | # adj = adj_normalize(adj) 121 | x = F.relu(torch.matmul(torch.matmul(adj, x), self.gcn1)) 122 | x = torch.matmul(torch.matmul(adj, x), self.gcn2) 123 | 124 | return x 125 | 126 | 127 | class linear_module(nn.Module): 128 | def __init__(self, dim): 129 | super(linear_module, self).__init__() 130 | self.fc1 = nn.Linear(dim, int(dim/2), bias=True) 131 | self.fc2 = nn.Linear(int(dim/2), 2, bias=False) 132 | # weight initialization 133 | nn.init.xavier_normal_(self.fc1.weight, gain=1.414) 134 | nn.init.xavier_normal_(self.fc2.weight, gain=1.414) 135 | 136 | def forward(self, x, x2=None): 137 | x = F.relu(self.fc1(x)) 138 | x = self.fc2(x) 139 | 140 | return x 141 | 142 | class MAGNN_lp(nn.Module): 143 | def __init__(self, 144 | num_metapaths_list, 145 | num_edge_type, 146 | etypes_lists, 147 | feats_dim_list, 148 | hidden_dim, 149 | out_dim, 150 | num_heads, 151 | attn_vec_dim, 152 | rnn_type='gru', 153 | dropout_rate=0.5, 154 | attn_switch=False, 155 | args=None): 156 | super(MAGNN_lp, self).__init__() 157 | self.hidden_dim = hidden_dim 158 | self.args = args 159 | 160 | # ntype-specific transformation 161 | self.fc_list = nn.ModuleList([nn.Linear(feats_dim, hidden_dim, bias=True) for feats_dim in feats_dim_list]) 162 | # feature dropout after trainsformation 163 | if dropout_rate > 0: 164 | self.feat_drop = nn.Dropout(dropout_rate) 165 | else: 166 | self.feat_drop = lambda x: x 167 | # initialization of fc layers 168 | for fc in self.fc_list: 169 | nn.init.xavier_normal_(fc.weight, gain=1.414) 170 | 171 | # MAGNN_lp layers 172 | self.layer1 = MAGNN_lp_layer(num_metapaths_list, 173 | num_edge_type, 174 | etypes_lists, 175 | hidden_dim, 176 | out_dim, 177 | num_heads, 178 | attn_vec_dim, 179 | rnn_type, 180 | attn_drop=dropout_rate, 181 | attn_switch=attn_switch, 182 | args=args) 183 | dim = out_dim * num_heads * 2 184 | if self.args.semantic_fusion == 'concatenation': 185 | dim = out_dim * num_heads * (num_metapaths_list[0] + num_metapaths_list[1]) 186 | # predictor 187 | if self.args.predictor =='gcn': 188 | self.classifier = GCN_layer(dim=dim) 189 | elif self.args.predictor == 'linear': 190 | self.classifier = linear_module(dim=dim) 191 | 192 | def forward(self, inputs): 193 | g_lists, features_list, type_mask, edge_metapath_indices_lists, target_idx_lists, adj = inputs 194 | 195 | # ntype-specific transformation 196 | transformed_features = torch.zeros(type_mask.shape[0], self.hidden_dim, device=features_list[0].device) 197 | for i, fc in enumerate(self.fc_list): 198 | node_indices = np.where(type_mask == i)[0] 199 | transformed_features[node_indices] = fc(features_list[i]) 200 | transformed_features = self.feat_drop(transformed_features) 201 | 202 | # hidden layers 203 | # [logits_user, logits_item], [h_user, h_item] = self.layer1( 204 | # (g_lists, transformed_features, type_mask, edge_metapath_indices_lists, target_idx_lists)) 205 | [h_user, h_item] = self.layer1((g_lists, transformed_features, type_mask, edge_metapath_indices_lists, target_idx_lists)) 206 | x = torch.cat([h_user, h_item], dim=1) 207 | x_out = self.classifier(x, adj) 208 | 209 | return F.softmax(x_out, dim=-1) 210 | 211 | def feature_transform(self, type_mask, features_list): 212 | # ntype-specific transformation 213 | transformed_features = torch.zeros(type_mask.shape[0], self.hidden_dim, device=features_list[0].device) 214 | for i, fc in enumerate(self.fc_list): 215 | node_indices = np.where(type_mask == i)[0] 216 | transformed_features[node_indices] = fc(features_list[i]) 217 | transformed_features = self.feat_drop(transformed_features) 218 | 219 | return transformed_features 220 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import scipy 4 | import pickle 5 | import json 6 | from torch.utils.data import DataLoader, Dataset 7 | import dgl 8 | import torch 9 | 10 | def read_adjlist(fp): 11 | in_file = open(fp, 'r') 12 | adjlist = [line.strip() for line in in_file] 13 | in_file.close() 14 | return adjlist 15 | 16 | def read_pickle(fp): 17 | in_file = open(fp, 'rb') 18 | idx = pickle.load(in_file) 19 | in_file.close() 20 | return idx 21 | 22 | def read_json(fp): 23 | in_file = open(fp, 'r') 24 | idx = json.load(in_file) 25 | in_file.close() 26 | return idx 27 | 28 | def fold_train_test_idx(pos_folds, neg_folds, nFold, foldID): 29 | train_pos_idx = [] 30 | train_neg_idx = [] 31 | test_fold_idx = [] 32 | for fold in range(nFold): 33 | if fold == foldID: 34 | continue 35 | train_pos_idx.append(pos_folds['fold_' + str(fold)]) 36 | train_neg_idx.append(neg_folds['fold_' + str(fold)]) 37 | train_pos_idx = np.concatenate(train_pos_idx, axis=1) 38 | train_neg_idx = np.concatenate(train_neg_idx, axis=1) 39 | train_fold_idx = np.concatenate([train_pos_idx, train_neg_idx], axis=1) 40 | 41 | test_fold_idx.append(pos_folds['fold_' + str(foldID)]) 42 | test_fold_idx.append(neg_folds['fold_' + str(foldID)]) 43 | test_fold_idx = np.concatenate(test_fold_idx, axis=1) 44 | return train_fold_idx.T, test_fold_idx.T 45 | 46 | def load_data(args, fold, rp, neg_times=1, train_test='train'): 47 | # Drug related 48 | if train_test == 'train': 49 | prefix0 = args.data_dir + '/processed/repeat{}/0/{}/fold_{}/'.format(rp, train_test, fold) 50 | else: 51 | prefix0 = args.data_dir + '/processed/repeat{}/0/{}/neg_times_{}/fold_{}/'.format(rp, train_test, neg_times, fold) 52 | 53 | adjlist00 = read_pickle(prefix0 + '/0-0.adjlist.pkl') 54 | adjlist01 = read_pickle(prefix0 + '/0-1-0.adjlist.pkl') 55 | adjlist02 = read_pickle(prefix0 + '/0-2-0.adjlist.pkl') 56 | adjlist03 = read_pickle(prefix0 + '/0-3-0.adjlist.pkl') 57 | adjlist04 = read_pickle(prefix0 + '/0-1-1-0.adjlist.pkl') 58 | # adjlist05 = read_pickle(prefix0 + '/0-1-0-1-0.adjlist.pkl') 59 | # adjlist06 = read_pickle(prefix0 + '/0-2-0-2-0.adjlist.pkl') 60 | # adjlist07 = read_pickle(prefix0 + '/0-3-0-3-0.adjlist.pkl') 61 | # adjlist08 = read_pickle(prefix0 + '/0-1-2-1-0.adjlist.pkl') 62 | # adjlist09 = read_pickle(prefix0 + '/0-2-1-2-0.adjlist.pkl') 63 | 64 | idx00 = read_pickle(prefix0 + '/0-0.idx.pkl') 65 | idx01 = read_pickle(prefix0 + '/0-1-0.idx.pkl') 66 | idx02 = read_pickle(prefix0 + '/0-2-0.idx.pkl') 67 | idx03 = read_pickle(prefix0 + '/0-3-0.idx.pkl') 68 | idx04 = read_pickle(prefix0 + '/0-1-1-0.idx.pkl') 69 | # idx05 = read_pickle(prefix0 + '/0-1-0-1-0.idx.pkl') 70 | # idx06 = read_pickle(prefix0 + '/0-2-0-2-0.idx.pkl') 71 | # idx07 = read_pickle(prefix0 + '/0-3-0-3-0.idx.pkl') 72 | # idx08 = read_pickle(prefix0 + '/0-1-2-1-0.idx.pkl') 73 | # idx09 = read_pickle(prefix0 + '/0-2-1-2-0.idx.pkl') 74 | 75 | # Protein related 76 | if train_test == 'train': 77 | prefix1 = args.data_dir + '/processed/repeat{}/1/{}/fold_{}/'.format(rp, train_test, fold) 78 | else: 79 | prefix1 = args.data_dir + '/processed/repeat{}/1/{}/neg_times_{}/fold_{}/'.format(rp, train_test, neg_times, fold) 80 | adjlist10 = read_pickle(prefix1 + '/1-1.adjlist.pkl') 81 | adjlist11 = read_pickle(prefix1 + '/1-0-1.adjlist.pkl') 82 | adjlist12 = read_pickle(prefix1 + '/1-2-1.adjlist.pkl') 83 | adjlist13 = read_pickle(prefix1 + '/1-0-0-1.adjlist.pkl') 84 | # adjlist14 = read_pickle(prefix1 + '/1-0-1-0-1.adjlist.pkl') 85 | # adjlist15 = read_pickle(prefix1 + '/1-0-2-0-1.adjlist.pkl') 86 | # adjlist16 = read_pickle(prefix1 + '/1-2-0-2-1.adjlist.pkl') 87 | # adjlist17 = read_pickle(prefix1 + '/1-2-1-2-1.adjlist.pkl') 88 | 89 | idx10 = read_pickle(prefix1 + '/1-1.idx.pkl') 90 | idx11 = read_pickle(prefix1 + '/1-0-1.idx.pkl') 91 | idx12 = read_pickle(prefix1 + '/1-2-1.idx.pkl') 92 | idx13 = read_pickle(prefix1 + '/1-0-0-1.idx.pkl') 93 | # idx14 = read_pickle(prefix1 + '/1-0-1-0-1.idx.pkl') 94 | # idx15 = read_pickle(prefix1 + '/1-0-2-0-1.idx.pkl') 95 | # idx16 = read_pickle(prefix1 + '/1-2-0-2-1.idx.pkl') 96 | idx17 = read_pickle(prefix1 + '/1-2-1-2-1.idx.pkl') 97 | 98 | return [[adjlist00, adjlist01, adjlist02, adjlist03, adjlist04], 99 | [adjlist10, adjlist11, adjlist12, adjlist13]], \ 100 | [[idx00, idx01, idx02, idx03, idx04], 101 | [idx10, idx11, idx12, idx13]] 102 | # return [[adjlist00, adjlist01, adjlist02, adjlist03, adjlist04, adjlist05, adjlist06, adjlist07, adjlist08, adjlist09], 103 | # [adjlist10, adjlist11, adjlist12, adjlist13, adjlist14, adjlist15, adjlist16, adjlist17]],\ 104 | # [[idx00, idx01, idx02, idx03, idx04, idx05, idx06, idx07, idx08, idx09], 105 | # [idx10, idx11, idx12, idx13, idx14, idx15, idx16, idx17]] 106 | 107 | 108 | 109 | class mydataset(Dataset): 110 | def __init__(self, drug_protein_idx, y_true): 111 | self.drug_protein_idx = drug_protein_idx 112 | self.Y = y_true 113 | 114 | def __len__(self): 115 | return len(self.drug_protein_idx) 116 | 117 | def __getitem__(self, index): 118 | d_p_idx = self.drug_protein_idx[index].tolist() 119 | y = self.Y[index] 120 | 121 | return d_p_idx, y 122 | 123 | class collate_fc(object): 124 | def __init__(self, adjlists, edge_metapath_indices_list, num_samples, offset, device): 125 | self.adjlists = adjlists 126 | self.edge_metapath_indices_list = edge_metapath_indices_list 127 | self.num_samples = num_samples 128 | self.offset = offset 129 | self.device = device 130 | 131 | def collate_func(self, batch_list): 132 | y_true = [y for _, y in batch_list] 133 | batch_list = [idx for idx, _ in batch_list] 134 | 135 | g_lists = [[], []] 136 | result_indices_lists = [[], []] 137 | idx_batch_mapped_lists = [[], []] 138 | for mode, (adjlists, edge_metapath_indices_list) in enumerate(zip(self.adjlists, self.edge_metapath_indices_list)): 139 | for adjlist, indices in zip(adjlists, edge_metapath_indices_list): 140 | edges, result_indices, num_nodes, mapping = parse_adjlist([adjlist[row[mode]] for row in batch_list], 141 | [indices[row[mode]] for row in batch_list], 142 | self.num_samples, offset=self.offset, mode=mode) 143 | 144 | g = dgl.DGLGraph() 145 | g.add_nodes(num_nodes) 146 | if len(edges) > 0: 147 | sorted_index = sorted(range(len(edges)), key=lambda i: edges[i]) 148 | g.add_edges(*list(zip(*[(edges[i][1], edges[i][0]) for i in sorted_index]))) 149 | result_indices = torch.LongTensor(result_indices[sorted_index]).to(self.device) 150 | else: 151 | result_indices = torch.LongTensor(result_indices).to(self.device) 152 | g_lists[mode].append(g) 153 | result_indices_lists[mode].append(result_indices) 154 | idx_batch_mapped_lists[mode].append(np.array([mapping[row[mode]] for row in batch_list])) 155 | 156 | return g_lists, result_indices_lists, idx_batch_mapped_lists, y_true, batch_list 157 | 158 | def parse_adjlist(adjlist, edge_metapath_indices, samples=None, offset=None, mode=None): 159 | edges = [] 160 | nodes = set() 161 | result_indices = [] 162 | for row, indices in zip(adjlist, edge_metapath_indices): 163 | row_parsed = list(map(int, row)) 164 | nodes.add(row_parsed[0]) 165 | if len(row_parsed) > 1: 166 | # sampling neighbors 167 | if samples is None: 168 | neighbors = row_parsed[1:] 169 | result_indices.append(indices) 170 | else: 171 | # undersampling frequent neighbors 172 | unique, counts = np.unique(row_parsed[1:], return_counts=True) 173 | p = [] 174 | for count in counts: 175 | p += [(count ** (3 / 4)) / count] * count 176 | p = np.array(p) 177 | p = p / p.sum() 178 | samples = min(samples, len(row_parsed) - 1) 179 | sampled_idx = np.sort(np.random.choice(len(row_parsed) - 1, samples, replace=False, p=p)) 180 | neighbors = [row_parsed[i + 1] for i in sampled_idx] 181 | result_indices.append(indices[sampled_idx]) 182 | else: 183 | neighbors = [row_parsed[0]] 184 | indices = np.array([[row_parsed[0]] * indices.shape[1]]) 185 | if mode == 1: 186 | indices += offset 187 | result_indices.append(indices) 188 | for dst in neighbors: 189 | nodes.add(dst) 190 | edges.append((row_parsed[0], dst)) 191 | mapping = {map_from: map_to for map_to, map_from in enumerate(sorted(nodes))} 192 | edges = list(map(lambda tup: (mapping[tup[0]], mapping[tup[1]]), edges)) 193 | result_indices = np.vstack(result_indices) 194 | return edges, result_indices, len(nodes), mapping 195 | 196 | def get_features(args, type_mask): 197 | features_list = [] 198 | in_dims = [] 199 | if args.feats_type == 0: 200 | for i in range(args.num_ntype): 201 | dim = (type_mask == i).sum() 202 | in_dims.append(dim) 203 | indices = np.vstack((np.arange(dim), np.arange(dim))) 204 | indices = torch.LongTensor(indices) 205 | values = torch.FloatTensor(np.ones(dim)) 206 | features_list.append(torch.sparse.FloatTensor(indices, values, torch.Size([dim, dim])).to(args.device)) 207 | elif args.feats_type == 1: 208 | for i in range(args.num_ntype): 209 | dim = 10 210 | num_nodes = (type_mask == i).sum() 211 | in_dims.append(dim) 212 | features_list.append(torch.zeros((num_nodes, 10)).to(args.device)) 213 | 214 | return features_list, in_dims 215 | 216 | -------------------------------------------------------------------------------- /model/base_MAGNN_dti.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl.function as fn 5 | from dgl.nn.pytorch import edge_softmax 6 | 7 | 8 | class MAGNN_metapath_specific(nn.Module): 9 | def __init__(self, 10 | etypes, 11 | out_dim, 12 | num_heads, 13 | rnn_type='gru', 14 | r_vec=None, 15 | attn_drop=0.5, 16 | alpha=0.01, 17 | use_minibatch=False, 18 | attn_switch=False): 19 | super(MAGNN_metapath_specific, self).__init__() 20 | self.out_dim = out_dim 21 | self.num_heads = num_heads 22 | self.rnn_type = rnn_type 23 | self.etypes = etypes 24 | self.r_vec = r_vec 25 | self.use_minibatch = use_minibatch 26 | self.attn_switch = attn_switch 27 | 28 | # rnn-like metapath instance aggregator 29 | # consider multiple attention heads 30 | if rnn_type == 'gru': 31 | self.rnn = nn.GRU(out_dim, num_heads * out_dim) 32 | elif rnn_type == 'lstm': 33 | self.rnn = nn.LSTM(out_dim, num_heads * out_dim) 34 | elif rnn_type == 'bi-gru': 35 | self.rnn = nn.GRU(out_dim, num_heads * out_dim // 2, bidirectional=True) 36 | elif rnn_type == 'bi-lstm': 37 | self.rnn = nn.LSTM(out_dim, num_heads * out_dim // 2, bidirectional=True) 38 | elif rnn_type == 'linear': 39 | self.rnn = nn.Linear(out_dim, num_heads * out_dim) 40 | elif rnn_type == 'max-pooling': 41 | self.rnn = nn.Linear(out_dim, num_heads * out_dim) 42 | elif rnn_type == 'neighbor-linear': 43 | self.rnn = nn.Linear(out_dim, num_heads * out_dim) 44 | 45 | # node-level attention 46 | # attention considers the center node embedding or not 47 | if self.attn_switch: 48 | self.attn1 = nn.Linear(out_dim, num_heads, bias=False) 49 | self.attn2 = nn.Parameter(torch.empty(size=(1, num_heads, out_dim))) 50 | else: 51 | self.attn = nn.Parameter(torch.empty(size=(1, num_heads, out_dim))) 52 | self.leaky_relu = nn.LeakyReLU(alpha) 53 | self.softmax = edge_softmax 54 | if attn_drop: 55 | self.attn_drop = nn.Dropout(attn_drop) 56 | else: 57 | self.attn_drop = lambda x: x 58 | 59 | # weight initialization 60 | if self.attn_switch: 61 | nn.init.xavier_normal_(self.attn1.weight, gain=1.414) 62 | nn.init.xavier_normal_(self.attn2.data, gain=1.414) 63 | else: 64 | nn.init.xavier_normal_(self.attn.data, gain=1.414) 65 | 66 | def edge_softmax(self, g): 67 | attention = self.softmax(g, g.edata.pop('a')) 68 | # Dropout attention scores and save them 69 | g.edata['a_drop'] = self.attn_drop(attention) 70 | 71 | def message_passing(self, edges): 72 | ft = edges.data['eft'] * edges.data['a_drop'] 73 | return {'ft': ft} 74 | 75 | def forward(self, inputs): 76 | # features: num_all_nodes x out_dim 77 | if self.use_minibatch: 78 | g, features, type_mask, edge_metapath_indices, target_idx = inputs 79 | else: 80 | g, features, type_mask, edge_metapath_indices = inputs 81 | 82 | g.to(features.device) 83 | 84 | # Embedding layer 85 | # use torch.nn.functional.embedding or torch.embedding here 86 | # do not use torch.nn.embedding 87 | # edata: E x Seq x out_dim 88 | edata = F.embedding(edge_metapath_indices, features) ## node features on each metapath 89 | 90 | # apply rnn to metapath-based feature sequence 91 | if self.rnn_type == 'gru': 92 | _, hidden = self.rnn(edata.permute(1, 0, 2)) 93 | elif self.rnn_type == 'lstm': 94 | _, (hidden, _) = self.rnn(edata.permute(1, 0, 2)) 95 | elif self.rnn_type == 'bi-gru': 96 | _, hidden = self.rnn(edata.permute(1, 0, 2)) 97 | hidden = hidden.permute(1, 0, 2).reshape(-1, self.out_dim, self.num_heads).permute(0, 2, 1).reshape( 98 | -1, self.num_heads * self.out_dim).unsqueeze(dim=0) 99 | elif self.rnn_type == 'bi-lstm': 100 | _, (hidden, _) = self.rnn(edata.permute(1, 0, 2)) 101 | hidden = hidden.permute(1, 0, 2).reshape(-1, self.out_dim, self.num_heads).permute(0, 2, 1).reshape( 102 | -1, self.num_heads * self.out_dim).unsqueeze(dim=0) 103 | elif self.rnn_type == 'average': 104 | hidden = torch.mean(edata, dim=1) 105 | hidden = torch.cat([hidden] * self.num_heads, dim=1) 106 | hidden = hidden.unsqueeze(dim=0) 107 | elif self.rnn_type == 'linear': 108 | hidden = self.rnn(torch.mean(edata, dim=1)) 109 | hidden = hidden.unsqueeze(dim=0) 110 | elif self.rnn_type == 'max-pooling': 111 | hidden, _ = torch.max(self.rnn(edata), dim=1) 112 | hidden = hidden.unsqueeze(dim=0) 113 | elif self.rnn_type == 'TransE0' or self.rnn_type == 'TransE1': 114 | r_vec = self.r_vec 115 | if self.rnn_type == 'TransE0': 116 | r_vec = torch.stack((r_vec, -r_vec), dim=1) 117 | r_vec = r_vec.reshape(self.r_vec.shape[0] * 2, self.r_vec.shape[1]) # etypes x out_dim 118 | edata = F.normalize(edata, p=2, dim=2) 119 | for i in range(edata.shape[1] - 1): 120 | # consider None edge (symmetric relation) 121 | temp_etypes = [etype for etype in self.etypes[i:] if etype is not None] 122 | edata[:, i] = edata[:, i] + r_vec[temp_etypes].sum(dim=0) 123 | hidden = torch.mean(edata, dim=1) 124 | hidden = torch.cat([hidden] * self.num_heads, dim=1) 125 | hidden = hidden.unsqueeze(dim=0) 126 | elif self.rnn_type == 'RotatE0' or self.rnn_type == 'RotatE1': 127 | r_vec = F.normalize(self.r_vec, p=2, dim=2) 128 | if self.rnn_type == 'RotatE0': 129 | r_vec = torch.stack((r_vec, r_vec), dim=1) 130 | r_vec[:, 1, :, 1] = -r_vec[:, 1, :, 1] 131 | r_vec = r_vec.reshape(self.r_vec.shape[0] * 2, self.r_vec.shape[1], 2) # etypes x out_dim/2 x 2 132 | edata = edata.reshape(edata.shape[0], edata.shape[1], edata.shape[2] // 2, 2) 133 | final_r_vec = torch.zeros([edata.shape[1], self.out_dim // 2, 2], device=edata.device) 134 | final_r_vec[-1, :, 0] = 1 135 | for i in range(final_r_vec.shape[0] - 2, -1, -1): 136 | # consider None edge (symmetric relation) 137 | if self.etypes[i] is not None: 138 | final_r_vec[i, :, 0] = final_r_vec[i + 1, :, 0].clone() * r_vec[self.etypes[i], :, 0] -\ 139 | final_r_vec[i + 1, :, 1].clone() * r_vec[self.etypes[i], :, 1] 140 | final_r_vec[i, :, 1] = final_r_vec[i + 1, :, 0].clone() * r_vec[self.etypes[i], :, 1] +\ 141 | final_r_vec[i + 1, :, 1].clone() * r_vec[self.etypes[i], :, 0] 142 | else: 143 | final_r_vec[i, :, 0] = final_r_vec[i + 1, :, 0].clone() 144 | final_r_vec[i, :, 1] = final_r_vec[i + 1, :, 1].clone() 145 | for i in range(edata.shape[1] - 1): 146 | temp1 = edata[:, i, :, 0].clone() * final_r_vec[i, :, 0] -\ 147 | edata[:, i, :, 1].clone() * final_r_vec[i, :, 1] 148 | temp2 = edata[:, i, :, 0].clone() * final_r_vec[i, :, 1] +\ 149 | edata[:, i, :, 1].clone() * final_r_vec[i, :, 0] 150 | edata[:, i, :, 0] = temp1 151 | edata[:, i, :, 1] = temp2 152 | edata = edata.reshape(edata.shape[0], edata.shape[1], -1) 153 | hidden = torch.mean(edata, dim=1) 154 | hidden = torch.cat([hidden] * self.num_heads, dim=1) 155 | hidden = hidden.unsqueeze(dim=0) 156 | elif self.rnn_type == 'neighbor': 157 | hidden = edata[:, 0] 158 | hidden = torch.cat([hidden] * self.num_heads, dim=1) 159 | hidden = hidden.unsqueeze(dim=0) 160 | elif self.rnn_type == 'neighbor-linear': 161 | hidden = self.rnn(edata[:, 0]) 162 | hidden = hidden.unsqueeze(dim=0) 163 | 164 | eft = hidden.permute(1, 0, 2).view(-1, self.num_heads, self.out_dim) # E x num_heads x out_dim 165 | if self.attn_switch: 166 | center_node_feat = F.embedding(edge_metapath_indices[:, -1], features) # E x out_dim 167 | a1 = self.attn1(center_node_feat) # E x num_heads 168 | a2 = (eft * self.attn2).sum(dim=-1) # E x num_heads 169 | a = (a1 + a2).unsqueeze(dim=-1) # E x num_heads x 1 170 | else: 171 | a = (eft * self.attn).sum(dim=-1).unsqueeze(dim=-1) # E x num_heads x 1 172 | a = self.leaky_relu(a) 173 | g.edata.update({'eft': eft, 'a': a}) 174 | # compute softmax normalized attention values 175 | self.edge_softmax(g) # Compute softmax over weights of incoming edges for every node. 176 | # compute the aggregated node features scaled by the dropped, 177 | # unnormalized attention values. 178 | g.update_all(self.message_passing, fn.sum('ft', 'ft')) # Send messages along all the edges of the specified type and update all the nodes of the corresponding destination type 179 | ret = g.ndata['ft'] # E x num_heads x out_dim 180 | 181 | if self.use_minibatch: 182 | return ret[target_idx] 183 | else: 184 | return ret 185 | 186 | 187 | class MAGNN_ctr_ntype_specific(nn.Module): 188 | def __init__(self, 189 | num_metapaths, 190 | etypes_list, 191 | out_dim, 192 | num_heads, 193 | attn_vec_dim, 194 | rnn_type='gru', 195 | r_vec=None, 196 | attn_drop=0.5, 197 | use_minibatch=False, 198 | attn_switch=False, 199 | args=None): 200 | super(MAGNN_ctr_ntype_specific, self).__init__() 201 | self.out_dim = out_dim 202 | self.num_heads = num_heads 203 | self.use_minibatch = use_minibatch 204 | self.args = args 205 | 206 | # metapath-specific layers 207 | self.metapath_layers = nn.ModuleList() 208 | for i in range(num_metapaths): 209 | self.metapath_layers.append(MAGNN_metapath_specific(etypes_list[i], 210 | out_dim, 211 | num_heads, 212 | rnn_type, 213 | r_vec, 214 | attn_drop=attn_drop, 215 | use_minibatch=use_minibatch, 216 | attn_switch=attn_switch)) 217 | 218 | # metapath-level attention 219 | # note that the acutal input dimension should consider the number of heads 220 | # as multiple head outputs are concatenated together 221 | if self.args.semantic_fusion == 'attention': 222 | self.fc1 = nn.Linear(out_dim * num_heads, attn_vec_dim, bias=True) 223 | self.fc2 = nn.Linear(attn_vec_dim, 1, bias=False) 224 | # weight initialization 225 | nn.init.xavier_normal_(self.fc1.weight, gain=1.414) 226 | nn.init.xavier_normal_(self.fc2.weight, gain=1.414) 227 | 228 | def forward(self, inputs): 229 | if self.use_minibatch: 230 | g_list, features, type_mask, edge_metapath_indices_list, target_idx_list = inputs 231 | 232 | # metapath-specific layers 233 | metapath_outs = [F.elu(metapath_layer((g, features, type_mask, edge_metapath_indices, target_idx)).view(-1, self.num_heads * self.out_dim)) 234 | for g, edge_metapath_indices, target_idx, metapath_layer in zip(g_list, edge_metapath_indices_list, target_idx_list, self.metapath_layers)] 235 | else: 236 | g_list, features, type_mask, edge_metapath_indices_list = inputs 237 | 238 | # metapath-specific layers 239 | metapath_outs = [F.elu(metapath_layer((g, features, type_mask, edge_metapath_indices)).view(-1, self.num_heads * self.out_dim)) 240 | for g, edge_metapath_indices, metapath_layer in zip(g_list, edge_metapath_indices_list, self.metapath_layers)] 241 | 242 | if self.args.semantic_fusion == 'attention': 243 | beta = [] 244 | for metapath_out in metapath_outs: 245 | fc1 = torch.tanh(self.fc1(metapath_out)) 246 | fc1_mean = torch.mean(fc1, dim=0) # metapath specific vector 247 | fc2 = self.fc2(fc1_mean) # metapath importance 248 | beta.append(fc2) 249 | beta = torch.cat(beta, dim=0) 250 | beta = F.softmax(beta, dim=0) 251 | beta = torch.unsqueeze(beta, dim=-1) 252 | beta = torch.unsqueeze(beta, dim=-1) 253 | metapath_outs = [torch.unsqueeze(metapath_out, dim=0) for metapath_out in metapath_outs] 254 | metapath_outs = torch.cat(metapath_outs, dim=0) 255 | h = torch.sum(beta * metapath_outs, dim=0) 256 | elif self.args.semantic_fusion == 'average': 257 | h = torch.mean(torch.stack(metapath_outs, dim=0), dim=0) 258 | elif self.args.semantic_fusion == 'max-pooling': 259 | h, _ = torch.max(torch.stack(metapath_outs, dim=0), dim=0) 260 | elif self.args.semantic_fusion == 'concatenation': 261 | h = torch.cat(metapath_outs, dim=1) 262 | return h 263 | -------------------------------------------------------------------------------- /run_MHAN_dti.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import numpy as np 6 | import json 7 | import pandas as pd 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.utils.data import DataLoader 13 | from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score 14 | 15 | from utils.pytorchtools import EarlyStopping 16 | from utils.data2 import load_data, fold_train_test_idx, mydataset, collate_fc, get_features 17 | from model.MAGNN_dti import MAGNN_lp 18 | 19 | loss_bec = nn.BCELoss() 20 | 21 | def setup_seed(seed): 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | np.random.seed(seed) 25 | # random.seed(seed) 26 | # torch.backends.cudnn.deterministic = True 27 | 28 | setup_seed(20) 29 | 30 | 31 | 32 | class Logger(object): 33 | def __init__(self, fileN="Default.log"): 34 | self.terminal = sys.stdout 35 | self.log = open(fileN, "a") 36 | 37 | def write(self, message): 38 | self.terminal.write(message) 39 | self.terminal.flush() 40 | self.log.write(message) 41 | self.log.flush() 42 | 43 | def flush(self): 44 | pass 45 | 46 | def make_dir(fp): 47 | if not os.path.exists(fp): 48 | os.makedirs(fp, exist_ok=True) 49 | 50 | def get_MSE(y,f): 51 | mse = ((y - f)**2).mean(axis=0) 52 | return mse 53 | 54 | def get_adj(dti): 55 | len_dti = len(dti) 56 | dpp_adj = np.zeros((len_dti, len_dti), dtype=int) 57 | for i, dpp1 in enumerate(dti): 58 | for j, dpp2 in enumerate(dti): 59 | if (dpp1[0] == dpp2[0]) | (dpp1[1] == dpp2[1]): 60 | dpp_adj[i][j] = 1 61 | return dpp_adj 62 | 63 | 64 | def training(net, optimizer, train_loader, features_list, type_mask): 65 | net.train() 66 | train_loss = 0 67 | total = 0 68 | for i, (train_g_lists, train_indices_lists, train_idx_batch_mapped_lists, y_train, batch_list) in enumerate(train_loader): 69 | y_train = torch.tensor(y_train).long().to(features_list[0].device) 70 | adj_i = get_adj(batch_list) 71 | adj_i = torch.FloatTensor(adj_i).to(features_list[0].device) 72 | # forward 73 | output = net((train_g_lists, features_list, type_mask, train_indices_lists, train_idx_batch_mapped_lists, adj_i)) 74 | loss = F.nll_loss(torch.log(output), y_train) 75 | # loss = loss_bec(output[:, 1], y_train.float()) # the same to above 76 | 77 | # autograd 78 | optimizer.zero_grad() 79 | loss.backward() 80 | optimizer.step() 81 | train_loss = train_loss + loss.item() * len(y_train) 82 | total = total + len(y_train) 83 | 84 | return train_loss/total 85 | 86 | def evaluate(net, test_loader, features_list, type_mask, y_true): 87 | net.eval() 88 | pred_val = [] 89 | y_true_s = [] 90 | with torch.no_grad(): 91 | for i, (val_g_lists, val_indices_lists, val_idx_batch_mapped_lists, y_true, batch_list) in enumerate(test_loader): 92 | # forward 93 | adj_i = get_adj(batch_list) 94 | adj_i = torch.FloatTensor(adj_i).to(features_list[0].device) 95 | output = net((val_g_lists, features_list, type_mask, val_indices_lists, val_idx_batch_mapped_lists, adj_i)) 96 | pred_val.append(output) 97 | y_true_s.append(torch.tensor(y_true).long().to(features_list[0].device)) 98 | 99 | val_pred = torch.cat(pred_val) 100 | y_true = torch.cat(y_true_s) 101 | val_loss = F.nll_loss(torch.log(val_pred), y_true) 102 | val_pred = val_pred.cpu().numpy() 103 | y_true = y_true.cpu().numpy() 104 | acc = accuracy_score(y_true, np.argmax(val_pred, axis=1)) 105 | auc = roc_auc_score(y_true, val_pred[:, 1]) 106 | aupr = average_precision_score(y_true, val_pred[:, 1]) 107 | 108 | return val_loss, acc, auc, aupr, val_pred 109 | 110 | def testing(net, test_loader, features_list, type_mask, y_true_test): 111 | net.eval() 112 | proba_list = [] 113 | with torch.no_grad(): 114 | for i, (test_g_lists, test_indices_lists, test_idx_batch_mapped_lists, y_test, batch_list) in enumerate(test_loader): 115 | # forward 116 | adj_i = get_adj(batch_list) 117 | adj_i = torch.FloatTensor(adj_i).to(features_list[0].device) 118 | output = net((test_g_lists, features_list, type_mask, test_indices_lists, test_idx_batch_mapped_lists, adj_i)) 119 | proba_list.append(output) 120 | 121 | y_proba_test = torch.cat(proba_list) 122 | y_proba_test = y_proba_test.cpu().numpy() 123 | auc = roc_auc_score(y_true_test, y_proba_test[:, 1]) 124 | aupr = average_precision_score(y_true_test, y_proba_test[:, 1]) 125 | return auc, aupr, y_true_test, y_proba_test 126 | 127 | def run_model(args): 128 | fold_path = args.data_dir + '/{}_folds/'.format(str(args.nFold)) 129 | pos_folds = json.load(open(fold_path + 'pos_folds.json', 'r')) 130 | neg_folds = json.load(open(fold_path + 'neg_folds_times_{}.json'.format(str(args.neg_times), 'r'))) 131 | type_mask = np.load(args.data_dir + '/processed/node_types.npy') 132 | drug_protein = np.loadtxt(args.data_dir + '/mat_data/mat_drug_protein.txt', dtype=int) 133 | 134 | f_csv = open(args.save_dir + 'results.csv', 'a') 135 | f_csv.write('Fold,AUC,AUPR\n') 136 | f_csv.close() 137 | 138 | for fold in range(args.nFold): 139 | results = {} 140 | print('\nThis is fold ', fold, '...') 141 | if os.path.exists(args.save_dir + '/checkpoint/checkpoint_fold_{}_best.pt'.format(fold)): 142 | print('The training of this fold has been completed!\n') 143 | continue 144 | train_fold_idx, test_fold_idx = fold_train_test_idx(pos_folds, neg_folds, args.nFold, fold) 145 | train_adjlists, train_edge_metapath_indices_list = load_data(args, fold, args.rp, args.neg_times, 'train') 146 | test_adjlists, test_edge_metapath_indices_list = load_data(args, fold, args.rp, args.neg_times, 'test') 147 | y_true_train = drug_protein[train_fold_idx[:,0], train_fold_idx[:,1]] 148 | y_true_test = drug_protein[test_fold_idx[:,0], test_fold_idx[:,1]] 149 | [num_metapaths_drug, num_metapaths_protein] = len(train_adjlists[0]), len(train_adjlists[1]) 150 | 151 | # training set 152 | train_dataset = mydataset(train_fold_idx, y_true_train) 153 | train_collate = collate_fc(train_adjlists, train_edge_metapath_indices_list, num_samples=args.samples, 154 | offset=drug_protein.shape[0], device=args.device) 155 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=False, 156 | collate_fn=train_collate.collate_func) 157 | 158 | # test set 159 | test_dataset = mydataset(test_fold_idx, y_true_test) 160 | test_collate = collate_fc(test_adjlists, test_edge_metapath_indices_list, num_samples=args.samples, 161 | offset=drug_protein.shape[0], device=args.device) 162 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False, 163 | collate_fn=test_collate.collate_func) 164 | 165 | # Input features 166 | features_list, in_dims = get_features(args, type_mask) 167 | 168 | # network 169 | net = MAGNN_lp([num_metapaths_drug, num_metapaths_protein], args.num_etypes, args.etypes_lists, in_dims, 170 | args.hidden_dim, args.hidden_dim, args.num_heads, args.attn_vec_dim, args.rnn_type, 171 | args.dropout_rate, args.attn_switch, args) 172 | net.to(args.device) 173 | optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) 174 | # early_stopping = EarlyStopping(patience=args.patience, verbose=True, 175 | # save_path=args.save_dir + '/checkpoint/checkpoint_{}.pt'.format(fold)) 176 | make_dir(args.save_dir + '/checkpoint') 177 | 178 | best_acc = 0 179 | best_auc = 0 180 | best_aupr = 0 181 | pred = None 182 | counter = 0 183 | if args.only_test: 184 | # Test 185 | net.load_state_dict(torch.load(args.save_dir + '/checkpoint/checkpoint_fold_{}.pt'.format(fold))) 186 | auc, aupr, ground_truth, y_pred = testing(net, test_loader, features_list, type_mask, y_true_test) 187 | best_auc, best_aupr, pred = auc, aupr, y_pred 188 | else: 189 | if os.path.exists(args.save_dir + '/checkpoint/checkpoint_fold_{}.pt'.format(fold)): 190 | print('Load mdeol weights from /checkpoint/checkpoint_fold_{}.pt'.format(fold)) 191 | net.load_state_dict(torch.load(args.save_dir + '/checkpoint/checkpoint_fold_{}.pt'.format(fold), map_location=args.device)) 192 | for epoch in range(args.epoch): 193 | # training 194 | train_loss = training(net, optimizer, train_loader, features_list, type_mask) 195 | # validation 196 | val_loss, acc, auc, aupr, y_pred = evaluate(net, test_loader, features_list, type_mask, y_true_test) 197 | print('Epoch {:d} | Train loss {:.6f} | Val loss {:.6f} | acc {:.4f} | auc {:.4f} | aupr {:.4f}'.format( 198 | epoch, train_loss, val_loss, acc, auc, aupr)) 199 | # early stopping 200 | if (best_aupr < aupr) | (best_acc < acc): 201 | best_acc = acc 202 | best_auc = auc 203 | best_aupr, pred = aupr, y_pred 204 | torch.save(net.state_dict(), args.save_dir + '/checkpoint/checkpoint_fold_{}.pt'.format(fold)) 205 | counter = 0 206 | else: 207 | counter += 1 208 | 209 | if counter > args.patience: 210 | print('Early stopping!') 211 | break 212 | f_csv = open(args.save_dir + 'results.csv', 'a') 213 | f_csv.write(','.join(map(str, [fold, best_auc, best_aupr])) + '\n') 214 | f_csv.close() 215 | best_weights = torch.load(args.save_dir + '/checkpoint/checkpoint_fold_{}.pt'.format(fold), map_location=args.device) 216 | torch.save(best_weights, args.save_dir + '/checkpoint/checkpoint_fold_{}_best.pt'.format(fold)) 217 | 218 | results['pred'] = pred.tolist() 219 | results['ground_truth'] = y_true_test.tolist() 220 | results['AUC'] = best_auc.item() 221 | results['AUPR'] = best_aupr.item() 222 | json.dump(results, open(args.save_dir.format(rp) + f'/fold000{fold}_pred_results.json', 'w')) 223 | 224 | 225 | res = pd.read_csv(args.save_dir + 'results.csv') 226 | try: 227 | auc_list = [float(res[res['Fold'] == i]['AUC'].values[0]) for i in range(args.nFold)] 228 | aupr_list = [float(res[res['Fold'] == i]['AUPR'].values[0]) for i in range(args.nFold)] 229 | except: 230 | auc_list = [float(res[res['Fold'] == str(i)]['AUC'].values[0]) for i in range(args.nFold)] 231 | aupr_list = [float(res[res['Fold'] == str(i)]['AUPR'].values[0]) for i in range(args.nFold)] 232 | 233 | print('----------------------------------------------------------------') 234 | print('Link Prediction Tests Summary') 235 | print('AUC_mean = {}, AUC_std = {}'.format(np.mean(auc_list), np.std(auc_list))) 236 | print('AUPR_mean = {}, AUPR_std = {}'.format(np.mean(aupr_list), np.std(aupr_list))) 237 | 238 | f_csv = open(args.save_dir + 'results.csv', 'a') 239 | f_csv.write(','.join(map(str, ['mean', np.mean(auc_list), np.mean(aupr_list)])) + '\n') 240 | f_csv.write(','.join(map(str, ['std', np.std(auc_list), np.std(aupr_list)])) + '\n') 241 | f_csv.close() 242 | 243 | 244 | # Params 245 | def parser(): 246 | ap = argparse.ArgumentParser(description='MRGNN testing for the recommendation dataset') 247 | ap.add_argument('--device', default='cuda:0') 248 | ap.add_argument('--feats_type', type=int, default=0, 249 | help='Type of the node features used. 0 - all id vectors; 1 - all zero vector. Default is 0.') 250 | ap.add_argument('--hidden_dim', type=int, default=64, help='Dimension of the node hidden state. Default is 64.') 251 | ap.add_argument('--num_heads', type=int, default=8, help='Number of the attention heads. Default is 8.') 252 | ap.add_argument('--attn_vec_dim', type=int, default=128, help='Dimension of the attention vector. Default is 128.') 253 | ap.add_argument('--attn_switch', type=bool, default=True, help='attention considers the center node embedding or not') 254 | ap.add_argument('--rnn_type', default='max-pooling', help='Type of the aggregator. max-pooling, average, linear, neighbor, RotatE0.') 255 | ap.add_argument('--predictor', default='gcn', help='options: linear, gcn.') 256 | ap.add_argument('--semantic_fusion', default='concatenation', help='options: concatenation, attention, max-pooling, average.') 257 | ap.add_argument('--epoch', type=int, default=200, help='Number of epochs. Default is 100.') 258 | ap.add_argument('--patience', type=int, default=15, help='Patience. Default is 5.') 259 | ap.add_argument('--batch_size', type=int, default=256, help='Batch size. Default is 8.') 260 | ap.add_argument('--samples', type=int, default=100, help='Number of neighbors sampled. Default is 100.') 261 | ap.add_argument('--repeat', type=int, default=1, help='Repeat the training and testing for N times. Default is 1.') 262 | ap.add_argument('--num_ntype', default=4, type=int, help='Number of node types') 263 | ap.add_argument('--lr', default=0.0001) 264 | ap.add_argument('--weight_decay', default=1e-5) 265 | ap.add_argument('--dropout_rate', default=0.5) 266 | ap.add_argument('--num_workers', default=0, type=int) 267 | 268 | ap.add_argument('--nFold', default=10, type=int) 269 | ap.add_argument('--neg_times', default=1, type=int, help='The ratio between positive samples and negative samples') 270 | ap.add_argument('--data_dir', default='/media/data2/lm/Experiments/MHAN-DTI/hetero_dataset/{}/') 271 | ap.add_argument('--save_dir', 272 | default='./results_dpp/{}/repeat{}/neg_times{}_{}_{}_{}_num_head{}_hidden_dim{}_batch_sz{}_semantic_fusion_{}' 273 | '_predictor_{}/', 274 | help='Postfix for the saved model and result. Default is LastFM.') 275 | ap.add_argument('--only_test', default=False, type=bool) 276 | args = ap.parse_args() 277 | return args 278 | 279 | if __name__ == '__main__': 280 | args = parser() 281 | args.dataset = 'data' 282 | args.data_dir = args.data_dir.format(args.dataset) 283 | # args.save_dir = args.save_dir.format(args.dataset) 284 | # make_dir(args.save_dir) 285 | 286 | etypes_lists = [ 287 | [[None], [0, 1], [2, 3], [4, 5], [0, None, 1]],# [0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [0, 6, 7, 1], [2, 7, 6, 3]], 288 | [[None], [1, 0], [6, 7], [1, None, 0]]#, [1, 0, 1, 0], [1, 2, 3, 0], [6, 3, 2, 7], [6, 7, 6, 7]] 289 | ] 290 | 291 | expected_metapaths = [ 292 | [(0, 0), (0, 1, 0), (0, 2, 0), (0, 3, 0), (0, 1, 1, 0)], 293 | # (0, 1, 0, 1, 0), (0, 2, 0, 2, 0), (0, 3, 0, 3, 0), (0, 1, 2, 1, 0), (0, 2, 1, 2, 0)], 294 | [(1, 1), (1, 0, 1), (1, 2, 1), (1, 0, 0, 1)], 295 | # (1, 0, 1, 0, 1), (1, 0, 2, 0, 1), (1, 2, 0, 2, 1), (1, 2, 1, 2, 1)] 296 | ] 297 | 298 | args.etypes_lists = etypes_lists 299 | args.num_etypes = 8 300 | args.expected_metapaths = expected_metapaths 301 | 302 | for rp in range(args.repeat): 303 | print('This is repeat ', rp) 304 | args.rp = rp 305 | save_dir = args.save_dir 306 | args.save_dir = args.save_dir.format(args.dataset, args.rp, args.neg_times, args.rnn_type.capitalize(), 307 | len(args.expected_metapaths[0]), len(args.expected_metapaths[1]), 308 | args.num_heads, args.hidden_dim, args.batch_size, 309 | args.semantic_fusion, args.predictor) 310 | print('Save path ', args.save_dir) 311 | make_dir(args.save_dir) 312 | sys.stdout = Logger(args.save_dir + 'log.txt') 313 | 314 | run_model(args) 315 | 316 | print('Save path ', args.save_dir) 317 | args.save_dir = save_dir 318 | -------------------------------------------------------------------------------- /train_preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import os 4 | import pathlib 5 | import pickle 6 | import random 7 | import numpy as np 8 | import scipy.sparse as sp 9 | import scipy.io 10 | import pandas as pd 11 | import json 12 | 13 | seed = 123 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | 17 | def make_dir(fp): 18 | if not os.path.exists(fp): 19 | os.makedirs(fp, exist_ok=True) 20 | 21 | def metapath_xx(x_x_list, num, sample=None): 22 | x_x = [] 23 | for x, x_list in x_x_list.items(): 24 | if sample is not None: 25 | candidate_list = np.random.choice(len(x_list), min(len(x_list), sample), replace=False) 26 | x_list = x_list[candidate_list] 27 | x_x.extend([(x, x1) for x1 in x_list]) 28 | x_x = np.array(x_x) 29 | x_x = x_x + num 30 | sorted_index = sorted(list(range(len(x_x))), key=lambda i: x_x[i].tolist()) 31 | x_x = x_x[sorted_index] 32 | return x_x 33 | 34 | def metapath_yxy(x_y_list, num1, num2, sample=None): 35 | y_x_y = [] 36 | for x, y_list in x_y_list.items(): 37 | if sample is not None: 38 | candidate_list1 = np.random.choice(len(y_list), min(len(y_list), sample), replace=False) 39 | candidate_list2 = np.random.choice(len(y_list), min(len(y_list), sample), replace=False) 40 | # print(len(candidate_list1)) 41 | y_list1 = y_list[candidate_list1] 42 | y_list2 = y_list[candidate_list2] 43 | y_x_y.extend([(y1, x, y2) for y1 in y_list1 for y2 in y_list2]) 44 | else: 45 | y_x_y.extend([(y1, x, y2) for y1 in y_list for y2 in y_list]) 46 | y_x_y = np.array(y_x_y) 47 | y_x_y[:, [0, 2]] += num1 48 | y_x_y[:, 1] += num2 49 | sorted_index = sorted(list(range(len(y_x_y))), key=lambda i: y_x_y[i, [0, 2, 1]].tolist()) 50 | y_x_y = y_x_y[sorted_index] 51 | return y_x_y 52 | 53 | def metapath_yxxy(x_x, x_y_list, num1, num2, sample=None): 54 | y_x_x_y = [] 55 | for x1, x2 in x_x: 56 | if sample is not None: 57 | candidate_list1 = np.random.choice(len(x_y_list[x1 - num2]), min(len(x_y_list[x1 - num2]), sample), replace=False) 58 | candidate_list2 = np.random.choice(len(x_y_list[x2 - num2]), min(len(x_y_list[x2 - num2]), sample), replace=False) 59 | # print(len(candidate_list1)) 60 | x_y_list1 = x_y_list[x1 - num2][candidate_list1] 61 | x_y_list2 = x_y_list[x2 - num2][candidate_list2] 62 | y_x_x_y.extend([(y1, x1, x2, y2) for y1 in x_y_list1 for y2 in x_y_list2]) 63 | else: 64 | y_x_x_y.extend([(y1, x1, x2, y2) for y1 in x_y_list[x1 - num2] for y2 in x_y_list[x2 - num2]]) 65 | y_x_x_y = np.array(y_x_x_y) 66 | y_x_x_y[:, [0, 3]] += num1 67 | sorted_index = sorted(list(range(len(y_x_x_y))), key=lambda i: y_x_x_y[i, [0, 3, 1, 2]].tolist()) 68 | y_x_x_y = y_x_x_y[sorted_index] 69 | return y_x_x_y 70 | 71 | def metapath_zyxyz(y_x_y, y_z_list, num1, num2, ratio): 72 | # z-y-x-y-z 73 | z_y_x_y_z = [] 74 | for y1, x, y2 in y_x_y: 75 | if len(y_z_list[y1 - num2]) == 0 or len(y_z_list[y2 - num2]) == 0: 76 | continue 77 | if ratio <= 1: 78 | candidate_z1_list = np.random.choice(len(y_z_list[y1 - num2]), int(ratio * len(y_z_list[y1 - num2])) + 1, replace=False) 79 | candidate_z2_list = np.random.choice(len(y_z_list[y2 - num2]), int(ratio * len(y_z_list[y2 - num2])) + 1, replace=False) 80 | else: 81 | candidate_z1_list = np.random.choice(len(y_z_list[y1 - num2]), min(ratio, len(y_z_list[y1 - num2])), replace=False) 82 | candidate_z2_list = np.random.choice(len(y_z_list[y2 - num2]), min(ratio, len(y_z_list[y2 - num2])), replace=False) 83 | candidate_z1_list = y_z_list[y1 - num2][candidate_z1_list] 84 | candidate_z2_list = y_z_list[y2 - num2][candidate_z2_list] 85 | 86 | z_y_x_y_z.extend([(z1, y1, x, y2, z2) for z1 in candidate_z1_list for z2 in candidate_z2_list]) 87 | z_y_x_y_z = np.array(z_y_x_y_z) 88 | z_y_x_y_z[:, [0, 4]] += num1 89 | sorted_index = sorted(list(range(len(z_y_x_y_z))), key=lambda i: z_y_x_y_z[i, [0, 4, 1, 2, 3]].tolist()) 90 | z_y_x_y_z = z_y_x_y_z[sorted_index] 91 | return z_y_x_y_z 92 | 93 | def sampling(array_list, num, offset): 94 | target_list = np.arange(num).tolist() 95 | sampled_list = [] 96 | k = 100 # number of samiling 97 | 98 | left = 0 99 | right = 0 100 | for target_idx in target_list: 101 | while right < len(array_list) and array_list[right, 0] == target_idx + offset: 102 | right += 1 103 | target_array = array_list[left:right, :] 104 | 105 | if len(target_array) > 0: 106 | samples = min(k, len(target_array)) 107 | sampled_idx = np.random.choice(len(target_array), samples, replace=False) 108 | target_array = target_array[sampled_idx] 109 | 110 | sampled_list.append(target_array) 111 | left = right 112 | sampled_array = np.concatenate(sampled_list, axis=0) 113 | sorted_index = sorted(list(range(len(sampled_array))), key=lambda i: sampled_array[i, [0, 2, 1]].tolist()) 114 | sampled_array = sampled_array[sorted_index] 115 | 116 | return sampled_array 117 | 118 | 119 | def get_metapath(metapath, num_drug, num_protein, num_disease, num_se, save_prefix): 120 | if len(metapath) == 2: 121 | # (0, 0) 122 | if metapath == (0, 0): 123 | metapath_indices = metapath_xx(drug_drug_list, num=0) 124 | # (1, 1) 125 | elif metapath == (1, 1): 126 | metapath_indices = metapath_xx(protein_protein_list, num=num_drug) 127 | 128 | elif len(metapath) == 3: 129 | # (0, 1, 0) 130 | if metapath == (0, 1, 0): 131 | metapath_indices = metapath_yxy(protein_drug_list, num1=0, num2=num_drug) 132 | # (0, 2, 0) 133 | elif metapath == (0, 2, 0): 134 | metapath_indices = metapath_yxy(disease_drug_list, num1=0, num2=num_drug + num_protein, sample=100) 135 | # (0, 3, 0) 136 | elif metapath == (0, 3, 0): 137 | metapath_indices = metapath_yxy(se_drug_list, num1=0, num2=num_drug + num_protein + num_disease, sample=100) 138 | # (1, 0, 1) 139 | elif metapath == (1, 0, 1): 140 | metapath_indices = metapath_yxy(drug_protein_list, num1=num_drug, num2=0) 141 | # (1, 2, 1) 142 | elif metapath == (1, 2, 1): 143 | metapath_indices = metapath_yxy(disease_protein_list, num1=num_drug, num2=num_drug + num_protein, sample=100) 144 | 145 | elif len(metapath) == 4: 146 | # (0, 1, 1, 0) 147 | if metapath == (0, 1, 1, 0): 148 | # if os.path.isfile(save_prefix + '-'.join(map(str, (1, 1))) + '.npy'): 149 | # p_p = np.load(save_prefix + '-'.join(map(str, (1, 1))) + '.npy') 150 | # # else: 151 | # p_p = metapath_xx(protein_protein_list, num=num_drug) 152 | # np.save(save_prefix + '-'.join(map(str, (1, 1))) + '.npy', p_p) 153 | p_p = metapath_xx(protein_protein_list, num=num_drug, sample=50) 154 | metapath_indices = metapath_yxxy(p_p, protein_drug_list, num1=0, num2=num_drug, sample=30) 155 | # (1, 0, 0, 1) 156 | elif metapath == (1, 0, 0, 1): 157 | # if os.path.isfile(save_prefix + '-'.join(map(str, (0, 0))) + '.npy'): 158 | # d_d = np.load(save_prefix + '-'.join(map(str, (0, 0))) + '.npy') 159 | # else: 160 | # d_d = metapath_xx(drug_drug_list, num=0) 161 | # np.save(save_prefix + '-'.join(map(str, (0, 0))) + '.npy', d_d) 162 | d_d = metapath_xx(drug_drug_list, num=0, sample=100) 163 | metapath_indices = metapath_yxxy(d_d, drug_protein_list, num1=num_drug, num2=0, sample=10) 164 | 165 | elif len(metapath) == 5: 166 | # 0-1-0-1-0 167 | if metapath == (0, 1, 0, 1, 0): 168 | # if os.path.isfile(save_prefix + '-'.join(map(str, (1, 0, 1))) + '.npy'): 169 | # p_d_p = np.load(save_prefix + '-'.join(map(str, (1, 0, 1))) + '.npy') 170 | # else: 171 | # p_d_p = metapath_yxy(drug_protein_list, num1=num_drug, num2=0) 172 | # np.save(save_prefix + '-'.join(map(str, (1, 0, 1))) + '.npy', p_d_p) 173 | p_d_p = metapath_yxy(drug_protein_list, num1=num_drug, num2=0, sample=20) 174 | p_d_p = sampling(p_d_p, num=num_protein, offset=num_drug) 175 | metapath_indices = metapath_zyxyz(p_d_p, protein_drug_list, num1=0, num2=num_drug, ratio=5) 176 | # 0-1-2-1-0 177 | elif metapath == (0, 1, 2, 1, 0): 178 | # if os.path.isfile(save_prefix + '-'.join(map(str, (1, 2, 1))) + '.npy'): 179 | # p_i_p = np.load(save_prefix + '-'.join(map(str, (1, 2, 1))) + '.npy') 180 | # else: 181 | # p_i_p = metapath_yxy(disease_protein_list, num1=num_drug, num2=num_drug + num_protein, sample=80) 182 | # np.save(save_prefix + '-'.join(map(str, (1, 2, 1))) + '.npy', p_i_p) 183 | p_i_p = metapath_yxy(disease_protein_list, num1=num_drug, num2=num_drug + num_protein, sample=80) 184 | p_i_p = sampling(p_i_p, num=num_protein, offset=num_drug) 185 | metapath_indices = metapath_zyxyz(p_i_p, protein_drug_list, num1=0, num2=num_drug, ratio=5) 186 | # 0-2-0-2-0 187 | elif metapath == (0, 2, 0, 2, 0): 188 | # if os.path.isfile(save_prefix + '-'.join(map(str, (2, 0, 2))) + '.npy'): 189 | # i_d_i = np.load(save_prefix + '-'.join(map(str, (2, 0, 2))) + '.npy') 190 | # else: 191 | # i_d_i = metapath_yxy(drug_disease_lit, num1=num_drug + num_protein, num2=0, sample=80) 192 | # np.save(save_prefix + '-'.join(map(str, (2, 0, 2))) + '.npy', i_d_i) 193 | i_d_i = metapath_yxy(drug_disease_lit, num1=num_drug + num_protein, num2=0, sample=80) 194 | i_d_i = sampling(i_d_i, num=num_disease, offset=num_drug + num_protein) 195 | metapath_indices = metapath_zyxyz(i_d_i, disease_drug_list, num1=0, num2=num_drug + num_protein, ratio=5) 196 | # 0-3-0-3-0 197 | elif metapath == (0, 3, 0, 3, 0): 198 | # if os.path.isfile(save_prefix + '-'.join(map(str, (3, 0, 3))) + '.npy'): 199 | # s_d_s = np.load(save_prefix + '-'.join(map(str, (3, 0, 3))) + '.npy') 200 | # else: 201 | # s_d_s = metapath_yxy(drug_se_list, num1=num_drug + num_protein + num_disease, num2=0, sample=80) 202 | # np.save(save_prefix + '-'.join(map(str, (3, 0, 3))) + '.npy', s_d_s) 203 | s_d_s = metapath_yxy(drug_se_list, num1=num_drug + num_protein + num_disease, num2=0, sample=80) 204 | s_d_s = sampling(s_d_s, num=num_se, offset=num_drug + num_protein + num_disease) 205 | metapath_indices = metapath_zyxyz(s_d_s, se_drug_list, num1=0, num2=num_drug + num_protein + num_disease, ratio=5) 206 | # 0-2-1-2-0 207 | elif metapath == (0, 2, 1, 2, 0): 208 | # if os.path.isfile(save_prefix + '-'.join(map(str, (2, 1, 2))) + '.npy'): 209 | # i_p_i = np.load(save_prefix + '-'.join(map(str, (2, 1, 2))) + '.npy') 210 | # else: 211 | # i_p_i = metapath_yxy(protein_disease_list, num1=num_drug + num_protein, num2=num_drug, sample=80) 212 | # np.save(save_prefix + '-'.join(map(str, (2, 1, 2))) + '.npy', i_p_i) 213 | i_p_i = metapath_yxy(protein_disease_list, num1=num_drug + num_protein, num2=num_drug, sample=80) 214 | i_p_i = sampling(i_p_i, num=num_disease, offset=num_drug + num_protein) 215 | metapath_indices = metapath_zyxyz(i_p_i, disease_drug_list, num1=0, num2=num_drug + num_protein, ratio=5) 216 | # 1-0-1-0-1 217 | elif metapath == (1, 0, 1, 0, 1): 218 | # if os.path.isfile(save_prefix + '-'.join(map(str, (0, 1, 0))) + '.npy'): 219 | # d_p_d = np.load(save_prefix + '-'.join(map(str, (0, 1, 0))) + '.npy') 220 | # else: 221 | # d_p_d = metapath_yxy(protein_drug_list, num1=0, num2=num_drug) 222 | # np.save(save_prefix + '-'.join(map(str, (0, 1, 0))) + '.npy', d_p_d) 223 | d_p_d = metapath_yxy(protein_drug_list, num1=0, num2=num_drug, sample=10) 224 | d_p_d = sampling(d_p_d, num=num_drug, offset=0) 225 | metapath_indices = metapath_zyxyz(d_p_d, drug_protein_list, num1=num_drug, num2=0, ratio=10) 226 | # 1-0-2-0-1 227 | elif metapath == (1, 0, 2, 0, 1): 228 | # if os.path.isfile(save_prefix + '-'.join(map(str, (0, 2, 0))) + '.npy'): 229 | # d_i_d = np.load(save_prefix + '-'.join(map(str, (0, 2, 0))) + '.npy') 230 | # else: 231 | # d_i_d = metapath_yxy(disease_drug_list, num1=0, num2=num_drug + num_protein, sample=80) 232 | # np.save(save_prefix + '-'.join(map(str, (0, 2, 0))) + '.npy', d_i_d) 233 | d_i_d = metapath_yxy(disease_drug_list, num1=0, num2=num_drug + num_protein, sample=80) 234 | d_i_d = sampling(d_i_d, num=num_drug, offset=0) 235 | metapath_indices = metapath_zyxyz(d_i_d, drug_protein_list, num1=num_drug, num2=0, ratio=5) 236 | # 1-2-0-2-1 237 | elif metapath == (1, 2, 0, 2, 1): 238 | # if os.path.isfile(save_prefix + '-'.join(map(str, (2, 0, 2))) + '.npy'): 239 | # i_d_i = np.load(save_prefix + '-'.join(map(str, (2, 0, 2))) + '.npy') 240 | # else: 241 | # i_d_i = metapath_yxy(drug_disease_lit, num1=num_drug + num_protein, num2=0, sample=80) 242 | # np.save(save_prefix + '-'.join(map(str, (2, 0, 2))) + '.npy', i_d_i) 243 | i_d_i = metapath_yxy(drug_disease_lit, num1=num_drug + num_protein, num2=0, sample=80) 244 | i_d_i = sampling(i_d_i, num=num_disease, offset=num_drug + num_protein) 245 | metapath_indices = metapath_zyxyz(i_d_i, disease_protein_list, num1=num_drug, num2=num_drug + num_protein, ratio=5) 246 | # 1-2-1-2-1 247 | elif metapath == (1, 2, 1, 2, 1): 248 | # if os.path.isfile(save_prefix + '-'.join(map(str, (2, 1, 2))) + '.npy'): 249 | # i_p_i = np.load(save_prefix + '-'.join(map(str, (2, 1, 2))) + '.npy') 250 | # else: 251 | # i_p_i = metapath_yxy(protein_disease_list, num1=num_drug + num_protein, num2=num_drug, sample=80) 252 | # np.save(save_prefix + '-'.join(map(str, (2, 1, 2))) + '.npy', i_p_i) 253 | i_p_i = metapath_yxy(protein_disease_list, num1=num_drug + num_protein, num2=num_drug, sample=80) 254 | i_p_i = sampling(i_p_i, num=num_disease, offset=num_drug + num_protein) 255 | metapath_indices = metapath_zyxyz(i_p_i, disease_protein_list, num1=num_drug, num2=num_drug + num_protein, ratio=5) 256 | 257 | return metapath_indices 258 | 259 | def target_metapath_and_neightbors(edge_metapath_idx_array, target_idx_list, offset): 260 | # write all things 261 | target_metapaths_mapping = {} 262 | target_neighbors = {} 263 | left = 0 264 | right = 0 265 | for target_idx in target_idx_list: 266 | # target_metapaths_mapping = {} 267 | # target_neighbors = {} 268 | while right < len(edge_metapath_idx_array) and edge_metapath_idx_array[right, 0] == target_idx + offset: 269 | right += 1 270 | target_metapaths_mapping[target_idx] = edge_metapath_idx_array[left:right, ::-1] 271 | neighbors = edge_metapath_idx_array[left:right, -1] - offset_list[i] 272 | # neighbors = list(map(str, neighbors)) 273 | target_neighbors[target_idx] = [target_idx] + neighbors.tolist() 274 | left = right 275 | 276 | return target_metapaths_mapping, target_neighbors 277 | 278 | 279 | def Load_Adj_Togerther(dir_lists, ratio=0.01): 280 | a = np.loadtxt(dir_lists[0]) 281 | print('Before Interactions: ', sum(sum(a))) 282 | 283 | for i in range(len(dir_lists) - 1): 284 | b_new = np.zeros_like(a) 285 | 286 | b = np.loadtxt(dir_lists[i + 1]) 287 | # remove diagonal elements 288 | b = b - np.diag(np.diag(b)) 289 | # if the matrix are symmetrical, get the triu matrix 290 | if (b == b.T).all(): 291 | b = np.triu(b) 292 | index = np.nonzero(b) 293 | values = b[index] 294 | index = np.transpose(index) 295 | edgelist = np.concatenate([index, values.reshape(-1, 1)], axis=1) 296 | topK_idx = np.argpartition(edgelist[:, 2], int(ratio * len(edgelist)))[-(int(ratio * len(edgelist))):] 297 | print(len(topK_idx)) 298 | select_idx = index[topK_idx] 299 | b_new[select_idx[:, 0], select_idx[:, 1]] = b[select_idx[:, 0], select_idx[:, 1]] 300 | a = a + b_new 301 | 302 | a = a + a.T 303 | a[a > 0] = 1 304 | a[a <= 0] = 0 305 | a = a + np.eye(a.shape[0], a.shape[1]) 306 | a = a.astype(int) 307 | print('After Interactions: ', sum(sum(a))) 308 | 309 | return a 310 | 311 | def get_adjM(drug_drug, drug_protein, drug_disease, drug_sideEffect, protein_protein, protein_disease, 312 | num_drug, num_protein, num_disease, num_se): 313 | # Drug-0, Protein-1, Disease-2, Side-effect-3 314 | dim = num_drug + num_protein + num_disease + num_se 315 | adjM = np.zeros((dim, dim), dtype=int) 316 | adjM[:num_drug, :num_drug] = drug_drug 317 | adjM[:num_drug, num_drug: num_drug + num_protein] = drug_protein 318 | adjM[:num_drug, num_drug + num_protein: num_drug + num_protein + num_disease] = drug_disease 319 | adjM[:num_drug, num_drug + num_protein + num_disease:] = drug_sideEffect 320 | adjM[num_drug: num_drug + num_protein, num_drug: num_drug + num_protein] = protein_protein 321 | adjM[num_drug: num_drug + num_protein, num_drug + num_protein: num_drug + num_protein + num_disease] = protein_disease 322 | 323 | adjM[num_drug: num_drug + num_protein, :num_drug] = drug_protein.T 324 | adjM[num_drug + num_protein: num_drug + num_protein + num_disease, :num_drug] = drug_disease.T 325 | adjM[num_drug + num_protein + num_disease:, :num_drug] = drug_sideEffect.T 326 | adjM[num_drug + num_protein: num_drug + num_protein + num_disease, num_drug: num_drug + num_protein] = protein_disease.T 327 | 328 | return adjM 329 | 330 | def fold_train_idx(pos_folds, nFold, foldID): 331 | fold_idx = [] 332 | for fold in range(nFold): 333 | if fold == foldID: 334 | continue 335 | fold_idx.append(pos_folds['fold_' + str(fold)]) 336 | fold_idx = np.concatenate(fold_idx, axis=1) 337 | return fold_idx 338 | 339 | 340 | def get_type_mask(num_drug, num_protein, num_disease, num_se): 341 | # Drug-0, Protein-1, Disease-2, Side-effect-3 342 | dim = num_drug + num_protein + num_disease + num_se 343 | type_mask = np.zeros((dim), dtype=int) 344 | type_mask[num_drug: num_drug + num_protein] = 1 345 | type_mask[num_drug + num_protein: num_drug + num_protein + num_disease] = 2 346 | type_mask[num_drug + num_protein + num_disease:] = 3 347 | return type_mask 348 | 349 | if __name__ == '__main__': 350 | data_set = 'data_luo' 351 | nFold = 10 352 | neg_times = 1 353 | data_dir = './hetero_dataset/{}/'.format(data_set) 354 | fold_path = data_dir + '/{}_folds/'.format(str(nFold)) 355 | pos_folds = json.load(open(fold_path + 'pos_folds.json', 'r')) 356 | neg_folds = json.load(open(fold_path + 'neg_folds_times_{}.json'.format(str(neg_times), 'r'))) 357 | num_repeats = 1 358 | 359 | save_prefix = data_dir + '/processed/' 360 | os.makedirs(save_prefix, exist_ok=True) 361 | 362 | expected_metapaths = [[(0, 0), (0, 1, 0), (0, 2, 0), (0, 3, 0), (0, 1, 1, 0), 363 | (0, 1, 0, 1, 0), (0, 2, 0, 2, 0), (0, 3, 0, 3, 0), (0, 1, 2, 1, 0), (0, 2, 1, 2, 0)], 364 | [(1, 1), (1, 0, 1), (1, 2, 1), (1, 0, 0, 1), 365 | (1, 0, 1, 0, 1), (1, 0, 2, 0, 1), (1, 2, 0, 2, 1), (1, 2, 1, 2, 1)]] 366 | 367 | ## Step 1: Reconstruct Drug-Drug interaction network and Protein-Protein interaxtion network 368 | # Reconstruct Drug-Drug interaction network 369 | # 1 interaction + 4 sim 370 | drug_drug_path = data_dir + '/mat_data/mat_drug_drug.txt' 371 | drug_drug_sim_chemical_path = data_dir + '/sim_network/Sim_mat_drugs.txt' 372 | drug_drug_sim_interaction_path = data_dir + '/sim_network/Sim_mat_drug_drug.txt' 373 | drug_drug_sim_se_path = data_dir + '/sim_network/Sim_mat_drug_se.txt' 374 | drug_drug_sim_disease_path = data_dir + '/sim_network/Sim_mat_drug_disease.txt' 375 | 376 | # Reconstruct Protein-Protein interaxtion network 377 | # 1interaction + 3 sim 378 | protein_protein_path = data_dir + '/mat_data/mat_protein_protein.txt' 379 | protein_protein_sim_sequence_path = data_dir + '/sim_network/Sim_mat_proteins.txt' 380 | protein_protein_sim_disease_path = data_dir + '/sim_network/Sim_mat_protein_disease.txt' 381 | protein_protein_sim_interaction_path = data_dir + '/sim_network/Sim_mat_protein_protein.txt' 382 | 383 | # About drug and protein (others)... 384 | drug_protein_path = data_dir + '/mat_data/mat_drug_protein.txt' 385 | drug_disease_path = data_dir + '/mat_data/mat_drug_disease.txt' 386 | drug_sideEffect_path = data_dir + '/mat_data/mat_drug_se.txt' 387 | protein_disease_path = data_dir + '/mat_data/mat_protein_disease.txt' 388 | 389 | # drug_drug and protein_protein combine the simNets and interactions 390 | # print('Load_Drug_Adj_Togerther ...') 391 | # drug_drug = Load_Adj_Togerther(dir_lists=[drug_drug_path, drug_drug_sim_chemical_path, 392 | # drug_drug_sim_interaction_path, drug_drug_sim_se_path, 393 | # drug_drug_sim_disease_path], ratio=0.01) 394 | # 395 | # print('Load_Protein_Adj_Togerther ...') 396 | # protein_protein = Load_Adj_Togerther(dir_lists=[protein_protein_path, protein_protein_sim_sequence_path, 397 | # protein_protein_sim_disease_path, protein_protein_sim_interaction_path], 398 | # ratio=0.005) 399 | 400 | drug_drug = np.loadtxt(drug_drug_path, dtype=int) 401 | drug_protein = np.loadtxt(drug_protein_path, dtype=int) 402 | drug_disease = np.loadtxt(drug_disease_path, dtype=int) 403 | protein_protein = np.loadtxt(protein_protein_path, dtype=int) 404 | drug_sideEffect = np.loadtxt(drug_sideEffect_path, dtype=int) 405 | protein_disease = np.loadtxt(protein_disease_path, dtype=int) 406 | 407 | num_drug, num_protein = drug_protein.shape 408 | num_disease = drug_disease.shape[1] 409 | num_se = drug_sideEffect.shape[1] 410 | type_mask = get_type_mask(num_drug, num_protein, num_disease, num_se) 411 | # np.save(save_prefix + 'node_types.npy', type_mask) 412 | 413 | for counter in range(num_repeats): # repeat ten times 414 | print('\nThis is the {} repeat...'.format(counter)) 415 | for foldID in range(nFold): 416 | print('\nThis is Fold ', str(foldID), '...') 417 | fold_drug_protein = np.zeros_like(drug_protein) 418 | fold_idx = fold_train_idx(pos_folds, nFold, foldID) 419 | fold_drug_protein[fold_idx[0], fold_idx[1]] = drug_protein[fold_idx[0], fold_idx[1]] 420 | 421 | ## Syep 2: Build the Adjacency Matrix 422 | # Drug-0, Protein-1, Disease-2, Side-effect-3 423 | adjM = get_adjM(drug_drug, fold_drug_protein, drug_disease, drug_sideEffect, protein_protein, protein_disease, 424 | num_drug, num_protein, num_disease, num_se) 425 | # sp.save_npz(save_prefix + 'adjM_train_fold_{}.npz'.format(foldID), sp.csr_matrix(adjM)) 426 | 427 | drug_drug_list = {i: adjM[i, :num_drug].nonzero()[0] for i in range(num_drug)} 428 | drug_protein_list = {i: adjM[i, num_drug:num_drug + num_protein].nonzero()[0] for i in range(num_drug)} 429 | drug_disease_lit = {i: adjM[i, num_drug + num_protein:num_drug + num_protein + num_disease].nonzero()[0] for i 430 | in range(num_drug)} 431 | drug_se_list = {i: adjM[i, num_drug + num_protein + num_disease:].nonzero()[0] for i in range(num_drug)} 432 | protein_drug_list = {i: adjM[num_drug + i, :num_drug].nonzero()[0] for i in range(num_protein)} 433 | protein_protein_list = {i: adjM[num_drug + i, num_drug:num_drug + num_protein].nonzero()[0] for i in 434 | range(num_protein)} 435 | protein_disease_list = { 436 | i: adjM[num_drug + i, num_drug + num_protein:num_drug + num_protein + num_disease].nonzero()[0] 437 | for i in range(num_protein)} 438 | disease_drug_list = {i: adjM[num_drug + num_protein + i, :num_drug].nonzero()[0] for i in range(num_disease)} 439 | disease_protein_list = {i: adjM[num_drug + num_protein + i, num_drug:num_drug + num_protein].nonzero()[0] for i 440 | in range(num_disease)} 441 | se_drug_list = {i: adjM[num_drug + num_protein + num_disease + i, : num_drug].nonzero()[0] for i in range(num_se)} 442 | 443 | # Step 3: Get target metapaths and neighbors for each train fold 444 | target_idx_lists = [np.arange(num_drug).tolist(), np.arange(num_protein).tolist()] 445 | offset_list = [0, num_drug] 446 | for i, metapaths in enumerate(expected_metapaths): 447 | # print(metapaths) 448 | for metapath in metapaths: 449 | metapath_fold_dir = save_prefix + 'repeat{}/{}/train/fold_{}/_'.format(counter, i, foldID) 450 | make_dir(metapath_fold_dir) 451 | # Get all the metapaths in the schema of 'metapath' 452 | if os.path.isfile(metapath_fold_dir + '-'.join(map(str, metapath)) + '.npy'): 453 | edge_metapath_idx_array = np.load(metapath_fold_dir + '-'.join(map(str, metapath)) + '.npy') 454 | else: 455 | edge_metapath_idx_array = get_metapath(metapath, num_drug, num_protein, num_disease, num_se, metapath_fold_dir) 456 | np.save(metapath_fold_dir + '-'.join(map(str, metapath)) + '.npy', edge_metapath_idx_array) 457 | print(metapath, len(edge_metapath_idx_array)) 458 | target_metapaths, target_neighbors = target_metapath_and_neightbors(edge_metapath_idx_array, target_idx_lists[i], offset=offset_list[i]) 459 | pickle.dump(target_metapaths, open(metapath_fold_dir + '-'.join(map(str, metapath)) + '.idx.pkl', 'wb')) 460 | pickle.dump(target_neighbors, open(metapath_fold_dir + '-'.join(map(str, metapath)) + '.adjlist.pkl', 'wb')) 461 | -------------------------------------------------------------------------------- /test_preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import os 4 | import pathlib 5 | import pickle 6 | import random 7 | import numpy as np 8 | import scipy.sparse as sp 9 | import scipy.io 10 | import pandas as pd 11 | import json 12 | 13 | seed = 12345 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | 17 | def make_dir(fp): 18 | if not os.path.exists(fp): 19 | os.makedirs(fp, exist_ok=True) 20 | 21 | def metapath_xx(x_x_list, num, sample=None): 22 | x_x = [] 23 | for x, x_list in x_x_list.items(): 24 | if sample is not None: 25 | candidate_list = np.random.choice(len(x_list), min(len(x_list), sample), replace=False) 26 | x_list = x_list[candidate_list] 27 | x_x.extend([(x, x1) for x1 in x_list]) 28 | x_x = np.array(x_x) 29 | x_x = x_x + num 30 | sorted_index = sorted(list(range(len(x_x))), key=lambda i: x_x[i].tolist()) 31 | x_x = x_x[sorted_index] 32 | return x_x 33 | 34 | def metapath_yxy(x_y_list, num1, num2, sample=None): 35 | y_x_y = [] 36 | for x, y_list in x_y_list.items(): 37 | if sample is not None: 38 | candidate_list1 = np.random.choice(len(y_list), min(len(y_list), sample), replace=False) 39 | candidate_list2 = np.random.choice(len(y_list), min(len(y_list), sample), replace=False) 40 | # print(len(candidate_list1)) 41 | y_list1 = y_list[candidate_list1] 42 | y_list2 = y_list[candidate_list2] 43 | y_x_y.extend([(y1, x, y2) for y1 in y_list1 for y2 in y_list2]) 44 | else: 45 | y_x_y.extend([(y1, x, y2) for y1 in y_list for y2 in y_list]) 46 | y_x_y = np.array(y_x_y) 47 | y_x_y[:, [0, 2]] += num1 48 | y_x_y[:, 1] += num2 49 | sorted_index = sorted(list(range(len(y_x_y))), key=lambda i: y_x_y[i, [0, 2, 1]].tolist()) 50 | y_x_y = y_x_y[sorted_index] 51 | return y_x_y 52 | 53 | def metapath_yxxy(x_x, x_y_list, num1, num2, sample=None): 54 | y_x_x_y = [] 55 | for x1, x2 in x_x: 56 | if sample is not None: 57 | candidate_list1 = np.random.choice(len(x_y_list[x1 - num2]), min(len(x_y_list[x1 - num2]), sample), replace=False) 58 | candidate_list2 = np.random.choice(len(x_y_list[x2 - num2]), min(len(x_y_list[x2 - num2]), sample), replace=False) 59 | # print(len(candidate_list1)) 60 | x_y_list1 = x_y_list[x1 - num2][candidate_list1] 61 | x_y_list2 = x_y_list[x2 - num2][candidate_list2] 62 | y_x_x_y.extend([(y1, x1, x2, y2) for y1 in x_y_list1 for y2 in x_y_list2]) 63 | else: 64 | y_x_x_y.extend([(y1, x1, x2, y2) for y1 in x_y_list[x1 - num2] for y2 in x_y_list[x2 - num2]]) 65 | y_x_x_y = np.array(y_x_x_y) 66 | y_x_x_y[:, [0, 3]] += num1 67 | sorted_index = sorted(list(range(len(y_x_x_y))), key=lambda i: y_x_x_y[i, [0, 3, 1, 2]].tolist()) 68 | y_x_x_y = y_x_x_y[sorted_index] 69 | return y_x_x_y 70 | 71 | def metapath_zyxyz(y_x_y, y_z_list, num1, num2, ratio): 72 | # z-y-x-y-z 73 | z_y_x_y_z = [] 74 | for y1, x, y2 in y_x_y: 75 | if len(y_z_list[y1 - num2]) == 0 or len(y_z_list[y2 - num2]) == 0: 76 | continue 77 | if ratio <= 1: 78 | candidate_z1_list = np.random.choice(len(y_z_list[y1 - num2]), int(ratio * len(y_z_list[y1 - num2])) + 1, replace=False) 79 | candidate_z2_list = np.random.choice(len(y_z_list[y2 - num2]), int(ratio * len(y_z_list[y2 - num2])) + 1, replace=False) 80 | else: 81 | candidate_z1_list = np.random.choice(len(y_z_list[y1 - num2]), min(ratio, len(y_z_list[y1 - num2])), replace=False) 82 | candidate_z2_list = np.random.choice(len(y_z_list[y2 - num2]), min(ratio, len(y_z_list[y2 - num2])), replace=False) 83 | candidate_z1_list = y_z_list[y1 - num2][candidate_z1_list] 84 | candidate_z2_list = y_z_list[y2 - num2][candidate_z2_list] 85 | 86 | z_y_x_y_z.extend([(z1, y1, x, y2, z2) for z1 in candidate_z1_list for z2 in candidate_z2_list]) 87 | z_y_x_y_z = np.array(z_y_x_y_z) 88 | z_y_x_y_z[:, [0, 4]] += num1 89 | sorted_index = sorted(list(range(len(z_y_x_y_z))), key=lambda i: z_y_x_y_z[i, [0, 4, 1, 2, 3]].tolist()) 90 | z_y_x_y_z = z_y_x_y_z[sorted_index] 91 | return z_y_x_y_z 92 | 93 | def sampling(array_list, num, offset): 94 | target_list = np.arange(num).tolist() 95 | sampled_list = [] 96 | k = 100 # number of samiling 97 | 98 | left = 0 99 | right = 0 100 | for target_idx in target_list: 101 | while right < len(array_list) and array_list[right, 0] == target_idx + offset: 102 | right += 1 103 | target_array = array_list[left:right, :] 104 | 105 | if len(target_array) > 0: 106 | samples = min(k, len(target_array)) 107 | sampled_idx = np.random.choice(len(target_array), samples, replace=False) 108 | target_array = target_array[sampled_idx] 109 | 110 | sampled_list.append(target_array) 111 | left = right 112 | sampled_array = np.concatenate(sampled_list, axis=0) 113 | sorted_index = sorted(list(range(len(sampled_array))), key=lambda i: sampled_array[i, [0, 2, 1]].tolist()) 114 | sampled_array = sampled_array[sorted_index] 115 | 116 | return sampled_array 117 | 118 | 119 | def get_metapath(metapath, num_drug, num_protein, num_disease, num_se, save_prefix): 120 | if len(metapath) == 2: 121 | # (0, 0) 122 | if metapath == (0, 0): 123 | metapath_indices = metapath_xx(drug_drug_list, num=0) 124 | # (1, 1) 125 | elif metapath == (1, 1): 126 | metapath_indices = metapath_xx(protein_protein_list, num=num_drug) 127 | 128 | elif len(metapath) == 3: 129 | # (0, 1, 0) 130 | if metapath == (0, 1, 0): 131 | metapath_indices = metapath_yxy(protein_drug_list, num1=0, num2=num_drug) 132 | # (0, 2, 0) 133 | elif metapath == (0, 2, 0): 134 | metapath_indices = metapath_yxy(disease_drug_list, num1=0, num2=num_drug + num_protein, sample=100) 135 | # (0, 3, 0) 136 | elif metapath == (0, 3, 0): 137 | metapath_indices = metapath_yxy(se_drug_list, num1=0, num2=num_drug + num_protein + num_disease, sample=100) 138 | # (1, 0, 1) 139 | elif metapath == (1, 0, 1): 140 | metapath_indices = metapath_yxy(drug_protein_list, num1=num_drug, num2=0) 141 | # (1, 2, 1) 142 | elif metapath == (1, 2, 1): 143 | metapath_indices = metapath_yxy(disease_protein_list, num1=num_drug, num2=num_drug + num_protein, sample=100) 144 | 145 | elif len(metapath) == 4: 146 | # (0, 1, 1, 0) 147 | if metapath == (0, 1, 1, 0): 148 | # if os.path.isfile(save_prefix + '-'.join(map(str, (1, 1))) + '.npy'): 149 | # p_p = np.load(save_prefix + '-'.join(map(str, (1, 1))) + '.npy') 150 | # # else: 151 | # p_p = metapath_xx(protein_protein_list, num=num_drug) 152 | # np.save(save_prefix + '-'.join(map(str, (1, 1))) + '.npy', p_p) 153 | p_p = metapath_xx(protein_protein_list, num=num_drug, sample=50) 154 | metapath_indices = metapath_yxxy(p_p, protein_drug_list, num1=0, num2=num_drug, sample=30) 155 | # (1, 0, 0, 1) 156 | elif metapath == (1, 0, 0, 1): 157 | # if os.path.isfile(save_prefix + '-'.join(map(str, (0, 0))) + '.npy'): 158 | # d_d = np.load(save_prefix + '-'.join(map(str, (0, 0))) + '.npy') 159 | # else: 160 | # d_d = metapath_xx(drug_drug_list, num=0) 161 | # np.save(save_prefix + '-'.join(map(str, (0, 0))) + '.npy', d_d) 162 | d_d = metapath_xx(drug_drug_list, num=0, sample=100) 163 | metapath_indices = metapath_yxxy(d_d, drug_protein_list, num1=num_drug, num2=0, sample=10) 164 | 165 | elif len(metapath) == 5: 166 | # 0-1-0-1-0 167 | if metapath == (0, 1, 0, 1, 0): 168 | # if os.path.isfile(save_prefix + '-'.join(map(str, (1, 0, 1))) + '.npy'): 169 | # p_d_p = np.load(save_prefix + '-'.join(map(str, (1, 0, 1))) + '.npy') 170 | # else: 171 | # p_d_p = metapath_yxy(drug_protein_list, num1=num_drug, num2=0) 172 | # np.save(save_prefix + '-'.join(map(str, (1, 0, 1))) + '.npy', p_d_p) 173 | p_d_p = metapath_yxy(drug_protein_list, num1=num_drug, num2=0, sample=30) 174 | p_d_p = sampling(p_d_p, num=num_protein, offset=num_drug) 175 | metapath_indices = metapath_zyxyz(p_d_p, protein_drug_list, num1=0, num2=num_drug, ratio=5) 176 | # 0-1-2-1-0 177 | elif metapath == (0, 1, 2, 1, 0): 178 | # if os.path.isfile(save_prefix + '-'.join(map(str, (1, 2, 1))) + '.npy'): 179 | # p_i_p = np.load(save_prefix + '-'.join(map(str, (1, 2, 1))) + '.npy') 180 | # else: 181 | # p_i_p = metapath_yxy(disease_protein_list, num1=num_drug, num2=num_drug + num_protein, sample=80) 182 | # np.save(save_prefix + '-'.join(map(str, (1, 2, 1))) + '.npy', p_i_p) 183 | p_i_p = metapath_yxy(disease_protein_list, num1=num_drug, num2=num_drug + num_protein, sample=80) 184 | p_i_p = sampling(p_i_p, num=num_protein, offset=num_drug) 185 | metapath_indices = metapath_zyxyz(p_i_p, protein_drug_list, num1=0, num2=num_drug, ratio=5) 186 | # 0-2-0-2-0 187 | elif metapath == (0, 2, 0, 2, 0): 188 | # if os.path.isfile(save_prefix + '-'.join(map(str, (2, 0, 2))) + '.npy'): 189 | # i_d_i = np.load(save_prefix + '-'.join(map(str, (2, 0, 2))) + '.npy') 190 | # else: 191 | # i_d_i = metapath_yxy(drug_disease_lit, num1=num_drug + num_protein, num2=0, sample=80) 192 | # np.save(save_prefix + '-'.join(map(str, (2, 0, 2))) + '.npy', i_d_i) 193 | i_d_i = metapath_yxy(drug_disease_lit, num1=num_drug + num_protein, num2=0, sample=80) 194 | i_d_i = sampling(i_d_i, num=num_disease, offset=num_drug + num_protein) 195 | metapath_indices = metapath_zyxyz(i_d_i, disease_drug_list, num1=0, num2=num_drug + num_protein, ratio=5) 196 | # 0-3-0-3-0 197 | elif metapath == (0, 3, 0, 3, 0): 198 | # if os.path.isfile(save_prefix + '-'.join(map(str, (3, 0, 3))) + '.npy'): 199 | # s_d_s = np.load(save_prefix + '-'.join(map(str, (3, 0, 3))) + '.npy') 200 | # else: 201 | # s_d_s = metapath_yxy(drug_se_list, num1=num_drug + num_protein + num_disease, num2=0, sample=80) 202 | # np.save(save_prefix + '-'.join(map(str, (3, 0, 3))) + '.npy', s_d_s) 203 | s_d_s = metapath_yxy(drug_se_list, num1=num_drug + num_protein + num_disease, num2=0, sample=80) 204 | s_d_s = sampling(s_d_s, num=num_se, offset=num_drug + num_protein + num_disease) 205 | metapath_indices = metapath_zyxyz(s_d_s, se_drug_list, num1=0, num2=num_drug + num_protein + num_disease, ratio=5) 206 | # 0-2-1-2-0 207 | elif metapath == (0, 2, 1, 2, 0): 208 | # if os.path.isfile(save_prefix + '-'.join(map(str, (2, 1, 2))) + '.npy'): 209 | # i_p_i = np.load(save_prefix + '-'.join(map(str, (2, 1, 2))) + '.npy') 210 | # else: 211 | # i_p_i = metapath_yxy(protein_disease_list, num1=num_drug + num_protein, num2=num_drug, sample=80) 212 | # np.save(save_prefix + '-'.join(map(str, (2, 1, 2))) + '.npy', i_p_i) 213 | i_p_i = metapath_yxy(protein_disease_list, num1=num_drug + num_protein, num2=num_drug, sample=80) 214 | i_p_i = sampling(i_p_i, num=num_disease, offset=num_drug + num_protein) 215 | metapath_indices = metapath_zyxyz(i_p_i, disease_drug_list, num1=0, num2=num_drug + num_protein, ratio=5) 216 | # 1-0-1-0-1 217 | elif metapath == (1, 0, 1, 0, 1): 218 | # if os.path.isfile(save_prefix + '-'.join(map(str, (0, 1, 0))) + '.npy'): 219 | # d_p_d = np.load(save_prefix + '-'.join(map(str, (0, 1, 0))) + '.npy') 220 | # else: 221 | # d_p_d = metapath_yxy(protein_drug_list, num1=0, num2=num_drug) 222 | # np.save(save_prefix + '-'.join(map(str, (0, 1, 0))) + '.npy', d_p_d) 223 | d_p_d = metapath_yxy(protein_drug_list, num1=0, num2=num_drug, sample=50) 224 | d_p_d = sampling(d_p_d, num=num_drug, offset=0) 225 | metapath_indices = metapath_zyxyz(d_p_d, drug_protein_list, num1=num_drug, num2=0, ratio=5) 226 | # 1-0-2-0-1 227 | elif metapath == (1, 0, 2, 0, 1): 228 | # if os.path.isfile(save_prefix + '-'.join(map(str, (0, 2, 0))) + '.npy'): 229 | # d_i_d = np.load(save_prefix + '-'.join(map(str, (0, 2, 0))) + '.npy') 230 | # else: 231 | # d_i_d = metapath_yxy(disease_drug_list, num1=0, num2=num_drug + num_protein, sample=80) 232 | # np.save(save_prefix + '-'.join(map(str, (0, 2, 0))) + '.npy', d_i_d) 233 | d_i_d = metapath_yxy(disease_drug_list, num1=0, num2=num_drug + num_protein, sample=80) 234 | d_i_d = sampling(d_i_d, num=num_drug, offset=0) 235 | metapath_indices = metapath_zyxyz(d_i_d, drug_protein_list, num1=num_drug, num2=0, ratio=5) 236 | # 1-2-0-2-1 237 | elif metapath == (1, 2, 0, 2, 1): 238 | # if os.path.isfile(save_prefix + '-'.join(map(str, (2, 0, 2))) + '.npy'): 239 | # i_d_i = np.load(save_prefix + '-'.join(map(str, (2, 0, 2))) + '.npy') 240 | # else: 241 | # i_d_i = metapath_yxy(drug_disease_lit, num1=num_drug + num_protein, num2=0, sample=80) 242 | # np.save(save_prefix + '-'.join(map(str, (2, 0, 2))) + '.npy', i_d_i) 243 | i_d_i = metapath_yxy(drug_disease_lit, num1=num_drug + num_protein, num2=0, sample=80) 244 | i_d_i = sampling(i_d_i, num=num_disease, offset=num_drug + num_protein) 245 | metapath_indices = metapath_zyxyz(i_d_i, disease_protein_list, num1=num_drug, num2=num_drug + num_protein, ratio=5) 246 | # 1-2-1-2-1 247 | elif metapath == (1, 2, 1, 2, 1): 248 | # if os.path.isfile(save_prefix + '-'.join(map(str, (2, 1, 2))) + '.npy'): 249 | # i_p_i = np.load(save_prefix + '-'.join(map(str, (2, 1, 2))) + '.npy') 250 | # else: 251 | # i_p_i = metapath_yxy(protein_disease_list, num1=num_drug + num_protein, num2=num_drug, sample=80) 252 | # np.save(save_prefix + '-'.join(map(str, (2, 1, 2))) + '.npy', i_p_i) 253 | i_p_i = metapath_yxy(protein_disease_list, num1=num_drug + num_protein, num2=num_drug, sample=80) 254 | i_p_i = sampling(i_p_i, num=num_disease, offset=num_drug + num_protein) 255 | metapath_indices = metapath_zyxyz(i_p_i, disease_protein_list, num1=num_drug, num2=num_drug + num_protein, ratio=5) 256 | 257 | return metapath_indices 258 | 259 | def target_metapath_and_neightbors(edge_metapath_idx_array, target_idx_list, offset): 260 | # write all things 261 | target_metapaths_mapping = {} 262 | target_neighbors = {} 263 | left = 0 264 | right = 0 265 | for target_idx in target_idx_list: 266 | # target_metapaths_mapping = {} 267 | # target_neighbors = {} 268 | while right < len(edge_metapath_idx_array) and edge_metapath_idx_array[right, 0] == target_idx + offset: 269 | right += 1 270 | target_metapaths_mapping[target_idx] = edge_metapath_idx_array[left:right, ::-1] 271 | neighbors = edge_metapath_idx_array[left:right, -1] - offset_list[i] 272 | # neighbors = list(map(str, neighbors)) 273 | target_neighbors[target_idx] = [target_idx] + neighbors.tolist() 274 | left = right 275 | 276 | return target_metapaths_mapping, target_neighbors 277 | 278 | 279 | def Load_Adj_Togerther(dir_lists, ratio=0.01): 280 | a = np.loadtxt(dir_lists[0]) 281 | print('Before Interactions: ', sum(sum(a))) 282 | 283 | for i in range(len(dir_lists) - 1): 284 | b_new = np.zeros_like(a) 285 | 286 | b = np.loadtxt(dir_lists[i + 1]) 287 | # remove diagonal elements 288 | b = b - np.diag(np.diag(b)) 289 | # if the matrix are symmetrical, get the triu matrix 290 | if (b == b.T).all(): 291 | b = np.triu(b) 292 | index = np.nonzero(b) 293 | values = b[index] 294 | index = np.transpose(index) 295 | edgelist = np.concatenate([index, values.reshape(-1, 1)], axis=1) 296 | topK_idx = np.argpartition(edgelist[:, 2], int(ratio * len(edgelist)))[-(int(ratio * len(edgelist))):] 297 | print(len(topK_idx)) 298 | select_idx = index[topK_idx] 299 | b_new[select_idx[:, 0], select_idx[:, 1]] = b[select_idx[:, 0], select_idx[:, 1]] 300 | a = a + b_new 301 | 302 | a = a + a.T 303 | a[a > 0] = 1 304 | a[a <= 0] = 0 305 | a = a + np.eye(a.shape[0], a.shape[1]) 306 | a = a.astype(int) 307 | print('After Interactions: ', sum(sum(a))) 308 | 309 | return a 310 | 311 | def get_adjM(drug_drug, drug_protein, drug_disease, drug_sideEffect, protein_protein, protein_disease, 312 | num_drug, num_protein, num_disease, num_se): 313 | # Drug-0, Protein-1, Disease-2, Side-effect-3 314 | dim = num_drug + num_protein + num_disease + num_se 315 | adjM = np.zeros((dim, dim), dtype=int) 316 | adjM[:num_drug, :num_drug] = drug_drug 317 | adjM[:num_drug, num_drug: num_drug + num_protein] = drug_protein 318 | adjM[:num_drug, num_drug + num_protein: num_drug + num_protein + num_disease] = drug_disease 319 | adjM[:num_drug, num_drug + num_protein + num_disease:] = drug_sideEffect 320 | adjM[num_drug: num_drug + num_protein, num_drug: num_drug + num_protein] = protein_protein 321 | adjM[num_drug: num_drug + num_protein, num_drug + num_protein: num_drug + num_protein + num_disease] = protein_disease 322 | 323 | adjM[num_drug: num_drug + num_protein, :num_drug] = drug_protein.T 324 | adjM[num_drug + num_protein: num_drug + num_protein + num_disease, :num_drug] = drug_disease.T 325 | adjM[num_drug + num_protein + num_disease:, :num_drug] = drug_sideEffect.T 326 | adjM[num_drug + num_protein: num_drug + num_protein + num_disease, num_drug: num_drug + num_protein] = protein_disease.T 327 | 328 | return adjM 329 | 330 | def fold_test_idx(pos_folds, neg_folds, foldID): 331 | fold_idx = [[], []] 332 | fold_posIdx = pos_folds['fold_' + str(foldID)] 333 | fold_negIdx = neg_folds['fold_' + str(foldID)] 334 | fold_idx[0] = fold_posIdx[0] + fold_negIdx[0] 335 | fold_idx[1] = fold_posIdx[1] + fold_negIdx[1] 336 | return fold_idx 337 | 338 | def get_type_mask(num_drug, num_protein, num_disease, num_se): 339 | # Drug-0, Protein-1, Disease-2, Side-effect-3 340 | dim = num_drug + num_protein + num_disease + num_se 341 | type_mask = np.zeros((dim), dtype=int) 342 | type_mask[num_drug: num_drug + num_protein] = 1 343 | type_mask[num_drug + num_protein: num_drug + num_protein + num_disease] = 2 344 | type_mask[num_drug + num_protein + num_disease:] = 3 345 | return type_mask 346 | 347 | if __name__ == '__main__': 348 | data_set = 'data_luo' 349 | nFold = 10 350 | neg_times = 1 351 | data_dir = './hetero_dataset/{}/'.format(data_set) 352 | fold_path = data_dir + '/{}_folds/'.format(str(nFold)) 353 | pos_folds = json.load(open(fold_path + 'pos_folds.json', 'r')) 354 | neg_folds = json.load(open(fold_path + 'neg_folds_times_{}.json'.format(str(neg_times), 'r'))) 355 | num_repeats = 1 # (repeat 10 times) 356 | 357 | save_prefix = data_dir + '/processed/' 358 | os.makedirs(save_prefix, exist_ok=True) 359 | 360 | expected_metapaths = [[(0, 0), (0, 1, 0), (0, 2, 0), (0, 3, 0), (0, 1, 1, 0), 361 | (0, 1, 0, 1, 0), (0, 2, 0, 2, 0), (0, 3, 0, 3, 0), (0, 1, 2, 1, 0), (0, 2, 1, 2, 0)], 362 | [(1, 1), (1, 0, 1), (1, 2, 1), (1, 0, 0, 1), 363 | (1, 0, 1, 0, 1), (1, 0, 2, 0, 1), (1, 2, 0, 2, 1), (1, 2, 1, 2, 1)]] 364 | 365 | ## Step 1: Reconstruct Drug-Drug interaction network and Protein-Protein interaxtion network 366 | # Reconstruct Drug-Drug interaction network 367 | # 1 interaction + 4 sim 368 | drug_drug_path = data_dir + '/mat_data/mat_drug_drug.txt' 369 | drug_drug_sim_chemical_path = data_dir + '/sim_network/Sim_mat_drugs.txt' 370 | drug_drug_sim_interaction_path = data_dir + '/sim_network/Sim_mat_drug_drug.txt' 371 | drug_drug_sim_se_path = data_dir + '/sim_network/Sim_mat_drug_se.txt' 372 | drug_drug_sim_disease_path = data_dir + '/sim_network/Sim_mat_drug_disease.txt' 373 | 374 | # Reconstruct Protein-Protein interaxtion network 375 | # 1interaction + 3 sim 376 | protein_protein_path = data_dir + '/mat_data/mat_protein_protein.txt' 377 | protein_protein_sim_sequence_path = data_dir + '/sim_network/Sim_mat_proteins.txt' 378 | protein_protein_sim_disease_path = data_dir + '/sim_network/Sim_mat_protein_disease.txt' 379 | protein_protein_sim_interaction_path = data_dir + '/sim_network/Sim_mat_protein_protein.txt' 380 | 381 | # About drug and protein (others)... 382 | drug_protein_path = data_dir + '/mat_data/mat_drug_protein.txt' 383 | drug_disease_path = data_dir + '/mat_data/mat_drug_disease.txt' 384 | drug_sideEffect_path = data_dir + '/mat_data/mat_drug_se.txt' 385 | protein_disease_path = data_dir + '/mat_data/mat_protein_disease.txt' 386 | 387 | # drug_drug and protein_protein combine the simNets and interactions 388 | # print('Load_Drug_Adj_Togerther ...') 389 | # drug_drug = Load_Adj_Togerther(dir_lists=[drug_drug_path, drug_drug_sim_chemical_path, 390 | # drug_drug_sim_interaction_path, drug_drug_sim_se_path, 391 | # drug_drug_sim_disease_path], ratio=0.01) 392 | # 393 | # print('Load_Protein_Adj_Togerther ...') 394 | # protein_protein = Load_Adj_Togerther(dir_lists=[protein_protein_path, protein_protein_sim_sequence_path, 395 | # protein_protein_sim_disease_path, protein_protein_sim_interaction_path], 396 | # ratio=0.005) 397 | 398 | drug_drug = np.loadtxt(drug_drug_path, dtype=int) 399 | drug_protein = np.loadtxt(drug_protein_path, dtype=int) 400 | drug_disease = np.loadtxt(drug_disease_path, dtype=int) 401 | protein_protein = np.loadtxt(protein_protein_path, dtype=int) 402 | drug_sideEffect = np.loadtxt(drug_sideEffect_path, dtype=int) 403 | protein_disease = np.loadtxt(protein_disease_path, dtype=int) 404 | 405 | print(sum(sum(drug_drug)), sum(sum(drug_protein)), sum(sum(drug_disease)), sum(sum(protein_protein)), 406 | sum(sum(drug_sideEffect)), sum(sum(protein_disease))) 407 | 408 | num_drug, num_protein = drug_protein.shape 409 | num_disease = drug_disease.shape[1] 410 | num_se = drug_sideEffect.shape[1] 411 | type_mask = get_type_mask(num_drug, num_protein, num_disease, num_se) 412 | np.save(save_prefix + 'node_types.npy', type_mask) 413 | 414 | ## Syep 2: Build the Adjacency Matrix 415 | # Drug-0, Protein-1, Disease-2, Side-effect-3 416 | adjM = get_adjM(drug_drug, drug_protein, drug_disease, drug_sideEffect, protein_protein, protein_disease, 417 | num_drug, num_protein, num_disease, num_se) 418 | # sp.save_npz(save_prefix + 'adjM_test.npz', sp.csr_matrix(adjM)) 419 | 420 | drug_drug_list = {i: adjM[i, :num_drug].nonzero()[0] for i in range(num_drug)} 421 | drug_protein_list = {i: adjM[i, num_drug:num_drug + num_protein].nonzero()[0] for i in range(num_drug)} 422 | drug_disease_lit = {i: adjM[i, num_drug + num_protein:num_drug + num_protein + num_disease].nonzero()[0] for i in range(num_drug)} 423 | drug_se_list = {i: adjM[i, num_drug + num_protein + num_disease:].nonzero()[0] for i in range(num_drug)} 424 | protein_drug_list = {i: adjM[num_drug + i, :num_drug].nonzero()[0] for i in range(num_protein)} 425 | protein_protein_list = {i: adjM[num_drug + i, num_drug:num_drug + num_protein].nonzero()[0] for i in range(num_protein)} 426 | protein_disease_list = { i: adjM[num_drug + i, num_drug + num_protein:num_drug + num_protein + num_disease].nonzero()[0] 427 | for i in range(num_protein)} 428 | disease_drug_list = {i: adjM[num_drug + num_protein + i, :num_drug].nonzero()[0] for i in range(num_disease)} 429 | disease_protein_list = {i: adjM[num_drug + num_protein + i, num_drug:num_drug + num_protein].nonzero()[0] for i in range(num_disease)} 430 | se_drug_list = {i: adjM[num_drug + num_protein + num_disease + i, : num_drug].nonzero()[0] for i in range(num_se)} 431 | 432 | # Step 3: Get target metapaths and neighbors for each test fold 433 | target_idx_lists = [np.arange(num_drug).tolist(), np.arange(num_protein).tolist()] 434 | offset_list = [0, num_drug] 435 | 436 | for counter in range(num_repeats): # repeat ten times 437 | print('\nThis is the {} repeat...'.format(counter)) 438 | for i, metapaths in enumerate(expected_metapaths): 439 | # print(metapaths) 440 | for metapath in metapaths: 441 | metapath_dir = save_prefix + 'repeat{}/{}/test/neg_times_{}/'.format(counter, i, neg_times) 442 | make_dir(metapath_dir) 443 | # Get all the metapaths in the schema of 'metapath' 444 | if os.path.isfile(metapath_dir + '-'.join(map(str, metapath)) + '.npy'): 445 | edge_metapath_idx_array = np.load(metapath_dir + '-'.join(map(str, metapath)) + '.npy') 446 | else: 447 | edge_metapath_idx_array = get_metapath(metapath, num_drug, num_protein, num_disease, num_se, metapath_dir) 448 | np.save(metapath_dir + '-'.join(map(str, metapath)) + '.npy', edge_metapath_idx_array) 449 | print(metapath, len(edge_metapath_idx_array)) 450 | 451 | target_metapaths, target_neighbors = target_metapath_and_neightbors(edge_metapath_idx_array, target_idx_lists[i], offset=offset_list[i]) 452 | # print('\n') 453 | for foldID in range(nFold): 454 | # print('Fold {}, metapath {}'.format(foldID, metapath)) 455 | metapath_fold_dir = metapath_dir + '/fold_{}/'.format(str(foldID)) + '-'.join(map(str, metapath)) 456 | make_dir(os.path.dirname(metapath_fold_dir)) 457 | test_fold_idx = fold_test_idx(pos_folds, neg_folds, foldID) 458 | fold_target_metapaths = {target: target_metapaths[target] for target in test_fold_idx[i]} 459 | fold_target_neighbors = {target: target_neighbors[target] for target in test_fold_idx[i]} 460 | pickle.dump(fold_target_metapaths, open(metapath_fold_dir + '.idx.pkl', 'wb')) 461 | pickle.dump(fold_target_neighbors, open(metapath_fold_dir + '.adjlist.pkl', 'wb')) 462 | --------------------------------------------------------------------------------