├── .gitignore
├── README.md
├── Synthesis.ipynb
├── Synthesis.py
└── scripts
├── .DS_Store
├── Test.py
├── Train.py
├── __init__.py
├── dataset.py
├── get_edit.py
├── stopper.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *__pycache__/*
3 |
4 | .ipynb_checkpoints
5 | *.ipynb_checkpoints/*
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LocalTransform
2 | [](https://opensource.org/licenses/Apache-2.0)[](https://zenodo.org/badge/latestdoi/443246460)
3 | Implementation of organic reactivity prediction with LocalTransform developed by Prof. Yousung Jung group at KAIST (now moved to SNU, contact: yousung.jung@snu.ac.kr).
4 | 
5 |
6 | ## Remove code and license announcement (2025.03.28)
7 | Part of the code and license are removed.
8 |
9 | ## Model size decrease announcement (2022.10.31)
10 | We slightly modified the model architechture to decrease the model size from 59MB to 36.4MB so we can upload to GitHub repo by decrease to size of bond feature from 512 to 256 through bond_net (see `scripts/model.py` for more detail). This modification also accelerate the training process.
11 | Also we fix few part of code to enable smooth implementation on cpu.
12 |
13 | ## Contents
14 |
15 | - [Developer](#developer)
16 | - [OS Requirements](#os-requirements)
17 | - [Python Dependencies](#python-dependencies)
18 | - [Installation Guide](#installation-guide)
19 | - [Reproduce the results](#reproduce-the-results)
20 | - [Demo and human benchmark results](#demo-and-human-benchmark-results)
21 | - [Publication](#publication)
22 | - [License](#license)
23 |
24 | ## Developer
25 | Shuan Chen (shuan.micc@gmail.com)
26 |
27 | ## OS Requirements
28 | This repository has been tested on both **Linux** and **Windows** operating systems.
29 |
30 | ## Python Dependencies
31 | * Python (version >= 3.6)
32 | * Numpy (version >= 1.16.4)
33 | * PyTorch (version >= 1.0.0)
34 | * RDKit (version >= 2019)
35 | * DGL (version >= 0.5.2)
36 | * DGLLife (version >= 0.2.6)
37 |
38 | ## Installation Guide
39 | Create a virtual environment to run the code of LocalTransform.
40 | Make sure to install pytorch with the cuda version that fits your device.
41 | This process usually takes few munites to complete.
42 | ```
43 | git clone https://github.com/kaist-amsg/LocalTransform.git
44 | cd LocalTransform
45 | conda create -c conda-forge -n rdenv python=3.6 -y
46 | conda activate rdenv
47 | conda install pytorch cudatoolkit=11.3 -c pytorch -y
48 | conda install -c conda-forge rdkit -y
49 | conda install -c dglteam dgl-cuda11.3
50 | pip install dgllife
51 | ```
52 |
53 | ## Reproduce the results
54 | ### [1] Download the raw data of USPTO-480k dataset
55 | Download the data from https://github.com/wengong-jin/nips17-rexgen/blob/master/USPTO/ and move the data to `./data/USPTO_480k/`.
56 |
57 | ### [2] Data preprocessing
58 | A two-step data preprocessing is needed to train the LocalTransform model.
59 |
60 | #### 1) Local reaction template derivation
61 | First go to the data processing folder
62 | ```
63 | cd preprocessing
64 | ```
65 | and extract the reaction templates.
66 | ```
67 | python Extract_from_train_data.py
68 | ```
69 | This will give you four files, including
70 | (1) real_templates.csv (reaction templates for real bonds)
71 | (2) virtual_templates.csv (reaction templates for imaginary bonds)
72 | (3) template_infos.csv (including the hydrogen change, charge change and action information)
73 |
74 | #### 2) Assign the derived templates to raw data
75 | By running
76 | ```
77 | python Run_preprocessing.py
78 | ```
79 | You can get four preprocessed files, including
80 | (1) preprocessed_train.csv
81 | (2) preprocessed_valid.csv
82 | (3) preprocessed_test.csv
83 | (4) labeled_data.csv
84 |
85 |
86 | ### [3] Train LocalTransform model
87 | Go to the main scripts folder
88 | ```
89 | cd ../scripts
90 | ```
91 | and run the following to train the model with reagent seperated or not (default: False)
92 | ```
93 | python Train.py -sep True
94 | ```
95 | The trained model will be saved at `LocalTransform/models/LocalTransform_sep.pth`
96 |
97 | ### [4] Test LocalTransform model
98 | To use the model to test on test set, simply run
99 | ```
100 | python Test.py -sep True
101 | ```
102 | to get the raw prediction file saved at `LocalTransform/outputs/raw_prediction/LocalTransform_sep.txt`
103 | Finally you can get the reactants of each prediciton by decoding the raw prediction file
104 | ```
105 | python Decode_predictions.py -sep True
106 | ```
107 | The decoded reactants will be saved at
108 | `LocalTransform/outputs/decoded_prediction/LocalTransform_sep.txt`
109 |
110 | ### [5] Exact match accuracy calculation
111 | By using
112 | ```
113 | python Calculate_topk_accuracy.py -m sep
114 | ```
115 | the top-k accuracy will be calculated from the files generated at step [4]
116 |
117 | ## Demo and human benchmark results
118 | See `Synthesis.ipynb` for running instructions and expected output. Human benchmark results is also shown at the end of the notebook.
119 |
120 |
121 | ## Publication
122 | [A Generalized Template-Based Graph Neural Network for Accurate Organic Reactivity Prediction, Nat Mach Intell 2022](https://www.nature.com/articles/s42256-022-00526-z)
123 |
124 | ## License
125 | This project is covered under the **Apache 2.0 License**.
126 |
--------------------------------------------------------------------------------
/Synthesis.py:
--------------------------------------------------------------------------------
1 | from itertools import permutations
2 | import pandas as pd
3 | import json
4 | from rdkit import Chem
5 |
6 | import torch
7 | from torch import nn
8 | import sklearn
9 |
10 | import dgl
11 | from dgllife.utils import smiles_to_bigraph, WeaveAtomFeaturizer, CanonicalBondFeaturizer
12 | from functools import partial
13 |
14 | from scripts.dataset import combine_reactants, get_bonds, get_adm
15 | from scripts.utils import init_featurizer, load_model, pad_atom_distance_matrix, predict
16 | from scripts.get_edit import get_bg_partition, combined_edit
17 | from LocalTemplate.template_collector import Collector
18 |
19 | atom_types = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe',
20 | 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti',
21 | 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb',
22 | 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Ta', 'Tc', 'Ba', 'Bi', 'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir',
23 | 'Ce', 'Gd', 'Ga', 'Cs']
24 |
25 | def demap(smiles):
26 | mol = Chem.MolFromSmiles(smiles)
27 | [atom.SetAtomMapNum(0) for atom in mol.GetAtoms()]
28 | return Chem.MolToSmiles(mol)
29 |
30 | class localtransform():
31 | def __init__(self, dataset, device='cuda:0'):
32 | self.data_dir = 'data/%s' % dataset
33 | self.config_path = 'data/configs/default_config'
34 | self.model_path = 'models/LocalTransform_%s.pth' % dataset
35 | self.device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
36 | self.args = {'data_dir': self.data_dir, 'model_path': self.model_path, 'config_path': self.config_path, 'device': self.device, 'mode': 'test'}
37 | self.template_dicts, self.template_infos = self.load_templates()
38 | self.model, self.graph_function = self.init_model()
39 |
40 | def load_templates(self):
41 | template_dicts = {}
42 | for site in ['real', 'virtual']:
43 | template_df = pd.read_csv('%s/%s_templates.csv' % (self.data_dir, site))
44 | template_dict = {template_df['Class'][i]: template_df['Template'][i].split('_') for i in template_df.index}
45 | print ('loaded %s %s templates' % (len(template_dict), site))
46 | template_dicts[site[0]] = template_dict
47 | template_infos = pd.read_csv('%s/template_infos.csv' % self.data_dir)
48 | template_infos = {template_infos['Template'][i]: {
49 | 'edit_site': eval(template_infos['edit_site'][i]),
50 | 'change_H': eval(template_infos['change_H'][i]),
51 | 'change_C': eval(template_infos['change_C'][i]),
52 | 'change_S': eval(template_infos['change_S'][i])} for i in template_infos.index}
53 | return template_dicts, template_infos
54 |
55 | def init_model(self):
56 | self.args = init_featurizer(self.args)
57 | model = load_model(self.args)
58 | model.eval()
59 | smiles_to_graph = partial(smiles_to_bigraph, add_self_loop=True)
60 | node_featurizer = WeaveAtomFeaturizer(atom_types=atom_types)
61 | edge_featurizer = CanonicalBondFeaturizer(self_loop=True)
62 | graph_function = lambda s: smiles_to_graph(s, node_featurizer = node_featurizer, edge_featurizer = edge_featurizer, canonical_atom_order = False)
63 | return model, graph_function
64 |
65 | def make_inference(self, reactant_list, topk=5):
66 | fgraphs = []
67 | dgraphs = []
68 | for smiles in reactant_list:
69 | mol = Chem.MolFromSmiles(smiles)
70 | fgraph = self.graph_function(smiles)
71 | dgraph = {'atom_distance_matrix': get_adm(mol), 'bonds':get_bonds(smiles)}
72 | dgraph['v_bonds'], dgraph['r_bonds'] = dgraph['bonds']
73 | fgraphs.append(fgraph)
74 | dgraphs.append(dgraph)
75 | bg = dgl.batch(fgraphs)
76 | bg.set_n_initializer(dgl.init.zero_initializer)
77 | bg.set_e_initializer(dgl.init.zero_initializer)
78 | adm_lists = [graph['atom_distance_matrix'] for graph in dgraphs]
79 | adms = pad_atom_distance_matrix(adm_lists)
80 | bonds_dicts = {'virtual': [torch.from_numpy(graph['v_bonds']).long() for graph in dgraphs], 'real': [torch.from_numpy(graph['r_bonds']).long() for graph in dgraphs]}
81 |
82 | with torch.no_grad():
83 | pred_VT, pred_RT, _, _, pred_VI, pred_RI, attentions = predict(self.args, self.model, bg, adms, bonds_dicts)
84 | pred_VT = nn.Softmax(dim=1)(pred_VT)
85 | pred_RT = nn.Softmax(dim=1)(pred_RT)
86 | v_sep, r_sep = get_bg_partition(bg, bonds_dicts)
87 | start_v, start_r = 0, 0
88 | predictions = []
89 | for i, (reactant) in enumerate(reactant_list):
90 | end_v, end_r = v_sep[i], r_sep[i]
91 | virtual_bonds, real_bonds = bonds_dicts['virtual'][i].numpy(), bonds_dicts['real'][i].numpy()
92 | pred_vi, pred_ri = pred_VI[i].cpu(), pred_RI[i].cpu()
93 | pred_v, pred_r = pred_VT[start_v:end_v], pred_RT[start_r:end_r]
94 | prediction = combined_edit(virtual_bonds, real_bonds, pred_vi, pred_ri, pred_v, pred_r, topk*10)
95 | predictions.append(prediction)
96 | start_v = end_v
97 | start_r = end_r
98 | return predictions
99 |
100 | def predict_product(self, reactant_list, topk=5, verbose=0):
101 | if isinstance(reactant_list, str):
102 | reactant_list = [reactant_list]
103 | predictions = self.make_inference(reactant_list, topk)
104 | results_df = {'Reactants' : []}
105 | results_dict = {}
106 | for k in range(topk):
107 | results_df['Top-%d' % (k+1)] = []
108 |
109 | for reactant, prediction in zip(reactant_list, predictions):
110 | pred_types, pred_sites, scores = prediction
111 | collector = Collector(reactant, self.template_infos, 'nan', False, verbose = verbose > 1)
112 | for k, (pred_type, pred_site, score) in enumerate(zip(pred_types, pred_sites, scores)):
113 | template, H_code, C_code, S_code, action = self.template_dicts[pred_type][pred_site[1]]
114 | pred_site = pred_site[0]
115 | if verbose > 0:
116 | print ('%dth prediction:' % (k+1), template, action, pred_site, score)
117 | collector.collect(template, H_code, C_code, S_code, action, pred_site, score)
118 | if len(collector.predictions) >= topk:
119 | break
120 | sorted_predictions = [k for k, v in sorted(collector.predictions.items(), key=lambda item: -item[1]['score'])]
121 | results_df['Reactants'].append(Chem.MolFromSmiles(reactant))
122 | results_dict[reactant] = {}
123 | for k in range(topk):
124 | p = sorted_predictions[k] if len(sorted_predictions)>k else ''
125 | results_dict[reactant]['Top-%d' % (k+1)] = collector.predictions[p]
126 | results_dict[reactant]['Top-%d' % (k+1)]['product'] = p
127 | results_df['Top-%d' % (k+1)].append(Chem.MolFromSmiles(p))
128 |
129 | results_df = pd.DataFrame(results_df)
130 | return results_df, results_dict
131 |
132 | # def predict_products(self, args, reactant_list, model, graph_functions, template_dicts, template_infos, product = None, reagents = 'nan', top_k = 5, collect_n = 100, verbose = 0, sep = False):
133 | # model.eval()
134 | # if reagents != 'nan':
135 | # smiles = reactant + '.' + reagents
136 | # else:
137 | # smiles = reactant
138 | # dglgraph = graph_functions(smiles)
139 | # adms = pad_atom_distance_matrix([get_adm(Chem.MolFromSmiles(smiles))])
140 | # v_bonds, r_bonds = get_bonds(smiles)
141 | # bonds_dicts = {'virtual': [torch.from_numpy(v_bonds).long()], 'real': [torch.from_numpy(r_bonds).long()]}
142 | # with torch.no_grad():
143 | # pred_VT, pred_RT, _, _, pred_VI, pred_RI, attentions = predict(args, model, dglgraph, adms, bonds_dicts)
144 | # pred_v = nn.Softmax(dim=1)(pred_VT)
145 | # pred_r = nn.Softmax(dim=1)(pred_RT)
146 | # pred_vi = pred_VI[0].cpu()
147 | # pred_ri = pred_RI[0].cpu()
148 | # pred_types, pred_sites, pred_scores = combined_edit(v_bonds, r_bonds, pred_vi, pred_ri, pred_v, pred_r, collect_n)
149 |
150 | # collector = Collector(reactant, template_infos, reagents, sep, verbose = verbose > 1)
151 | # for k, (pred_type, pred_site, score) in enumerate(zip(pred_types, pred_sites, pred_scores)):
152 | # template, H_code, C_code, S_code, action = template_dicts[pred_type][pred_site[1]]
153 | # pred_site = pred_site[0]
154 | # if verbose > 0:
155 | # print ('%sth prediction:' % k, template, action, pred_site, score)
156 | # collector.collect(template, H_code, C_code, S_code, action, pred_site, score)
157 | # if len(collector.predictions) >= top_k:
158 | # break
159 | # sort_predictions = [k for k, v in sorted(collector.predictions.items(), key=lambda item: -item[1]['score'])]
160 |
161 | # reactant = demap(reactant)
162 | # if product != None:
163 | # correct_at = False
164 | # product = demap(product)
165 |
166 | # results_dict = {'Reactants' : demap(reactant)}
167 | # results_df = pd.DataFrame({'Reactants' : [Chem.MolFromSmiles(reactant)]})
168 | # for k, p in enumerate(sort_predictions):
169 | # results_dict['Top-%d' % (k+1)] = collector.predictions[p]
170 | # results_dict['Top-%d' % (k+1)]['product'] = p
171 | # results_df['Top-%d' % (k+1)] = [Chem.MolFromSmiles(p)]
172 | # if product != None:
173 | # if set(p.split('.')).intersection(set(product.split('.'))):
174 | # correct_at = k+1
175 |
176 | # if product != None:
177 | # results_df['Correct at'] = correct_at
178 | # return results_df, results_dict
179 |
--------------------------------------------------------------------------------
/scripts/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaist-amsg/LocalTransform/1b763f20e4d1df560d15aab2a61291fe0c50fae3/scripts/.DS_Store
--------------------------------------------------------------------------------
/scripts/Test.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import torch
4 | import sklearn
5 | import torch.nn as nn
6 |
7 | from utils import init_featurizer, mkdir_p, get_configure, load_model, load_dataloader, predict
8 | from get_edit import write_edits
9 |
10 | def main(args):
11 | if args['model_name'] == 'default':
12 | if args['sep']:
13 | args['model_name'] = 'LocalTransform_sep.pth'
14 | else:
15 | args['model_name'] = 'LocalTransform_mix.pth'
16 | else:
17 | args['model_name'] = 'LocalTransform_%s.pth' % args['model_name']
18 |
19 | args['model_path'] = '../models/%s' % args['model_name']
20 | args['config_path'] = '../data/configs/%s' % args['config']
21 | args['data_dir'] = '../data/%s' % args['dataset']
22 | args['result_path'] = '../outputs/raw_prediction/%s' % args['model_name'].replace('.pth', '.txt')
23 | mkdir_p('../outputs')
24 | mkdir_p('../outputs/raw_prediction')
25 | args = init_featurizer(args)
26 | model = load_model(args)
27 | test_loader = load_dataloader(args)
28 |
29 | write_edits(args, model, test_loader)
30 | return
31 |
32 | if __name__ == '__main__':
33 | parser = ArgumentParser('Testing arguements')
34 | parser.add_argument('-g', '--gpu', default='cuda:0', help='GPU device to use')
35 | parser.add_argument('-d', '--dataset', default='USPTO_480k', help='Dataset to use')
36 | parser.add_argument('-m', '--model-name', default='default', help='Model to use')
37 | parser.add_argument('-c', '--config', default='default_config', help='Configuration of model')
38 | parser.add_argument('-b', '--batch-size', default=32, help='Batch size of dataloader')
39 | parser.add_argument('-k', '--top_num', default=100, help='Num. of predictions to write')
40 | parser.add_argument('-s', '--sep', default=False, help='Train the model with reagent seperated or not')
41 | parser.add_argument('-nw', '--num-workers', type=int, default=0, help='Number of processes for data loading')
42 | args = parser.parse_args().__dict__
43 | args['mode'] = 'test'
44 | args['device'] = torch.device(args['gpu']) if torch.cuda.is_available() else torch.device('cpu')
45 | print ('Using device %s' % args['device'])
46 | main(args)
--------------------------------------------------------------------------------
/scripts/Train.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import torch
4 | import sklearn
5 | import torch.nn as nn
6 |
7 | from utils import init_featurizer, mkdir_p, get_configure, load_model, load_dataloader, predict, get_reactive_template_labels
8 |
9 | def mask_loss(loss_criterion1, loss_criterion2, pred_v, pred_r, true_v, true_r, vmask, rmask):
10 | vmask, rmask = vmask.double, rmask.double
11 | vloss = (loss_criterion1(pred_v, true_v) * (vmask != 0)).float().mean()
12 | rloss = (loss_criterion2(pred_r, true_r) * (rmask != 0)).float().mean()
13 | return vloss + rloss
14 |
15 | def run_a_train_epoch(args, epoch, model, data_loader, loss_criterions, optimizer):
16 | model.train()
17 | train_R_loss = 0
18 | train_T_loss = 0
19 | for batch_id, batch_data in enumerate(data_loader):
20 | smiles, bg, adm_lists, bonds_dicts, true_VT, true_RT, masks = batch_data
21 | if len(smiles) == 1:
22 | print ('Skip problematic graph')
23 | continue
24 | pred_VT, pred_RT, pred_VR, pred_RR, pred_VI, pred_RI, _ = predict(args, model, bg, adm_lists, bonds_dicts)
25 | true_VT, true_VR, mask_V = get_reactive_template_labels(true_VT, masks, pred_VI)
26 | true_RT, true_RR, mask_R = get_reactive_template_labels(true_RT, masks, pred_RI)
27 | true_VT, true_RT, true_VR, true_RR, mask_V, mask_R = true_VT.to(args['device']), true_RT.to(args['device']), true_VR.to(args['device']), true_RR.to(args['device']), mask_V.to(args['device']), mask_R.to(args['device'])
28 |
29 | R_loss = mask_loss(loss_criterions[0], loss_criterions[0], pred_VR, pred_RR, true_VR, true_RR, mask_V, mask_R)
30 | T_loss = mask_loss(loss_criterions[1], loss_criterions[2], pred_VT, pred_RT, true_VT, true_RT, mask_V, mask_R)
31 | loss = R_loss + T_loss
32 | train_R_loss += R_loss.item()
33 | train_T_loss += T_loss.item()
34 | optimizer.zero_grad()
35 | loss.backward()
36 | nn.utils.clip_grad_norm_(model.parameters(), args['max_clip'])
37 | optimizer.step()
38 | if batch_id % args['print_every'] == 0:
39 | print('\repoch %d/%d, batch %d/%d, reactive loss %.4f, template loss %.4f' % (epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), R_loss.item(), T_loss.item()), end='', flush=True)
40 |
41 | print('\nepoch %d/%d, train reactive loss %.4f, template loss %.4f' % (epoch + 1, args['num_epochs'], train_R_loss/batch_id, train_T_loss/batch_id))
42 |
43 | def run_an_eval_epoch(args, model, data_loader, loss_criterions):
44 | model.eval()
45 | val_loss = 0
46 | with torch.no_grad():
47 | for batch_id, batch_data in enumerate(data_loader):
48 | smiles, bg, adm_lists, bonds_dicts, true_VT, true_RT, masks = batch_data
49 | if len(smiles) == 1:
50 | print ('Skip problematic graph')
51 | continue
52 | pred_VT, pred_RT, pred_VR, pred_RR, pred_VI, pred_RI, _ = predict(args, model, bg, adm_lists, bonds_dicts)
53 | true_VT, true_VR, mask_V = get_reactive_template_labels(true_VT, masks, pred_VI)
54 | true_RT, true_RR, mask_R = get_reactive_template_labels(true_RT, masks, pred_RI)
55 | true_VT, true_RT, true_VR, true_RR, mask_V, mask_R = true_VT.to(args['device']), true_RT.to(args['device']), true_VR.to(args['device']), true_RR.to(args['device']), mask_V.to(args['device']), mask_R.to(args['device'])
56 | loss = mask_loss(loss_criterions[1], loss_criterions[2], pred_VT, pred_RT, true_VT, true_RT, mask_V, mask_R)
57 | val_loss += loss.item()
58 | return val_loss/batch_id
59 |
60 |
61 | def main(args):
62 | if args['model_name'] == 'default':
63 | if args['sep']:
64 | args['model_name'] = 'LocalTransform_sep.pth'
65 | else:
66 | args['model_name'] = 'LocalTransform_mix.pth'
67 | else:
68 | args['model_name'] = '%s.pth' % args['model_name']
69 |
70 | args['model_path'] = '../models/' + args['model_name']
71 | args['config_path'] = '../data/configs/%s' % args['config']
72 | args['data_dir'] = '../data/%s' % args['dataset']
73 | mkdir_p('../models')
74 | args = init_featurizer(args)
75 | model, loss_criterions, optimizer, scheduler, stopper = load_model(args)
76 | train_loader, val_loader, test_loader = load_dataloader(args)
77 | for epoch in range(args['num_epochs']):
78 | run_a_train_epoch(args, epoch, model, train_loader, loss_criterions, optimizer)
79 | val_loss = run_an_eval_epoch(args, model, val_loader, loss_criterions)
80 | early_stop = stopper.step(val_loss, model)
81 | scheduler.step()
82 | print('epoch %d/%d, validation loss: %.4f' % (epoch + 1, args['num_epochs'], val_loss))
83 | print('epoch %d/%d, Best loss: %.4f' % (epoch + 1, args['num_epochs'], stopper.best_score))
84 | if early_stop:
85 | print ('Early stopped!!')
86 | break
87 |
88 | stopper.load_checkpoint(model)
89 | test_loss = run_an_eval_epoch(args, model, test_loader, loss_criterions)
90 | print('test loss: %.4f' % test_loss)
91 |
92 | if __name__ == '__main__':
93 | parser = ArgumentParser('Training arguements')
94 | parser.add_argument('-g', '--gpu', default='cuda:0', help='GPU device to use')
95 | parser.add_argument('-d', '--dataset', default='USPTO_480k', help='Dataset to use')
96 | parser.add_argument('-c', '--config', default='default_config', help='Configuration of model')
97 | parser.add_argument('-b', '--batch-size', default=16, help='Batch size of dataloader')
98 | parser.add_argument('-n', '--num-epochs', type=int, default=20, help='Maximum number of epochs for training')
99 | parser.add_argument('-m', '--model-name', type=str, default='default', help='Model name')
100 | parser.add_argument('-p', '--patience', type=int, default=3, help='Patience for early stopping')
101 | parser.add_argument('-w', '--negative-weight', type=float, default=0.5, help='Loss weight for negative labels')
102 | parser.add_argument('-s', '--sep', default=False, help='Train the model with reagent seperated or not')
103 | parser.add_argument('-cl', '--max-clip', type=int, default=20, help='Maximum number of gradient clip')
104 | parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4, help='Learning rate of optimizer')
105 | parser.add_argument('-l2', '--weight-decay', type=float, default=1e-6, help='Weight decay of optimizer')
106 | parser.add_argument('-ss', '--schedule-step', type=float, default=6, help='Step size of learning scheduler')
107 | parser.add_argument('-nw', '--num-workers', type=int, default=0, help='Number of processes for data loading')
108 | parser.add_argument('-pe', '--print-every', type=int, default=20, help='Print the training progress every X mini-batches')
109 | args = parser.parse_args().__dict__
110 | args['mode'] = 'train'
111 | args['device'] = torch.device(args['gpu']) if torch.cuda.is_available() else torch.device('cpu')
112 | print ('Using device %s' % args['device'], 'seperate reagent: %s' % args['sep'])
113 | main(args)
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaist-amsg/LocalTransform/1b763f20e4d1df560d15aab2a61291fe0c50fae3/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/dataset.py:
--------------------------------------------------------------------------------
1 | import os, pickle
2 | import numpy as np
3 | import pandas as pd
4 | from tqdm import tqdm
5 | from rdkit import Chem
6 |
7 | import torch
8 | import sklearn
9 | import dgl
10 | import dgl.backend as F
11 | from dgl.data.utils import save_graphs, load_graphs
12 |
13 | def combine_reactants(reactant, reagent):
14 | if str(reagent) == 'nan':
15 | smiles = reactant
16 | else:
17 | smiles = '.'.join([reactant, reagent])
18 | return smiles
19 |
20 | def get_bonds(smiles):
21 | mol = Chem.MolFromSmiles(smiles)
22 | A = [a for a in range(mol.GetNumAtoms())]
23 | B = []
24 | for atom in mol.GetAtoms():
25 | others = []
26 | bonds = atom.GetBonds()
27 | for bond in bonds:
28 | atoms = [bond.GetBeginAtom().GetIdx(), bond.GetEndAtom().GetIdx()]
29 | other = [a for a in atoms if a != atom.GetIdx()][0]
30 | others.append(other)
31 | b = [(atom.GetIdx(), other) for other in sorted(others)]
32 | B += b
33 | V = []
34 | for a in A:
35 | V += [(a,b) for b in A if a != b and (a,b) not in B]
36 | return np.array(V), np.array(B)
37 |
38 | def get_adm(mol, max_distance = 6):
39 | mol_size = mol.GetNumAtoms()
40 | distance_matrix = np.ones((mol_size, mol_size)) * max_distance + 1
41 | dm = Chem.GetDistanceMatrix(mol)
42 | dm[dm > 100] = -1 # remote (different molecule)
43 | dm[dm > max_distance] = max_distance # remote (same molecule)
44 | dm[dm == -1] = max_distance + 1
45 | distance_matrix[:dm.shape[0],:dm.shape[1]] = dm
46 | return distance_matrix
47 |
48 | class USPTODataset(object):
49 | def __init__(self, args, mol_to_graph, node_featurizer, edge_featurizer, load=True, log_every=1000):
50 | df = pd.read_csv('%s/labeled_data.csv' % args['data_dir'])
51 | self.train_ids = df.index[df['Split'] == 'train'].values
52 | self.val_ids = df.index[df['Split'] == 'valid'].values
53 | self.test_ids = df.index[df['Split'] == 'test'].values
54 | self.reactants = df['Reactants'].tolist()
55 | self.reagents = df['Reagents'].tolist()
56 | self.masks = df['Mask'].tolist()
57 |
58 | self.sep = args['sep']
59 | self.fgraph_path = '../data/saved_graphs/full_%s_fgraph.bin' % args['dataset']
60 | if self.sep:
61 | self.labels = [eval(t) for t in df['Labels_sep']]
62 | self.dgraph_path = '../data/saved_graphs/full_%s_dgraph_sep.pkl' % args['dataset'] # changed by numpy matrix for faster processing speed
63 | else:
64 | self.labels = [eval(t) for t in df['Labels_mix']]
65 | self.dgraph_path = '../data/saved_graphs/full_%s_dgraph_mix.pkl' % args['dataset']
66 | self._pre_process(mol_to_graph, node_featurizer, edge_featurizer)
67 |
68 | def _pre_process(self, mol_to_graph, node_featurizer, edge_featurizer):
69 | self.fgraphs_exist = os.path.exists(self.fgraph_path)
70 | self.dgraphs_exist = os.path.exists(self.dgraph_path)
71 | self.fgraphs = []
72 | self.dgraphs = []
73 |
74 | if not self.fgraphs_exist or not self.dgraphs_exist:
75 | for (s1, s2) in tqdm(zip(self.reactants, self.reagents), total=len(self.reactants), desc='Building dgl graphs...'):
76 | smiles = combine_reactants(s1, s2) # s1.s2
77 | mol = Chem.MolFromSmiles(smiles)
78 | if not self.fgraphs_exist:
79 | fgraph = mol_to_graph(mol, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer, canonical_atom_order=False)
80 | self.fgraphs.append(fgraph)
81 | if not self.dgraphs_exist:
82 | if self.sep:
83 | self.dgraphs.append({'atom_distance_matrix': get_adm(mol), 'bonds':get_bonds(s1)})
84 | else:
85 | self.dgraphs.append({'atom_distance_matrix': get_adm(mol), 'bonds':get_bonds(smiles)})
86 |
87 | if self.fgraphs_exist:
88 | print ('Loading feture graphs from %s...' % self.fgraph_path)
89 | self.fgraphs, _ = load_graphs(self.fgraph_path)
90 | else:
91 | save_graphs(self.fgraph_path, self.fgraphs)
92 |
93 | if self.dgraphs_exist:
94 | print ('Loading dense graphs from %s...' % self.dgraph_path)
95 | with open(self.dgraph_path, 'rb') as f:
96 | self.dgraphs = pickle.load(f)
97 | else:
98 | with open(self.dgraph_path, 'wb') as f:
99 | pickle.dump(self.dgraphs, f)
100 |
101 | def __getitem__(self, item):
102 | dgraph = self.dgraphs[item]
103 | dgraph['v_bonds'], dgraph['r_bonds'] = dgraph['bonds']
104 | return self.reactants[item], self.reagents[item], self.fgraphs[item], dgraph, self.labels[item], self.masks[item]
105 |
106 | def __len__(self):
107 | return len(self.reactants)
108 |
109 | class USPTOTestDataset(object):
110 | def __init__(self, args, mol_to_graph, node_featurizer, edge_featurizer, load=True, log_every=1000):
111 | df = pd.read_csv('%s/preprocessed_test.csv' % args['data_dir'])
112 | self.reactants = df['Reactants'].tolist()
113 | self.reagents = df['Reagents'].tolist()
114 | self.sep = args['sep']
115 | self.fgraph_path = '../data/saved_graphs/test_%s_fgraph.bin' % args['dataset']
116 | if self.sep:
117 | self.dgraph_path = '../data/saved_graphs/test_%s_dgraph_sep.pkl' % args['dataset'] # changed by numpy matrix for faster processing speed
118 | else:
119 | self.dgraph_path = '../data/saved_graphs/test_%s_dgraph_mix.pkl' % args['dataset']
120 | self._pre_process(mol_to_graph, node_featurizer, edge_featurizer)
121 |
122 |
123 | def _pre_process(self, mol_to_graph, node_featurizer, edge_featurizer):
124 | self.fgraphs_exist = os.path.exists(self.fgraph_path)
125 | self.dgraphs_exist = os.path.exists(self.dgraph_path)
126 | self.fgraphs = []
127 | self.dgraphs = []
128 |
129 | if not self.fgraphs_exist or not self.dgraphs_exist:
130 | for (s1, s2) in tqdm(zip(self.reactants, self.reagents), total=len(self.reactants), desc='Building dgl graphs...'):
131 | smiles = combine_reactants(s1, s2) # s1.s2
132 | if not self.fgraphs_exist:
133 | mol = Chem.MolFromSmiles(smiles)
134 | fgraph = mol_to_graph(mol, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer, canonical_atom_order=False)
135 | self.fgraphs.append(fgraph)
136 | if not self.dgraphs_exist:
137 | if self.sep:
138 | self.dgraphs.append({'atom_distance_matrix': get_adm(mol), 'bonds':get_bonds(s1)})
139 | else:
140 | self.dgraphs.append({'atom_distance_matrix': get_adm(mol), 'bonds':get_bonds(smiles)})
141 |
142 | if self.fgraphs_exist:
143 | print ('Loading feture graphs from %s...' % self.fgraph_path)
144 | self.fgraphs, _ = load_graphs(self.fgraph_path)
145 | else:
146 | save_graphs(self.fgraph_path, self.fgraphs)
147 |
148 | if self.dgraphs_exist:
149 | print ('Loading dense graphs from %s...' % self.dgraph_path)
150 | with open(self.dgraph_path, 'rb') as f:
151 | self.dgraphs = pickle.load(f)
152 | else:
153 | with open(self.dgraph_path, 'wb') as f:
154 | pickle.dump(self.dgraphs, f)
155 |
156 | def __getitem__(self, item):
157 | dgraph = self.dgraphs[item]
158 | dgraph['v_bonds'], dgraph['r_bonds'] = dgraph['bonds']
159 | return self.reactants[item], self.reagents[item], self.fgraphs[item], dgraph
160 |
161 | def __len__(self):
162 | return len(self.reactants)
163 |
164 |
--------------------------------------------------------------------------------
/scripts/get_edit.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 |
5 | import time
6 | import torch
7 | import torch.nn as nn
8 | import dgl
9 |
10 | from utils import predict
11 |
12 | def get_id_template(a, class_n):
13 | edit_idx = a//class_n
14 | template = a%class_n
15 | return (edit_idx, template)
16 |
17 | def logit2pred(out, top_num):
18 | class_n = out.size(-1)
19 | readout = out.cpu().detach().numpy()
20 | readout = readout.reshape(-1)
21 | output_rank = np.flip(np.argsort(readout))
22 | output_rank = [r for r in output_rank if get_id_template(r, class_n)[1] != 0][:top_num]
23 | selected_site = [get_id_template(a, class_n) for a in output_rank]
24 | selected_proba = [readout[a] for a in output_rank]
25 |
26 | return selected_site, selected_proba
27 |
28 | def get_bond_site(graph, etype):
29 | atom_pair_list = torch.transpose(graph.adjacency_matrix(etype=etype).coalesce().indices(), 0, 1).numpy()
30 | atom_pair_list = [atom_pair_list_v[idx] for idx in pred_vi]
31 |
32 | def combined_edit(virtual_bonds, real_bonds, pred_vi, pred_ri, pred_v, pred_r, top_num):
33 | pred_site_v, pred_score_v = logit2pred(pred_v, top_num)
34 | pred_site_r, pred_score_r = logit2pred(pred_r, top_num)
35 | pooled_virtual_bonds = [virtual_bonds[idx] for idx in pred_vi]
36 | pooled_real_bonds = [real_bonds[idx] for idx in pred_ri]
37 | pred_site_v = [(list(pooled_virtual_bonds[pred_site]), pred_temp) for pred_site, pred_temp in pred_site_v]
38 | pred_site_r = [(list(pooled_real_bonds[pred_site]), pred_temp) for pred_site, pred_temp in pred_site_r]
39 |
40 | pred_sites = pred_site_v + pred_site_r
41 | pred_types = ['v'] * top_num + ['r'] * top_num
42 | pred_scores = pred_score_v + pred_score_r
43 | pred_ranks = np.flip(np.argsort(pred_scores))[:top_num]
44 |
45 | pred_types = [pred_types[r] for r in pred_ranks]
46 | pred_sites = [pred_sites[r] for r in pred_ranks]
47 | pred_scores = [pred_scores[r] for r in pred_ranks]
48 | return pred_types, pred_sites, pred_scores
49 |
50 | def get_bg_partition(bg, bonds_dicts):
51 | gs = dgl.unbatch(bg)
52 | v_sep = [0]
53 | r_sep = [0]
54 | for g, v_bonds, r_bonds in zip(gs, bonds_dicts['virtual'], bonds_dicts['real']):
55 | pooling_size = g.num_nodes()
56 | n_v_bonds = len(v_bonds)
57 | if n_v_bonds < pooling_size:
58 | v_sep.append(v_sep[-1] + n_v_bonds)
59 | else:
60 | v_sep.append(v_sep[-1] + pooling_size)
61 | n_r_bonds = len(r_bonds)
62 | if n_r_bonds < pooling_size:
63 | r_sep.append(r_sep[-1] + n_r_bonds)
64 | else:
65 | r_sep.append(r_sep[-1] + pooling_size)
66 | return v_sep[1:], r_sep[1:]
67 |
68 | def write_edits(args, model, test_loader):
69 | model.eval()
70 | with open(args['result_path'], 'w') as f:
71 | f.write('Test_id\tReactants\tReagents\t%s\n' % '\t'.join(['Prediction %s' % (i+1) for i in range(args['top_num'])]))
72 | with torch.no_grad():
73 | for batch_id, data in enumerate(test_loader):
74 | reactants, reagents, bg, adms, bonds_dicts = data
75 | pred_VT, pred_RT, _, _, pred_VI, pred_RI, _ = predict(args, model, bg, adms, bonds_dicts)
76 | pred_VT = nn.Softmax(dim=1)(pred_VT)
77 | pred_RT = nn.Softmax(dim=1)(pred_RT)
78 | v_sep, r_sep = get_bg_partition(bg, bonds_dicts)
79 | start_v = 0
80 | start_r = 0
81 | print('\rWriting test molecule batch %s/%s' % (batch_id, len(test_loader)), end='', flush=True)
82 | for i, (reactant, reagent) in enumerate(zip(reactants, reagents)):
83 | end_v, end_r = v_sep[i], r_sep[i]
84 | virtual_bonds, real_bonds = bonds_dicts['virtual'][i].numpy(), bonds_dicts['real'][i].numpy()
85 | pred_vi, pred_ri = pred_VI[i].cpu(), pred_RI[i].cpu()
86 | pred_v, pred_r = pred_VT[start_v:end_v], pred_RT[start_r:end_r]
87 |
88 | pred_types, pred_sites, pred_scores = combined_edit(virtual_bonds, real_bonds, pred_vi, pred_ri, pred_v, pred_r, args['top_num'])
89 | test_id = (batch_id * args['batch_size']) + i
90 | f.write('%s\t%s\t%s\t%s\n' % (test_id, reactant, reagent, '\t'.join(['(%s, %s, %s, %.3f)' % (pred_types[i], pred_sites[i][0], pred_sites[i][1], pred_scores[i]) for i in range(args['top_num'])])))
91 | start_v = end_v
92 | start_r = end_r
93 | print ()
94 | return
95 |
--------------------------------------------------------------------------------
/scripts/stopper.py:
--------------------------------------------------------------------------------
1 | # modified from dgllife (https://lifesci.dgl.ai/_modules/dgllife/utils/early_stop.html)
2 | import datetime
3 | import torch
4 |
5 | class EarlyStopping(object):
6 | """Early stop tracker
7 |
8 | Save model checkpoint when observing a performance improvement on
9 | the validation set and early stop if improvement has not been
10 | observed for a particular number of epochs.
11 |
12 | Parameters
13 | ----------
14 | mode : str
15 | * 'higher': Higher metric suggests a better model
16 | * 'lower': Lower metric suggests a better model
17 | If ``metric`` is not None, then mode will be determined
18 | automatically from that.
19 | patience : int
20 | The early stopping will happen if we do not observe performance
21 | improvement for ``patience`` consecutive epochs.
22 | filename : str or None
23 | Filename for storing the model checkpoint. If not specified,
24 | we will automatically generate a file starting with ``early_stop``
25 | based on the current time.
26 | metric : str or None
27 | A metric name that can be used to identify if a higher value is
28 | better, or vice versa. Default to None. Valid options include:
29 | ``'r2'``, ``'mae'``, ``'rmse'``, ``'roc_auc_score'``.
30 |
31 | Examples
32 | --------
33 | Below gives a demo for a fake training process.
34 |
35 | >>> import torch
36 | >>> import torch.nn as nn
37 | >>> from torch.nn import MSELoss
38 | >>> from torch.optim import Adam
39 | >>> from dgllife.utils import EarlyStopping
40 |
41 | >>> model = nn.Linear(1, 1)
42 | >>> criterion = MSELoss()
43 | >>> # For MSE, the lower, the better
44 | >>> stopper = EarlyStopping(mode='lower', filename='test.pth')
45 | >>> optimizer = Adam(params=model.parameters(), lr=1e-3)
46 |
47 | >>> for epoch in range(1000):
48 | >>> x = torch.randn(1, 1) # Fake input
49 | >>> y = torch.randn(1, 1) # Fake label
50 | >>> pred = model(x)
51 | >>> loss = criterion(y, pred)
52 | >>> optimizer.zero_grad()
53 | >>> loss.backward()
54 | >>> optimizer.step()
55 | >>> early_stop = stopper.step(loss.detach().data, model)
56 | >>> if early_stop:
57 | >>> break
58 |
59 | >>> # Load the final parameters saved by the model
60 | >>> stopper.load_checkpoint(model)
61 | """
62 | def __init__(self, args, mode='lower'):
63 | patience=args['patience']
64 | filename=args['model_path']
65 | device=args['device']
66 |
67 | if filename is None:
68 | dt = datetime.datetime.now()
69 | filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format(
70 | dt.date(), dt.hour, dt.minute, dt.second)
71 |
72 | assert mode in ['higher', 'lower']
73 | self.mode = mode
74 | if self.mode == 'higher':
75 | self._check = self._check_higher
76 | else:
77 | self._check = self._check_lower
78 |
79 | self.patience = patience
80 | self.counter = 0
81 | self.timestep = 0
82 | self.filename = filename
83 | self.device = device
84 | self.best_score = None
85 | self.early_stop = False
86 |
87 | def _check_higher(self, score, prev_best_score):
88 | """Check if the new score is higher than the previous best score.
89 |
90 | Parameters
91 | ----------
92 | score : float
93 | New score.
94 | prev_best_score : float
95 | Previous best score.
96 |
97 | Returns
98 | -------
99 | bool
100 | Whether the new score is higher than the previous best score.
101 | """
102 | return score > prev_best_score
103 |
104 | def _check_lower(self, score, prev_best_score):
105 | """Check if the new score is lower than the previous best score.
106 |
107 | Parameters
108 | ----------
109 | score : float
110 | New score.
111 | prev_best_score : float
112 | Previous best score.
113 |
114 | Returns
115 | -------
116 | bool
117 | Whether the new score is lower than the previous best score.
118 | """
119 | return score < prev_best_score
120 |
121 | def step(self, score, model):
122 | """Update based on a new score.
123 |
124 | The new score is typically model performance on the validation set
125 | for a new epoch.
126 |
127 | Parameters
128 | ----------
129 | score : float
130 | New score.
131 | model : nn.Module
132 | Model instance.
133 |
134 | Returns
135 | -------
136 | bool
137 | Whether an early stop should be performed.
138 | """
139 | self.timestep += 1
140 | if self.best_score is None:
141 | self.best_score = score
142 | self.save_checkpoint(model)
143 | elif self._check(score, self.best_score):
144 | self.best_score = score
145 | self.save_checkpoint(model)
146 | self.counter = 0
147 | else:
148 | self.counter += 1
149 | print(
150 | f'EarlyStopping counter: {self.counter} out of {self.patience}')
151 | if self.counter >= self.patience:
152 | self.early_stop = True
153 | return self.early_stop
154 |
155 |
156 | def save_checkpoint(self, model):
157 | '''Saves model when the metric on the validation set gets improved.
158 |
159 | Parameters
160 | ----------
161 | model : nn.Module
162 | Model instance.
163 | '''
164 | torch.save({'model_state_dict': model.state_dict(),
165 | 'timestep': self.timestep}, self.filename)
166 |
167 |
168 | def load_checkpoint(self, model):
169 | '''Load the latest checkpoint
170 |
171 | Parameters
172 | ----------
173 | model : nn.Module
174 | Model instance.
175 | '''
176 | model.load_state_dict(torch.load(self.filename, map_location=self.device)['model_state_dict'])
--------------------------------------------------------------------------------
/scripts/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sklearn
3 | import dgl
4 | import errno
5 | import json
6 | import os, sys
7 | import time
8 | import numpy as np
9 | import pandas as pd
10 | from functools import partial
11 |
12 | import torch.nn as nn
13 | from torch.utils.data import DataLoader
14 | from torch.optim import Adam, lr_scheduler
15 |
16 | from rdkit import Chem
17 | from dgl.data.utils import Subset
18 | from dgllife.utils import WeaveAtomFeaturizer, CanonicalBondFeaturizer, mol_to_bigraph
19 |
20 |
21 | from models import LocalTransform
22 | from dataset import USPTODataset, USPTOTestDataset, combine_reactants
23 | from stopper import EarlyStopping
24 |
25 | def init_featurizer(args):
26 | atom_types = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe',
27 | 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti',
28 | 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb',
29 | 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Ta', 'Tc', 'Ba', 'Bi', 'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir',
30 | 'Ce', 'Gd', 'Ga', 'Cs']
31 | args['node_featurizer'] = WeaveAtomFeaturizer(atom_types = atom_types)
32 | args['edge_featurizer'] = CanonicalBondFeaturizer(self_loop=True)
33 | return args
34 |
35 | def get_configure(args):
36 | with open('%s.json' % args['config_path'], 'r') as f:
37 | config = json.load(f)
38 | config['Template_rn'] = len(pd.read_csv('%s/real_templates.csv' % args['data_dir']))
39 | config['Template_vn'] = len(pd.read_csv('%s/virtual_templates.csv' % args['data_dir']))
40 | args['Template_rn'] = config['Template_rn']
41 | args['Template_vn'] = config['Template_vn']
42 | config['in_node_feats'] = args['node_featurizer'].feat_size()
43 | config['in_edge_feats'] = args['edge_featurizer'].feat_size()
44 | return config
45 |
46 | def mkdir_p(path):
47 | try:
48 | os.makedirs(path)
49 | print('Created directory %s'% path)
50 | except OSError as exc:
51 | if exc.errno == errno.EEXIST and os.path.isdir(path):
52 | print('Directory %s already exists.' % path)
53 | else:
54 | raise
55 |
56 | def load_dataloader(args):
57 | if args['mode'] == 'train':
58 | dataset = USPTODataset(args,
59 | mol_to_graph=partial(mol_to_bigraph, add_self_loop=True),
60 | node_featurizer=args['node_featurizer'],
61 | edge_featurizer=args['edge_featurizer'])
62 |
63 | train_set, val_set, test_set = Subset(dataset, dataset.train_ids), Subset(dataset, dataset.val_ids), Subset(dataset, dataset.test_ids)
64 |
65 | train_loader = DataLoader(dataset=train_set, batch_size=args['batch_size'], shuffle=True,
66 | collate_fn=collate_molgraphs, num_workers=args['num_workers'])
67 | val_loader = DataLoader(dataset=val_set, batch_size=args['batch_size'],
68 | collate_fn=collate_molgraphs, num_workers=args['num_workers'])
69 | test_loader = DataLoader(dataset=test_set, batch_size=args['batch_size'],
70 | collate_fn=collate_molgraphs, num_workers=args['num_workers'])
71 | return train_loader, val_loader, test_loader
72 |
73 | elif args['mode'] == 'test':
74 | test_set = USPTOTestDataset(args,
75 | mol_to_graph=partial(mol_to_bigraph, add_self_loop=True),
76 | node_featurizer=args['node_featurizer'],
77 | edge_featurizer=args['edge_featurizer'])
78 | test_loader = DataLoader(dataset=test_set, batch_size=args['batch_size'],
79 | collate_fn=collate_molgraphs_test, num_workers=args['num_workers'])
80 |
81 | elif args['mode'] == 'analyze':
82 | test_set = USPTOAnaynizeDataset(args,
83 | mol_to_graph=partial(mol_to_bigraph, add_self_loop=True),
84 | node_featurizer=args['node_featurizer'],
85 | edge_featurizer=args['edge_featurizer'])
86 | test_loader = DataLoader(dataset=test_set, batch_size=args['batch_size'],
87 | collate_fn=collate_molgraphs_analysis, num_workers=args['num_workers'])
88 | return test_loader
89 |
90 | def weight_loss(class_num, args, reduction = 'none'):
91 | weights=torch.ones(class_num+1)
92 | weights[0] = args['negative_weight']
93 | return nn.CrossEntropyLoss(weight = weights.to(args['device']), reduction = reduction)
94 |
95 | def load_model(args):
96 | exp_config = get_configure(args)
97 | model = LocalTransform(
98 | node_in_feats=exp_config['in_node_feats'],
99 | edge_in_feats=exp_config['in_edge_feats'],
100 | node_out_feats=exp_config['node_out_feats'],
101 | edge_hidden_feats=exp_config['edge_hidden_feats'],
102 | num_step_message_passing=exp_config['num_step_message_passing'],
103 | attention_heads = exp_config['attention_heads'],
104 | attention_layers = exp_config['attention_layers'],
105 | Template_rn = exp_config['Template_rn'],
106 | Template_vn = exp_config['Template_vn'])
107 | model = model.to(args['device'])
108 | print ('Parameters of loaded LocalTransform:')
109 | print (exp_config)
110 |
111 | if args['mode'] == 'train':
112 | loss_criterions = [weight_loss(1, args), weight_loss(exp_config['Template_vn'], args), weight_loss(exp_config['Template_rn'], args)]
113 |
114 | optimizer = Adam(model.parameters(), lr = args['learning_rate'], weight_decay = args['weight_decay'])
115 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args['schedule_step'])
116 | if os.path.exists(args['model_path']):
117 | user_answer = input('%s exists, want to (a) overlap (b) continue from checkpoint (c) make a new model?' % args['model_path'])
118 | if user_answer == 'a':
119 | stopper = EarlyStopping(args)
120 | print ('Overlap exsited model and training a new model...')
121 | elif user_answer == 'b':
122 | stopper = EarlyStopping(args)
123 | stopper.load_checkpoint(model)
124 | print ('Train from existed model checkpoint...')
125 | elif user_answer == 'c':
126 | model_name = input('Enter new model name: ')
127 | args['model_path'] = '../models/%s.pth' % model_name
128 | stopper = EarlyStopping(args)
129 | print ('Training a new model %s.pth' % model_name)
130 | else:
131 | print ("Input error: please enter a, b or c to specify the model name")
132 | try:
133 | sys.exit(0)
134 | except SystemExit:
135 | os._exit(0)
136 | else:
137 | stopper = EarlyStopping(args)
138 | return model, loss_criterions, optimizer, scheduler, stopper
139 |
140 | else:
141 | model.load_state_dict(torch.load(args['model_path'], map_location=args['device'])['model_state_dict'])
142 | return model
143 |
144 | def pad_atom_distance_matrix(adm_list):
145 | max_size = max([adm.shape[0] for adm in adm_list])
146 | adm_list = [torch.tensor(np.pad(adm, (0, max_size - adm.shape[0]), 'maximum')).unsqueeze(0).long() for adm in adm_list]
147 | return torch.cat(adm_list, dim = 0)
148 |
149 | def make_labels(dgraphs, labels, masks):
150 | vtemplate_labels = []
151 | rtemplate_labels = []
152 | for graph, label, m in zip(dgraphs, labels, masks):
153 | vtemplate_label, rtemplate_label = [0]*len(graph['v_bonds']), [0]*len(graph['r_bonds'])
154 | if m == 1:
155 | for l in label:
156 | edit_type = l[0]
157 | edit_idx = l[1]
158 | edit_template = l[2]
159 | if edit_type == 'v':
160 | vtemplate_label[edit_idx] = edit_template
161 | else:
162 | rtemplate_label[edit_idx] = edit_template
163 |
164 | vtemplate_labels.append(vtemplate_label)
165 | rtemplate_labels.append(rtemplate_label)
166 | return vtemplate_labels, rtemplate_labels
167 |
168 | def get_reactive_template_labels(all_template_labels, masks, top_idxs):
169 | template_labels = []
170 | clipped_masks = []
171 | for i, idxs in enumerate(top_idxs):
172 | template_labels += [all_template_labels[i][idx] for idx in idxs]
173 | clipped_masks += [masks[i] for idx in idxs]
174 | reactive_labels = [int(y > 0) for y in template_labels]
175 | return torch.LongTensor(template_labels), torch.LongTensor(reactive_labels), torch.LongTensor(clipped_masks)
176 |
177 | def collate_molgraphs(data):
178 | reactants, reagents, fgraphs, dgraphs, labels, masks = map(list, zip(*data))
179 | true_VT, true_RT = make_labels(dgraphs, labels, masks)
180 | bg = dgl.batch(fgraphs)
181 | bg.set_n_initializer(dgl.init.zero_initializer)
182 | bg.set_e_initializer(dgl.init.zero_initializer)
183 | adm_lists = [graph['atom_distance_matrix'] for graph in dgraphs]
184 | adms = pad_atom_distance_matrix(adm_lists)
185 | bonds_dicts = {'virtual': [torch.from_numpy(graph['v_bonds']).long() for graph in dgraphs], 'real': [torch.from_numpy(graph['r_bonds']).long() for graph in dgraphs]}
186 | return reactants, bg, adms, bonds_dicts, true_VT, true_RT, masks
187 |
188 | def collate_molgraphs_test(data):
189 | reactants, reagents, fgraphs, dgraphs = map(list, zip(*data))
190 | bg = dgl.batch(fgraphs)
191 | bg.set_n_initializer(dgl.init.zero_initializer)
192 | bg.set_e_initializer(dgl.init.zero_initializer)
193 | adm_lists = [graph['atom_distance_matrix'] for graph in dgraphs]
194 | adms = pad_atom_distance_matrix(adm_lists)
195 | bonds_dicts = {'virtual': [torch.from_numpy(graph['v_bonds']).long() for graph in dgraphs], 'real': [torch.from_numpy(graph['r_bonds']).long() for graph in dgraphs]}
196 | return reactants, reagents, bg, adms, bonds_dicts
197 |
198 |
199 | def predict(args, model, bg, adms, bonds_dicts):
200 | adms = adms.to(args['device'])
201 | bg = bg.to(args['device'])
202 | node_feats = bg.ndata.pop('h').to(args['device'])
203 | edge_feats = bg.edata.pop('e').to(args['device'])
204 | return model(bg, adms, bonds_dicts, node_feats, edge_feats)
205 |
--------------------------------------------------------------------------------