├── log ├── dir_for_logfile └── checkpoint │ └── dir_for_model_checkpoint ├── dataset ├── raw │ └── dir_for_rawcsvfile └── processed │ └── dir_for_processed_dataset ├── img └── Framework.jpg ├── requirements.txt ├── LICENSE ├── TFM ├── Dataset.py ├── utils.py └── model.py ├── README.md ├── molnetdata.py └── run.py /log/dir_for_logfile: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataset/raw/dir_for_rawcsvfile: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataset/processed/dir_for_processed_dataset: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /log/checkpoint/dir_for_model_checkpoint: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /img/Framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaojianl/KnoMol/HEAD/img/Framework.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.9.10 2 | hyperopt==0.2.7 3 | numpy==1.22.1 4 | torch==1.10.0 5 | torch-cluster==1.6.0 6 | torch-geometric==2.0.4 7 | torch-scatter==2.0.9 8 | torch-sparse==0.6.13 9 | torch-spline-conv==1.2.1 10 | torchaudio==0.10.0 11 | torchvision==0.11.0 12 | scikit-learn==1.1.2 13 | scipy==1.9.1 14 | rdkit==2022.03.2 15 | networkx==2.8.6 16 | pandas==1.4.4 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AlphaGao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /TFM/Dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch_geometric.data import InMemoryDataset 4 | from torch_geometric import data as DATA 5 | import torch 6 | 7 | 8 | class MolNet(InMemoryDataset): 9 | def __init__(self, root='dataset', dataset=None, xd=None, y=None, transform=None, pre_transform=None, smile_graph=None): 10 | # root is required for save raw data and preprocessed data 11 | super(MolNet, self).__init__(root, transform, pre_transform) 12 | self.dataset = dataset 13 | if os.path.isfile(self.processed_paths[0]): 14 | print('Pre-processed data found: {}, loading ...'.format(self.processed_paths[0])) 15 | self.data, self.slices = torch.load(self.processed_paths[0]) 16 | else: 17 | print('Pre-processed data {} not found, doing pre-processing...'.format(self.processed_paths[0])) 18 | self.process(xd, y, smile_graph) 19 | self.data, self.slices = torch.load(self.processed_paths[0]) 20 | 21 | @property 22 | def raw_file_names(self): 23 | #pass 24 | return ['raw_file'] 25 | 26 | @property 27 | def processed_file_names(self): 28 | return [self.dataset + '_pyg.pt'] 29 | 30 | def download(self): 31 | # Download to `self.raw_dir`. 32 | pass 33 | 34 | def _process(self): 35 | if not os.path.exists(self.processed_dir): 36 | os.makedirs(self.processed_dir) 37 | 38 | def process(self, xd, y, smile_graph): 39 | assert (len(xd) == len(y)), "smiles and labels must be the same length!" 40 | data_list = [] 41 | data_len = len(xd) 42 | print('number of data ', data_len) 43 | for i in range(data_len): 44 | smiles = xd[i] 45 | if smiles is not None: 46 | labels = np.asarray([y[i]]) 47 | leng, features, edge_index, edge_attr, ringm, aromm, alipm, hetem, adj_order_matrix, dis_order_matrix = smile_graph[smiles] 48 | if len(edge_index) > 0: 49 | GCNData = DATA.Data(x=torch.Tensor(features), edge_index=torch.LongTensor(edge_index).transpose(1, 0).contiguous(), edge_attr=torch.Tensor(edge_attr), y=torch.FloatTensor(labels)) 50 | GCNData.leng = [leng] 51 | GCNData.adj = adj_order_matrix 52 | GCNData.dis = dis_order_matrix 53 | GCNData.ringm = ringm 54 | GCNData.aromm = aromm 55 | GCNData.alipm = alipm 56 | GCNData.hetem = hetem 57 | GCNData.smi = smiles 58 | data_list.append(GCNData) 59 | 60 | if self.pre_filter is not None: 61 | data_list = [data for data in data_list if self.pre_filter(data)] 62 | 63 | if self.pre_transform is not None: 64 | data_list = [self.pre_transform(data) for data in data_list] 65 | 66 | print('Graph construction done. Saving to file.') 67 | data, slices = self.collate(data_list) 68 | torch.save((data, slices), self.processed_paths[0]) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KnoMol 2 | 3 | ![](img/Framework.jpg) 4 | 5 | This is a Pytorch implementation of the paper: https://pubs.acs.org/doi/10.1021/acs.jcim.4c01092 6 | 7 | ## Installation 8 | You can just execute following command to create the conda environment. 9 | ''' 10 | conda create --name KnoMol --file requirements.txt 11 | ''' 12 | 13 | ## Usage 14 | 15 | #### 1. Dataset preparation 16 | Put your raw csvfile(`DATASET_NAME.csv`, first column is `smiles`, followed by `label` columns) in `dataset/raw/`. 17 | ``` 18 | python molnetdata.py \ 19 | --moldata DATASET_NAME \ # file name 20 | --task clas \ # clas: Binary classification, reg: Regression 21 | --numtasks 1 \ # number of properties to predict 22 | --ncpu 10 # number of cpus to use 23 | ``` 24 | This will save the processed dataset in `dataset/processed/`. 25 | 26 | #### 2. Hyper-parameter searching 27 | ``` 28 | python run.py search DATASET_NAME \ 29 | --task clas \ # clas: Binary classification, reg: Regression 30 | --numtasks 1 \ # number of properties to predict 31 | --seed 426 \ # random seed 32 | --split random_scaffold \ # data splitting method random_scaffold/balan_scaffold/random 33 | --max_eval 100 \ # Number hyperparameter settings to try 34 | --metric rmse \ # metric to optimize rmse/mae (only for regression) 35 | --device cuda:0 # which GPU to use 36 | ``` 37 | This will return the best hyper-params and performance in the end of the log file. 38 | 39 | #### 3. Training 40 | Train and save a model with the best hyper-params. Here is an example: 41 | ``` 42 | python run.py train DATASET_NAME \ 43 | --task reg \ 44 | --numtasks 1 \ 45 | --device cuda:0 \ 46 | --batch_size 32 \ 47 | --train_epoch 50 \ 48 | --lr 0.001 \ 49 | --valrate 0.1 \ 50 | --testrate 0.1 \ 51 | --seed 426 \ 52 | --split random_scaffold \ 53 | --fold 1 \ 54 | --dropout 0.05 \ 55 | --attn_layers 2 \ 56 | --output_dim 256 \ 57 | --D 4 \ 58 | --metric rmse 59 | ``` 60 | This will save the resulting model in `log/checkpoint/xxx.pkl`. 61 | 62 | #### 4. Testing 63 | Make predictions using the model. Here is an example: 64 | ``` 65 | python run.py test DATASET_NAME \ 66 | --task clas \ 67 | --numtasks 1 \ 68 | --device cuda:0 \ 69 | --batch_size 32 \ 70 | --attn_layers 2 \ 71 | --output_dim 256 \ 72 | --D 4 \ 73 | --pretrain log/checkpoint/XXXX.pkl 74 | ``` 75 | This will load the model in `log/checkpoint/` to make predictions and the results are saved in `log/xxx.csv`. 76 | -------------------------------------------------------------------------------- /TFM/utils.py: -------------------------------------------------------------------------------- 1 | import torch, random, os, math 2 | import torch.nn as nn 3 | import numpy as np 4 | import pandas as pd 5 | import logging 6 | from torch_geometric.data import DataLoader 7 | from TFM.Dataset import MolNet 8 | from rdkit import Chem 9 | from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles 10 | 11 | 12 | def load_data(dataset, batch_size, valid_size, test_size, cpus_per_gpu, task, split, seed=426): 13 | data = MolNet(root='./dataset', dataset=dataset) 14 | 15 | if split == 'balan_scaffold': 16 | trainset, validset, testset = balanscaffold_split(data, valid_size, test_size) 17 | else: 18 | if split == 'random_scaffold': 19 | scaffold = True 20 | elif split == 'random': 21 | scaffold = False 22 | else: 23 | raise ValueError('Invalid split type') 24 | cont = True 25 | while cont: 26 | trainset, validset, testset = randomscaffold_split(data, valid_size, test_size, scaffold=scaffold, seed=seed) 27 | if task == 'clas': 28 | vy = [d.y for d in validset]; ty = [d.y for d in testset] 29 | vy = torch.cat(vy, 0); ty = torch.cat(ty, 0) 30 | if torch.any(torch.mean(vy, 0) == 1) or torch.any(torch.mean(vy, 0) == 0) or torch.any(torch.mean(ty, 0) == 1) or torch.any(torch.mean(ty, 0) == 0): 31 | cont = True 32 | if seed is not None: 33 | seed += 10 34 | else: 35 | cont = False 36 | else: 37 | cont = False 38 | 39 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=cpus_per_gpu, drop_last=False) 40 | valid_loader = DataLoader(validset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=cpus_per_gpu, drop_last=False) 41 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=cpus_per_gpu, drop_last=False) 42 | return train_loader, valid_loader, test_loader 43 | 44 | 45 | def set_seed(seed): 46 | random.seed(seed) 47 | np.random.seed(seed) 48 | torch.manual_seed(seed) 49 | torch.cuda.manual_seed(seed) 50 | torch.cuda.manual_seed_all(seed) 51 | 52 | 53 | class metrics_c(nn.Module): 54 | def __init__(self, acc_f, pre_f, rec_f, f1_f, auc_f): 55 | super(metrics_c, self).__init__() 56 | self.acc_f = acc_f 57 | self.pre_f = pre_f 58 | self.rec_f = rec_f 59 | self.f1_f = f1_f 60 | self.auc_f = auc_f 61 | 62 | def forward(self, out, prob, tar): 63 | if len(out.shape) > 1: 64 | acc, f1, pre, rec, auc = [], [], [], [], [] 65 | 66 | for i in range(out.shape[-1]): 67 | acc_, f1_, pre_, rec_, auc_ = 0, 0, 0, 0, 0 68 | acc_ = self.acc_f(tar[:, i], out[:, i]) 69 | f1_ = self.f1_f(tar[:, i], out[:, i]) 70 | pre_ = self.pre_f(tar[:, i], out[:, i]) 71 | rec_ = self.rec_f(tar[:, i], out[:, i]) 72 | try: 73 | auc_ = self.auc_f(tar[:, i], prob[:, i]) 74 | auc.append(auc_) 75 | except:pass 76 | 77 | acc.append(acc_); f1.append(f1_); pre.append(pre_); rec.append(rec_) 78 | return np.mean(acc), np.mean(f1), np.mean(pre), np.mean(rec), np.mean(auc) 79 | else: 80 | acc = self.acc_f(tar, out) 81 | f1 = self.f1_f(tar, out) 82 | pre = self.pre_f(tar, out) 83 | rec = self.rec_f(tar, out) 84 | auc = self.auc_f(tar, prob) 85 | return acc, f1, pre, rec, auc 86 | 87 | 88 | class metrics_r(nn.Module): 89 | def __init__(self, mae_f, rmse_f, r2_f): 90 | super(metrics_r, self).__init__() 91 | self.mae_f = mae_f 92 | self.rmse_f = rmse_f 93 | self.r2_f = r2_f 94 | 95 | def forward(self, out, tar): 96 | mae, rmse, r2 = 0, 0, 0 97 | if self.mae_f is not None: 98 | mae = self.mae_f(tar, out) 99 | 100 | if self.rmse_f is not None: 101 | rmse = self.rmse_f(tar, out, squared=False) 102 | 103 | if self.r2_f is not None: 104 | r2 = self.r2_f(tar, out) 105 | 106 | return mae, rmse, r2, None, None 107 | 108 | 109 | def create_ffn(task, tasks, output_dim, dropout): 110 | if task == 'clas': 111 | act = nn.Sequential( 112 | nn.Dropout(dropout), 113 | nn.Linear(output_dim*2, output_dim), 114 | nn.BatchNorm1d(output_dim), 115 | nn.Dropout(dropout), 116 | nn.ReLU(), 117 | nn.Linear(output_dim, tasks), 118 | nn.Sigmoid()) 119 | elif task == 'reg': 120 | act = nn.Sequential( 121 | nn.Dropout(dropout), 122 | nn.Linear(output_dim, output_dim), 123 | nn.BatchNorm1d(output_dim), 124 | nn.Dropout(dropout), 125 | nn.ReLU(), 126 | nn.Linear(output_dim, tasks)) 127 | else: 128 | raise NameError('task must be reg or clas!') 129 | return act 130 | 131 | 132 | def get_attn_pad_mask(mask): 133 | batch_size, len_q = mask.size(0), mask.size(1) 134 | a = mask.unsqueeze(1).expand(batch_size, len_q, len_q) 135 | pad_attn_mask = a * a.transpose(-1, -2) 136 | return pad_attn_mask.data.eq(0) 137 | 138 | 139 | def balanscaffold_split(data, validrate, testrate): 140 | trainrate = 1 - validrate - testrate 141 | assert trainrate > 0.4 142 | 143 | train_inds, valid_inds, test_inds = [], [], [] 144 | scaffolds = {} 145 | for ind, dat in enumerate(data): 146 | mol = Chem.MolFromSmiles(dat.smi) 147 | scaffold = MurckoScaffoldSmiles(mol=mol, includeChirality=True) 148 | if scaffold not in scaffolds: 149 | scaffolds[scaffold] = [ind] 150 | else: 151 | scaffolds[scaffold].append(ind) 152 | scaffolds = {key: sorted(value) for key, value in scaffolds.items()} 153 | scaffold_sets = [scaffold_set for (scaffold, scaffold_set) in sorted(scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)] 154 | 155 | n_total_valid = round(validrate * len(data)) 156 | n_total_test = round(testrate * len(data)) 157 | for scaffold_set in scaffold_sets: 158 | if (len(valid_inds) + len(scaffold_set) <= n_total_valid) and (len(scaffold_set) < n_total_valid*0.5): 159 | valid_inds.extend(scaffold_set) 160 | elif (len(test_inds) + len(scaffold_set) <= n_total_test) and (len(scaffold_set) < n_total_test*0.5): 161 | test_inds.extend(scaffold_set) 162 | else: 163 | train_inds.extend(scaffold_set) 164 | return data[train_inds], data[valid_inds], data[test_inds] 165 | 166 | 167 | def randomscaffold_split(data, validrate, testrate, scaffold=True, seed=426): 168 | trainrate = 1 - validrate - testrate 169 | assert trainrate > 0.4 170 | lenth = len(data) 171 | g1 = int(lenth*trainrate) 172 | g2 = int(lenth*(trainrate+validrate)) 173 | 174 | if not scaffold: 175 | rng = np.random.RandomState(seed) 176 | random_num = list(range(lenth)) 177 | rng.shuffle(random_num) 178 | data = data[random_num] 179 | return data[:g1], data[g1:g2], data[g2:] 180 | 181 | else: 182 | train_inds, valid_inds, test_inds = [], [], [] 183 | scaffolds = {} 184 | for ind, dat in enumerate(data): 185 | mol = Chem.MolFromSmiles(dat.smi) 186 | scaffold = MurckoScaffoldSmiles(mol=mol, includeChirality=True) 187 | if scaffold not in scaffolds: 188 | scaffolds[scaffold] = [ind] 189 | else: 190 | scaffolds[scaffold].append(ind) 191 | 192 | rng = np.random.RandomState(seed) 193 | scaffold_sets = rng.permutation(np.array(list(scaffolds.values()), dtype=object)) 194 | 195 | n_total_valid = round(validrate * len(data)) 196 | n_total_test = round(testrate * len(data)) 197 | for scaffold_set in scaffold_sets: 198 | if len(valid_inds) + len(scaffold_set) <= n_total_valid: 199 | valid_inds.extend(scaffold_set) 200 | elif len(test_inds) + len(scaffold_set) <= n_total_test: 201 | test_inds.extend(scaffold_set) 202 | else: 203 | train_inds.extend(scaffold_set) 204 | 205 | return data[train_inds], data[valid_inds], data[test_inds] 206 | 207 | 208 | def get_logger(filename, verbosity=1, name=None): 209 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING, 3: logging.ERROR} 210 | formatter = logging.Formatter("[%(asctime)s][line:%(lineno)d][%(levelname)s] %(message)s") 211 | logger = logging.getLogger(name) 212 | logger.setLevel(level_dict[verbosity]) 213 | fh = logging.FileHandler(filename, "w") 214 | fh.setFormatter(formatter) 215 | logger.addHandler(fh) 216 | sh = logging.StreamHandler() 217 | sh.setFormatter(formatter) 218 | logger.addHandler(sh) 219 | return logger 220 | -------------------------------------------------------------------------------- /molnetdata.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | import argparse 6 | from multiprocessing import Pool 7 | from rdkit import Chem 8 | from rdkit.Chem.SaltRemover import SaltRemover 9 | from TFM.Dataset import MolNet 10 | remover = SaltRemover() 11 | smile_graph = {} 12 | meta = ['W', 'U', 'Zr', 'He', 'Be', 'Na', 'Mg', 'Al', 'K', 'Ca', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Rb', 'Sr', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'Gd', 'Tb', 'Ho', 'W', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Ac'] 13 | 14 | 15 | def atom_features(atom, use_chirality=True): 16 | res = one_of_k_encoding_unk(atom.GetSymbol(),['C','N','O','F','P','S','Cl','Br','I','B','Si','Unknown']) + \ 17 | one_of_k_encoding(atom.GetDegree(), [1, 2, 3, 4, 5, 6]) + one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) + [atom.GetIsAromatic(), atom.IsInRing()] + one_of_k_encoding_unk(atom.GetFormalCharge(), [-1,0,1,3]) + \ 18 | one_of_k_encoding_unk(atom.GetHybridization(), [Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]) 19 | if use_chirality: 20 | try: 21 | res = res + one_of_k_encoding_unk(atom.GetProp('_CIPCode'), ['R', 'S']) + [atom.HasProp('_ChiralityPossible')] 22 | except: 23 | bonds = atom.GetBonds() 24 | for bond in bonds: 25 | if bond.GetBondType() == Chem.rdchem.BondType.DOUBLE and str(bond.GetStereo()) in ["STEREOZ", "STEREOE"]: 26 | res = res + one_of_k_encoding_unk(str(bond.GetStereo()), ["STEREOZ", "STEREOE"]) + [atom.HasProp('_ChiralityPossible')] 27 | if len(res) == 34: 28 | res = res + [False, False] + [atom.HasProp('_ChiralityPossible')] 29 | return np.array(res) # 37 30 | 31 | 32 | def order_gnn_features(bond): 33 | weight = [0.3, 0.4, 0.5, 0.36] 34 | bt = bond.GetBondType() 35 | bond_feats = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC] 36 | 37 | for i, m in enumerate(bond_feats): 38 | if m == True and i != 0: 39 | b = weight[i] 40 | elif m == True and i == 0: 41 | if bond.GetIsConjugated() == True: 42 | b = 0.32 43 | else: 44 | b = 0.3 45 | else:pass 46 | return b 47 | 48 | 49 | def order_tf_features(bond): 50 | weight = [0.8, 0.9, 1., 0.85] 51 | bt = bond.GetBondType() 52 | bond_feats = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC] 53 | for i, m in enumerate(bond_feats): 54 | if m == True: 55 | b = weight[i] 56 | return b 57 | 58 | 59 | def one_of_k_encoding(x, allowable_set): 60 | if x not in allowable_set: 61 | raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set)) 62 | return list(map(lambda s: x == s, allowable_set)) 63 | 64 | 65 | def one_of_k_encoding_unk(x, allowable_set): 66 | if x not in allowable_set: 67 | x = allowable_set[-1] 68 | return list(map(lambda s: x == s, allowable_set)) 69 | 70 | 71 | def smiletopyg(smi): 72 | g = nx.Graph() 73 | mol = Chem.MolFromSmiles(smi) 74 | c_size = mol.GetNumAtoms() 75 | 76 | features = [] 77 | hete_mask = [] 78 | for i, atom in enumerate(mol.GetAtoms()): 79 | g.add_node(i) 80 | feature = atom_features(atom) 81 | features.append((feature / sum(feature)).tolist()) 82 | 83 | if atom.GetSymbol() == 'C': 84 | hete_mask.append(0) 85 | else: 86 | hete_mask.append(1) 87 | hete_mask = np.asarray(hete_mask) 88 | 89 | 90 | ssr = Chem.GetSymmSSSR(mol) 91 | ring_masm = np.zeros((c_size)) 92 | arom_masm = np.zeros((c_size)) 93 | alip_masm = np.zeros((c_size)) 94 | aroml = [] 95 | for i in range(0, len(ssr)): 96 | aromring = True 97 | for r in list(ssr[i]): 98 | atom = mol.GetAtomWithIdx(r) 99 | if not atom.GetIsAromatic(): 100 | aromring = False 101 | aroml.append(aromring) 102 | 103 | inter_arom, inter_alip = [], [] 104 | for i in range(0, len(ssr)): 105 | if aroml[i] and (i not in inter_arom): 106 | for r in list(ssr[i]): 107 | ring_masm[r] = i+1 108 | arom_masm[r] = i+1 109 | for j in range(i+1,len(ssr)): 110 | if aroml[j] and bool(set(ssr[i]) & set(ssr[j])): 111 | inter_arom.append(j) 112 | for r in list(ssr[j]): 113 | ring_masm[r] = i+1 114 | arom_masm[r] = i+1 115 | elif (not aroml[i]) and (i not in inter_alip): 116 | for r in list(ssr[i]): 117 | ring_masm[r] = i+1 118 | alip_masm[r] = i+1 119 | for j in range(i+1,len(ssr)): 120 | if (not aroml[j]) and bool(set(ssr[i]) & set(ssr[j])): 121 | inter_alip.append(j) 122 | for r in list(ssr[j]): 123 | ring_masm[r] = i+1 124 | alip_masm[r] = i+1 125 | else:pass 126 | 127 | c = [] 128 | adj_order_matrix = np.eye(c_size) 129 | adj_order_matrix = adj_order_matrix * 0.8 130 | dis_order_matrix = np.zeros((c_size,c_size)) 131 | for bond in mol.GetBonds(): 132 | a1 = bond.GetBeginAtomIdx() 133 | a2 = bond.GetEndAtomIdx() 134 | bfeat = order_gnn_features(bond) 135 | g.add_edge(a1, a2, weight=bfeat) 136 | tfft = order_tf_features(bond) 137 | adj_order_matrix[a1, a2] = tfft 138 | adj_order_matrix[a2, a1] = tfft 139 | if bond.GetIsConjugated(): 140 | c = list(set(c).union(set([a1, a2]))) 141 | 142 | g = g.to_directed() 143 | edge_index = np.array(g.edges).tolist() 144 | 145 | edge_attr = [] 146 | for w in list(g.edges.data('weight')): 147 | edge_attr.append(w[2]) 148 | 149 | for i in range(c_size): 150 | for j in range(i,c_size): 151 | if adj_order_matrix[i, j] == 0 and i != j: 152 | conj = False 153 | paths = list(nx.node_disjoint_paths(g, i, j)) 154 | if len(paths) > 1: 155 | paths = sorted(paths, key=lambda i:len(i),reverse=False) 156 | for path in paths: 157 | if set(path) < set(c): 158 | conj = True 159 | break 160 | if conj: 161 | adj_order_matrix[i, j] = 0.825 162 | adj_order_matrix[j, i] = 0.825 163 | else: 164 | path = paths[0] 165 | dis_order_matrix[i, j] = len(path) - 1 166 | dis_order_matrix[j, i] = len(path) - 1 167 | 168 | g = [c_size, features, edge_index, edge_attr, ring_masm, arom_masm, alip_masm, hete_mask, adj_order_matrix, dis_order_matrix] 169 | return [smi, g] 170 | 171 | 172 | def write(res): 173 | smi, g = res 174 | smile_graph[smi] = g 175 | 176 | 177 | if __name__ == '__main__': 178 | parser = argparse.ArgumentParser(description='TransFoxMol') 179 | parser.add_argument('--moldata', type=str, help='dataset name to process') 180 | parser.add_argument('--task', type=str, choices=['clas', 'reg'], help='Binary classification or Regression') 181 | parser.add_argument('--numtasks', type=int, default=1, help='Number of tasks (default: 1). ') 182 | parser.add_argument('--ncpu', type=int, default=4, help='number of cpus to use (default: 4)') 183 | args = parser.parse_args() 184 | 185 | moldata = args.moldata 186 | if moldata in ['esol', 'freesolv', 'lipo', 'qm7', 'qm8', 'qm9']: 187 | task = 'reg' 188 | if moldata == 'qm8': 189 | numtasks = 12 190 | elif moldata == 'qm9': 191 | numtasks = 3 192 | else: 193 | numtasks = 1 194 | elif moldata in ['bbbp', 'sider', 'clintox', 'tox21', 'toxcast', 'bace', 'pcba', 'muv', 'hiv']: 195 | task = 'clas' 196 | if moldata == 'sider': 197 | numtasks = 27 198 | elif moldata == 'clintox': 199 | numtasks = 2 200 | elif moldata == 'tox21': 201 | numtasks = 12 202 | elif moldata == 'toxcast': 203 | numtasks = 617 204 | elif moldata == 'pcba': 205 | numtasks = 128 206 | elif moldata == 'muv': 207 | numtasks = 17 208 | else: 209 | numtasks = 1 210 | else: 211 | task = args.task 212 | numtasks = args.numtasks 213 | 214 | processed_data_file = 'dataset/processed/' + moldata+task + '_pyg.pt' 215 | if not os.path.isfile(processed_data_file): 216 | try: 217 | df = pd.read_csv('./dataset/raw/'+moldata+'.csv') 218 | except: 219 | print('Raw data not found! Put the right raw csvfile in **/dataset/raw/') 220 | try: 221 | compound_iso_smiles = np.array(df['smiles']) 222 | except: 223 | print('The smiles column does not exist') 224 | try: 225 | ic50s = np.array(df.iloc[:, 1:numtasks+1]) 226 | except: 227 | print('Mismatch between number of tasks and .csv file') 228 | #ic50s = -np.log10(np.array(ic50s)) 229 | pool = Pool(args.ncpu) 230 | smis = [] 231 | y = [] 232 | result = [] 233 | 234 | for smi, label in zip(compound_iso_smiles, ic50s): 235 | record = True 236 | mol = Chem.MolFromSmiles(smi) 237 | if mol is not None: 238 | if '.' in smi: 239 | mol = remover.StripMol(mol) 240 | if mol is not None: 241 | if 80 > mol.GetNumAtoms() > 1: 242 | smi = Chem.MolToSmiles(mol) 243 | if '.' not in smi: 244 | record = True 245 | else: 246 | record = False 247 | else: 248 | record = False 249 | for ele in meta: 250 | if ele in smi: 251 | record = False 252 | if record: 253 | smis.append(smi) 254 | y.append(label) 255 | result.append(pool.apply_async(smiletopyg, (smi,))) 256 | else: 257 | print(smi) 258 | pool.close() 259 | pool.join() 260 | 261 | for res in result: 262 | smi, g = res.get() 263 | smile_graph[smi] = g 264 | 265 | MolNet(root='./dataset', dataset=moldata+task, xd=smis, y=y, smile_graph=smile_graph) 266 | -------------------------------------------------------------------------------- /TFM/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.nn import GraphConv 4 | from torch_geometric.utils import to_dense_batch 5 | from math import sqrt 6 | from TFM.utils import get_attn_pad_mask, create_ffn 7 | 8 | 9 | class Embed(nn.Module): 10 | def __init__(self, attn_head=4, output_dim=128, d_k=64, attn_layers=4, dropout=0.1, useedge=False, device='cuda:0'): 11 | super(Embed, self).__init__() 12 | self.device = device 13 | self.relu = nn.ReLU() 14 | self.edge = useedge 15 | self.layer_num = attn_layers 16 | self.gnns = nn.ModuleList([GraphConv(37, output_dim) if i == 0 else GraphConv(output_dim, output_dim) for i in range(attn_layers)]) 17 | self.nms = nn.ModuleList([nn.LayerNorm(output_dim) for _ in range(attn_layers)]) 18 | self.dps = nn.ModuleList([nn.Dropout(dropout) for _ in range(attn_layers)]) 19 | self.tfs = nn.ModuleList([Encoder(output_dim, d_k, d_k, 1, attn_head, dropout) for _ in range(attn_layers)]) 20 | 21 | def forward(self, x, edge_index, edge_attr, batch, leng, adj, dis, ring_masm, arom_masm, alip_masm, hete_mask): 22 | if self.edge: 23 | x = self.gnns[0](x, edge_index, edge_weight=edge_attr) 24 | else: 25 | x = self.gnns[0](x, edge_index) 26 | x = self.dps[0](self.nms[0](x)) 27 | x = self.relu(x) 28 | 29 | x_batch, mask = to_dense_batch(x, batch) 30 | 31 | batch_size, max_len, output_dim = x_batch.size() 32 | matrix_pad = torch.zeros((batch_size, max_len, max_len)) 33 | manu_mask_pad = torch.ones((batch_size, 6, max_len, max_len)) 34 | for i, l in enumerate(leng): 35 | adj_ = torch.FloatTensor(adj[i]) 36 | localmask = torch.where((adj_ >= 0.8) & (adj_ != 0.825), torch.ones_like(adj_), torch.zeros_like(adj_)) 37 | cojmask = torch.where(adj_ > 0.8, torch.ones_like(adj_), torch.zeros_like(adj_)) 38 | hete = torch.BoolTensor(hete_mask[i]) 39 | hetemask = hete.unsqueeze(0) * hete.unsqueeze(-1) 40 | ring = torch.IntTensor(ring_masm[i]) 41 | ring1 = ring.unsqueeze(0).repeat(ring.size(0), 1) 42 | ring2 = ring.unsqueeze(-1).repeat(1, ring.size(0)) 43 | ringmask = torch.where((ring1 == ring2) & (ring1 > 0), torch.ones((ring.size(0), ring.size(0))), torch.zeros((ring.size(0),ring.size(0)))) 44 | arom = torch.IntTensor(arom_masm[i]) 45 | arom1 = arom.unsqueeze(0).repeat(arom.size(0), 1) 46 | arom2 = arom.unsqueeze(-1).repeat(1, arom.size(0)) 47 | arommask = torch.where((arom1 == arom2) & (arom1 > 0), torch.ones((arom.size(0), arom.size(0))), torch.zeros((arom.size(0),arom.size(0)))) 48 | alip = torch.IntTensor(alip_masm[i]) 49 | alip1 = alip.unsqueeze(0).repeat(alip.size(0), 1) 50 | alip2 = alip.unsqueeze(-1).repeat(1, alip.size(0)) 51 | alipmask = torch.where((alip1 == alip2) & (alip1 > 0), torch.ones((alip.size(0), alip.size(0))), torch.zeros((alip.size(0),alip.size(0)))) 52 | dis_ = torch.FloatTensor(dis[i]) 53 | dis_ = torch.where(dis_ == 0, dis_, 1/torch.sqrt(dis_)) 54 | matrix = torch.where(adj_==0, dis_, adj_) 55 | matrix_pad[i, :int(l[0]), :int(l[0])] = matrix 56 | manu_mask_pad[i, :, :int(l[0]), :int(l[0])] = torch.cat([localmask.eq(0).unsqueeze(0),cojmask.eq(0).unsqueeze(0),hetemask.eq(0).unsqueeze(0),ringmask.eq(0).unsqueeze(0),arommask.eq(0).unsqueeze(0),alipmask.eq(0).unsqueeze(0)], 0) 57 | 58 | matrix_pad = matrix_pad.to(self.device) 59 | manu_mask_pad = manu_mask_pad.to(self.device) 60 | 61 | x_batch, matrix = self.tfs[0](x_batch, mask, matrix_pad, manu_mask_pad) 62 | for i in range(1, self.layer_num): 63 | x = torch.masked_select(x_batch, mask.unsqueeze(-1)) 64 | x = x.reshape(-1, output_dim) 65 | if self.edge: 66 | x = self.gnns[i](x, edge_index, edge_weight=edge_attr) 67 | else: 68 | x = self.gnns[i](x, edge_index) 69 | x = self.dps[i](self.nms[i](x)) 70 | 71 | x = self.relu(x) 72 | x_batch, mask = to_dense_batch(x, batch) 73 | x_batch, matrix = self.tfs[i](x_batch, mask, matrix, manu_mask_pad) 74 | 75 | return x_batch 76 | 77 | 78 | class Kno(nn.Module): 79 | def __init__(self, task='reg', tasks=1, attn_head=4, output_dim=128, d_k=64, attn_layers=4, D=16, dropout=0.1, useedge=False, device='cuda:0'): 80 | super(Kno, self).__init__() 81 | self.device = device 82 | self.emb = Embed(attn_head, output_dim, d_k, attn_layers, dropout, useedge, device) 83 | self.task = task 84 | # prediction module 85 | if task == 'clas': 86 | self.w1 = torch.nn.Parameter(torch.FloatTensor(D, output_dim)) 87 | self.w2 = torch.nn.Parameter(torch.FloatTensor(2, D)) 88 | self.th = nn.Tanh() 89 | self.sm = nn.Softmax(-1) 90 | self.bm = nn.BatchNorm1d(2, output_dim) 91 | elif task == 'reg': 92 | self.w1 = nn.Linear(output_dim, 1) 93 | self.sm = nn.Softmax(1) 94 | self.bm = nn.BatchNorm1d(output_dim) 95 | else: 96 | raise NameError('task must be reg or clas!') 97 | self.act = create_ffn(task, tasks, output_dim, dropout) 98 | self.reset_params() 99 | 100 | def reset_params(self): 101 | for weight in self.parameters(): 102 | if len(weight.size()) > 1: 103 | nn.init.xavier_normal_(weight) 104 | 105 | def forward(self, data): 106 | x, edge_index, edge_attr = data.x.to(self.device), data.edge_index.to(self.device), data.edge_attr.to(self.device) # tensor 107 | leng, adj, dis = data.leng, data.adj, data.dis # list 108 | ring_masm, arom_masm, alip_masm, hete_mask = data.ringm, data.aromm, data.alipm, data.hetem 109 | batch = data.batch.to(self.device) 110 | 111 | x_batch = self.emb(x, edge_index, edge_attr, batch, leng, adj, dis, ring_masm, arom_masm, alip_masm, hete_mask) 112 | 113 | if self.task == 'clas': 114 | x_bat = self.th(torch.matmul(self.w1, x_batch.permute(0, 2, 1))) # B O L 115 | x_bat = self.sm(torch.matmul(self.w2, x_bat)) # B X L 116 | x_p = torch.matmul(x_bat, x_batch) # B X D 117 | x_p = self.bm(x_p) 118 | x_p = x_p.reshape(x_p.size(0), x_p.size(1)*x_p.size(2)) 119 | else: 120 | x_p = self.bm(torch.sum(self.sm(self.w1(x_batch)) * x_batch, 1)) 121 | 122 | # prediction 123 | logits = self.act(x_p) 124 | 125 | return logits 126 | 127 | 128 | class ScaledDotProductAttention(nn.Module): 129 | def __init__(self, d_k, dropout): 130 | super(ScaledDotProductAttention, self).__init__() 131 | self.d_k = d_k 132 | self.dp = nn.Dropout(dropout) 133 | self.sm = nn.Softmax(dim=-1) 134 | 135 | def forward(self, Q, K, V, attn_mask): 136 | scores = torch.matmul(Q, K.transpose(-1, -2)) / sqrt(self.d_k) 137 | scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True. 138 | attn = self.sm(scores) 139 | context = torch.matmul(self.dp(attn), V) # [batch_size, len_q, d_v] 140 | return context, attn 141 | 142 | 143 | class MultiHeadAttention(nn.Module): 144 | def __init__(self, d_model, d_k, d_v, n_heads, dropout): 145 | super(MultiHeadAttention, self).__init__() 146 | self.W_Q = nn.Linear(d_model, d_k*n_heads, bias=False) 147 | self.W_K = nn.Linear(d_model, d_k*n_heads, bias=False) 148 | self.W_V = nn.Linear(d_model, d_v*n_heads, bias=False) 149 | self.W_V2 = nn.Linear(d_model, d_v, bias=False) 150 | self.fc = nn.Linear(d_v*(n_heads+1), d_model, bias=False) 151 | self.nm = nn.LayerNorm(d_model) 152 | self.n_heads = n_heads 153 | self.d_k = d_k 154 | self.d_v = d_v 155 | self.dp = nn.Dropout(p=dropout) 156 | self.sdpa = ScaledDotProductAttention(d_k, dropout) 157 | self.fu = nn.Sequential( 158 | nn.LayerNorm(n_heads+1), 159 | nn.Linear(n_heads+1, 6), 160 | nn.ReLU(), 161 | nn.Linear(6, 1), 162 | nn.Sigmoid()) 163 | 164 | def forward(self, input_Q, input_K, input_V, attn_mask, matrix): 165 | batch_size = input_Q.size(0) 166 | 167 | Q = self.W_Q(input_Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2) # Q: [batch_size, n_heads, max_len, d_k] 168 | K = self.W_K(input_K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2) # K: [batch_size, n_heads, max_len, d_k] 169 | V = self.W_V(input_V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2) # V: [batch_size, n_heads, max_len, d_v] 170 | 171 | context, attn = self.sdpa(Q, K, V, attn_mask) 172 | context2 = torch.matmul(matrix, self.W_V2(input_V)) 173 | 174 | matrix = matrix * self.fu(torch.cat([matrix.unsqueeze(1), attn], 1).transpose(1,3)).squeeze() 175 | 176 | context = context.transpose(1, 2).reshape(batch_size, -1, self.n_heads * self.d_v) 177 | context = torch.cat([context, context2], -1) 178 | output = self.fc(context) # [batch_size, max_len, d_model] 179 | 180 | return self.dp(self.nm(output)), matrix 181 | 182 | 183 | class EncoderLayer(nn.Module): 184 | def __init__(self, d_model, d_k, d_v, n_heads, dropout): 185 | super(EncoderLayer, self).__init__() 186 | self.enc_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads, dropout) 187 | self.nm = nn.LayerNorm(d_model) 188 | self.pos_ffn = nn.Sigmoid() 189 | 190 | def forward(self, enc_inputs, attn_mask, matrix): 191 | residual = enc_inputs 192 | enc_outputs, matrix = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, attn_mask, matrix) 193 | enc_outputs = self.pos_ffn(enc_outputs) 194 | return self.nm(enc_outputs+residual), matrix 195 | 196 | 197 | class Encoder(nn.Module): 198 | def __init__(self, d_model, d_k, d_v, n_layers, n_heads, dropout): 199 | super(Encoder, self).__init__() 200 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, n_heads, dropout) for _ in range(n_layers)]) 201 | self.nhead = n_heads 202 | 203 | def forward(self, enc_inputs, mask, matrix, manu_mask_pad): 204 | attn_mask = get_attn_pad_mask(mask) 205 | attn_mask = attn_mask.unsqueeze(1).repeat(1, self.nhead-manu_mask_pad.size(1), 1, 1) # attn_mask : [batch_size, n_heads, max_len, max_len] 206 | attn_mask = torch.cat([attn_mask, manu_mask_pad], 1).bool() 207 | for i, layer in enumerate(self.layers): 208 | enc_inputs, matrix = layer(enc_inputs, attn_mask, matrix) 209 | return enc_inputs, matrix 210 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch, argparse 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import pandas as pd 6 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, mean_absolute_error, mean_squared_error, r2_score 7 | from torch.nn import BCELoss 8 | from torch_geometric.data import DataLoader 9 | from TFM.Dataset import MolNet 10 | from TFM.model import Kno 11 | from TFM.utils import get_logger, metrics_c, metrics_r, set_seed, load_data 12 | from rdkit.Chem.SaltRemover import SaltRemover 13 | import hyperopt 14 | from hyperopt import fmin, hp, Trials 15 | from hyperopt.early_stop import no_progress_loss 16 | import warnings 17 | from datetime import datetime 18 | warnings.filterwarnings("ignore") 19 | remover = SaltRemover() 20 | bad = ['He', 'Be', 'Na', 'Mg', 'Al', 'K', 'Ca', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Rb', 'Sr', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'Gd', 'Tb', 'Ho', 'W', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Ac'] 21 | _use_shared_memory = True 22 | torch.backends.cudnn.benchmark = True 23 | 24 | 25 | def training(model, train_loader, optimizer, loss_f, metric, task, device, mean, stds): 26 | loss_record, record_count = 0., 0. 27 | preds = torch.Tensor([]); tars = torch.Tensor([]) 28 | model.train() 29 | if task == 'clas': 30 | for data in train_loader: 31 | if data.y.size()[0] > 1: 32 | y = data.y.to(device) 33 | logits = model(data) 34 | 35 | loss = loss_f(logits.squeeze(), y.squeeze()) 36 | loss_record += float(loss.item()) 37 | record_count += 1 38 | optimizer.zero_grad() 39 | loss.backward() 40 | nn.utils.clip_grad_value_(model.parameters(), clip_value=2) 41 | optimizer.step() 42 | 43 | pred = logits.detach().cpu() 44 | preds = torch.cat([preds, pred], 0); tars = torch.cat([tars, y.cpu()], 0) 45 | clas = preds > 0.5 46 | acc, f1, pre, rec, auc = metric(clas.squeeze().numpy(), preds.squeeze().numpy(), tars.squeeze().numpy()) 47 | else: 48 | for data in train_loader: 49 | if data.y.size()[0] > 1: 50 | y = data.y.to(device) 51 | 52 | y_ = (y - mean) / (stds+1e-5) 53 | logits = model(data) 54 | 55 | loss = loss_f(logits.squeeze(), y_.squeeze()) 56 | loss_record += float(loss.item()) 57 | record_count += 1 58 | optimizer.zero_grad() 59 | loss.backward() 60 | nn.utils.clip_grad_value_(model.parameters(), clip_value=2) 61 | optimizer.step() 62 | 63 | pred = logits.detach()*stds+mean 64 | preds = torch.cat([preds, pred.cpu()], 0); tars = torch.cat([tars, y.cpu()], 0) 65 | acc, f1, pre, rec, auc = metric(preds.squeeze().numpy(), tars.squeeze().numpy()) 66 | 67 | epoch_loss = loss_record / record_count 68 | return epoch_loss, acc, f1, pre, rec, auc 69 | 70 | 71 | def testing(model, test_loader, loss_f, metric, task, device, mean, stds, resu): 72 | loss_record, record_count = 0., 0. 73 | preds = torch.Tensor([]); tars = torch.Tensor([]) 74 | model.eval() 75 | with torch.no_grad(): 76 | if task == 'clas': 77 | for data in test_loader: 78 | if data.y.size()[0] > 1: 79 | y = data.y.to(device) 80 | logits = model(data) 81 | 82 | loss = loss_f(logits.squeeze(), y.squeeze()) 83 | loss_record += float(loss.item()) 84 | record_count += 1 85 | 86 | pred = logits.detach().cpu() 87 | preds = torch.cat([preds, pred], 0); tars = torch.cat([tars, y.cpu()], 0) 88 | preds, tars = preds.squeeze().numpy(), tars.squeeze().numpy() 89 | clas = preds > 0.5 90 | acc, f1, pre, rec, auc = metric(clas, preds, tars) 91 | else: 92 | for data in test_loader: 93 | if data.y.size()[0] > 1: 94 | y = data.y.to(device) 95 | 96 | y_ = (y - mean) / (stds+1e-5) 97 | logits = model(data) 98 | 99 | loss = loss_f(logits.squeeze(), y_.squeeze()) 100 | loss_record += float(loss.item()) 101 | record_count += 1 102 | 103 | pred = logits.detach()*stds+mean 104 | preds = torch.cat([preds, pred.cpu()], 0); tars = torch.cat([tars, y.cpu()], 0) 105 | preds, tars = preds.squeeze().numpy(), tars.squeeze().numpy() 106 | acc, f1, pre, rec, auc = metric(preds, tars) 107 | 108 | epoch_loss = loss_record / record_count 109 | if resu: 110 | return epoch_loss, acc, f1, pre, rec, auc, preds, tars 111 | else: 112 | return epoch_loss, acc, f1, pre, rec, auc 113 | 114 | 115 | def main(tasks, task, dataset, device, train_epoch, seed, fold, batch_size, rate, split, modelpath, logger, lr, attn_head, output_dim, attn_layers, dropout, mean, stds, D, useedge, met, savem): 116 | logger.info('Dataset: {} task: {} train_epoch: {}'.format(dataset, task, train_epoch)) 117 | d_k, seed_ = round(output_dim/attn_head), seed 118 | 119 | fold_result = [[], []] 120 | if task == 'clas': 121 | loss_f = BCELoss().to(device) 122 | metric = metrics_c(accuracy_score, precision_score, recall_score, f1_score, roc_auc_score) 123 | 124 | for fol in range(1, fold+1): 125 | best_val_auc, best_test_auc = 0., 0. 126 | if seed is not None: 127 | seed_ = seed + fol-1 128 | set_seed(seed_) 129 | model = Kno(task, tasks, attn_head, output_dim, d_k, attn_layers, D, dropout, useedge, device).to(device) 130 | optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.1) 131 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6, last_epoch=-1) 132 | 133 | train_loader, valid_loader, test_loader = load_data(dataset, batch_size, rate[0], rate[1], 0, task, split, seed_) 134 | logger.info('Dataset: {} Fold: {:<4d}'.format(moldata, fol)) 135 | 136 | for i in range(1,train_epoch+1): 137 | train_loss, train_acc, train_f1, train_pre, train_rec, train_auc = training(model, train_loader, optimizer, loss_f, metric, task, device, mean, stds) 138 | if sche: 139 | scheduler.step() 140 | valid_loss, valid_acc, valid_f1, valid_pre, valid_rec, valid_auc = testing(model, valid_loader, loss_f, metric, task, device, mean, stds, False) 141 | 142 | logger.info('Dataset: {} Epoch: {:<3d} train_loss: {:.4f} train_auc: {:.4f}'.format(dataset ,i, train_loss, train_auc)) 143 | logger.info('Dataset: {} Epoch: {:<3d} valid_loss: {:.4f} valid_auc: {:.4f}'.format(dataset, i, valid_loss, valid_auc)) 144 | 145 | if valid_auc > best_val_auc: 146 | best_val_auc = valid_auc 147 | if savem: 148 | model_save_path = modelpath + '{}_{}_{}.pkl'.format(dataset, i, round(valid_auc, 4)) 149 | torch.save(model.state_dict(), model_save_path) 150 | test_loss, test_acc, test_f1, test_pre, test_rec, test_auc = testing(model, test_loader, loss_f, metric, task, device, mean, stds, False) 151 | logger.info('Dataset: {} Epoch: {:<3d} test__loss: {:.4f} test__auc: {:.4f}'.format(dataset, i, test_loss, test_auc)) 152 | best_test_auc = test_auc 153 | fold_result[0].append(best_val_auc) 154 | fold_result[1].append(best_test_auc) 155 | logger.info('Dataset: {} Fold: {} best_val_auc: {:.4f} best_test_auc: {:.4f}'.format(dataset, fol, best_val_auc, best_test_auc)) 156 | logger.info('Dataset: {} Fold result: {}'.format(dataset, fold_result)) 157 | return fold_result 158 | 159 | else: 160 | if met == 'mae': 161 | loss_f = nn.L1Loss().to(device) 162 | else: 163 | loss_f = nn.MSELoss().to(device) 164 | metric = metrics_r(mean_absolute_error, mean_squared_error, r2_score) 165 | 166 | for fol in range(1, fold+1): 167 | best_val_rmse, best_test_rmse = 9999., 9999. 168 | if seed is not None: 169 | seed_ = seed + fol-1 170 | set_seed(seed_) 171 | model = Kno(task, tasks, attn_head, output_dim, d_k, attn_layers, D, dropout, useedge, device).to(device) 172 | optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.1) 173 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6, last_epoch=-1) 174 | 175 | train_loader, valid_loader, test_loader = load_data(dataset, batch_size, rate[0], rate[1], 0, task, split, seed_) 176 | logger.info('Dataset: {} Fold: {:<4d}'.format(moldata, fol)) 177 | 178 | for i in range(1, train_epoch+1): 179 | train_loss, train_mae, train_rmse, train_r2, _, _ = training(model, train_loader, optimizer, loss_f, metric, task, device, mean, stds) 180 | if sche: 181 | scheduler.step() 182 | valid_loss, valid_mae, valid_rmse, valid_r2, _, _ = testing(model, valid_loader, loss_f, metric, task, device, mean, stds, False) 183 | logger.info('Dataset: {} Epoch: {:<3d} train_loss: {:.4f} train_mae: {:.4f} train_rmse: {:.4f}'.format(dataset, i, train_loss, train_mae, train_rmse)) 184 | logger.info('Dataset: {} Epoch: {:<3d} valid_loss: {:.4f} valid_mae: {:.4f} valid_rmse: {:.4f}'.format(dataset, i, valid_loss, valid_mae, valid_rmse)) 185 | 186 | if met == 'rmse': 187 | if valid_rmse < best_val_rmse: 188 | best_val_rmse = valid_rmse 189 | if savem: 190 | model_save_path = modelpath + '{}_{}_{}.pkl'.format(dataset, i, round(valid_rmse,4)) 191 | torch.save(model.state_dict(), model_save_path) 192 | test_loss, test_mae, test_rmse, test_r2, _, _ = testing(model, test_loader, loss_f, metric, task, device, mean, stds, False) 193 | logger.info('Dataset: {} Epoch: {:<3d} test_loss: {:.4f} test_rmse: {:.4f}'.format(dataset, i, test_loss, test_rmse)) 194 | best_test_rmse = test_rmse 195 | 196 | elif met == 'mae': 197 | if valid_mae < best_val_rmse: 198 | best_val_rmse = valid_mae 199 | if savem: 200 | model_save_path = modelpath + '{}_{}_{}.pkl'.format(dataset, i, round(valid_rmse,4)) 201 | torch.save(model.state_dict(), model_save_path) 202 | test_loss, test_mae, test_rmse, test_r2, _, _ = testing(model, test_loader, loss_f, metric, task, device, mean, stds, False) 203 | logger.info('Dataset: {} Epoch: {:<3d} test_loss: {:.4f} test_mae: {:.4f}'.format(dataset, i, test_loss, test_mae)) 204 | best_test_rmse = test_mae 205 | else: 206 | raise ValueError('regression metric must be rmse or mae') 207 | fold_result[0].append(best_val_rmse) 208 | fold_result[1].append(best_test_rmse) 209 | logger.info('Dataset: {} Fold: {} best_val_{}: {:.4f} best_test_{}: {:.4f}'.format(dataset, fol, met, best_val_rmse, met, best_test_rmse)) 210 | logger.info('Dataset: {} Fold result: {}'.format(dataset, fold_result)) 211 | return fold_result 212 | 213 | 214 | def test(tasks, task, dataset, device, seed, batch_size, logger, attn_head, output_dim, attn_layers, dropout, pretrain, mean, stds, D, useedge, met): 215 | logger.info('Dataset: {} task: {} testing:'.format(dataset, task)) 216 | d_k = round(output_dim/attn_head) 217 | if seed is not None: 218 | set_seed(seed) 219 | model = Kno(task, tasks, attn_head, output_dim, d_k, attn_layers, D, dropout, useedge, device).to(device) 220 | state_dict = torch.load(pretrain) 221 | model.load_state_dict(state_dict) 222 | 223 | data = MolNet(root='./dataset', dataset=dataset) 224 | loader = DataLoader(data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0, drop_last=False) 225 | if task == 'clas': 226 | loss_f = BCELoss().to(device) 227 | metric = metrics_c(accuracy_score, precision_score, recall_score, f1_score, roc_auc_score) 228 | loss, acc, f1, pre, rec, auc, preds, tars = testing(model, loader, loss_f, metric, task, device, mean, stds, True) 229 | logger.info('Dataset: {} test_loss: {:.4f} test_acc: {:.4f} test_f1: {:.4f} test_auc: {:.4f} test_pre: {:.4f} test_rec: {:.4f}'.format(dataset, loss, acc, f1, auc, pre, rec)) 230 | results = { 231 | 'test_loss': loss, 232 | 'test_acc': acc, 233 | 'test_f1': f1, 234 | 'test_pre': pre, 235 | 'test_rec': rec, 236 | 'test_auc': auc, 237 | } 238 | df_prediction = pd.DataFrame({'prediction': preds}) 239 | df_target = pd.DataFrame({'target': tars}) 240 | df_single_values = pd.DataFrame({k: [v] for k, v in results.items()}) 241 | 242 | for col in df_single_values.columns: 243 | df_single_values[col] = df_single_values[col].reindex(df_prediction.index, method='ffill') 244 | 245 | df = pd.concat([df_single_values, df_prediction, df_target], axis=1) 246 | df.to_csv('log/Result'+moldata+'_test.csv', index=False) 247 | else: 248 | if met == 'mae': 249 | loss_f = nn.L1Loss().to(device) 250 | else: 251 | loss_f = nn.MSELoss().to(device) 252 | metric = metrics_r(mean_absolute_error, mean_squared_error, r2_score) 253 | loss, mae, rmse, r2, _, _, preds, tars= testing(model, loader, loss_f, metric, task, device, mean, stds, True) 254 | logger.info('Dataset: {} test_loss: {:.4f} test_mae: {:.4f} test_rmse: {:.4f} test_r2: {:.4f}'.format(dataset, loss, mae, rmse, r2)) 255 | results = { 256 | 'test_loss': loss, 257 | 'test_mae': mae, 258 | 'test_rmse': rmse, 259 | 'test_r2': r2, 260 | } 261 | df_prediction = pd.DataFrame({'prediction': preds}) 262 | df_target = pd.DataFrame({'target': tars}) 263 | df_single_values = pd.DataFrame({k: [v] for k, v in results.items()}) 264 | 265 | for col in df_single_values.columns: 266 | df_single_values[col] = df_single_values[col].reindex(df_prediction.index, method='ffill') 267 | 268 | df = pd.concat([df_single_values, df_prediction, df_target], axis=1) 269 | df.to_csv('log/Result'+moldata+'_test.csv', index=False) 270 | 271 | 272 | def psearch(params): 273 | logger.info('Optimizing Hyperparameters') 274 | fold_result = main(params['tasks'],params['task'],params['moldata'],params['device'],params['train_epoch'],params['seed'],params['fold'],params['batch_size'],params['rate'],params['split'],params['modelpath'],params['logger'],params['lr'],params['attn_head'],params['output_dim'],params['attn_layers'],params['dropout'],params['mean'], params['std'], params['D'], params['useedge'], params['metric'], False) 275 | if task == 'reg': 276 | valid_res = np.mean(fold_result[1]) 277 | else: 278 | valid_res = -np.mean(fold_result[1]) 279 | return valid_res 280 | 281 | 282 | if __name__ == '__main__': 283 | parser = argparse.ArgumentParser(description='TransFoxMol') 284 | parser.add_argument('mode', type=str, choices=['train', 'test', 'search'], help='train, test or hyperparameter_search') 285 | parser.add_argument('moldata', type=str, help='Dataset name') 286 | parser.add_argument('--task', type=str, choices=['clas', 'reg'], help='Classification or Regression') 287 | parser.add_argument('--numtasks', type=int, default=1, help='Number of tasks (default: 1).') 288 | parser.add_argument('--device', type=str, default='cuda:0', help='Which gpu to use if any (default: cuda:0)') 289 | parser.add_argument('--batch_size', type=int, default=32, help='Input batch size for training (default: 32)') 290 | parser.add_argument('--train_epoch', type=int, default=50, help='Number of epochs to train (default: 50)') 291 | parser.add_argument('--max_eval', type=int, default=100, help='Number hyperparameter settings to try (default: 100)') 292 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 293 | parser.add_argument('--valrate', type=float, default=0.1, help='valid rate (default: 0.1)') 294 | parser.add_argument('--testrate', type=float, default=0.1, help='test rate (default: 0.1)') 295 | parser.add_argument('--fold', type=int, default=3, help='Number of folds for cross validation (default: 3)') 296 | parser.add_argument('--dropout', type=float, default=0.05, help='dropout ratio') 297 | parser.add_argument('--split', type =str, default='random_scaffold', help = 'random_scaffold/balan_scaffold/random (default: random_scaffold)') 298 | parser.add_argument('--attn_layers', type=int, default=2, help='Number of feature learning layers') 299 | parser.add_argument('--output_dim', type=int, default=256, help='Hidden size of embedding layer') 300 | parser.add_argument('--D', type=int, default=4, help='Hidden size of readout layer') 301 | parser.add_argument('--seed', type=int, help = "Seed for splitting the dataset") 302 | parser.add_argument('--pretrain', type=str, help = "Path of retrained weights") 303 | parser.add_argument('--metric', type=str, choices=['rmse', 'mae'], help='Metric to evaluate the regression performance') 304 | args = parser.parse_args() 305 | 306 | device = torch.device(args.device) 307 | moldata = args.moldata 308 | attn_head = 10 309 | max_eval = args.max_eval 310 | rate = [args.valrate, args.testrate] 311 | useedge = False 312 | sche = True 313 | 314 | if moldata in ['esol', 'freesolv', 'lipo', 'qm7', 'qm8', 'qm9']: 315 | task = 'reg' 316 | if moldata == 'qm8': 317 | numtasks = 12 318 | elif moldata == 'qm9': 319 | numtasks = 3 320 | else: 321 | numtasks = 1 322 | elif moldata in ['bbbp', 'sider', 'clintox', 'tox21', 'toxcast', 'bace', 'pcba', 'muv', 'hiv']: 323 | task = 'clas' 324 | if moldata == 'sider': 325 | numtasks = 27 326 | elif moldata == 'clintox': 327 | numtasks = 2 328 | useedge = True 329 | elif moldata == 'tox21': 330 | numtasks = 12 331 | elif moldata == 'toxcast': 332 | numtasks = 617 333 | elif moldata == 'pcba': 334 | numtasks = 128 335 | elif moldata == 'muv': 336 | numtasks = 17 337 | else: 338 | numtasks = 1 339 | else: 340 | task = args.task 341 | numtasks = args.numtasks 342 | 343 | logf = 'log/{}_{}_{}_{}.log'.format(moldata, args.task, args.split, args.mode) 344 | modelpath = 'log/checkpoint/' 345 | logger = get_logger(logf) 346 | 347 | logger.info("Arguments:") 348 | for arg in vars(args): 349 | logger.info(f"{arg}: {getattr(args, arg)}") 350 | 351 | moldata += task 352 | try: 353 | data = MolNet(root='./dataset', dataset=moldata) 354 | except: 355 | raise ValueError('Process the dataset first!') 356 | length = len(data) 357 | if task == 'clas': 358 | mean, std = None, None 359 | else: 360 | max_eval = 50 361 | if numtasks > 1: 362 | ys = np.asarray([d.y.numpy() for d in data]) 363 | mean, std = np.mean(ys, 0), np.std(ys, 0) 364 | mean, stds = torch.FloatTensor(mean).to(device), torch.FloatTensor(std).to(device) 365 | else: 366 | ys = np.asarray([d.y.item() for d in data]) 367 | mean, std = np.mean(ys, 0), np.std(ys, 0) 368 | 369 | dps = [0.05, 0.1] 370 | if args.mode == 'search': 371 | trials = Trials() 372 | if length < 500: 373 | batch_size = 8 374 | attn_head = 8 375 | elif length < 5000: 376 | batch_size = 32 377 | else: 378 | batch_size = 256 379 | max_eval = 50 380 | if task == 'clas': 381 | dps = [0.1, 0.2, 0.3] 382 | if args.moldata == 'bbbp': 383 | lrs = [1e-4, 5e-5, 1e-5] 384 | sche = False 385 | else: 386 | lrs = [1e-2, 5e-3, 1e-3] 387 | 388 | parm_space = { # search space of param 389 | 'tasks': numtasks, 390 | 'task': task, 391 | 'moldata': moldata, 392 | 'mean': mean, 393 | 'std': std, 394 | 'device': args.device, 395 | 'modelpath': modelpath, 396 | 'logger': logger, 397 | 'useedge': useedge, 398 | 'seed': args.seed, 399 | 'fold': args.fold, 400 | 'metric': args.metric, 401 | 'rate': rate, 402 | 'split': args.split, 403 | 'train_epoch': args.train_epoch, 404 | 'attn_head': attn_head, 405 | 'output_dim': hp.choice('output_dim', [128, 256]), 406 | 'attn_layers': hp.choice('attn_layers', [1, 2, 3, 4]), 407 | 'dropout': hp.choice('dropout', dps), 408 | 'lr': hp.choice('lr', lrs), 409 | 'D': hp.choice('D', [2, 4, 6, 8, 12, 16]), 410 | 'batch_size': batch_size 411 | } 412 | param_mappings = { 413 | 'output_dim': [128, 256], 414 | 'attn_layers': [1, 2, 3, 4], 415 | 'dropout': dps, 416 | 'lr': lrs, 417 | 'D': [2, 4, 6, 8, 12, 16] 418 | } 419 | best = fmin(fn=psearch, space=parm_space, algo=hyperopt.tpe.suggest, max_evals=max_eval, trials=trials, early_stop_fn=no_progress_loss(int(max_eval/2))) 420 | best_values = {k: param_mappings[k][v] if k in param_mappings else v for k, v in best.items()} 421 | ys = [t['result']['loss'] for t in trials.trials] 422 | logger.info('Dataset {} Hyperopt Results: {}'.format(moldata, ys)) 423 | logger.info('Dataset {} Best Params: {}'.format(moldata, best_values)) 424 | logger.info('Dataset {} Best Perform: {}'.format(moldata, np.min(ys))) 425 | 426 | elif args.mode == 'train': 427 | logger.info('Training') 428 | fold_result = main(numtasks, task, moldata, device, args.train_epoch, args.seed, args.fold, args.batch_size, rate, args.split, modelpath, logger, args.lr, attn_head, args.output_dim, args.attn_layers, args.dropout, mean, std, args.D, useedge, args.metric, True) 429 | elif args.mode == 'test': 430 | assert (args.pretrain is not None) 431 | fold_result = test(numtasks, task, moldata, device, args.seed, args.batch_size, logger, attn_head, args.output_dim, args.attn_layers, args.dropout, args.pretrain, mean, std, args.D, useedge, args.metric) 432 | else:pass 433 | --------------------------------------------------------------------------------