├── .gitignore ├── README.md ├── Synthesis.ipynb ├── Synthesis.py └── scripts ├── .DS_Store ├── Test.py ├── Train.py ├── __init__.py ├── dataset.py ├── get_edit.py ├── stopper.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *__pycache__/* 3 | 4 | .ipynb_checkpoints 5 | *.ipynb_checkpoints/* 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LocalTransform 2 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)[![DOI](https://zenodo.org/badge/443246460.svg)](https://zenodo.org/badge/latestdoi/443246460)
3 | Implementation of organic reactivity prediction with LocalTransform developed by Prof. Yousung Jung group at KAIST (now moved to SNU, contact: yousung.jung@snu.ac.kr).

4 | ![LocalTransform](https://i.imgur.com/9SA50iK.jpg) 5 | 6 | ## Remove code and license announcement (2025.03.28) 7 | Part of the code and license are removed. 8 | 9 | ## Model size decrease announcement (2022.10.31) 10 | We slightly modified the model architechture to decrease the model size from 59MB to 36.4MB so we can upload to GitHub repo by decrease to size of bond feature from 512 to 256 through bond_net (see `scripts/model.py` for more detail). This modification also accelerate the training process. 11 | Also we fix few part of code to enable smooth implementation on cpu. 12 | 13 | ## Contents 14 | 15 | - [Developer](#developer) 16 | - [OS Requirements](#os-requirements) 17 | - [Python Dependencies](#python-dependencies) 18 | - [Installation Guide](#installation-guide) 19 | - [Reproduce the results](#reproduce-the-results) 20 | - [Demo and human benchmark results](#demo-and-human-benchmark-results) 21 | - [Publication](#publication) 22 | - [License](#license) 23 | 24 | ## Developer 25 | Shuan Chen (shuan.micc@gmail.com)
26 | 27 | ## OS Requirements 28 | This repository has been tested on both **Linux** and **Windows** operating systems. 29 | 30 | ## Python Dependencies 31 | * Python (version >= 3.6) 32 | * Numpy (version >= 1.16.4) 33 | * PyTorch (version >= 1.0.0) 34 | * RDKit (version >= 2019) 35 | * DGL (version >= 0.5.2) 36 | * DGLLife (version >= 0.2.6) 37 | 38 | ## Installation Guide 39 | Create a virtual environment to run the code of LocalTransform.
40 | Make sure to install pytorch with the cuda version that fits your device.
41 | This process usually takes few munites to complete.
42 | ``` 43 | git clone https://github.com/kaist-amsg/LocalTransform.git 44 | cd LocalTransform 45 | conda create -c conda-forge -n rdenv python=3.6 -y 46 | conda activate rdenv 47 | conda install pytorch cudatoolkit=11.3 -c pytorch -y 48 | conda install -c conda-forge rdkit -y 49 | conda install -c dglteam dgl-cuda11.3 50 | pip install dgllife 51 | ``` 52 | 53 | ## Reproduce the results 54 | ### [1] Download the raw data of USPTO-480k dataset 55 | Download the data from https://github.com/wengong-jin/nips17-rexgen/blob/master/USPTO/ and move the data to `./data/USPTO_480k/`. 56 | 57 | ### [2] Data preprocessing 58 | A two-step data preprocessing is needed to train the LocalTransform model. 59 | 60 | #### 1) Local reaction template derivation 61 | First go to the data processing folder 62 | ``` 63 | cd preprocessing 64 | ``` 65 | and extract the reaction templates. 66 | ``` 67 | python Extract_from_train_data.py 68 | ``` 69 | This will give you four files, including 70 | (1) real_templates.csv (reaction templates for real bonds) 71 | (2) virtual_templates.csv (reaction templates for imaginary bonds) 72 | (3) template_infos.csv (including the hydrogen change, charge change and action information)
73 | 74 | #### 2) Assign the derived templates to raw data 75 | By running 76 | ``` 77 | python Run_preprocessing.py 78 | ``` 79 | You can get four preprocessed files, including 80 | (1) preprocessed_train.csv 81 | (2) preprocessed_valid.csv 82 | (3) preprocessed_test.csv 83 | (4) labeled_data.csv
84 | 85 | 86 | ### [3] Train LocalTransform model 87 | Go to the main scripts folder 88 | ``` 89 | cd ../scripts 90 | ``` 91 | and run the following to train the model with reagent seperated or not (default: False) 92 | ``` 93 | python Train.py -sep True 94 | ``` 95 | The trained model will be saved at `LocalTransform/models/LocalTransform_sep.pth`
96 | 97 | ### [4] Test LocalTransform model 98 | To use the model to test on test set, simply run 99 | ``` 100 | python Test.py -sep True 101 | ``` 102 | to get the raw prediction file saved at `LocalTransform/outputs/raw_prediction/LocalTransform_sep.txt`
103 | Finally you can get the reactants of each prediciton by decoding the raw prediction file 104 | ``` 105 | python Decode_predictions.py -sep True 106 | ``` 107 | The decoded reactants will be saved at 108 | `LocalTransform/outputs/decoded_prediction/LocalTransform_sep.txt`
109 | 110 | ### [5] Exact match accuracy calculation 111 | By using 112 | ``` 113 | python Calculate_topk_accuracy.py -m sep 114 | ``` 115 | the top-k accuracy will be calculated from the files generated at step [4] 116 | 117 | ## Demo and human benchmark results 118 | See `Synthesis.ipynb` for running instructions and expected output. Human benchmark results is also shown at the end of the notebook.
119 | 120 | 121 | ## Publication 122 | [A Generalized Template-Based Graph Neural Network for Accurate Organic Reactivity Prediction, Nat Mach Intell 2022](https://www.nature.com/articles/s42256-022-00526-z) 123 | 124 | ## License 125 | This project is covered under the **Apache 2.0 License**. 126 | -------------------------------------------------------------------------------- /Synthesis.py: -------------------------------------------------------------------------------- 1 | from itertools import permutations 2 | import pandas as pd 3 | import json 4 | from rdkit import Chem 5 | 6 | import torch 7 | from torch import nn 8 | import sklearn 9 | 10 | import dgl 11 | from dgllife.utils import smiles_to_bigraph, WeaveAtomFeaturizer, CanonicalBondFeaturizer 12 | from functools import partial 13 | 14 | from scripts.dataset import combine_reactants, get_bonds, get_adm 15 | from scripts.utils import init_featurizer, load_model, pad_atom_distance_matrix, predict 16 | from scripts.get_edit import get_bg_partition, combined_edit 17 | from LocalTemplate.template_collector import Collector 18 | 19 | atom_types = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 20 | 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 21 | 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 22 | 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Ta', 'Tc', 'Ba', 'Bi', 'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir', 23 | 'Ce', 'Gd', 'Ga', 'Cs'] 24 | 25 | def demap(smiles): 26 | mol = Chem.MolFromSmiles(smiles) 27 | [atom.SetAtomMapNum(0) for atom in mol.GetAtoms()] 28 | return Chem.MolToSmiles(mol) 29 | 30 | class localtransform(): 31 | def __init__(self, dataset, device='cuda:0'): 32 | self.data_dir = 'data/%s' % dataset 33 | self.config_path = 'data/configs/default_config' 34 | self.model_path = 'models/LocalTransform_%s.pth' % dataset 35 | self.device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu') 36 | self.args = {'data_dir': self.data_dir, 'model_path': self.model_path, 'config_path': self.config_path, 'device': self.device, 'mode': 'test'} 37 | self.template_dicts, self.template_infos = self.load_templates() 38 | self.model, self.graph_function = self.init_model() 39 | 40 | def load_templates(self): 41 | template_dicts = {} 42 | for site in ['real', 'virtual']: 43 | template_df = pd.read_csv('%s/%s_templates.csv' % (self.data_dir, site)) 44 | template_dict = {template_df['Class'][i]: template_df['Template'][i].split('_') for i in template_df.index} 45 | print ('loaded %s %s templates' % (len(template_dict), site)) 46 | template_dicts[site[0]] = template_dict 47 | template_infos = pd.read_csv('%s/template_infos.csv' % self.data_dir) 48 | template_infos = {template_infos['Template'][i]: { 49 | 'edit_site': eval(template_infos['edit_site'][i]), 50 | 'change_H': eval(template_infos['change_H'][i]), 51 | 'change_C': eval(template_infos['change_C'][i]), 52 | 'change_S': eval(template_infos['change_S'][i])} for i in template_infos.index} 53 | return template_dicts, template_infos 54 | 55 | def init_model(self): 56 | self.args = init_featurizer(self.args) 57 | model = load_model(self.args) 58 | model.eval() 59 | smiles_to_graph = partial(smiles_to_bigraph, add_self_loop=True) 60 | node_featurizer = WeaveAtomFeaturizer(atom_types=atom_types) 61 | edge_featurizer = CanonicalBondFeaturizer(self_loop=True) 62 | graph_function = lambda s: smiles_to_graph(s, node_featurizer = node_featurizer, edge_featurizer = edge_featurizer, canonical_atom_order = False) 63 | return model, graph_function 64 | 65 | def make_inference(self, reactant_list, topk=5): 66 | fgraphs = [] 67 | dgraphs = [] 68 | for smiles in reactant_list: 69 | mol = Chem.MolFromSmiles(smiles) 70 | fgraph = self.graph_function(smiles) 71 | dgraph = {'atom_distance_matrix': get_adm(mol), 'bonds':get_bonds(smiles)} 72 | dgraph['v_bonds'], dgraph['r_bonds'] = dgraph['bonds'] 73 | fgraphs.append(fgraph) 74 | dgraphs.append(dgraph) 75 | bg = dgl.batch(fgraphs) 76 | bg.set_n_initializer(dgl.init.zero_initializer) 77 | bg.set_e_initializer(dgl.init.zero_initializer) 78 | adm_lists = [graph['atom_distance_matrix'] for graph in dgraphs] 79 | adms = pad_atom_distance_matrix(adm_lists) 80 | bonds_dicts = {'virtual': [torch.from_numpy(graph['v_bonds']).long() for graph in dgraphs], 'real': [torch.from_numpy(graph['r_bonds']).long() for graph in dgraphs]} 81 | 82 | with torch.no_grad(): 83 | pred_VT, pred_RT, _, _, pred_VI, pred_RI, attentions = predict(self.args, self.model, bg, adms, bonds_dicts) 84 | pred_VT = nn.Softmax(dim=1)(pred_VT) 85 | pred_RT = nn.Softmax(dim=1)(pred_RT) 86 | v_sep, r_sep = get_bg_partition(bg, bonds_dicts) 87 | start_v, start_r = 0, 0 88 | predictions = [] 89 | for i, (reactant) in enumerate(reactant_list): 90 | end_v, end_r = v_sep[i], r_sep[i] 91 | virtual_bonds, real_bonds = bonds_dicts['virtual'][i].numpy(), bonds_dicts['real'][i].numpy() 92 | pred_vi, pred_ri = pred_VI[i].cpu(), pred_RI[i].cpu() 93 | pred_v, pred_r = pred_VT[start_v:end_v], pred_RT[start_r:end_r] 94 | prediction = combined_edit(virtual_bonds, real_bonds, pred_vi, pred_ri, pred_v, pred_r, topk*10) 95 | predictions.append(prediction) 96 | start_v = end_v 97 | start_r = end_r 98 | return predictions 99 | 100 | def predict_product(self, reactant_list, topk=5, verbose=0): 101 | if isinstance(reactant_list, str): 102 | reactant_list = [reactant_list] 103 | predictions = self.make_inference(reactant_list, topk) 104 | results_df = {'Reactants' : []} 105 | results_dict = {} 106 | for k in range(topk): 107 | results_df['Top-%d' % (k+1)] = [] 108 | 109 | for reactant, prediction in zip(reactant_list, predictions): 110 | pred_types, pred_sites, scores = prediction 111 | collector = Collector(reactant, self.template_infos, 'nan', False, verbose = verbose > 1) 112 | for k, (pred_type, pred_site, score) in enumerate(zip(pred_types, pred_sites, scores)): 113 | template, H_code, C_code, S_code, action = self.template_dicts[pred_type][pred_site[1]] 114 | pred_site = pred_site[0] 115 | if verbose > 0: 116 | print ('%dth prediction:' % (k+1), template, action, pred_site, score) 117 | collector.collect(template, H_code, C_code, S_code, action, pred_site, score) 118 | if len(collector.predictions) >= topk: 119 | break 120 | sorted_predictions = [k for k, v in sorted(collector.predictions.items(), key=lambda item: -item[1]['score'])] 121 | results_df['Reactants'].append(Chem.MolFromSmiles(reactant)) 122 | results_dict[reactant] = {} 123 | for k in range(topk): 124 | p = sorted_predictions[k] if len(sorted_predictions)>k else '' 125 | results_dict[reactant]['Top-%d' % (k+1)] = collector.predictions[p] 126 | results_dict[reactant]['Top-%d' % (k+1)]['product'] = p 127 | results_df['Top-%d' % (k+1)].append(Chem.MolFromSmiles(p)) 128 | 129 | results_df = pd.DataFrame(results_df) 130 | return results_df, results_dict 131 | 132 | # def predict_products(self, args, reactant_list, model, graph_functions, template_dicts, template_infos, product = None, reagents = 'nan', top_k = 5, collect_n = 100, verbose = 0, sep = False): 133 | # model.eval() 134 | # if reagents != 'nan': 135 | # smiles = reactant + '.' + reagents 136 | # else: 137 | # smiles = reactant 138 | # dglgraph = graph_functions(smiles) 139 | # adms = pad_atom_distance_matrix([get_adm(Chem.MolFromSmiles(smiles))]) 140 | # v_bonds, r_bonds = get_bonds(smiles) 141 | # bonds_dicts = {'virtual': [torch.from_numpy(v_bonds).long()], 'real': [torch.from_numpy(r_bonds).long()]} 142 | # with torch.no_grad(): 143 | # pred_VT, pred_RT, _, _, pred_VI, pred_RI, attentions = predict(args, model, dglgraph, adms, bonds_dicts) 144 | # pred_v = nn.Softmax(dim=1)(pred_VT) 145 | # pred_r = nn.Softmax(dim=1)(pred_RT) 146 | # pred_vi = pred_VI[0].cpu() 147 | # pred_ri = pred_RI[0].cpu() 148 | # pred_types, pred_sites, pred_scores = combined_edit(v_bonds, r_bonds, pred_vi, pred_ri, pred_v, pred_r, collect_n) 149 | 150 | # collector = Collector(reactant, template_infos, reagents, sep, verbose = verbose > 1) 151 | # for k, (pred_type, pred_site, score) in enumerate(zip(pred_types, pred_sites, pred_scores)): 152 | # template, H_code, C_code, S_code, action = template_dicts[pred_type][pred_site[1]] 153 | # pred_site = pred_site[0] 154 | # if verbose > 0: 155 | # print ('%sth prediction:' % k, template, action, pred_site, score) 156 | # collector.collect(template, H_code, C_code, S_code, action, pred_site, score) 157 | # if len(collector.predictions) >= top_k: 158 | # break 159 | # sort_predictions = [k for k, v in sorted(collector.predictions.items(), key=lambda item: -item[1]['score'])] 160 | 161 | # reactant = demap(reactant) 162 | # if product != None: 163 | # correct_at = False 164 | # product = demap(product) 165 | 166 | # results_dict = {'Reactants' : demap(reactant)} 167 | # results_df = pd.DataFrame({'Reactants' : [Chem.MolFromSmiles(reactant)]}) 168 | # for k, p in enumerate(sort_predictions): 169 | # results_dict['Top-%d' % (k+1)] = collector.predictions[p] 170 | # results_dict['Top-%d' % (k+1)]['product'] = p 171 | # results_df['Top-%d' % (k+1)] = [Chem.MolFromSmiles(p)] 172 | # if product != None: 173 | # if set(p.split('.')).intersection(set(product.split('.'))): 174 | # correct_at = k+1 175 | 176 | # if product != None: 177 | # results_df['Correct at'] = correct_at 178 | # return results_df, results_dict 179 | -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaist-amsg/LocalTransform/1b763f20e4d1df560d15aab2a61291fe0c50fae3/scripts/.DS_Store -------------------------------------------------------------------------------- /scripts/Test.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | import sklearn 5 | import torch.nn as nn 6 | 7 | from utils import init_featurizer, mkdir_p, get_configure, load_model, load_dataloader, predict 8 | from get_edit import write_edits 9 | 10 | def main(args): 11 | if args['model_name'] == 'default': 12 | if args['sep']: 13 | args['model_name'] = 'LocalTransform_sep.pth' 14 | else: 15 | args['model_name'] = 'LocalTransform_mix.pth' 16 | else: 17 | args['model_name'] = 'LocalTransform_%s.pth' % args['model_name'] 18 | 19 | args['model_path'] = '../models/%s' % args['model_name'] 20 | args['config_path'] = '../data/configs/%s' % args['config'] 21 | args['data_dir'] = '../data/%s' % args['dataset'] 22 | args['result_path'] = '../outputs/raw_prediction/%s' % args['model_name'].replace('.pth', '.txt') 23 | mkdir_p('../outputs') 24 | mkdir_p('../outputs/raw_prediction') 25 | args = init_featurizer(args) 26 | model = load_model(args) 27 | test_loader = load_dataloader(args) 28 | 29 | write_edits(args, model, test_loader) 30 | return 31 | 32 | if __name__ == '__main__': 33 | parser = ArgumentParser('Testing arguements') 34 | parser.add_argument('-g', '--gpu', default='cuda:0', help='GPU device to use') 35 | parser.add_argument('-d', '--dataset', default='USPTO_480k', help='Dataset to use') 36 | parser.add_argument('-m', '--model-name', default='default', help='Model to use') 37 | parser.add_argument('-c', '--config', default='default_config', help='Configuration of model') 38 | parser.add_argument('-b', '--batch-size', default=32, help='Batch size of dataloader') 39 | parser.add_argument('-k', '--top_num', default=100, help='Num. of predictions to write') 40 | parser.add_argument('-s', '--sep', default=False, help='Train the model with reagent seperated or not') 41 | parser.add_argument('-nw', '--num-workers', type=int, default=0, help='Number of processes for data loading') 42 | args = parser.parse_args().__dict__ 43 | args['mode'] = 'test' 44 | args['device'] = torch.device(args['gpu']) if torch.cuda.is_available() else torch.device('cpu') 45 | print ('Using device %s' % args['device']) 46 | main(args) -------------------------------------------------------------------------------- /scripts/Train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | import sklearn 5 | import torch.nn as nn 6 | 7 | from utils import init_featurizer, mkdir_p, get_configure, load_model, load_dataloader, predict, get_reactive_template_labels 8 | 9 | def mask_loss(loss_criterion1, loss_criterion2, pred_v, pred_r, true_v, true_r, vmask, rmask): 10 | vmask, rmask = vmask.double, rmask.double 11 | vloss = (loss_criterion1(pred_v, true_v) * (vmask != 0)).float().mean() 12 | rloss = (loss_criterion2(pred_r, true_r) * (rmask != 0)).float().mean() 13 | return vloss + rloss 14 | 15 | def run_a_train_epoch(args, epoch, model, data_loader, loss_criterions, optimizer): 16 | model.train() 17 | train_R_loss = 0 18 | train_T_loss = 0 19 | for batch_id, batch_data in enumerate(data_loader): 20 | smiles, bg, adm_lists, bonds_dicts, true_VT, true_RT, masks = batch_data 21 | if len(smiles) == 1: 22 | print ('Skip problematic graph') 23 | continue 24 | pred_VT, pred_RT, pred_VR, pred_RR, pred_VI, pred_RI, _ = predict(args, model, bg, adm_lists, bonds_dicts) 25 | true_VT, true_VR, mask_V = get_reactive_template_labels(true_VT, masks, pred_VI) 26 | true_RT, true_RR, mask_R = get_reactive_template_labels(true_RT, masks, pred_RI) 27 | true_VT, true_RT, true_VR, true_RR, mask_V, mask_R = true_VT.to(args['device']), true_RT.to(args['device']), true_VR.to(args['device']), true_RR.to(args['device']), mask_V.to(args['device']), mask_R.to(args['device']) 28 | 29 | R_loss = mask_loss(loss_criterions[0], loss_criterions[0], pred_VR, pred_RR, true_VR, true_RR, mask_V, mask_R) 30 | T_loss = mask_loss(loss_criterions[1], loss_criterions[2], pred_VT, pred_RT, true_VT, true_RT, mask_V, mask_R) 31 | loss = R_loss + T_loss 32 | train_R_loss += R_loss.item() 33 | train_T_loss += T_loss.item() 34 | optimizer.zero_grad() 35 | loss.backward() 36 | nn.utils.clip_grad_norm_(model.parameters(), args['max_clip']) 37 | optimizer.step() 38 | if batch_id % args['print_every'] == 0: 39 | print('\repoch %d/%d, batch %d/%d, reactive loss %.4f, template loss %.4f' % (epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), R_loss.item(), T_loss.item()), end='', flush=True) 40 | 41 | print('\nepoch %d/%d, train reactive loss %.4f, template loss %.4f' % (epoch + 1, args['num_epochs'], train_R_loss/batch_id, train_T_loss/batch_id)) 42 | 43 | def run_an_eval_epoch(args, model, data_loader, loss_criterions): 44 | model.eval() 45 | val_loss = 0 46 | with torch.no_grad(): 47 | for batch_id, batch_data in enumerate(data_loader): 48 | smiles, bg, adm_lists, bonds_dicts, true_VT, true_RT, masks = batch_data 49 | if len(smiles) == 1: 50 | print ('Skip problematic graph') 51 | continue 52 | pred_VT, pred_RT, pred_VR, pred_RR, pred_VI, pred_RI, _ = predict(args, model, bg, adm_lists, bonds_dicts) 53 | true_VT, true_VR, mask_V = get_reactive_template_labels(true_VT, masks, pred_VI) 54 | true_RT, true_RR, mask_R = get_reactive_template_labels(true_RT, masks, pred_RI) 55 | true_VT, true_RT, true_VR, true_RR, mask_V, mask_R = true_VT.to(args['device']), true_RT.to(args['device']), true_VR.to(args['device']), true_RR.to(args['device']), mask_V.to(args['device']), mask_R.to(args['device']) 56 | loss = mask_loss(loss_criterions[1], loss_criterions[2], pred_VT, pred_RT, true_VT, true_RT, mask_V, mask_R) 57 | val_loss += loss.item() 58 | return val_loss/batch_id 59 | 60 | 61 | def main(args): 62 | if args['model_name'] == 'default': 63 | if args['sep']: 64 | args['model_name'] = 'LocalTransform_sep.pth' 65 | else: 66 | args['model_name'] = 'LocalTransform_mix.pth' 67 | else: 68 | args['model_name'] = '%s.pth' % args['model_name'] 69 | 70 | args['model_path'] = '../models/' + args['model_name'] 71 | args['config_path'] = '../data/configs/%s' % args['config'] 72 | args['data_dir'] = '../data/%s' % args['dataset'] 73 | mkdir_p('../models') 74 | args = init_featurizer(args) 75 | model, loss_criterions, optimizer, scheduler, stopper = load_model(args) 76 | train_loader, val_loader, test_loader = load_dataloader(args) 77 | for epoch in range(args['num_epochs']): 78 | run_a_train_epoch(args, epoch, model, train_loader, loss_criterions, optimizer) 79 | val_loss = run_an_eval_epoch(args, model, val_loader, loss_criterions) 80 | early_stop = stopper.step(val_loss, model) 81 | scheduler.step() 82 | print('epoch %d/%d, validation loss: %.4f' % (epoch + 1, args['num_epochs'], val_loss)) 83 | print('epoch %d/%d, Best loss: %.4f' % (epoch + 1, args['num_epochs'], stopper.best_score)) 84 | if early_stop: 85 | print ('Early stopped!!') 86 | break 87 | 88 | stopper.load_checkpoint(model) 89 | test_loss = run_an_eval_epoch(args, model, test_loader, loss_criterions) 90 | print('test loss: %.4f' % test_loss) 91 | 92 | if __name__ == '__main__': 93 | parser = ArgumentParser('Training arguements') 94 | parser.add_argument('-g', '--gpu', default='cuda:0', help='GPU device to use') 95 | parser.add_argument('-d', '--dataset', default='USPTO_480k', help='Dataset to use') 96 | parser.add_argument('-c', '--config', default='default_config', help='Configuration of model') 97 | parser.add_argument('-b', '--batch-size', default=16, help='Batch size of dataloader') 98 | parser.add_argument('-n', '--num-epochs', type=int, default=20, help='Maximum number of epochs for training') 99 | parser.add_argument('-m', '--model-name', type=str, default='default', help='Model name') 100 | parser.add_argument('-p', '--patience', type=int, default=3, help='Patience for early stopping') 101 | parser.add_argument('-w', '--negative-weight', type=float, default=0.5, help='Loss weight for negative labels') 102 | parser.add_argument('-s', '--sep', default=False, help='Train the model with reagent seperated or not') 103 | parser.add_argument('-cl', '--max-clip', type=int, default=20, help='Maximum number of gradient clip') 104 | parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4, help='Learning rate of optimizer') 105 | parser.add_argument('-l2', '--weight-decay', type=float, default=1e-6, help='Weight decay of optimizer') 106 | parser.add_argument('-ss', '--schedule-step', type=float, default=6, help='Step size of learning scheduler') 107 | parser.add_argument('-nw', '--num-workers', type=int, default=0, help='Number of processes for data loading') 108 | parser.add_argument('-pe', '--print-every', type=int, default=20, help='Print the training progress every X mini-batches') 109 | args = parser.parse_args().__dict__ 110 | args['mode'] = 'train' 111 | args['device'] = torch.device(args['gpu']) if torch.cuda.is_available() else torch.device('cpu') 112 | print ('Using device %s' % args['device'], 'seperate reagent: %s' % args['sep']) 113 | main(args) -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaist-amsg/LocalTransform/1b763f20e4d1df560d15aab2a61291fe0c50fae3/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/dataset.py: -------------------------------------------------------------------------------- 1 | import os, pickle 2 | import numpy as np 3 | import pandas as pd 4 | from tqdm import tqdm 5 | from rdkit import Chem 6 | 7 | import torch 8 | import sklearn 9 | import dgl 10 | import dgl.backend as F 11 | from dgl.data.utils import save_graphs, load_graphs 12 | 13 | def combine_reactants(reactant, reagent): 14 | if str(reagent) == 'nan': 15 | smiles = reactant 16 | else: 17 | smiles = '.'.join([reactant, reagent]) 18 | return smiles 19 | 20 | def get_bonds(smiles): 21 | mol = Chem.MolFromSmiles(smiles) 22 | A = [a for a in range(mol.GetNumAtoms())] 23 | B = [] 24 | for atom in mol.GetAtoms(): 25 | others = [] 26 | bonds = atom.GetBonds() 27 | for bond in bonds: 28 | atoms = [bond.GetBeginAtom().GetIdx(), bond.GetEndAtom().GetIdx()] 29 | other = [a for a in atoms if a != atom.GetIdx()][0] 30 | others.append(other) 31 | b = [(atom.GetIdx(), other) for other in sorted(others)] 32 | B += b 33 | V = [] 34 | for a in A: 35 | V += [(a,b) for b in A if a != b and (a,b) not in B] 36 | return np.array(V), np.array(B) 37 | 38 | def get_adm(mol, max_distance = 6): 39 | mol_size = mol.GetNumAtoms() 40 | distance_matrix = np.ones((mol_size, mol_size)) * max_distance + 1 41 | dm = Chem.GetDistanceMatrix(mol) 42 | dm[dm > 100] = -1 # remote (different molecule) 43 | dm[dm > max_distance] = max_distance # remote (same molecule) 44 | dm[dm == -1] = max_distance + 1 45 | distance_matrix[:dm.shape[0],:dm.shape[1]] = dm 46 | return distance_matrix 47 | 48 | class USPTODataset(object): 49 | def __init__(self, args, mol_to_graph, node_featurizer, edge_featurizer, load=True, log_every=1000): 50 | df = pd.read_csv('%s/labeled_data.csv' % args['data_dir']) 51 | self.train_ids = df.index[df['Split'] == 'train'].values 52 | self.val_ids = df.index[df['Split'] == 'valid'].values 53 | self.test_ids = df.index[df['Split'] == 'test'].values 54 | self.reactants = df['Reactants'].tolist() 55 | self.reagents = df['Reagents'].tolist() 56 | self.masks = df['Mask'].tolist() 57 | 58 | self.sep = args['sep'] 59 | self.fgraph_path = '../data/saved_graphs/full_%s_fgraph.bin' % args['dataset'] 60 | if self.sep: 61 | self.labels = [eval(t) for t in df['Labels_sep']] 62 | self.dgraph_path = '../data/saved_graphs/full_%s_dgraph_sep.pkl' % args['dataset'] # changed by numpy matrix for faster processing speed 63 | else: 64 | self.labels = [eval(t) for t in df['Labels_mix']] 65 | self.dgraph_path = '../data/saved_graphs/full_%s_dgraph_mix.pkl' % args['dataset'] 66 | self._pre_process(mol_to_graph, node_featurizer, edge_featurizer) 67 | 68 | def _pre_process(self, mol_to_graph, node_featurizer, edge_featurizer): 69 | self.fgraphs_exist = os.path.exists(self.fgraph_path) 70 | self.dgraphs_exist = os.path.exists(self.dgraph_path) 71 | self.fgraphs = [] 72 | self.dgraphs = [] 73 | 74 | if not self.fgraphs_exist or not self.dgraphs_exist: 75 | for (s1, s2) in tqdm(zip(self.reactants, self.reagents), total=len(self.reactants), desc='Building dgl graphs...'): 76 | smiles = combine_reactants(s1, s2) # s1.s2 77 | mol = Chem.MolFromSmiles(smiles) 78 | if not self.fgraphs_exist: 79 | fgraph = mol_to_graph(mol, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer, canonical_atom_order=False) 80 | self.fgraphs.append(fgraph) 81 | if not self.dgraphs_exist: 82 | if self.sep: 83 | self.dgraphs.append({'atom_distance_matrix': get_adm(mol), 'bonds':get_bonds(s1)}) 84 | else: 85 | self.dgraphs.append({'atom_distance_matrix': get_adm(mol), 'bonds':get_bonds(smiles)}) 86 | 87 | if self.fgraphs_exist: 88 | print ('Loading feture graphs from %s...' % self.fgraph_path) 89 | self.fgraphs, _ = load_graphs(self.fgraph_path) 90 | else: 91 | save_graphs(self.fgraph_path, self.fgraphs) 92 | 93 | if self.dgraphs_exist: 94 | print ('Loading dense graphs from %s...' % self.dgraph_path) 95 | with open(self.dgraph_path, 'rb') as f: 96 | self.dgraphs = pickle.load(f) 97 | else: 98 | with open(self.dgraph_path, 'wb') as f: 99 | pickle.dump(self.dgraphs, f) 100 | 101 | def __getitem__(self, item): 102 | dgraph = self.dgraphs[item] 103 | dgraph['v_bonds'], dgraph['r_bonds'] = dgraph['bonds'] 104 | return self.reactants[item], self.reagents[item], self.fgraphs[item], dgraph, self.labels[item], self.masks[item] 105 | 106 | def __len__(self): 107 | return len(self.reactants) 108 | 109 | class USPTOTestDataset(object): 110 | def __init__(self, args, mol_to_graph, node_featurizer, edge_featurizer, load=True, log_every=1000): 111 | df = pd.read_csv('%s/preprocessed_test.csv' % args['data_dir']) 112 | self.reactants = df['Reactants'].tolist() 113 | self.reagents = df['Reagents'].tolist() 114 | self.sep = args['sep'] 115 | self.fgraph_path = '../data/saved_graphs/test_%s_fgraph.bin' % args['dataset'] 116 | if self.sep: 117 | self.dgraph_path = '../data/saved_graphs/test_%s_dgraph_sep.pkl' % args['dataset'] # changed by numpy matrix for faster processing speed 118 | else: 119 | self.dgraph_path = '../data/saved_graphs/test_%s_dgraph_mix.pkl' % args['dataset'] 120 | self._pre_process(mol_to_graph, node_featurizer, edge_featurizer) 121 | 122 | 123 | def _pre_process(self, mol_to_graph, node_featurizer, edge_featurizer): 124 | self.fgraphs_exist = os.path.exists(self.fgraph_path) 125 | self.dgraphs_exist = os.path.exists(self.dgraph_path) 126 | self.fgraphs = [] 127 | self.dgraphs = [] 128 | 129 | if not self.fgraphs_exist or not self.dgraphs_exist: 130 | for (s1, s2) in tqdm(zip(self.reactants, self.reagents), total=len(self.reactants), desc='Building dgl graphs...'): 131 | smiles = combine_reactants(s1, s2) # s1.s2 132 | if not self.fgraphs_exist: 133 | mol = Chem.MolFromSmiles(smiles) 134 | fgraph = mol_to_graph(mol, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer, canonical_atom_order=False) 135 | self.fgraphs.append(fgraph) 136 | if not self.dgraphs_exist: 137 | if self.sep: 138 | self.dgraphs.append({'atom_distance_matrix': get_adm(mol), 'bonds':get_bonds(s1)}) 139 | else: 140 | self.dgraphs.append({'atom_distance_matrix': get_adm(mol), 'bonds':get_bonds(smiles)}) 141 | 142 | if self.fgraphs_exist: 143 | print ('Loading feture graphs from %s...' % self.fgraph_path) 144 | self.fgraphs, _ = load_graphs(self.fgraph_path) 145 | else: 146 | save_graphs(self.fgraph_path, self.fgraphs) 147 | 148 | if self.dgraphs_exist: 149 | print ('Loading dense graphs from %s...' % self.dgraph_path) 150 | with open(self.dgraph_path, 'rb') as f: 151 | self.dgraphs = pickle.load(f) 152 | else: 153 | with open(self.dgraph_path, 'wb') as f: 154 | pickle.dump(self.dgraphs, f) 155 | 156 | def __getitem__(self, item): 157 | dgraph = self.dgraphs[item] 158 | dgraph['v_bonds'], dgraph['r_bonds'] = dgraph['bonds'] 159 | return self.reactants[item], self.reagents[item], self.fgraphs[item], dgraph 160 | 161 | def __len__(self): 162 | return len(self.reactants) 163 | 164 | -------------------------------------------------------------------------------- /scripts/get_edit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import dgl 9 | 10 | from utils import predict 11 | 12 | def get_id_template(a, class_n): 13 | edit_idx = a//class_n 14 | template = a%class_n 15 | return (edit_idx, template) 16 | 17 | def logit2pred(out, top_num): 18 | class_n = out.size(-1) 19 | readout = out.cpu().detach().numpy() 20 | readout = readout.reshape(-1) 21 | output_rank = np.flip(np.argsort(readout)) 22 | output_rank = [r for r in output_rank if get_id_template(r, class_n)[1] != 0][:top_num] 23 | selected_site = [get_id_template(a, class_n) for a in output_rank] 24 | selected_proba = [readout[a] for a in output_rank] 25 | 26 | return selected_site, selected_proba 27 | 28 | def get_bond_site(graph, etype): 29 | atom_pair_list = torch.transpose(graph.adjacency_matrix(etype=etype).coalesce().indices(), 0, 1).numpy() 30 | atom_pair_list = [atom_pair_list_v[idx] for idx in pred_vi] 31 | 32 | def combined_edit(virtual_bonds, real_bonds, pred_vi, pred_ri, pred_v, pred_r, top_num): 33 | pred_site_v, pred_score_v = logit2pred(pred_v, top_num) 34 | pred_site_r, pred_score_r = logit2pred(pred_r, top_num) 35 | pooled_virtual_bonds = [virtual_bonds[idx] for idx in pred_vi] 36 | pooled_real_bonds = [real_bonds[idx] for idx in pred_ri] 37 | pred_site_v = [(list(pooled_virtual_bonds[pred_site]), pred_temp) for pred_site, pred_temp in pred_site_v] 38 | pred_site_r = [(list(pooled_real_bonds[pred_site]), pred_temp) for pred_site, pred_temp in pred_site_r] 39 | 40 | pred_sites = pred_site_v + pred_site_r 41 | pred_types = ['v'] * top_num + ['r'] * top_num 42 | pred_scores = pred_score_v + pred_score_r 43 | pred_ranks = np.flip(np.argsort(pred_scores))[:top_num] 44 | 45 | pred_types = [pred_types[r] for r in pred_ranks] 46 | pred_sites = [pred_sites[r] for r in pred_ranks] 47 | pred_scores = [pred_scores[r] for r in pred_ranks] 48 | return pred_types, pred_sites, pred_scores 49 | 50 | def get_bg_partition(bg, bonds_dicts): 51 | gs = dgl.unbatch(bg) 52 | v_sep = [0] 53 | r_sep = [0] 54 | for g, v_bonds, r_bonds in zip(gs, bonds_dicts['virtual'], bonds_dicts['real']): 55 | pooling_size = g.num_nodes() 56 | n_v_bonds = len(v_bonds) 57 | if n_v_bonds < pooling_size: 58 | v_sep.append(v_sep[-1] + n_v_bonds) 59 | else: 60 | v_sep.append(v_sep[-1] + pooling_size) 61 | n_r_bonds = len(r_bonds) 62 | if n_r_bonds < pooling_size: 63 | r_sep.append(r_sep[-1] + n_r_bonds) 64 | else: 65 | r_sep.append(r_sep[-1] + pooling_size) 66 | return v_sep[1:], r_sep[1:] 67 | 68 | def write_edits(args, model, test_loader): 69 | model.eval() 70 | with open(args['result_path'], 'w') as f: 71 | f.write('Test_id\tReactants\tReagents\t%s\n' % '\t'.join(['Prediction %s' % (i+1) for i in range(args['top_num'])])) 72 | with torch.no_grad(): 73 | for batch_id, data in enumerate(test_loader): 74 | reactants, reagents, bg, adms, bonds_dicts = data 75 | pred_VT, pred_RT, _, _, pred_VI, pred_RI, _ = predict(args, model, bg, adms, bonds_dicts) 76 | pred_VT = nn.Softmax(dim=1)(pred_VT) 77 | pred_RT = nn.Softmax(dim=1)(pred_RT) 78 | v_sep, r_sep = get_bg_partition(bg, bonds_dicts) 79 | start_v = 0 80 | start_r = 0 81 | print('\rWriting test molecule batch %s/%s' % (batch_id, len(test_loader)), end='', flush=True) 82 | for i, (reactant, reagent) in enumerate(zip(reactants, reagents)): 83 | end_v, end_r = v_sep[i], r_sep[i] 84 | virtual_bonds, real_bonds = bonds_dicts['virtual'][i].numpy(), bonds_dicts['real'][i].numpy() 85 | pred_vi, pred_ri = pred_VI[i].cpu(), pred_RI[i].cpu() 86 | pred_v, pred_r = pred_VT[start_v:end_v], pred_RT[start_r:end_r] 87 | 88 | pred_types, pred_sites, pred_scores = combined_edit(virtual_bonds, real_bonds, pred_vi, pred_ri, pred_v, pred_r, args['top_num']) 89 | test_id = (batch_id * args['batch_size']) + i 90 | f.write('%s\t%s\t%s\t%s\n' % (test_id, reactant, reagent, '\t'.join(['(%s, %s, %s, %.3f)' % (pred_types[i], pred_sites[i][0], pred_sites[i][1], pred_scores[i]) for i in range(args['top_num'])]))) 91 | start_v = end_v 92 | start_r = end_r 93 | print () 94 | return 95 | -------------------------------------------------------------------------------- /scripts/stopper.py: -------------------------------------------------------------------------------- 1 | # modified from dgllife (https://lifesci.dgl.ai/_modules/dgllife/utils/early_stop.html) 2 | import datetime 3 | import torch 4 | 5 | class EarlyStopping(object): 6 | """Early stop tracker 7 | 8 | Save model checkpoint when observing a performance improvement on 9 | the validation set and early stop if improvement has not been 10 | observed for a particular number of epochs. 11 | 12 | Parameters 13 | ---------- 14 | mode : str 15 | * 'higher': Higher metric suggests a better model 16 | * 'lower': Lower metric suggests a better model 17 | If ``metric`` is not None, then mode will be determined 18 | automatically from that. 19 | patience : int 20 | The early stopping will happen if we do not observe performance 21 | improvement for ``patience`` consecutive epochs. 22 | filename : str or None 23 | Filename for storing the model checkpoint. If not specified, 24 | we will automatically generate a file starting with ``early_stop`` 25 | based on the current time. 26 | metric : str or None 27 | A metric name that can be used to identify if a higher value is 28 | better, or vice versa. Default to None. Valid options include: 29 | ``'r2'``, ``'mae'``, ``'rmse'``, ``'roc_auc_score'``. 30 | 31 | Examples 32 | -------- 33 | Below gives a demo for a fake training process. 34 | 35 | >>> import torch 36 | >>> import torch.nn as nn 37 | >>> from torch.nn import MSELoss 38 | >>> from torch.optim import Adam 39 | >>> from dgllife.utils import EarlyStopping 40 | 41 | >>> model = nn.Linear(1, 1) 42 | >>> criterion = MSELoss() 43 | >>> # For MSE, the lower, the better 44 | >>> stopper = EarlyStopping(mode='lower', filename='test.pth') 45 | >>> optimizer = Adam(params=model.parameters(), lr=1e-3) 46 | 47 | >>> for epoch in range(1000): 48 | >>> x = torch.randn(1, 1) # Fake input 49 | >>> y = torch.randn(1, 1) # Fake label 50 | >>> pred = model(x) 51 | >>> loss = criterion(y, pred) 52 | >>> optimizer.zero_grad() 53 | >>> loss.backward() 54 | >>> optimizer.step() 55 | >>> early_stop = stopper.step(loss.detach().data, model) 56 | >>> if early_stop: 57 | >>> break 58 | 59 | >>> # Load the final parameters saved by the model 60 | >>> stopper.load_checkpoint(model) 61 | """ 62 | def __init__(self, args, mode='lower'): 63 | patience=args['patience'] 64 | filename=args['model_path'] 65 | device=args['device'] 66 | 67 | if filename is None: 68 | dt = datetime.datetime.now() 69 | filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format( 70 | dt.date(), dt.hour, dt.minute, dt.second) 71 | 72 | assert mode in ['higher', 'lower'] 73 | self.mode = mode 74 | if self.mode == 'higher': 75 | self._check = self._check_higher 76 | else: 77 | self._check = self._check_lower 78 | 79 | self.patience = patience 80 | self.counter = 0 81 | self.timestep = 0 82 | self.filename = filename 83 | self.device = device 84 | self.best_score = None 85 | self.early_stop = False 86 | 87 | def _check_higher(self, score, prev_best_score): 88 | """Check if the new score is higher than the previous best score. 89 | 90 | Parameters 91 | ---------- 92 | score : float 93 | New score. 94 | prev_best_score : float 95 | Previous best score. 96 | 97 | Returns 98 | ------- 99 | bool 100 | Whether the new score is higher than the previous best score. 101 | """ 102 | return score > prev_best_score 103 | 104 | def _check_lower(self, score, prev_best_score): 105 | """Check if the new score is lower than the previous best score. 106 | 107 | Parameters 108 | ---------- 109 | score : float 110 | New score. 111 | prev_best_score : float 112 | Previous best score. 113 | 114 | Returns 115 | ------- 116 | bool 117 | Whether the new score is lower than the previous best score. 118 | """ 119 | return score < prev_best_score 120 | 121 | def step(self, score, model): 122 | """Update based on a new score. 123 | 124 | The new score is typically model performance on the validation set 125 | for a new epoch. 126 | 127 | Parameters 128 | ---------- 129 | score : float 130 | New score. 131 | model : nn.Module 132 | Model instance. 133 | 134 | Returns 135 | ------- 136 | bool 137 | Whether an early stop should be performed. 138 | """ 139 | self.timestep += 1 140 | if self.best_score is None: 141 | self.best_score = score 142 | self.save_checkpoint(model) 143 | elif self._check(score, self.best_score): 144 | self.best_score = score 145 | self.save_checkpoint(model) 146 | self.counter = 0 147 | else: 148 | self.counter += 1 149 | print( 150 | f'EarlyStopping counter: {self.counter} out of {self.patience}') 151 | if self.counter >= self.patience: 152 | self.early_stop = True 153 | return self.early_stop 154 | 155 | 156 | def save_checkpoint(self, model): 157 | '''Saves model when the metric on the validation set gets improved. 158 | 159 | Parameters 160 | ---------- 161 | model : nn.Module 162 | Model instance. 163 | ''' 164 | torch.save({'model_state_dict': model.state_dict(), 165 | 'timestep': self.timestep}, self.filename) 166 | 167 | 168 | def load_checkpoint(self, model): 169 | '''Load the latest checkpoint 170 | 171 | Parameters 172 | ---------- 173 | model : nn.Module 174 | Model instance. 175 | ''' 176 | model.load_state_dict(torch.load(self.filename, map_location=self.device)['model_state_dict']) -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sklearn 3 | import dgl 4 | import errno 5 | import json 6 | import os, sys 7 | import time 8 | import numpy as np 9 | import pandas as pd 10 | from functools import partial 11 | 12 | import torch.nn as nn 13 | from torch.utils.data import DataLoader 14 | from torch.optim import Adam, lr_scheduler 15 | 16 | from rdkit import Chem 17 | from dgl.data.utils import Subset 18 | from dgllife.utils import WeaveAtomFeaturizer, CanonicalBondFeaturizer, mol_to_bigraph 19 | 20 | 21 | from models import LocalTransform 22 | from dataset import USPTODataset, USPTOTestDataset, combine_reactants 23 | from stopper import EarlyStopping 24 | 25 | def init_featurizer(args): 26 | atom_types = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 27 | 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 28 | 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 29 | 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Ta', 'Tc', 'Ba', 'Bi', 'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir', 30 | 'Ce', 'Gd', 'Ga', 'Cs'] 31 | args['node_featurizer'] = WeaveAtomFeaturizer(atom_types = atom_types) 32 | args['edge_featurizer'] = CanonicalBondFeaturizer(self_loop=True) 33 | return args 34 | 35 | def get_configure(args): 36 | with open('%s.json' % args['config_path'], 'r') as f: 37 | config = json.load(f) 38 | config['Template_rn'] = len(pd.read_csv('%s/real_templates.csv' % args['data_dir'])) 39 | config['Template_vn'] = len(pd.read_csv('%s/virtual_templates.csv' % args['data_dir'])) 40 | args['Template_rn'] = config['Template_rn'] 41 | args['Template_vn'] = config['Template_vn'] 42 | config['in_node_feats'] = args['node_featurizer'].feat_size() 43 | config['in_edge_feats'] = args['edge_featurizer'].feat_size() 44 | return config 45 | 46 | def mkdir_p(path): 47 | try: 48 | os.makedirs(path) 49 | print('Created directory %s'% path) 50 | except OSError as exc: 51 | if exc.errno == errno.EEXIST and os.path.isdir(path): 52 | print('Directory %s already exists.' % path) 53 | else: 54 | raise 55 | 56 | def load_dataloader(args): 57 | if args['mode'] == 'train': 58 | dataset = USPTODataset(args, 59 | mol_to_graph=partial(mol_to_bigraph, add_self_loop=True), 60 | node_featurizer=args['node_featurizer'], 61 | edge_featurizer=args['edge_featurizer']) 62 | 63 | train_set, val_set, test_set = Subset(dataset, dataset.train_ids), Subset(dataset, dataset.val_ids), Subset(dataset, dataset.test_ids) 64 | 65 | train_loader = DataLoader(dataset=train_set, batch_size=args['batch_size'], shuffle=True, 66 | collate_fn=collate_molgraphs, num_workers=args['num_workers']) 67 | val_loader = DataLoader(dataset=val_set, batch_size=args['batch_size'], 68 | collate_fn=collate_molgraphs, num_workers=args['num_workers']) 69 | test_loader = DataLoader(dataset=test_set, batch_size=args['batch_size'], 70 | collate_fn=collate_molgraphs, num_workers=args['num_workers']) 71 | return train_loader, val_loader, test_loader 72 | 73 | elif args['mode'] == 'test': 74 | test_set = USPTOTestDataset(args, 75 | mol_to_graph=partial(mol_to_bigraph, add_self_loop=True), 76 | node_featurizer=args['node_featurizer'], 77 | edge_featurizer=args['edge_featurizer']) 78 | test_loader = DataLoader(dataset=test_set, batch_size=args['batch_size'], 79 | collate_fn=collate_molgraphs_test, num_workers=args['num_workers']) 80 | 81 | elif args['mode'] == 'analyze': 82 | test_set = USPTOAnaynizeDataset(args, 83 | mol_to_graph=partial(mol_to_bigraph, add_self_loop=True), 84 | node_featurizer=args['node_featurizer'], 85 | edge_featurizer=args['edge_featurizer']) 86 | test_loader = DataLoader(dataset=test_set, batch_size=args['batch_size'], 87 | collate_fn=collate_molgraphs_analysis, num_workers=args['num_workers']) 88 | return test_loader 89 | 90 | def weight_loss(class_num, args, reduction = 'none'): 91 | weights=torch.ones(class_num+1) 92 | weights[0] = args['negative_weight'] 93 | return nn.CrossEntropyLoss(weight = weights.to(args['device']), reduction = reduction) 94 | 95 | def load_model(args): 96 | exp_config = get_configure(args) 97 | model = LocalTransform( 98 | node_in_feats=exp_config['in_node_feats'], 99 | edge_in_feats=exp_config['in_edge_feats'], 100 | node_out_feats=exp_config['node_out_feats'], 101 | edge_hidden_feats=exp_config['edge_hidden_feats'], 102 | num_step_message_passing=exp_config['num_step_message_passing'], 103 | attention_heads = exp_config['attention_heads'], 104 | attention_layers = exp_config['attention_layers'], 105 | Template_rn = exp_config['Template_rn'], 106 | Template_vn = exp_config['Template_vn']) 107 | model = model.to(args['device']) 108 | print ('Parameters of loaded LocalTransform:') 109 | print (exp_config) 110 | 111 | if args['mode'] == 'train': 112 | loss_criterions = [weight_loss(1, args), weight_loss(exp_config['Template_vn'], args), weight_loss(exp_config['Template_rn'], args)] 113 | 114 | optimizer = Adam(model.parameters(), lr = args['learning_rate'], weight_decay = args['weight_decay']) 115 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args['schedule_step']) 116 | if os.path.exists(args['model_path']): 117 | user_answer = input('%s exists, want to (a) overlap (b) continue from checkpoint (c) make a new model?' % args['model_path']) 118 | if user_answer == 'a': 119 | stopper = EarlyStopping(args) 120 | print ('Overlap exsited model and training a new model...') 121 | elif user_answer == 'b': 122 | stopper = EarlyStopping(args) 123 | stopper.load_checkpoint(model) 124 | print ('Train from existed model checkpoint...') 125 | elif user_answer == 'c': 126 | model_name = input('Enter new model name: ') 127 | args['model_path'] = '../models/%s.pth' % model_name 128 | stopper = EarlyStopping(args) 129 | print ('Training a new model %s.pth' % model_name) 130 | else: 131 | print ("Input error: please enter a, b or c to specify the model name") 132 | try: 133 | sys.exit(0) 134 | except SystemExit: 135 | os._exit(0) 136 | else: 137 | stopper = EarlyStopping(args) 138 | return model, loss_criterions, optimizer, scheduler, stopper 139 | 140 | else: 141 | model.load_state_dict(torch.load(args['model_path'], map_location=args['device'])['model_state_dict']) 142 | return model 143 | 144 | def pad_atom_distance_matrix(adm_list): 145 | max_size = max([adm.shape[0] for adm in adm_list]) 146 | adm_list = [torch.tensor(np.pad(adm, (0, max_size - adm.shape[0]), 'maximum')).unsqueeze(0).long() for adm in adm_list] 147 | return torch.cat(adm_list, dim = 0) 148 | 149 | def make_labels(dgraphs, labels, masks): 150 | vtemplate_labels = [] 151 | rtemplate_labels = [] 152 | for graph, label, m in zip(dgraphs, labels, masks): 153 | vtemplate_label, rtemplate_label = [0]*len(graph['v_bonds']), [0]*len(graph['r_bonds']) 154 | if m == 1: 155 | for l in label: 156 | edit_type = l[0] 157 | edit_idx = l[1] 158 | edit_template = l[2] 159 | if edit_type == 'v': 160 | vtemplate_label[edit_idx] = edit_template 161 | else: 162 | rtemplate_label[edit_idx] = edit_template 163 | 164 | vtemplate_labels.append(vtemplate_label) 165 | rtemplate_labels.append(rtemplate_label) 166 | return vtemplate_labels, rtemplate_labels 167 | 168 | def get_reactive_template_labels(all_template_labels, masks, top_idxs): 169 | template_labels = [] 170 | clipped_masks = [] 171 | for i, idxs in enumerate(top_idxs): 172 | template_labels += [all_template_labels[i][idx] for idx in idxs] 173 | clipped_masks += [masks[i] for idx in idxs] 174 | reactive_labels = [int(y > 0) for y in template_labels] 175 | return torch.LongTensor(template_labels), torch.LongTensor(reactive_labels), torch.LongTensor(clipped_masks) 176 | 177 | def collate_molgraphs(data): 178 | reactants, reagents, fgraphs, dgraphs, labels, masks = map(list, zip(*data)) 179 | true_VT, true_RT = make_labels(dgraphs, labels, masks) 180 | bg = dgl.batch(fgraphs) 181 | bg.set_n_initializer(dgl.init.zero_initializer) 182 | bg.set_e_initializer(dgl.init.zero_initializer) 183 | adm_lists = [graph['atom_distance_matrix'] for graph in dgraphs] 184 | adms = pad_atom_distance_matrix(adm_lists) 185 | bonds_dicts = {'virtual': [torch.from_numpy(graph['v_bonds']).long() for graph in dgraphs], 'real': [torch.from_numpy(graph['r_bonds']).long() for graph in dgraphs]} 186 | return reactants, bg, adms, bonds_dicts, true_VT, true_RT, masks 187 | 188 | def collate_molgraphs_test(data): 189 | reactants, reagents, fgraphs, dgraphs = map(list, zip(*data)) 190 | bg = dgl.batch(fgraphs) 191 | bg.set_n_initializer(dgl.init.zero_initializer) 192 | bg.set_e_initializer(dgl.init.zero_initializer) 193 | adm_lists = [graph['atom_distance_matrix'] for graph in dgraphs] 194 | adms = pad_atom_distance_matrix(adm_lists) 195 | bonds_dicts = {'virtual': [torch.from_numpy(graph['v_bonds']).long() for graph in dgraphs], 'real': [torch.from_numpy(graph['r_bonds']).long() for graph in dgraphs]} 196 | return reactants, reagents, bg, adms, bonds_dicts 197 | 198 | 199 | def predict(args, model, bg, adms, bonds_dicts): 200 | adms = adms.to(args['device']) 201 | bg = bg.to(args['device']) 202 | node_feats = bg.ndata.pop('h').to(args['device']) 203 | edge_feats = bg.edata.pop('e').to(args['device']) 204 | return model(bg, adms, bonds_dicts, node_feats, edge_feats) 205 | --------------------------------------------------------------------------------