├── datasets ├── drugbank │ └── raw │ │ └── drugbank.7z ├── ddi_test │ └── raw │ │ └── test_smiles_pos.7z └── ddi_train │ └── raw │ └── train_smiles_pos.7z ├── readme.md ├── drugbank.py ├── train_val.py ├── ddi2013.py ├── SchNet.py ├── utils.py ├── LICENSE ├── model.py └── dataset.py /datasets/drugbank/raw/drugbank.7z: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hehh77/3DGT-DDI/HEAD/datasets/drugbank/raw/drugbank.7z -------------------------------------------------------------------------------- /datasets/ddi_test/raw/test_smiles_pos.7z: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hehh77/3DGT-DDI/HEAD/datasets/ddi_test/raw/test_smiles_pos.7z -------------------------------------------------------------------------------- /datasets/ddi_train/raw/train_smiles_pos.7z: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hehh77/3DGT-DDI/HEAD/datasets/ddi_train/raw/train_smiles_pos.7z -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # 3DGT-DDI 2 | 3 | ## file list 4 | 5 | - datasets:Inside is the dataset. There are three folders, the official training set and test set of DDI2013 and the data set of DrugBank. 6 | The CSV file is saved in the **raw** folder inside each folder. After running the code, the processed dataset file **process.pt** will be generated in the **process** folder 7 | 8 | - model:It stores the model parameter file generated after the code runs. 9 | 10 | - runs:It stores the log file generated by **tensorboardx** 11 | 12 | 13 | 14 | ## run code 15 | 16 | To train the 3DGT-DDI model in DDI2013 datasets: 17 | 18 | ```python 19 | python ddi2013.py --train_root ../datasets/ddi_train/ --train_path train_smiles_pos.csv --test_root ../datasets/ddi_test/ --test_path test_smiles_pos.csv --batch_size 8 --epochs 300 --lr 2e-6 --weight_decay 1e-2 --model_name allenai/scibert_scivocab_uncased --num_class 5 --max_len 128 --emb_dim 64 --cutoff 10.0 --num_layers 6 --hidden_channels 128 --num_filters 128 --num_gaussians 50 --g_out_channels 32 20 | ``` 21 | 22 | To train the 3DGT-DDI model in DrugBank datasets: 23 | 24 | ```python 25 | python drugbank.py --drugbank_root ../datasets/drugbank/ --drugbank_path drugbank.csv --batch_size 16 --epochs 200 --lr 2e-5 --weight_decay 1e-2 --num_class 2 --cutoff 10.0 --num_layers 6 --hidden_channels 128 --num_filters 128 --num_gaussians 50 --g_out_channels 32 26 | ``` 27 | 28 | 29 | 30 | ## requirement 31 | 32 | matplotlib==3.5.1 33 | numpy==1.22.0 34 | pandas==1.3.5 35 | rdkit==2009.Q1-1 36 | scikit_learn==1.0.2 37 | tensorboardX==2.4.1 38 | torch==1.10.1 39 | torch_geometric==2.0.3 40 | torch_scatter==2.0.9 41 | tqdm==4.62.3 42 | transformers==4.15.0 43 | -------------------------------------------------------------------------------- /drugbank.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from train_val import * 3 | 4 | def main(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--drugbank_root', default=None, type=str, required=True, 7 | help='..') 8 | parser.add_argument('--drugbank_path', default=None, type=str, required=True, 9 | help='..') 10 | parser.add_argument('--batch_size', default=16, type=int, required=True, 11 | help='..') 12 | parser.add_argument('--epochs', default=2, type=int, required=True, 13 | help='..') 14 | parser.add_argument('--lr', default=2e-5, type=float, required=True, 15 | help='..') 16 | parser.add_argument('--weight_decay', default=1e-2, type=float, required=True, 17 | help='..') 18 | parser.add_argument('--num_class', default=2, type=int, required=True, 19 | help='..') 20 | parser.add_argument('--cutoff', default=10.0, type=float, required=False, 21 | help='..') 22 | parser.add_argument('--num_layers', default=6, type=int, required=False, 23 | help='..') 24 | parser.add_argument('--hidden_channels', default=128, type=int, required=False, 25 | help='..') 26 | parser.add_argument('--num_filters', default=128, type=int, required=False, 27 | help='..') 28 | parser.add_argument('--num_gaussians', default=50, type=int, required=False, 29 | help='..') 30 | parser.add_argument('--g_out_channels', default=2, type=int, required=False, 31 | help='..') 32 | 33 | parser.add_argument('--load_model_path', default=None, type=str, required=False, 34 | help='..') 35 | parser.add_argument('--save_log_path', default="default", type=str, required=False, 36 | help='..') 37 | args = parser.parse_args() 38 | model = myModel_graph_sch_cnn(num_class=args.num_class, 39 | cutoff=args.cutoff, 40 | num_layers=args.num_layers, hidden_channels=args.hidden_channels, 41 | num_filters=args.num_filters, num_gaussians=args.num_gaussians, 42 | g_out_channels=args.g_out_channels) 43 | if args.load_model_path is not None: 44 | save_model = torch.load(args.load_model_path) 45 | model_dict = model.state_dict() 46 | state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()} 47 | model_dict.update(state_dict) 48 | model.load_state_dict(model_dict) 49 | 50 | optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 51 | dataset = drugbankDataset(root=args.drugbank_root, path=args.drugbank_path) 52 | split_idx = dataset.get_idx_split(len(dataset.data.y), int(len(dataset.data.y)*0.8), seed=2) 53 | train_dataset, valid_dataset =dataset[split_idx['train']], dataset[split_idx['valid']] 54 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, follow_batch=['pos1', 'pos2']) 55 | val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, follow_batch=['pos1', 'pos2']) 56 | train_eval_drugbank(model, optimizer, train_loader, val_loader, args.epochs , log_path = args.save_log_path) 57 | if __name__ == '__main__': 58 | main() 59 | 60 | -------------------------------------------------------------------------------- /train_val.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from model import * 3 | from dataset import * 4 | from torch_geometric.loader import DataLoader 5 | import time 6 | import datetime 7 | from tensorboardX import SummaryWriter 8 | 9 | def train_eval(model, optimizer, train_loader, val_loader,test_loader, epochs=2 , log_path = 'default'): 10 | #x = 0.005 11 | x = 0.0005 12 | criterion = FocalLoss(alpha_t=[x, 0.19, 0.13, 0.6, 0.12], gamma=2) 13 | #criterion = FocalLoss(alpha_t=[x, 0.093, 0.076, 0.3375, 0.0958], gamma=4) 14 | writer1 = SummaryWriter('./runs/' + log_path) 15 | 16 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 17 | # model.load_state_dict(checkpoint['model_state_dict']) 18 | model.to(device) 19 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8) 20 | print('-----Training-----') 21 | starttime = datetime.datetime.now() 22 | last_epoch_time = starttime 23 | bestf1 = 0 24 | best_epoch = 0 25 | for epoch in range(epochs): 26 | if epoch %5 == 0: 27 | x = x+0.0003 28 | criterion = FocalLoss(alpha_t=[x, 0.19, 0.13, 0.6, 0.12], gamma=2) 29 | endtime = datetime.datetime.now() 30 | print('total run time: ', endtime - starttime) 31 | print('last epoch run time: ', endtime - last_epoch_time) 32 | last_epoch_time = endtime 33 | print('Epoch', epoch) 34 | model.train() 35 | for i, batch_data in enumerate(tqdm(train_loader)): 36 | batch_data = batch_data.to(device) 37 | logits = model(batch_data) 38 | loss = criterion(logits, batch_data.y) 39 | optimizer.zero_grad() 40 | loss.backward() 41 | optimizer.step() 42 | model.eval() 43 | label, pred = get_label_pred_ddi2013(model, val_loader) 44 | f1 = sk_print_p_r(label, pred) 45 | if f1 > bestf1: 46 | bestf1 = f1 47 | best_epoch = epoch 48 | path = './model/'+log_path+'.pt' 49 | torch.save(model.state_dict(), path) 50 | scheduler.step() 51 | writer1.add_scalar('macro_f1', f1, global_step=epoch, walltime=None) 52 | 53 | 54 | def train_eval_drugbank(model, optimizer, train_loader, val_loader, epochs=2, log_path='./'): 55 | criterion = nn.CrossEntropyLoss() 56 | writer1 = SummaryWriter('../runs/' + log_path) 57 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 58 | model.to(device) 59 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.7) 60 | print('-----Training-----') 61 | starttime = datetime.datetime.now() 62 | last_epoch_time = starttime 63 | bestacc = 0 64 | for epoch in range(epochs): 65 | endtime = datetime.datetime.now() 66 | print('total run time: ', endtime - starttime) 67 | print('last epoch run time: ', endtime - last_epoch_time) 68 | last_epoch_time = endtime 69 | print('Epoch', epoch) 70 | model.train() 71 | for i, batch_data in enumerate(tqdm(train_loader)): 72 | batch_data = batch_data.to(device) 73 | logits = model(batch_data) 74 | loss = criterion(logits, batch_data.y) 75 | optimizer.zero_grad() 76 | loss.backward() 77 | optimizer.step() 78 | model.eval() 79 | label, pred,prob = get_label_pred_prob(model, val_loader) 80 | print('roc_auc:', sklearn.metrics.roc_auc_score(label, prob)) 81 | area = sklearn.metrics.average_precision_score(label, prob) 82 | print('pr_auc:', area) 83 | acc = sklearn.metrics.accuracy_score(label, pred) 84 | print('acc:', acc) 85 | if acc > bestacc: 86 | bestacc = acc 87 | path = '../model/' + log_path + '.pt' 88 | torch.save(model.state_dict(), path) 89 | scheduler.step() 90 | writer1.add_scalar('acc', acc, global_step=epoch, walltime=None) 91 | 92 | -------------------------------------------------------------------------------- /ddi2013.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from train_val import * 3 | from sklearn.model_selection import KFold 4 | def main(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--train_root', default=None, type=str, required=True, 7 | help='..') 8 | parser.add_argument('--train_path', default=None, type=str, required=True, 9 | help='..') 10 | parser.add_argument('--test_root', default=None, type=str, required=True, 11 | help='..') 12 | parser.add_argument('--test_path', default=None, type=str, required=True, 13 | help='..') 14 | parser.add_argument('--batch_size', default=8, type=int, required=True, 15 | help='..') 16 | parser.add_argument('--epochs', default=2, type=int, required=True, 17 | help='..') 18 | parser.add_argument('--lr', default=2e-6, type=float, required=True, 19 | help='..') 20 | parser.add_argument('--weight_decay', default=1e-2, type=float, required=True, 21 | help='..') 22 | parser.add_argument('--model_name', default=None, type=str, required=True, 23 | help='..') 24 | parser.add_argument('--num_class', default=2, type=int, required=True, 25 | help='..') 26 | parser.add_argument('--max_len', default=128, type=int, required=False, 27 | help='..') 28 | parser.add_argument('--emb_dim', default=64, type=int, required=False, 29 | help='..') 30 | parser.add_argument('--cutoff', default=10.0, type=float, required=False, 31 | help='..') 32 | parser.add_argument('--num_layers', default=6, type=int, required=False, 33 | help='..') 34 | parser.add_argument('--hidden_channels', default=128, type=int, required=False, 35 | help='..') 36 | parser.add_argument('--num_filters', default=128, type=int, required=False, 37 | help='..') 38 | parser.add_argument('--num_gaussians', default=50, type=int, required=False, 39 | help='..') 40 | parser.add_argument('--g_out_channels', default=2, type=int, required=False, 41 | help='..') 42 | parser.add_argument('--load_model_path', default=None, type=str, required=False, 43 | help='..') 44 | parser.add_argument('--save_log_path', default='default', type=str, required=False, 45 | help='..') 46 | args = parser.parse_args() 47 | model = myModel_text_graph_pos_cnn_sch(model_name=args.model_name, 48 | num_class=args.num_class, 49 | cutoff=args.cutoff, 50 | num_layers=args.num_layers, hidden_channels=args.hidden_channels, 51 | num_filters=args.num_filters, num_gaussians=args.num_gaussians, 52 | g_out_channels=args.g_out_channels 53 | ) 54 | if args.load_model_path is not None: 55 | print(args.load_model_path) 56 | save_model = torch.load(args.load_model_path) 57 | model_dict = model.state_dict() 58 | state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()} 59 | model_dict.update(state_dict) 60 | model.load_state_dict(model_dict) 61 | optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 62 | kf = KFold(n_splits=5, shuffle=True, random_state=123) 63 | dataset = DDI2013Dataset(root=args.train_root, path=args.train_path,model_name=args.model_name) 64 | for train, valid in kf.split(range(len(dataset.data.y))): 65 | train_dataset, valid_dataset =dataset[list(train)], dataset[list(valid)] 66 | test_dataset = DDI2013Dataset(root=args.test_root, path=args.test_path,model_name=args.model_name) 67 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, follow_batch=['pos1', 'pos2']) 68 | val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, follow_batch=['pos1', 'pos2']) 69 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, follow_batch=['pos1', 'pos2']) 70 | train_eval(model, optimizer, train_loader, val_loader,test_loader, args.epochs , log_path = args.save_log_path) 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /SchNet.py: -------------------------------------------------------------------------------- 1 | from math import pi as PI 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn import Embedding, Sequential, Linear 5 | from torch_scatter import scatter 6 | from torch_geometric.nn import radius_graph 7 | 8 | 9 | class update_e(torch.nn.Module): 10 | def __init__(self, hidden_channels, num_filters, num_gaussians, cutoff): 11 | super(update_e, self).__init__() 12 | self.cutoff = cutoff 13 | self.lin = Linear(hidden_channels, num_filters, bias=False) 14 | self.mlp = Sequential( 15 | Linear(num_gaussians, num_filters), 16 | ShiftedSoftplus(), 17 | Linear(num_filters, num_filters), 18 | ) 19 | 20 | self.reset_parameters() 21 | 22 | def reset_parameters(self): 23 | torch.nn.init.xavier_uniform_(self.lin.weight) 24 | torch.nn.init.xavier_uniform_(self.mlp[0].weight) 25 | self.mlp[0].bias.data.fill_(0) 26 | torch.nn.init.xavier_uniform_(self.mlp[2].weight) 27 | self.mlp[0].bias.data.fill_(0) 28 | 29 | def forward(self, v, dist, dist_emb, edge_index): 30 | j, _ = edge_index 31 | C = 0.5 * (torch.cos(dist * PI / self.cutoff) + 1.0) 32 | W = self.mlp(dist_emb) * C.view(-1, 1) 33 | v = self.lin(v) 34 | e = v[j] * W 35 | return e 36 | 37 | 38 | class update_v(torch.nn.Module): 39 | def __init__(self, hidden_channels, num_filters): 40 | super(update_v, self).__init__() 41 | self.act = ShiftedSoftplus() 42 | self.lin1 = Linear(num_filters, hidden_channels) 43 | self.lin2 = Linear(hidden_channels, hidden_channels) 44 | 45 | self.reset_parameters() 46 | 47 | def reset_parameters(self): 48 | torch.nn.init.xavier_uniform_(self.lin1.weight) 49 | self.lin1.bias.data.fill_(0) 50 | torch.nn.init.xavier_uniform_(self.lin2.weight) 51 | self.lin2.bias.data.fill_(0) 52 | 53 | def forward(self, v, e, edge_index): 54 | _, i = edge_index 55 | out = scatter(e, i, dim=0) 56 | out = self.lin1(out) 57 | out = self.act(out) 58 | out = self.lin2(out) 59 | return v + out 60 | 61 | 62 | class update_u(torch.nn.Module): 63 | def __init__(self, hidden_channels,out_channels = 1): 64 | super(update_u, self).__init__() 65 | self.lin1 = Linear(hidden_channels, hidden_channels // 2) 66 | self.act = ShiftedSoftplus() 67 | self.lin2 = Linear(hidden_channels // 2, out_channels) 68 | 69 | self.reset_parameters() 70 | 71 | def reset_parameters(self): 72 | torch.nn.init.xavier_uniform_(self.lin1.weight) 73 | self.lin1.bias.data.fill_(0) 74 | torch.nn.init.xavier_uniform_(self.lin2.weight) 75 | self.lin2.bias.data.fill_(0) 76 | 77 | def forward(self, v, batch): 78 | v = self.lin1(v) 79 | v = self.act(v) 80 | v = self.lin2(v) 81 | u = scatter(v, batch, dim=0) 82 | return u 83 | 84 | 85 | class emb(torch.nn.Module): 86 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50): 87 | super(emb, self).__init__() 88 | offset = torch.linspace(start, stop, num_gaussians) 89 | self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2 90 | self.register_buffer('offset', offset) 91 | 92 | def forward(self, dist): 93 | dist = dist.view(-1, 1) - self.offset.view(1, -1) 94 | return torch.exp(self.coeff * torch.pow(dist, 2)) 95 | 96 | 97 | class ShiftedSoftplus(torch.nn.Module): 98 | def __init__(self): 99 | super(ShiftedSoftplus, self).__init__() 100 | self.shift = torch.log(torch.tensor(2.0)).item() 101 | 102 | def forward(self, x): 103 | return F.softplus(x) - self.shift 104 | 105 | 106 | 107 | 108 | class SchNet(torch.nn.Module): 109 | r""" 110 | The re-implementation for SchNet from the `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" `_ paper 111 | under the 3DGN gramework from `"Spherical Message Passing for 3D Graph Networks" `_ paper. 112 | 113 | Args: 114 | energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the negative of the derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`) 115 | num_layers (int, optional): The number of layers. (default: :obj:`6`) 116 | hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`) 117 | num_filters (int, optional): The number of filters to use. (default: :obj:`128`) 118 | num_gaussians (int, optional): The number of gaussians :math:`\mu`. (default: :obj:`50`) 119 | cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`). 120 | """ 121 | 122 | def __init__(self, energy_and_force=False, cutoff=10.0, num_layers=6, hidden_channels=128, num_filters=128, 123 | num_gaussians=50,out_channels = 1): 124 | super(SchNet, self).__init__() 125 | 126 | self.energy_and_force = energy_and_force 127 | self.cutoff = cutoff 128 | self.num_layers = num_layers 129 | self.hidden_channels = hidden_channels 130 | self.num_filters = num_filters 131 | self.num_gaussians = num_gaussians 132 | 133 | self.init_v = Embedding(100, hidden_channels) 134 | self.dist_emb = emb(0.0, cutoff, num_gaussians) 135 | 136 | self.update_vs = torch.nn.ModuleList([update_v(hidden_channels, num_filters) for _ in range(num_layers)]) 137 | 138 | self.update_es = torch.nn.ModuleList([ 139 | update_e(hidden_channels, num_filters, num_gaussians, cutoff) for _ in range(num_layers)]) 140 | 141 | self.update_u = update_u(hidden_channels,out_channels) 142 | 143 | self.reset_parameters() 144 | 145 | def reset_parameters(self): 146 | self.init_v.reset_parameters() 147 | for update_e in self.update_es: 148 | update_e.reset_parameters() 149 | for update_v in self.update_vs: 150 | update_v.reset_parameters() 151 | self.update_u.reset_parameters() 152 | 153 | def forward(self, batch_data): 154 | z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch 155 | if self.energy_and_force: 156 | pos.requires_grad_() 157 | 158 | edge_index = radius_graph(pos, r=self.cutoff, batch=batch) 159 | row, col = edge_index 160 | dist = (pos[row] - pos[col]).norm(dim=-1) 161 | dist_emb = self.dist_emb(dist) 162 | 163 | v = self.init_v(z) 164 | 165 | for update_e, update_v in zip(self.update_es, self.update_vs): 166 | e = update_e(v, dist, dist_emb, edge_index) 167 | v = update_v(v, e, edge_index) 168 | u = self.update_u(v, batch) 169 | 170 | return u 171 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import untangle 2 | import pandas as pd 3 | import numpy as np 4 | import os 5 | import re 6 | import random 7 | import torch 8 | 9 | import sklearn 10 | from sklearn.metrics import f1_score, precision_score, recall_score,classification_report 11 | import matplotlib.pyplot as plt 12 | import itertools 13 | from tqdm import tqdm 14 | 15 | 16 | def get_file(paths, drugbank): # 获取文件路径 17 | i = 0 18 | for path in paths: 19 | for root, dirs, files in os.walk(path): 20 | for file in files: 21 | # root = etree.parse(xml_path, parser=etree.XMLParser()) 22 | 23 | obj = untangle.parse(str(path + file)) 24 | obj = obj.document.children 25 | # print(obj) 26 | for sen in obj: 27 | # print(sen['id'],sen['text']) 28 | drug_dict = {} 29 | try: 30 | sen.entity 31 | except: 32 | count = 0 33 | else: 34 | for drug in sen.entity: 35 | drug_dict[drug["id"]] = drug["text"] 36 | try: 37 | sen.pair 38 | except: 39 | count = 0 40 | else: 41 | for pair in sen.pair: 42 | if pair['ddi'] == 'false' and random.random() > 0.1: 43 | continue 44 | 45 | text = sen['text'] 46 | text = text.replace(drug_dict[pair["e1"]], "DRUG1") 47 | text = text.replace(drug_dict[pair["e2"]], "DRUG2") 48 | for drug in drug_dict: 49 | text = text.replace(drug_dict[drug], "DRUGOTHER") 50 | drugbank.loc[i, "text"] = text 51 | drugbank.loc[i, "drug1"] = drug_dict[pair["e1"]] 52 | drugbank.loc[i, "drug2"] = drug_dict[pair["e2"]] 53 | drugbank.loc[i, "ddi"] = pair['ddi'] 54 | if pair['ddi'] == 'true': 55 | drugbank.loc[i, "type"] = pair['type'] 56 | i = i + 1 57 | # break 58 | return drugbank 59 | 60 | 61 | def process_data(paths): 62 | drugbank = pd.DataFrame( 63 | columns=["text", "drug1", "drug2", "ddi", "type"]) 64 | drugbank = get_file(paths, drugbank) 65 | drugbank.fillna('Neg', inplace=True) 66 | label_dict = {"Neg": 0, 67 | 'advise': 1, 68 | 'effect': 2, 69 | 'int': 3, 70 | 'mechanism': 4} 71 | drugbank['label'] = drugbank['type'] 72 | for i in range(drugbank.shape[0]): 73 | drugbank['label'][i] = label_dict[drugbank['type'][i]] 74 | 75 | drugbank.to_csv("./train_sample10.csv") 76 | 77 | 78 | def train_test_split(data_df, test_size=0.2, shuffle=True, random_state=None): 79 | if shuffle: 80 | data_df = sklearn.utils.shuffle(data_df, random_state=random_state) 81 | train = data_df[int(len(data_df) * test_size):].reset_index(drop=True) 82 | test = data_df[:int(len(data_df) * test_size)].reset_index(drop=True) 83 | 84 | return train, test 85 | 86 | 87 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 88 | 89 | 90 | def confusion_matrix(preds, labels, conf_matrix): 91 | for p, t in zip(preds, labels): 92 | conf_matrix[p, t] += 1 93 | return conf_matrix 94 | 95 | 96 | from matplotlib.colors import LinearSegmentedColormap 97 | 98 | 99 | # 绘制混淆矩阵 100 | def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues, title_add=0, 101 | polt_lim=0): 102 | ''' 103 | This function prints and plots the confusion matrix. 104 | Normalization can be applied by setting `normalize=True`. 105 | Input 106 | - cm : 计算出的混淆矩阵的值 107 | - classes : 混淆矩阵中每一行每一列对应的列 108 | - normalize : True:显示百分比, False:显示个数 109 | ''' 110 | colors = [] 111 | for l in np.linspace(0, 1, 100): 112 | colors.append((30. / 255, 136. / 255, 229. / 255, l)) 113 | transparent_blue = LinearSegmentedColormap.from_list("transparent_blue", colors) 114 | 115 | if title_add < polt_lim: 116 | return 117 | if normalize: 118 | cm = cm.astype('float') / cm.sum(axis=0)[np.newaxis, :] 119 | print("Normalized confusion matrix") 120 | else: 121 | print('Confusion matrix, without normalization') 122 | print(cm.T) 123 | 124 | plt.imshow(cm.T, interpolation='nearest', cmap=cmap) 125 | if normalize: 126 | plt.title('Normalized ' + title) 127 | else: 128 | plt.title(title) 129 | 130 | plt.colorbar() 131 | tick_marks = np.arange(len(classes)) 132 | plt.xticks(tick_marks, classes, rotation=90) 133 | plt.yticks(tick_marks, classes) 134 | 135 | # 。。。。。。。。。。。。新增代码开始处。。。。。。。。。。。。。。。。 136 | # x,y轴长度一致(问题1解决办法) 137 | plt.axis("equal") 138 | # x轴处理一下,如果x轴或者y轴两边有空白的话(问题2解决办法) 139 | ax = plt.gca() # 获得当前axis 140 | left, right = plt.xlim() # 获得x轴最大最小值 141 | ax.spines['left'].set_position(('data', left)) 142 | ax.spines['right'].set_position(('data', right)) 143 | for edge_i in ['top', 'bottom', 'right', 'left']: 144 | ax.spines[edge_i].set_edgecolor("white") 145 | # 。。。。。。。。。。。。新增代码结束处。。。。。。。。。。。。。。。。 146 | 147 | thresh = cm.max() / 2. 148 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 149 | num = '{:.2f}'.format(cm[i, j]) if normalize else int(cm[i, j]) 150 | plt.text(i, j, num, 151 | verticalalignment='center', 152 | horizontalalignment="center", 153 | color="white" if float(num) > thresh else "black") 154 | plt.tight_layout(pad=1.3) 155 | plt.ylabel('True label') 156 | plt.xlabel('Predicted label') 157 | 158 | plt.savefig('save_img_' + str(title_add) + '.jpg') 159 | plt.close() 160 | 161 | 162 | def plot_cm_and_get_label_pred(mymodel, val_loader, normalize=False, best_f1=0): 163 | mymodel.to(device) 164 | mymodel.eval() 165 | conf_matrix = torch.zeros(5, 5) 166 | types = ['Neg', 'Adv', 'Eff', 'Int', 'Mec'] 167 | ma_f1 = 0 168 | pred, label = [], [] 169 | for i, batch in enumerate(tqdm(val_loader)): 170 | # batch = tuple(t.to(device) for t in batch) 171 | batch = batch.to(device) 172 | # out = mymodel(batch[0], batch[1], batch[2]) 173 | out = mymodel(batch) 174 | # conf_matrix = confusion_matrix(np.argmax(out.detach().cpu().numpy(),axis=1), labels=batch[-1], conf_matrix=conf_matrix) 175 | conf_matrix = confusion_matrix(np.argmax(out.detach().cpu().numpy(), axis=1), labels=batch.y, 176 | conf_matrix=conf_matrix) 177 | pred.extend(np.argmax(out.detach().cpu().numpy(), axis=1).flatten()) 178 | # label.extend(batch[-1].detach().cpu().numpy().flatten()) 179 | label.extend(batch.y.detach().cpu().numpy().flatten()) 180 | ma_p = precision_score(np.array(label), np.array(pred), average='macro') 181 | ma_r = recall_score(np.array(label), np.array(pred), average='macro') 182 | ma_f1 = (2 * ma_p * ma_r) / (ma_p + ma_r) 183 | plot_confusion_matrix(conf_matrix.numpy(), classes=types, normalize=normalize, title_add=ma_f1, polt_lim=best_f1) 184 | return conf_matrix, np.array(label), np.array(pred) 185 | 186 | 187 | def get_label_pred_ddi2013(mymodel, val_loader): 188 | mymodel.to(device) 189 | mymodel.eval() 190 | pred, label = [], [] 191 | for i, batch in enumerate(val_loader): 192 | batch = batch.to(device) 193 | out = mymodel(batch) 194 | pred.extend(np.argmax(out.detach().cpu().numpy(), axis=1).flatten()) 195 | label.extend(batch.y.detach().cpu().numpy().flatten()) 196 | return np.array(label), np.array(pred) 197 | 198 | 199 | def calculate_accuracy(cm): 200 | return torch.trace(cm) / torch.sum(cm) 201 | 202 | 203 | def calculate_recall(cm, idx): 204 | return cm[idx][idx] / cm.sum(axis=0)[idx] 205 | 206 | 207 | def calculate_precision(cm, idx): 208 | return cm[idx][idx] / cm.sum(axis=1)[idx] 209 | 210 | 211 | def print_p_r(cm): 212 | precision_sum = 0 213 | recall_sum = 0 214 | precision = 0 215 | recall = 0 216 | f1_sum = 0 217 | print('accuracy:', calculate_accuracy(cm)) 218 | for i in range(cm.shape[0]): 219 | precision = calculate_precision(cm, i) 220 | recall = calculate_recall(cm, i) 221 | precision_sum += precision 222 | recall_sum += recall 223 | print('label:', i, ' precision:', precision) 224 | print('label:', i, ' recall:', recall) 225 | f1_sum += (precision * recall * 2) / (precision + recall) 226 | macro_precision = precision_sum / cm.shape[0] 227 | macro_recall = recall_sum / cm.shape[0] 228 | print('Macro_precision:', macro_precision) 229 | print('Macro_recall:', macro_recall) 230 | print('Macro_F1:', f1_sum / cm.shape[0]) 231 | 232 | 233 | def sk_print_p_r(label, pred): 234 | ma_p = precision_score(np.array(label), np.array(pred), average='macro') 235 | ma_r = recall_score(np.array(label), np.array(pred), average='macro') 236 | ma_f1 = (2 * ma_p * ma_r) / (ma_p + ma_r) 237 | print('macro f1', ma_f1, 238 | 'macro p', ma_p, 239 | 'macro r', ma_r) 240 | print('micro f1', f1_score(label, pred, average='micro')) 241 | #print(classification_report(label,pred,digits=4)) 242 | return ma_f1 243 | 244 | 245 | def get_label_pred_prob(mymodel, val_loader): 246 | mymodel.to(device) 247 | mymodel.eval() 248 | ma_f1 = 0 249 | pred, label = [], [] 250 | prob = [] 251 | for i, batch in enumerate(tqdm(val_loader)): 252 | batch = batch.to(device) 253 | out = mymodel(batch) 254 | prob.extend(torch.sigmoid(out)[:, 1].detach().cpu().numpy().flatten()) 255 | pred.extend(np.argmax(out.detach().cpu().numpy(), axis=1).flatten()) 256 | label.extend(batch.y.detach().cpu().numpy().flatten()) 257 | return np.array(label), np.array(pred), np.array(prob) 258 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoTokenizer 2 | import torch 3 | from torch import nn 4 | from torch.nn.functional import selu 5 | from SchNet import SchNet 6 | import random 7 | 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | 10 | 11 | 12 | 13 | class myModel_text_graph_pos_cnn_sch(nn.Module): 14 | def __init__(self, model_name, hidden_size=768, num_class=2, freeze_bert=False, max_len=128, 15 | emb_dim=64,cutoff = 10.0,num_layers = 6,hidden_channels = 128,num_filters = 128,num_gaussians = 50,g_out_channels = 5): 16 | super(myModel_text_graph_pos_cnn_sch, self).__init__() 17 | self.max_len = max_len 18 | self.emb_dim = emb_dim 19 | self.bert = AutoModel.from_pretrained(model_name, cache_dir='../cache', output_hidden_states=True, 20 | return_dict=True) 21 | if freeze_bert: 22 | for p in self.bert.parameters(): 23 | p.requires_grad = False 24 | self.cnn = CNN() 25 | self.cutoff = cutoff 26 | self.num_layers =num_layers 27 | self.hidden_channels =hidden_channels 28 | self.num_filters =num_filters 29 | self.num_gaussians =num_gaussians 30 | self.model1 = SchNet(energy_and_force=False, cutoff=self.cutoff, num_layers=self.num_layers, 31 | hidden_channels=self.hidden_channels, num_filters=self.num_filters, num_gaussians=self.num_gaussians, 32 | out_channels=g_out_channels) 33 | self.model2 = SchNet(energy_and_force=False, cutoff=self.cutoff, num_layers=self.num_layers, 34 | hidden_channels=self.hidden_channels, num_filters=self.num_filters, num_gaussians=self.num_gaussians, 35 | out_channels=g_out_channels) 36 | 37 | self.fc_g_1 = nn.Sequential( 38 | # nn.Dropout(), 39 | nn.Linear(32, 32 * 2, bias=True), 40 | nn.PReLU(), 41 | nn.Linear(32 * 2, 32, bias=True) 42 | ) 43 | self.fc_g_2 = nn.Sequential( 44 | # nn.Dropout(), 45 | nn.Linear(32, 32 * 2, bias=True), 46 | nn.PReLU(), 47 | nn.Linear(32 * 2, 32, bias=True) 48 | ) 49 | 50 | self.cnn_g = CNN_g(in_channel=2, out_channel=num_class) 51 | 52 | self.emb = nn.Embedding(self.max_len + 1, self.emb_dim) 53 | 54 | self.fc_emb = nn.Sequential( 55 | # nn.Dropout(), 56 | nn.Linear(self.emb_dim * 2, 32 * 2, bias=True), 57 | nn.PReLU(), 58 | nn.Linear(32 * 2, num_class, bias=True) 59 | ) 60 | 61 | self.fc3 = nn.Sequential( 62 | # nn.Dropout(), 63 | nn.Linear(3 * num_class, 32 * 2, bias=True), 64 | nn.PReLU(), 65 | nn.Linear(32 * 2, num_class, bias=True) 66 | ) 67 | 68 | def forward(self, batch_data): 69 | outputs = self.bert(input_ids=batch_data.token_ids.view(-1, self.max_len), 70 | token_type_ids=batch_data.token_type_ids.view(-1, self.max_len), 71 | attention_mask=batch_data.attn_masks.view(-1, self.max_len)) 72 | hidden_states = torch.cat(tuple([outputs.hidden_states[i] for i in [-1, -2, -3, -4, -5, -6]]), 73 | dim=-1).view(outputs.hidden_states[-1].shape[0], -1, 74 | outputs.hidden_states[-1].shape[1], 75 | outputs.hidden_states[-1].shape[-1]) # [bs, seq_len, hidden_dim*6] 76 | logits = self.cnn(hidden_states) 77 | batch_data.pos = batch_data.pos1 78 | batch_data.z = batch_data.z1 79 | batch_data.batch = batch_data.pos1_batch 80 | self.pred1 = self.model1(batch_data) 81 | batch_data.pos = batch_data.pos2 82 | batch_data.z = batch_data.z2 83 | batch_data.batch = batch_data.pos2_batch 84 | self.pred2 = self.model2(batch_data) 85 | self.pred1 = self.fc_g_1(self.pred1) 86 | self.pred2 = self.fc_g_2(self.pred2) 87 | self.pred1 = self.pred1.unsqueeze(1) 88 | self.pred2 = self.pred2.unsqueeze(1) 89 | self.pred = torch.cat((self.pred1, self.pred2), 1) 90 | self.pred = self.cnn_g(self.pred) 91 | 92 | # self.pred = (self.pred + 9*logits)/10.0 93 | 94 | drug1_pos = batch_data.drug1_pos 95 | drug2_pos = batch_data.drug2_pos 96 | drug1_pos[drug1_pos == -1] = self.max_len 97 | drug2_pos[drug2_pos == -1] = drug1_pos[drug2_pos == -1] 98 | self.emb1 = self.emb(drug1_pos) 99 | self.emb2 = self.emb(drug2_pos) 100 | self.emb_cat = torch.cat((self.emb1, self.emb2), 1) 101 | self.emb_cat = self.fc_emb(self.emb_cat) 102 | # self.emb_cat = F.softmax(self.emb_cat,dim=1) 103 | 104 | # self.pred = (19*self.pred + self.emb_cat)/20.0 105 | 106 | self.pred_total = torch.cat((logits, self.pred, self.emb_cat), 1) 107 | self.pred_total = self.fc3(self.pred_total) 108 | # if random.random() < 0.01: 109 | # print(logits) 110 | # print(self.pred) 111 | # print(self.emb_cat) 112 | # print(self.pred_total) 113 | return self.pred_total 114 | 115 | 116 | 117 | class myModel_graph_sch_cnn(nn.Module): 118 | def __init__(self,num_class=2,cutoff = 10.0,num_layers = 6,hidden_channels = 128, 119 | num_filters = 128,num_gaussians = 50,g_out_channels = 5): 120 | super(myModel_graph_sch_cnn, self).__init__() 121 | self.cutoff = cutoff 122 | self.num_layers = num_layers 123 | self.hidden_channels = hidden_channels 124 | self.num_filters = num_filters 125 | self.num_gaussians = num_gaussians 126 | self.model1 = SchNet(energy_and_force=False, cutoff=self.cutoff, num_layers=self.num_layers, 127 | hidden_channels=self.hidden_channels, num_filters=self.num_filters, num_gaussians=self.num_gaussians, 128 | out_channels=g_out_channels) 129 | self.model2 = SchNet(energy_and_force=False, cutoff=self.cutoff, num_layers=self.num_layers, 130 | hidden_channels=self.hidden_channels, num_filters=self.num_filters, num_gaussians=self.num_gaussians, 131 | out_channels=g_out_channels) 132 | 133 | self.fc1 = nn.Sequential( 134 | # nn.Dropout(), 135 | nn.Linear(32, 32 * 2, bias=True), 136 | nn.PReLU(), 137 | nn.Linear(32 * 2, 32, bias=True) 138 | ) 139 | self.fc2 = nn.Sequential( 140 | # nn.Dropout(), 141 | nn.Linear(32, 32 * 2, bias=True), 142 | nn.PReLU(), 143 | nn.Linear(32 * 2, 32, bias=True) 144 | ) 145 | 146 | self.cnn = CNN_g(in_channel=2, out_channel=num_class) 147 | 148 | def forward(self, batch_data): 149 | batch_data.pos = batch_data.pos1 150 | batch_data.z = batch_data.z1 151 | batch_data.batch = batch_data.pos1_batch 152 | self.pred1 = self.model1(batch_data) 153 | batch_data.pos = batch_data.pos2 154 | batch_data.z = batch_data.z2 155 | batch_data.batch = batch_data.pos2_batch 156 | self.pred2 = self.model2(batch_data) 157 | 158 | self.pred1 = self.fc1(self.pred1) 159 | self.pred2 = self.fc2(self.pred2) 160 | self.pred1 = self.pred1.unsqueeze(1) 161 | self.pred2 = self.pred2.unsqueeze(1) 162 | self.pred = torch.cat((self.pred1, self.pred2), 1) 163 | self.pred = self.cnn(self.pred) 164 | return self.pred 165 | 166 | 167 | 168 | class CNN(nn.Module): 169 | def __init__(self, in_channel=6, fc1_hid_dim=128 * 768, out_channel=5): 170 | super(CNN, self).__init__() 171 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=(1, 3), padding=[0, 1]) 172 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(1, 3), padding=[0, 1]) 173 | self.conv31 = nn.Conv2d(128, 128, kernel_size=(1, 3), padding=[0, 1]) 174 | self.conv32 = nn.Conv2d(128, 128, kernel_size=(1, 3), padding=[0, 1]) 175 | self.conv4 = nn.Conv2d(128, 256, kernel_size=(1, 3), padding=[0, 1]) 176 | self.fc1 = nn.Linear(fc1_hid_dim, 64) 177 | self.fc2 = nn.Linear(64, out_channel) 178 | self.out_channel = out_channel 179 | self.Lrelu = nn.LeakyReLU() 180 | self.bn1 = nn.BatchNorm2d(64) 181 | self.bn2 = nn.BatchNorm2d(128) 182 | self.bn31 = nn.BatchNorm2d(128) 183 | self.bn32 = nn.BatchNorm2d(128) 184 | self.bn4 = nn.BatchNorm2d(256) 185 | def forward(self, x): 186 | 187 | x = self.Lrelu(self.bn1(self.conv1(x))) # 输入 batch size * hiddenLayers * max_len * embedding Length (bs*6*128*768) 输出 bs*64*128*768 188 | x = self.Lrelu(self.bn2(self.conv2(x))) # 输出 bs*128*128*768 189 | res = x 190 | x = self.Lrelu(self.bn31(self.conv31(x))) # 输出 bs*128*128*768 191 | x = self.Lrelu(self.bn32(self.conv32(x))) # 输出 bs*128*128*768 192 | x = res + x 193 | x = self.Lrelu(self.bn4(self.conv4(x))) # 输出 bs*256*128*768 194 | x = self.Lrelu(self.fc1(x.view(x.shape[0], x.shape[1], -1))) # 输出 bs*64*64 195 | x = self.Lrelu(self.fc2(x)) # 输出 bs*64*out_channel 196 | x = F.adaptive_avg_pool2d(x, (1, self.out_channel)).squeeze(dim=-1).squeeze(1) # 平均池化为 bs*out_channel 197 | 198 | return x 199 | 200 | 201 | 202 | 203 | class CNN_g(nn.Module): 204 | def __init__(self, in_channel=2, fc1_hid_dim=256 * 32, out_channel=2): 205 | super(CNN_g, self).__init__() 206 | self.conv1 = nn.Conv1d(in_channel, 64, kernel_size=3, padding=1) 207 | self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1) 208 | self.conv31 = nn.Conv1d(128, 128, kernel_size=3, padding=1) 209 | self.conv32 = nn.Conv1d(128, 128, kernel_size=3, padding=1) 210 | self.conv4 = nn.Conv1d(128, 256, kernel_size=3, padding=1) 211 | self.fc1 = nn.Linear(fc1_hid_dim, 64) 212 | self.fc2 = nn.Linear(64, out_channel) 213 | self.Lrelu = nn.LeakyReLU() 214 | 215 | def forward(self, x): 216 | x = self.Lrelu(self.conv1(x)) # batchsize *2 * 32 变为 batchsize *64 * 32 217 | x = self.Lrelu(self.conv2(x)) # batchsize *128 * 32 218 | res = x 219 | x = self.Lrelu(self.conv31(x)) # batchsize *128 * 32 220 | x = self.Lrelu(self.conv32(x)) # batchsize *128 * 32 221 | x = res + x # batchsize *128 * 32 222 | x = self.Lrelu(self.conv4(x)) # batchsize *256 * 32 223 | x = self.Lrelu(self.fc1(x.view(x.shape[0], -1))) # batchsize * 64 224 | x = self.fc2(x) # batchsize * out_channel 输出通道数代表预测的类别数量 根据任务的分类类别来确定 ddi2013里面是5 drugbank里面是2 225 | 226 | return x 227 | 228 | 229 | 230 | 231 | import torch.nn.functional as F 232 | 233 | 234 | class myModel_text_cnn(nn.Module): 235 | def __init__(self, model_name, hidden_size=768, num_class=2, freeze_bert=False, 236 | max_len=128): # , freeze_bert=False ,model_name): 237 | super(myModel_text_cnn, self).__init__() 238 | self.max_len = max_len 239 | self.bert = AutoModel.from_pretrained(model_name, cache_dir='../cache', output_hidden_states=True, 240 | return_dict=True) 241 | if freeze_bert: 242 | for p in self.bert.parameters(): 243 | p.requires_grad = False 244 | 245 | self.fc1 = nn.Sequential( 246 | nn.Dropout(), 247 | nn.Linear(hidden_size * 6, num_class, bias=False) 248 | ) 249 | self.cnn = CNN() 250 | 251 | def forward(self, batch_data): 252 | outputs = self.bert(input_ids=batch_data.token_ids.view(-1, self.max_len), 253 | token_type_ids=batch_data.token_type_ids.view(-1, self.max_len), 254 | attention_mask=batch_data.attn_masks.view(-1, self.max_len)) 255 | hidden_states = torch.cat(tuple([outputs.hidden_states[i] for i in [-1, -2, -3, -4, -5, -6]]), 256 | dim=-1).view(outputs.hidden_states[-1].shape[0], -1, 257 | outputs.hidden_states[-1].shape[1], 258 | outputs.hidden_states[-1].shape[-1]) # [bs, seq_len, hidden_dim*6] 259 | self.pred = self.cnn(hidden_states) 260 | return self.pred 261 | 262 | 263 | 264 | 265 | class myModel_text_pos_cnn(nn.Module): 266 | def __init__(self, model_name, hidden_size=768, num_class=2, freeze_bert=False, max_len=128, 267 | emb_dim=64): # , freeze_bert=False ,model_name): 268 | super(myModel_text_pos_cnn, self).__init__() 269 | self.max_len = max_len 270 | self.emb_dim = emb_dim 271 | self.bert = AutoModel.from_pretrained(model_name, cache_dir='../cache', output_hidden_states=True, 272 | return_dict=True) 273 | if freeze_bert: 274 | for p in self.bert.parameters(): 275 | p.requires_grad = False 276 | 277 | self.fc1 = nn.Sequential( 278 | nn.Dropout(), 279 | nn.Linear(hidden_size * 6, num_class, bias=False) 280 | ) 281 | 282 | self.emb = nn.Embedding(self.max_len + 1, self.emb_dim) 283 | 284 | self.fc_emb = nn.Sequential( 285 | # nn.Dropout(), 286 | nn.Linear(self.emb_dim * 2, 32 * 2, bias=False), 287 | nn.PReLU(), 288 | nn.Linear(32 * 2, num_class, bias=False) 289 | ) 290 | self.cnn = CNN() 291 | 292 | def forward(self, batch_data): 293 | outputs = self.bert(input_ids=batch_data.token_ids.view(-1, self.max_len), 294 | token_type_ids=batch_data.token_type_ids.view(-1, self.max_len), 295 | attention_mask=batch_data.attn_masks.view(-1, self.max_len)) 296 | hidden_states = torch.cat(tuple([outputs.hidden_states[i] for i in [-1, -2, -3, -4, -5, -6]]), 297 | dim=-1).view(outputs.hidden_states[-1].shape[0], -1, 298 | outputs.hidden_states[-1].shape[1], 299 | outputs.hidden_states[-1].shape[-1]) # [bs, seq_len, hidden_dim*6] 300 | logits = self.cnn(hidden_states) 301 | 302 | drug1_pos = batch_data.drug1_pos 303 | drug2_pos = batch_data.drug2_pos 304 | drug1_pos[drug1_pos == -1] = self.max_len 305 | drug2_pos[drug2_pos == -1] = drug1_pos[drug2_pos == -1] 306 | self.emb1 = self.emb(drug1_pos) 307 | self.emb2 = self.emb(drug2_pos) 308 | self.emb_cat = torch.cat((self.emb1, self.emb2), 1) 309 | self.emb_cat = self.fc_emb(self.emb_cat) 310 | 311 | self.pred = (logits + 0.1 * self.emb_cat) / 1.1 312 | return self.pred 313 | 314 | 315 | class FocalLoss: 316 | def __init__(self, alpha_t=None, gamma=0): 317 | """ 318 | :param alpha_t: A list of weights for each class 319 | :param gamma: 320 | """ 321 | self.alpha_t = torch.tensor(alpha_t) if alpha_t else None 322 | self.gamma = gamma 323 | 324 | def __call__(self, outputs, targets): 325 | if self.alpha_t is None and self.gamma == 0: 326 | focal_loss = torch.nn.functional.cross_entropy(outputs, targets) 327 | 328 | elif self.alpha_t is not None and self.gamma == 0: 329 | if self.alpha_t.device != outputs.device: 330 | self.alpha_t = self.alpha_t.to(outputs) 331 | focal_loss = torch.nn.functional.cross_entropy(outputs, targets, 332 | weight=self.alpha_t) 333 | 334 | elif self.alpha_t is None and self.gamma != 0: 335 | ce_loss = torch.nn.functional.cross_entropy(outputs, targets, reduction='none') 336 | p_t = torch.exp(-ce_loss) 337 | focal_loss = ((1 - p_t) ** self.gamma * ce_loss).mean() 338 | 339 | elif self.alpha_t is not None and self.gamma != 0: 340 | if self.alpha_t.device != outputs.device: 341 | self.alpha_t = self.alpha_t.to(outputs) 342 | ce_loss = torch.nn.functional.cross_entropy(outputs, targets, reduction='none') 343 | p_t = torch.exp(-ce_loss) 344 | ce_loss = torch.nn.functional.cross_entropy(outputs, targets, 345 | weight=self.alpha_t, reduction='none') 346 | focal_loss = ((1 - p_t) ** self.gamma * ce_loss).mean() # mean over the batch 347 | 348 | return focal_loss 349 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as DATA 2 | from transformers import AdamW, AutoModel, AutoTokenizer 3 | import os 4 | import pandas as pd 5 | from torch_geometric.data import InMemoryDataset 6 | from torch_geometric.data import Data as DATA 7 | #import rdkit 8 | #from rdkit import Chem 9 | #from rdkit.Chem import AllChem 10 | import numpy as np 11 | import torch 12 | from sklearn.utils import shuffle 13 | import torch 14 | import random 15 | 16 | class DDI2013Dataset(InMemoryDataset): 17 | def __init__(self, root='/tmp', 18 | path='', 19 | transform=None, 20 | pre_transform=None, 21 | max_len=128, model_name=''): 22 | 23 | self.max_len = max_len 24 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 25 | 26 | self.path = path 27 | self.pass_list = [] 28 | self.pass_smiles = set() 29 | self.atomType = {'C': 1, 'H': 2, 'O': 3, 'N': 4, 'S': 5, 'Li': 6, 'Mg': 7, 'F': 8, 'K': 9, 'Al': 10, 'Cl': 11, 30 | 'Au': 12, 'Ca': 13, 'Hg': 14, 'Na': 15, 'P': 16, 'Ti': 17, 'Br': 18} 31 | self.NOTINDICT = 19 32 | # root is required for save preprocessed data, default is '/tmp' 33 | super(DDI2013Dataset, self).__init__(root, transform, pre_transform) 34 | 35 | if os.path.isfile(self.processed_paths[0]): 36 | print('Pre-processed data found: {}, loading ...'.format(self.processed_paths[0])) 37 | self.data, self.slices = torch.load(self.processed_paths[0]) 38 | else: 39 | print('Pre-processed data {} not found, doing pre-processing...'.format(self.processed_paths[0])) 40 | self.process(root) 41 | self.data, self.slices = torch.load(self.processed_paths[0]) 42 | 43 | @property 44 | def raw_file_names(self): 45 | pass 46 | # return ['some_file_1', 'some_file_2', ...] 47 | 48 | @property 49 | def processed_file_names(self): 50 | return ['process.pt'] 51 | 52 | def download(self): 53 | # Download to `self.raw_dir`. 54 | pass 55 | 56 | def _download(self): 57 | pass 58 | 59 | def _process(self): 60 | if not os.path.exists(self.processed_dir): 61 | os.makedirs(self.processed_dir) 62 | 63 | def get_idx_split(self, data_size, train_size, seed): 64 | ids = shuffle(range(data_size), random_state=seed) 65 | train_idx, val_idx = torch.tensor(ids[:train_size]), torch.tensor( 66 | ids[train_size:]) 67 | split_dict = {'train': train_idx, 'valid': val_idx} 68 | return split_dict 69 | 70 | def get_pos_z(self, smile1, i): 71 | # print(smile1) 72 | m1 = rdkit.Chem.MolFromSmiles(smile1) 73 | 74 | if m1 is None: 75 | self.pass_list.append(i) 76 | self.pass_smiles.add(smile1) 77 | return None, None 78 | 79 | if m1.GetNumAtoms() == 1: 80 | self.pass_list.append(i) 81 | if m1.GetNumAtoms() == 1: 82 | self.pass_smiles.add(smile1) 83 | return None, None 84 | m1 = Chem.AddHs(m1) 85 | 86 | ignore_flag1 = 0 87 | ignore1 = False 88 | 89 | while AllChem.EmbedMolecule(m1) == -1: 90 | print('retry') 91 | ignore_flag1 = ignore_flag1 + 1 92 | if ignore_flag1 >= 10: 93 | ignore1 = True 94 | break 95 | if ignore1: 96 | self.pass_list.append(i) 97 | self.pass_smiles.add(smile1) 98 | return None, None 99 | AllChem.MMFFOptimizeMolecule(m1) 100 | m1 = Chem.RemoveHs(m1) 101 | m1_con = m1.GetConformer(id=0) 102 | 103 | pos1 = [] 104 | for j in range(m1.GetNumAtoms()): 105 | pos1.append(list(m1_con.GetAtomPosition(j))) 106 | np_pos1 = np.array(pos1) 107 | ten_pos1 = torch.Tensor(np_pos1) 108 | 109 | z1 = [] 110 | for atom in m1.GetAtoms(): 111 | if self.atomType.__contains__(atom.GetSymbol()): 112 | z = self.atomType[atom.GetSymbol()] 113 | else: 114 | z = self.NOTINDICT 115 | z1.append(z) 116 | 117 | z1 = np.array(z1) 118 | z1 = torch.tensor(z1) 119 | return ten_pos1, z1 120 | 121 | def process(self, root): 122 | df1 = pd.read_csv(root + 'raw/' + self.path) 123 | # df2 = pd.read_csv(root + 'raw/' + 'test.csv') 124 | data_list = [] 125 | data_len = len(df1) 126 | 127 | smile_pos_dict = {} 128 | smile_z_dict = {} 129 | 130 | for i in range(data_len): 131 | print('Converting SMILES to 3Dgraph: {}/{}'.format(i + 1, data_len)) 132 | sent = df1.loc[i, 'text'] 133 | # sent = self.data.loc[idx, 'text'] 134 | encoded_pair = self.tokenizer(sent, 135 | padding='max_length', 136 | truncation=True, 137 | max_length=self.max_len, 138 | return_tensors='pt') 139 | token_ids = encoded_pair['input_ids'].squeeze(0) # tensor of token ids 140 | attn_masks = encoded_pair['attention_mask'].squeeze(0) 141 | # binary tensor with "0" for padded values and "1" for the other values 142 | token_type_ids = encoded_pair['token_type_ids'].squeeze(0) 143 | # binary tensor with "0" for the 1st sentence tokens & "1" for the 2nd sentence tokens 144 | 145 | pos1 = [] 146 | pos2 = [] 147 | smile1 = df1.loc[i, 'smile1'] 148 | smile2 = df1.loc[i, 'smile2'] 149 | if self.pass_smiles.__contains__(smile1) or self.pass_smiles.__contains__(smile2): 150 | self.pass_list.append(i) 151 | continue 152 | 153 | if smile_pos_dict.__contains__(smile1): 154 | ten_pos1 = smile_pos_dict[smile1] 155 | z1 = smile_z_dict[smile1] 156 | else: 157 | ten_pos1, z1 = self.get_pos_z(smile1, i) 158 | if ten_pos1 == None: 159 | continue 160 | else: 161 | smile_pos_dict[smile1] = ten_pos1 162 | smile_z_dict[smile1] = z1 163 | if smile_pos_dict.__contains__(smile2): 164 | ten_pos2 = smile_pos_dict[smile2] 165 | z2 = smile_z_dict[smile2] 166 | else: 167 | ten_pos2, z2 = self.get_pos_z(smile2, i) 168 | if ten_pos2 == None: 169 | continue 170 | else: 171 | smile_pos_dict[smile2] = ten_pos2 172 | smile_z_dict[smile2] = z2 173 | 174 | label = df1.loc[i, 'label'] 175 | label = np.array(label) 176 | label = torch.tensor(label) 177 | 178 | drug_pos1 = df1.loc[i, 'pos1'] 179 | drug_pos2 = df1.loc[i, 'pos2'] 180 | drug_pos1 = np.array(drug_pos1) 181 | drug_pos1 = torch.tensor(drug_pos1) 182 | 183 | drug_pos2 = np.array(drug_pos2) 184 | drug_pos2 = torch.tensor(drug_pos2) 185 | data = DATA(pos1=ten_pos1, z1=z1, 186 | y=label, 187 | pos2=ten_pos2, z2=z2, 188 | token_ids=token_ids, 189 | attn_masks=attn_masks, 190 | token_type_ids=token_type_ids, 191 | drug1_pos=drug_pos1, 192 | drug2_pos=drug_pos2 193 | ) 194 | print(data) 195 | 196 | data_list.append(data) 197 | 198 | if self.pre_filter is not None: 199 | data_list = [data for data in data_list if self.pre_filter(data)] 200 | 201 | if self.pre_transform is not None: 202 | data_list = [self.pre_transform(data) for data in data_list] 203 | data, slices = self.collate(data_list) 204 | # save preprocessed data: 205 | torch.save((data, slices), self.processed_paths[0]) 206 | print(self.pass_list) 207 | print(len(self.pass_list)) 208 | print(self.pass_smiles) 209 | print(len(self.pass_smiles)) 210 | 211 | 212 | class drugbankDataset(InMemoryDataset): 213 | def __init__(self, root='/tmp', 214 | path='', 215 | transform=None, 216 | pre_transform=None): 217 | 218 | self.path = path 219 | self.pass_list = [] 220 | self.pass_smiles = set() 221 | self.atomType = {'C': 1, 'H': 2, 'O': 3, 'N': 4, 'S': 5, 'Li': 6, 'Mg': 7, 'F': 8, 'K': 9, 'Al': 10, 'Cl': 11, 222 | 'Au': 12, 'Ca': 13, 'Hg': 14, 'Na': 15, 'P': 16, 'Ti': 17, 'Br': 18} 223 | self.NOTINDICT = 19 224 | # root is required for save preprocessed data, default is '/tmp' 225 | super(drugbankDataset, self).__init__(root, transform, pre_transform) 226 | 227 | if os.path.isfile(self.processed_paths[0]): 228 | print('Pre-processed data found: {}, loading ...'.format(self.processed_paths[0])) 229 | self.data, self.slices = torch.load(self.processed_paths[0]) 230 | else: 231 | print('Pre-processed data {} not found, doing pre-processing...'.format(self.processed_paths[0])) 232 | self.process(root) 233 | self.data, self.slices = torch.load(self.processed_paths[0]) 234 | 235 | @property 236 | def raw_file_names(self): 237 | pass 238 | # return ['some_file_1', 'some_file_2', ...] 239 | 240 | @property 241 | def processed_file_names(self): 242 | return ['process.pt'] 243 | 244 | def download(self): 245 | # Download to `self.raw_dir`. 246 | pass 247 | 248 | def _download(self): 249 | pass 250 | 251 | def _process(self): 252 | if not os.path.exists(self.processed_dir): 253 | os.makedirs(self.processed_dir) 254 | 255 | def get_idx_split(self, data_size, train_size, seed): 256 | ids = shuffle(range(data_size), random_state=seed) 257 | train_idx, val_idx = torch.tensor(ids[:train_size]), torch.tensor( 258 | ids[train_size:]) 259 | split_dict = {'train': train_idx, 'valid': val_idx} 260 | return split_dict 261 | 262 | def get_pos_z(self, smile1, i): 263 | # print(smile1) 264 | m1 = rdkit.Chem.MolFromSmiles(smile1) 265 | 266 | if m1 is None: 267 | self.pass_list.append(i) 268 | self.pass_smiles.add(smile1) 269 | return None, None 270 | 271 | if m1.GetNumAtoms() == 1: 272 | self.pass_list.append(i) 273 | if m1.GetNumAtoms() == 1: 274 | self.pass_smiles.add(smile1) 275 | return None, None 276 | m1 = Chem.AddHs(m1) 277 | 278 | ignore_flag1 = 0 279 | ignore1 = False 280 | 281 | while AllChem.EmbedMolecule(m1) == -1: 282 | print('retry') 283 | ignore_flag1 = ignore_flag1 + 1 284 | if ignore_flag1 >= 10: 285 | ignore1 = True 286 | break 287 | if ignore1: 288 | self.pass_list.append(i) 289 | self.pass_smiles.add(smile1) 290 | return None, None 291 | AllChem.MMFFOptimizeMolecule(m1) 292 | m1 = Chem.RemoveHs(m1) 293 | m1_con = m1.GetConformer(id=0) 294 | 295 | pos1 = [] 296 | for j in range(m1.GetNumAtoms()): 297 | pos1.append(list(m1_con.GetAtomPosition(j))) 298 | np_pos1 = np.array(pos1) 299 | ten_pos1 = torch.Tensor(np_pos1) 300 | 301 | z1 = [] 302 | for atom in m1.GetAtoms(): 303 | if self.atomType.__contains__(atom.GetSymbol()): 304 | z = self.atomType[atom.GetSymbol()] 305 | else: 306 | z = self.NOTINDICT 307 | z1.append(z) 308 | 309 | z1 = np.array(z1) 310 | z1 = torch.tensor(z1) 311 | return ten_pos1, z1 312 | 313 | def process(self, root): 314 | df1 = pd.read_csv(root + 'raw/' + self.path) 315 | # df2 = pd.read_csv(root + 'raw/' + 'test.csv') 316 | data_list = [] 317 | data_len = len(df1) 318 | 319 | smile_pos_dict = {} 320 | smile_z_dict = {} 321 | 322 | h_to_t_dict = {} 323 | t_to_h_dict = {} 324 | id_set = set() 325 | id_smiles_dict = {} 326 | for i in range(df1.shape[0]): 327 | head = df1.loc[i, 'Drug1_ID'] 328 | tail = df1.loc[i, 'Drug2_ID'] 329 | head_smile = df1.loc[i, 'Drug1'] 330 | tail_smile = df1.loc[i, 'Drug2'] 331 | if h_to_t_dict.__contains__(head): 332 | h_to_t_dict[head].append(tail) 333 | else: 334 | h_to_t_dict[head] = [] 335 | h_to_t_dict[head].append(tail) 336 | 337 | if t_to_h_dict.__contains__(tail): 338 | t_to_h_dict[tail].append(head) 339 | else: 340 | t_to_h_dict[tail] = [] 341 | t_to_h_dict[tail].append(head) 342 | id_smiles_dict[head] = head_smile 343 | id_smiles_dict[tail] = tail_smile 344 | id_set.add(head) 345 | id_set.add(tail) 346 | 347 | 348 | for i in range(data_len): 349 | print('Converting SMILES to 3Dgraph: {}/{}'.format(i + 1, data_len)) 350 | 351 | smile1 = df1.loc[i, 'Drug1'] 352 | smile2 = df1.loc[i, 'Drug2'] 353 | if self.pass_smiles.__contains__(smile1) or self.pass_smiles.__contains__(smile2): 354 | self.pass_list.append(i) 355 | continue 356 | 357 | if smile_pos_dict.__contains__(smile1): 358 | ten_pos1 = smile_pos_dict[smile1] 359 | z1 = smile_z_dict[smile1] 360 | else: 361 | ten_pos1, z1 = self.get_pos_z(smile1, i) 362 | if ten_pos1 == None: 363 | continue 364 | else: 365 | smile_pos_dict[smile1] = ten_pos1 366 | smile_z_dict[smile1] = z1 367 | if smile_pos_dict.__contains__(smile2): 368 | ten_pos2 = smile_pos_dict[smile2] 369 | z2 = smile_z_dict[smile2] 370 | else: 371 | ten_pos2, z2 = self.get_pos_z(smile2, i) 372 | if ten_pos2 == None: 373 | continue 374 | else: 375 | smile_pos_dict[smile2] = ten_pos2 376 | smile_z_dict[smile2] = z2 377 | 378 | label = torch.tensor(1) 379 | 380 | data = DATA(pos1=ten_pos1, z1=z1, 381 | y=label, 382 | pos2=ten_pos2, z2=z2, 383 | ) 384 | print(data) 385 | 386 | data_list.append(data) 387 | 388 | if random.random() > 0.5: # 换尾 389 | head = df1.loc[i, 'Drug1_ID'] 390 | tail_set = h_to_t_dict[head] 391 | pes_tail = random.sample(id_set - set(tail_set), 1) 392 | smile2 = id_smiles_dict[pes_tail[0]] 393 | else: 394 | tail = df1.loc[i, 'Drug2_ID'] 395 | head_set = t_to_h_dict[tail] 396 | pes_head = random.sample(id_set - set(head_set), 1) 397 | smile1 = id_smiles_dict[pes_head[0]] 398 | 399 | if self.pass_smiles.__contains__(smile1) or self.pass_smiles.__contains__(smile2): 400 | self.pass_list.append(i) 401 | continue 402 | 403 | if smile_pos_dict.__contains__(smile1): 404 | ten_pos1 = smile_pos_dict[smile1] 405 | z1 = smile_z_dict[smile1] 406 | else: 407 | ten_pos1, z1 = self.get_pos_z(smile1, i) 408 | if ten_pos1 == None: 409 | continue 410 | else: 411 | smile_pos_dict[smile1] = ten_pos1 412 | smile_z_dict[smile1] = z1 413 | if smile_pos_dict.__contains__(smile2): 414 | ten_pos2 = smile_pos_dict[smile2] 415 | z2 = smile_z_dict[smile2] 416 | else: 417 | ten_pos2, z2 = self.get_pos_z(smile2, i) 418 | if ten_pos2 == None: 419 | continue 420 | else: 421 | smile_pos_dict[smile2] = ten_pos2 422 | smile_z_dict[smile2] = z2 423 | 424 | label = torch.tensor(0) 425 | 426 | data = DATA(pos1=ten_pos1, z1=z1, 427 | y=label, 428 | pos2=ten_pos2, z2=z2, 429 | ) 430 | print(data) 431 | 432 | data_list.append(data) 433 | 434 | if self.pre_filter is not None: 435 | data_list = [data for data in data_list if self.pre_filter(data)] 436 | 437 | if self.pre_transform is not None: 438 | data_list = [self.pre_transform(data) for data in data_list] 439 | data, slices = self.collate(data_list) 440 | # save preprocessed data: 441 | torch.save((data, slices), self.processed_paths[0]) 442 | print(self.pass_list) 443 | print(len(self.pass_list)) 444 | print(self.pass_smiles) 445 | print(len(self.pass_smiles)) 446 | 447 | 448 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 449 | 450 | 451 | --------------------------------------------------------------------------------