├── .gitignore ├── LICENSE ├── README.md ├── gln ├── __init__.py ├── common │ ├── __init__.py │ ├── cmd_args.py │ ├── consts.py │ ├── evaluate.py │ ├── mol_utils.py │ ├── reactor.py │ └── torch_util.py ├── data_process │ ├── __init__.py │ ├── build_all_reactions.py │ ├── build_raw_template.py │ ├── clean_uspto.py │ ├── data_info.py │ ├── dump_graphs.py │ ├── filter_template.py │ ├── find_centers.py │ ├── get_canonical_smarts.py │ ├── get_canonical_smiles.py │ ├── step0.0_run_get_cano_smiles.sh │ ├── step0.1_run_raw_template_extract.sh │ ├── step1_filter_template.sh │ ├── step2_run_get_cano_smarts.sh │ ├── step3_run_find_centers.sh │ ├── step4_run_find_all_reactions.sh │ ├── step5_prun_dump_neg_graphs.sh │ └── step5_run_dump_graphs.sh ├── graph_logic │ ├── __init__.py │ ├── graph_feat.py │ ├── logic_net.py │ ├── run_test_graph_feat.sh │ └── soft_logic.py ├── mods │ ├── __init__.py │ ├── mol_gnn │ │ ├── __init__.py │ │ ├── gnn_family │ │ │ ├── __init__.py │ │ │ ├── ggnn.py │ │ │ ├── mean_field.py │ │ │ ├── morganfp.py │ │ │ ├── mpnn.py │ │ │ ├── s2v.py │ │ │ └── utils.py │ │ ├── mg_clib │ │ │ ├── Makefile │ │ │ ├── __init__.py │ │ │ ├── default_atoms.txt │ │ │ ├── include │ │ │ │ ├── config.h │ │ │ │ ├── mg_clib.h │ │ │ │ └── mol_utils.h │ │ │ ├── mg_lib.py │ │ │ └── src │ │ │ │ ├── lib │ │ │ │ ├── config.cpp │ │ │ │ └── mol_utils.cpp │ │ │ │ └── mg_clib.cpp │ │ ├── mol_utils.py │ │ └── torch_util.py │ ├── rdchiral │ │ ├── README.md │ │ ├── __init__.py │ │ ├── bonds.py │ │ ├── chiral.py │ │ ├── clean.py │ │ ├── initialization.py │ │ ├── main.py │ │ ├── template_extractor.py │ │ └── utils.py │ └── torchext │ │ ├── __init__.py │ │ ├── jagged_ops.py │ │ └── src │ │ ├── extlib.cpp │ │ ├── extlib.h │ │ ├── extlib_cuda.cpp │ │ ├── extlib_cuda.h │ │ ├── extlib_cuda_kernels.cu │ │ └── extlib_cuda_kernels.h ├── test │ ├── __init__.py │ ├── main_test.py │ ├── model_inference.py │ ├── report_test_stats.py │ ├── test_all.sh │ └── test_single.sh └── training │ ├── __init__.py │ ├── data_gen.py │ ├── main.py │ └── scripts │ └── run_mf.sh ├── pylintrc └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | logwrite 3 | .DS_Store 4 | .module_log 5 | build/ 6 | *.tar.gz 7 | .vscode/ 8 | __pycache__/ 9 | dropbox 10 | dropbox/ 11 | *.so 12 | *.pkl 13 | 14 | *.ipynb_checkpoints/ 15 | 16 | *-result/ 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hanjun Dai, Chengtao Li, Connor Coley, Bo Dai, Le Song 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GLN 2 | Implementation of Retrosynthesis Prediction with Conditional Graph Logic Network 3 | 4 | https://arxiv.org/abs/2001.01408 5 | 6 | # Setup 7 | 8 | - ## Install package 9 | 10 | This package requires the **rdkit** and **pytorch**. 11 | 12 | For rdkit, we tested with [2019_03_3](https://github.com/rdkit/rdkit/releases/tag/Release_2019_03_3). Please build the package from source and setup the `PYTHONPATH` and `LD_LIBRARY_PATH` properly, as this package requires the dynamic lib built from rdkit. 13 | 14 | For pytorch, we tested from 1.2.0 to 1.3.1. Please don't use versions older than that, as this package contains c++ customized ops which relies on specific toolchains from pytorch. 15 | 16 | After the above preparation, simply navigate to the project root folder and install: 17 | 18 | cd GLN 19 | pip install -e . 20 | 21 | Note that by default the cuda ops will not be enabled on Mac OSX. 22 | 23 | - ## Dropbox 24 | 25 | We provide the raw datasets, the cooked (data after preprocessing) datasets, and also the trained model dumps in a dropbox folder. 26 | 27 | [`https://www.dropbox.com/sh/6ideflxcakrak10/AADTbFBC0F8ax55-z-EDgrIza?dl=0`](https://www.dropbox.com/sh/6ideflxcakrak10/AADTbFBC0F8ax55-z-EDgrIza?dl=0) 28 | 29 | The cooked dataset is pretty large. You can also simply download the raw datasets only, and use the script provided in this repo for preprocessing. You don't have to have a dropbox to download, and the result folder doesn't have to be in your dropbox. The only thing needed is to create a symbolic link named `dropbox` and put it in the right place. 30 | 31 | Finally the folder structure will look like this: 32 | 33 | ``` 34 | GLN 35 | |___gln # source code 36 | | |___common # common implementations 37 | | |___... 38 | | 39 | |___setup.py 40 | | 41 | |___dropbox # data, trained model dumps and cooked data, this can be a symbolic link 42 | |___schneider50k # raw data 43 | |___|__raw_train.csv 44 | | |__... 45 | | 46 | |___cooked_schneider50k # cooked data 47 | |___schneider50k.ckpt # trained model dump 48 | ... 49 | ``` 50 | 51 | - **Remark: full USPTO dataset** 52 | 53 | We also released our cleaned USPTO dataset used in the paper via the above dropbox link (see `uspto_multi` folder under the dropbox folder). Meanwhile, the script for cleaning and de-duplication can be found under `gln/data_process/clean_uspto.py`. 54 | The version of USPTO is `1976_Sep2016_USPTOgrants_smiles.rsmi` (which can also be found via above dropbox link). If you run the `clean_uspto.py` on this raw rsmi file, you are expected to get the same data split as we used in the paper. 55 | 56 | # Preprocessing 57 | 58 | If you download the cooked data in the previous step, you can simply skip this step. 59 | 60 | Our paper mainly focused on schneider50k dataset. The raw data and data split is the same as https://github.com/connorcoley/retrosim/blob/master/retrosim/data/get_data.py 61 | 62 | Note that for both the reaction type unknown and type given experiments, we go throught the same steps as below. The only difference is the dataset name: **schneider50k** (type unknown) v.s. **typed_schneider50k** (type given) 63 | 64 | First go to the data processing folder: 65 | ``` 66 | cd gln/data_process 67 | ``` 68 | Then run the script in the following order: 69 | 70 | - **Get canonical smiles** 71 | 72 | Please specify the dataset name accordingly 73 | ``` 74 | ./step0.0_run_get_cano_smiles.sh 75 | ``` 76 | 77 | - **Extract raw templates** 78 | 79 | Please specify the dataset name, as well as #cpu threads (the more the better). 80 | ``` 81 | ./step0.1_run_raw_template_extract.sh 82 | ``` 83 | 84 | 1. **Filter template** 85 | 86 | One can filter out uncommon templates using this script. **We didn't filter out any template in our paper**, though a careful selection of templates may further improve the performance. 87 | 88 | Please specify the dataset name, template name (arbitrary one is fine), and possibly the minimum number of occurance one template needs to have in order to be included (default=1) 89 | 90 | ``` 91 | ./step1_filter_template.sh 92 | ``` 93 | 94 | 2. **Get subgraph SMARTS** 95 | 96 | Please specify the dataset name, template name (same as step 1.) 97 | ``` 98 | ./step2_run_get_cano_smarts.sh 99 | ``` 100 | 101 | 3. **Get feasible centers** 102 | 103 | Please specify the dataset name, template name (same as step 1.), # cpus (the more the better) 104 | ``` 105 | ./step3_run_find_centers.sh 106 | ``` 107 | 108 | 4. **Get the support for the graphical model** 109 | 110 | Please specify the dataset name, template name (same as step 1.), # cpus (the more the better). A 40-core machine would get it done in 15min. 111 | ``` 112 | ./step4_run_find_all_reactions.sh 113 | ``` 114 | 115 | 5. **Get graph feature dumps** 116 | 117 | Please specify the dataset name, template name (same as step 1.) 118 | ``` 119 | ./step5_run_dump_graphs.sh 120 | ``` 121 | 122 | # Training 123 | 124 | To train the model from scratch, first navigate to the training script folder 125 | ``` 126 | cd gln/training/scripts 127 | ``` 128 | Then run the default script `run_mf.sh` with the dataset name. 129 | - To run type unknown model, use `./run_mf.sh schneider50k` 130 | - To run type conditional model, use `./run_mf.sh typed_schneider50k` 131 | 132 | Usually ~10 x 3000 iterations would be able to get reasonable results that match the numbers in the paper. 133 | You are also welcome to tune any hyper-parameters or configurations in the script. 134 | 135 | # Test 136 | 137 | First navigate to the test folder: 138 | ``` 139 | cd gln/test 140 | ``` 141 | 1. **Reproducing results in the paper** 142 | 143 | To test the existing model dump in the dropbox, use the following commands: 144 | - To test pretrained type unknown model, use `./test_single.sh schneider50k` 145 | - To test pretrained type conditional model, use `./test_single.sh typed_schneider50k` 146 | 147 | You can also test whatever single model you want, by changing the `-model_for_test` argument to the model dump you want. 148 | 149 | 2. **Pick models that trained from scratch with best validation** 150 | 151 | The best model is picked with best validation loss. To do so, first get the performance of all model dumps: 152 | - `./test_all.sh schneider50k YOUR_MODEL_DUMP_ROOT` 153 | - `python report_test_stats.py YOUR_MODEL_DUMP_ROOT` 154 | 155 | 156 | # Reference 157 | 158 | If you find our paper/code is useful, please consider citing our paper: 159 | 160 | @inproceedings{dai2019retrosynthesis, 161 | title={Retrosynthesis Prediction with Conditional Graph Logic Network}, 162 | author={Dai, Hanjun and Li, Chengtao and Coley, Connor and Dai, Bo and Song, Le}, 163 | booktitle={Advances in Neural Information Processing Systems}, 164 | pages={8870--8880}, 165 | year={2019} 166 | } 167 | 168 | The orignal version of rdchiral comes from https://github.com/connorcoley/rdchiral, and this repo has made a wrapper over that. Please also cite the corresponding paper if possible: 169 | 170 | @article{coley2019rdchiral, 171 | title={RDChiral: An RDKit Wrapper for Handling Stereochemistry in Retrosynthetic Template Extraction and Application}, 172 | author={Coley, Connor W and Green, William H and Jensen, Klavs F}, 173 | journal={Journal of chemical information and modeling}, 174 | publisher={ACS Publications} 175 | } 176 | -------------------------------------------------------------------------------- /gln/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hanjun-Dai/GLN/b5bd7b181a61a8289cc1d1a33825b2c417bed0ef/gln/__init__.py -------------------------------------------------------------------------------- /gln/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hanjun-Dai/GLN/b5bd7b181a61a8289cc1d1a33825b2c417bed0ef/gln/common/__init__.py -------------------------------------------------------------------------------- /gln/common/cmd_args.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import argparse 6 | import os 7 | import pickle as cp 8 | 9 | cmd_opt = argparse.ArgumentParser(description='Argparser for retrosyn_graph') 10 | cmd_opt.add_argument('-save_dir', default='.', help='result output root') 11 | cmd_opt.add_argument('-dropbox', default=None, help='dropbox folder') 12 | cmd_opt.add_argument('-cooked_root', default=None, help='cooked data root folder') 13 | cmd_opt.add_argument('-init_model_dump', default=None, help='model dump') 14 | cmd_opt.add_argument('-data_name', default=None, help='dataset name') 15 | cmd_opt.add_argument('-tpl_name', default=None, help='template name') 16 | cmd_opt.add_argument('-tpl_min_cnt', default=0, type=int, help='template min cnt (for filtering)') 17 | cmd_opt.add_argument('-phase', default=None, help='phase') 18 | cmd_opt.add_argument('-is_training', default=True, type=eval, help='is training') 19 | cmd_opt.add_argument('-split_mode', default='single', help='single/multi/ignore') 20 | 21 | cmd_opt.add_argument('-bn', default=True, type=eval, help='using bn?') 22 | cmd_opt.add_argument('-file_for_eval', default=None, help='file for evaluation') 23 | cmd_opt.add_argument('-model_for_eval', default=None, help='model for evaluation') 24 | cmd_opt.add_argument('-num_cores', default=1, type=int, help='# cpu cores') 25 | cmd_opt.add_argument('-num_parts', default=1, type=int, help='num of parts to split') 26 | 27 | cmd_opt.add_argument('-part_id', default=0, type=int, help='part id') 28 | cmd_opt.add_argument('-epochs2save', default=1, type=int, help='epochs to save') 29 | cmd_opt.add_argument('-max_neg_reacts', default=0, type=int, help='max neg') 30 | cmd_opt.add_argument('-part_num', default=0, type=int, help='part num') 31 | cmd_opt.add_argument('-eval_func', default='acc', help='acc/mix_f1') 32 | 33 | cmd_opt.add_argument('-neg_sample', default='local', help='local/all') 34 | cmd_opt.add_argument('-num_data_proc', default=0, type=int, help='num of data process') 35 | cmd_opt.add_argument('-topk', default=1, type=int, help='topk eval') 36 | cmd_opt.add_argument('-neg_num', default=-1, type=int, help='num of negative samples') 37 | cmd_opt.add_argument('-beam_size', default=1, type=int, help='beam search size') 38 | cmd_opt.add_argument('-gm', default='mean_field', help='choose gnn module') 39 | cmd_opt.add_argument('-fp_degree', default=0, type=int, help='fingerprint? [>0, 0]') 40 | 41 | cmd_opt.add_argument('-latent_dim', default=64, type=int, help='latent dim of gnn') 42 | cmd_opt.add_argument('-embed_dim', default=128, type=int, help='embedding dim of gnn') 43 | 44 | cmd_opt.add_argument('-mlp_hidden', default=256, type=int, help='hidden dims in mlp') 45 | cmd_opt.add_argument('-seed', default=19260817, type=int, help='seed') 46 | 47 | cmd_opt.add_argument('-max_lv', default=3, type=int, help='# layers of gnn') 48 | cmd_opt.add_argument('-eval_start_idx', default=0, type=int, help='model idx for eval') 49 | 50 | cmd_opt.add_argument('-ggnn_update_type', default='gru', help='use gru or mlp for update state') 51 | cmd_opt.add_argument('-msg_agg_type', default='sum', help='how to aggregate the message') 52 | cmd_opt.add_argument('-att_type', default='inner_prod', help='mlp/inner_prod/bilinear') 53 | 54 | cmd_opt.add_argument('-readout_agg_type', default='sum', help='how to aggregate all node embeddings') 55 | cmd_opt.add_argument('-logic_net', default='gpath', help='gpath/mlp') 56 | 57 | cmd_opt.add_argument('-node_dims', default='128', help='hidden dims for node uptate') 58 | cmd_opt.add_argument('-edge_dims', default='128', help='hidden dims for edge update') 59 | cmd_opt.add_argument('-act_func', default='tanh', help='default activation function') 60 | cmd_opt.add_argument('-gnn_out', default='last', help='last/gru/sum/mean') 61 | cmd_opt.add_argument('-act_last', default=True, type=eval, help='activation of last embedding layer') 62 | cmd_opt.add_argument('-subg_enc', default='mean_field', help='subgraph embedding method') 63 | cmd_opt.add_argument('-tpl_enc', default='deepset', help='template embedding method') 64 | 65 | cmd_opt.add_argument('-neg_local', default=False, type=eval, help='local or global neg reaction?') 66 | 67 | cmd_opt.add_argument('-gnn_share_param', default=False, type=eval, help='share params across layers') 68 | cmd_opt.add_argument('-learning_rate', default=1e-3, type=float, help='learning rate') 69 | cmd_opt.add_argument('-grad_clip', default=5, type=float, help='clip gradient') 70 | cmd_opt.add_argument('-dropout', default=0, type=float, help='dropout') 71 | cmd_opt.add_argument('-fp_dim', default=2048, type=int, help='dim of fp') 72 | cmd_opt.add_argument('-gen_method', default='none', help='none/uniform/weighted') 73 | 74 | cmd_opt.add_argument('-test_during_train', default=False, type=eval, help='do fast testing during training') 75 | cmd_opt.add_argument('-test_mode', default='model', help='model/file') 76 | cmd_opt.add_argument('-num_epochs', default=10000, type=int, help='number of training epochs') 77 | cmd_opt.add_argument('-epochs_per_part', default=1, type=int, help='number of epochs per part') 78 | cmd_opt.add_argument('-iters_per_val', default=1000, type=int, help='number of iterations per evaluation') 79 | cmd_opt.add_argument('-batch_size', default=64, type=int, help='batch size for training') 80 | cmd_opt.add_argument('-retro_during_train', type=eval, default=False, help='doing retrosynthesis during training?') 81 | 82 | 83 | cmd_args, _ = cmd_opt.parse_known_args() 84 | 85 | if cmd_args.save_dir is not None: 86 | if not os.path.isdir(cmd_args.save_dir): 87 | os.makedirs(cmd_args.save_dir) 88 | 89 | from gln.mods.rdchiral.main import rdchiralReaction, rdchiralReactants, rdchiralRun 90 | 91 | print(cmd_args) 92 | -------------------------------------------------------------------------------- /gln/common/consts.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | 6 | import argparse 7 | import logging 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | t_float = torch.float32 13 | np_float = np.float32 14 | str_float = "float32" 15 | 16 | opts = argparse.ArgumentParser(description='gpu option') 17 | opts.add_argument('-gpu', type=int, default=-1, help='-1: cpu; 0 - ?: specific gpu index') 18 | 19 | args, _ = opts.parse_known_args() 20 | if torch.cuda.is_available() and args.gpu >= 0: 21 | DEVICE = torch.device('cuda:' + str(args.gpu)) 22 | print('use gpu indexed: %d' % args.gpu) 23 | else: 24 | DEVICE = torch.device('cpu') 25 | print('use cpu') 26 | 27 | -------------------------------------------------------------------------------- /gln/common/evaluate.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | 6 | import rdkit 7 | from rdkit import Chem 8 | 9 | def canonicalize(smiles): 10 | try: 11 | tmp = Chem.MolFromSmiles(smiles) 12 | except: 13 | print('no mol') 14 | return smiles 15 | if tmp is None: 16 | return smiles 17 | tmp = Chem.RemoveHs(tmp) 18 | [a.ClearProp('molAtomMapNumber') for a in tmp.GetAtoms()] 19 | return Chem.MolToSmiles(tmp) 20 | 21 | def get_weighted_f1(seq_pred, seq_gnd): 22 | if seq_pred is None or seq_gnd is None: 23 | return 0.0 24 | pred = set(seq_pred.split('.')) 25 | gnd = set(seq_gnd.split('.')) 26 | 27 | t = pred.intersection(gnd) 28 | w = len(t) / float(len(gnd)) 29 | precision = len(t) / float(len(pred)) 30 | recall = len(t) / float(len(gnd)) 31 | if precision + recall == 0.0: 32 | return 0.0 33 | return 2 * precision * recall * w / (precision + recall) 34 | 35 | def get_score(pred, gnd, score_type): 36 | x = canonicalize(gnd) 37 | y = canonicalize(pred) 38 | if score_type == 'mix_f1': 39 | f1 = get_weighted_f1(y, x) 40 | score = 0.75 * f1 + 0.25 * (x == y) 41 | elif score_type == 'acc': 42 | score = 1.0 * (x == y) 43 | else: 44 | raise NotImplementedError 45 | return score 46 | -------------------------------------------------------------------------------- /gln/common/mol_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import rdkit 7 | from rdkit import Chem 8 | 9 | def cano_smiles(smiles): 10 | try: 11 | tmp = Chem.MolFromSmiles(smiles) 12 | if tmp is None: 13 | return None, smiles 14 | tmp = Chem.RemoveHs(tmp) 15 | if tmp is None: 16 | return None, smiles 17 | [a.ClearProp('molAtomMapNumber') for a in tmp.GetAtoms()] 18 | return tmp, Chem.MolToSmiles(tmp) 19 | except: 20 | return None, smiles 21 | 22 | 23 | def cano_smarts(smarts): 24 | tmp = Chem.MolFromSmarts(smarts) 25 | if tmp is None: 26 | return None, smarts 27 | [a.ClearProp('molAtomMapNumber') for a in tmp.GetAtoms()] 28 | cano = Chem.MolToSmarts(tmp) 29 | if '[[se]]' in cano: # strange parse error 30 | cano = smarts 31 | return tmp, cano 32 | 33 | 34 | def smarts_has_useless_parentheses(smarts): 35 | if len(smarts) == 0: 36 | return False 37 | if smarts[0] != '(' or smarts[-1] != ')': 38 | return False 39 | cnt = 1 40 | for i in range(1, len(smarts)): 41 | if smarts[i] == '(': 42 | cnt += 1 43 | if smarts[i] == ')': 44 | cnt -= 1 45 | if cnt == 0: 46 | if i + 1 != len(smarts): 47 | return False 48 | return True 49 | -------------------------------------------------------------------------------- /gln/common/reactor.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | 6 | import rdkit 7 | from rdkit import Chem 8 | from gln.common.cmd_args import rdchiralReaction, rdchiralReactants, rdchiralRun 9 | 10 | class _Reactor(object): 11 | 12 | def __init__(self): 13 | self.rxn_cooked = {} 14 | self.src_cooked = {} 15 | self.cached_results = {} 16 | 17 | def get_rxn(self, rxn): 18 | p, a, r = rxn.split('>') 19 | if '.' in p: # we assume the product has only one molecule 20 | if p[0] != '(': 21 | p = '('+p+')' 22 | rxn = '>'.join((p, a, r)) 23 | if not rxn in self.rxn_cooked: 24 | try: 25 | t = rdchiralReaction(rxn) 26 | except: 27 | t = None 28 | self.rxn_cooked[rxn] = t 29 | return self.rxn_cooked[rxn] 30 | 31 | def get_src(self, smiles): 32 | if not smiles in self.src_cooked: 33 | self.src_cooked[smiles] = rdchiralReactants(smiles) 34 | return self.src_cooked[smiles] 35 | 36 | def run_reaction(self, src, template): 37 | key = (src, template) 38 | if key in self.cached_results: 39 | return self.cached_results[key] 40 | rxn = self.get_rxn(template) 41 | src = self.get_src(src) 42 | if rxn is None or src is None: 43 | return None 44 | try: 45 | outcomes = rdchiralRun(rxn, src) 46 | self.cached_results[key] = outcomes 47 | except: 48 | self.cached_results[key] = None 49 | return self.cached_results[key] 50 | 51 | 52 | Reactor = _Reactor() -------------------------------------------------------------------------------- /gln/common/torch_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | from torch.nn.parameter import Parameter 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | 13 | from gln.mods.mol_gnn.torch_util import MLP 14 | 15 | class Lambda(nn.Module): 16 | 17 | def __init__(self, f): 18 | super(Lambda, self).__init__() 19 | self.f = f 20 | 21 | def forward(self, x): 22 | return self.f(x) 23 | 24 | class Swish(nn.Module): 25 | 26 | def __init__(self): 27 | super(Swish, self).__init__() 28 | self.beta = nn.Parameter(torch.tensor(1.0)) 29 | 30 | def forward(self, x): 31 | return x * torch.sigmoid(self.beta * x) 32 | 33 | 34 | NONLINEARITIES = { 35 | "tanh": nn.Tanh(), 36 | "relu": nn.ReLU(), 37 | "softplus": nn.Softplus(), 38 | "sigmoid": nn.Sigmoid(), 39 | "elu": nn.ELU(), 40 | "swish": Swish(), 41 | "square": Lambda(lambda x: x**2), 42 | "identity": Lambda(lambda x: x), 43 | } -------------------------------------------------------------------------------- /gln/data_process/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hanjun-Dai/GLN/b5bd7b181a61a8289cc1d1a33825b2c417bed0ef/gln/data_process/__init__.py -------------------------------------------------------------------------------- /gln/data_process/build_all_reactions.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import os 7 | import rdkit 8 | from rdkit import Chem 9 | import random 10 | import csv 11 | import sys 12 | from itertools import chain 13 | from collections import defaultdict 14 | from gln.common.cmd_args import cmd_args 15 | from gln.data_process.data_info import DataInfo, load_train_reactions 16 | from tqdm import tqdm 17 | from gln.common.reactor import Reactor 18 | from collections import Counter 19 | 20 | import multiprocessing 21 | from rdkit import rdBase 22 | rdBase.DisableLog('rdApp.error') 23 | rdBase.DisableLog('rdApp.warning') 24 | 25 | 26 | def find_tpls(cur_task): 27 | idx, (rxn_type, rxn) = cur_task 28 | reactants, _, raw_prod = rxn.split('>') 29 | 30 | prod = DataInfo.get_cano_smiles(raw_prod) 31 | 32 | if not (rxn_type, prod) in DataInfo.prod_center_maps: 33 | return None 34 | reactants = DataInfo.get_cano_smiles(reactants) 35 | prod_center_cand_idx = DataInfo.prod_center_maps[(rxn_type, prod)] 36 | 37 | neg_reactants = set() 38 | pos_tpl_idx = {} 39 | tot_tpls = 0 40 | for center_idx in prod_center_cand_idx: 41 | c = DataInfo.prod_cano_smarts[center_idx] 42 | assert c in DataInfo.unique_tpl_of_prod_center 43 | 44 | tpl_indices = DataInfo.unique_tpl_of_prod_center[c][rxn_type] 45 | tot_tpls += len(tpl_indices) 46 | for tpl_idx in tpl_indices: 47 | cur_t, tpl = DataInfo.unique_templates[tpl_idx] 48 | assert cur_t == rxn_type 49 | pred_mols = Reactor.run_reaction(prod, tpl) 50 | if pred_mols is None or len(pred_mols) == 0: 51 | continue 52 | for pred in pred_mols: 53 | if pred != reactants: 54 | neg_reactants.add(pred) 55 | else: 56 | pos_tpl_idx[tpl_idx] = (len(tpl_indices), len(pred_mols)) 57 | return (idx, pos_tpl_idx, neg_reactants) 58 | 59 | 60 | def get_writer(fname, header): 61 | f = open(os.path.join(cmd_args.save_dir, 'np-%d' % cmd_args.num_parts, fname), 'w') 62 | writer = csv.writer(f) 63 | writer.writerow(header) 64 | return f, writer 65 | 66 | 67 | if __name__ == '__main__': 68 | random.seed(cmd_args.seed) 69 | np.random.seed(cmd_args.seed) 70 | DataInfo.init(cmd_args.dropbox, cmd_args) 71 | 72 | fn_pos = lambda idx: get_writer('pos_tpls-part-%d.csv' % idx, ['tpl_idx', 'pos_tpl_idx', 'num_tpl_compete', 'num_react_compete']) 73 | fn_neg = lambda idx: get_writer('neg_reacts-part-%d.csv' % idx, ['sample_idx', 'neg_reactants']) 74 | 75 | if cmd_args.num_parts <= 0: 76 | num_parts = cmd_args.num_cores 77 | DataInfo.load_cooked_part('train', load_graphs=False) 78 | else: 79 | num_parts = cmd_args.num_parts 80 | 81 | train_reactions = load_train_reactions(cmd_args) 82 | n_train = len(train_reactions) 83 | part_size = n_train // num_parts + 1 84 | 85 | if cmd_args.part_num > 0: 86 | prange = range(cmd_args.part_id, cmd_args.part_id + cmd_args.part_num) 87 | else: 88 | prange = range(num_parts) 89 | for pid in prange: 90 | f_pos, writer_pos = fn_pos(pid) 91 | f_neg, writer_neg = fn_neg(pid) 92 | if cmd_args.num_parts > 0: 93 | DataInfo.load_cooked_part('train', part=pid, load_graphs=False) 94 | part_tasks = [] 95 | idx_range = list(range(pid * part_size, min((pid + 1) * part_size, n_train))) 96 | for i in idx_range: 97 | part_tasks.append((i, train_reactions[i])) 98 | 99 | pool = multiprocessing.Pool(cmd_args.num_cores) 100 | for result in tqdm(pool.imap_unordered(find_tpls, part_tasks), total=len(idx_range)): 101 | if result is None: 102 | continue 103 | idx, pos_tpl_idx, neg_reactions = result 104 | idx = str(idx) 105 | neg_keys = neg_reactions 106 | 107 | if cmd_args.max_neg_reacts > 0: 108 | neg_keys = list(neg_keys) 109 | random.shuffle(neg_keys) 110 | neg_keys = neg_keys[:cmd_args.max_neg_reacts] 111 | for pred in neg_keys: 112 | writer_neg.writerow([idx, pred]) 113 | for key in pos_tpl_idx: 114 | nt, np = pos_tpl_idx[key] 115 | writer_pos.writerow([idx, key, nt, np]) 116 | f_pos.flush() 117 | f_neg.flush() 118 | f_pos.close() 119 | f_neg.close() 120 | pool.close() 121 | pool.join() 122 | -------------------------------------------------------------------------------- /gln/data_process/build_raw_template.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import csv 7 | import os 8 | from tqdm import tqdm 9 | import pickle as cp 10 | import multiprocessing 11 | from gln.common.cmd_args import cmd_args 12 | from gln.mods.rdchiral.template_extractor import extract_from_reaction 13 | 14 | def get_writer(fname, header): 15 | output_name = os.path.join(cmd_args.save_dir, fname) 16 | fout = open(output_name, 'w') 17 | writer = csv.writer(fout) 18 | writer.writerow(header) 19 | return fout, writer 20 | 21 | def get_tpl(task): 22 | idx, row_idx, rxn_smiles = task 23 | react, reagent, prod = rxn_smiles.split('>') 24 | reaction = {'_id': row_idx, 'reactants': react, 'products': prod} 25 | template = extract_from_reaction(reaction) 26 | return idx, template 27 | 28 | if __name__ == '__main__': 29 | fname = os.path.join(cmd_args.dropbox, cmd_args.data_name, 'raw_train.csv') 30 | with open(fname, 'r') as f: 31 | reader = csv.reader(f) 32 | header = next(reader) 33 | rows = [row for row in reader] 34 | 35 | pool = multiprocessing.Pool(cmd_args.num_cores) 36 | tasks = [] 37 | for idx, row in tqdm(enumerate(rows)): 38 | row_idx, _, rxn_smiles = row 39 | tasks.append((idx, row_idx, rxn_smiles)) 40 | 41 | fout, writer = get_writer('proc_train_singleprod.csv', ['id', 'class', 'rxn_smiles', 'retro_templates']) 42 | fout_failed, failed_writer = get_writer('failed_template.csv', ['id', 'class', 'rxn_smiles', 'err_msg']) 43 | 44 | for result in tqdm(pool.imap_unordered(get_tpl, tasks), total=len(tasks)): 45 | idx, template = result 46 | row_idx, rxn_type, rxn_smiles = rows[idx] 47 | 48 | if 'reaction_smarts' in template: 49 | writer.writerow([row_idx, rxn_type, rxn_smiles, template['reaction_smarts']]) 50 | fout.flush() 51 | else: 52 | failed_writer.writerow([row_idx, rxn_type, rxn_smiles, template['err_msg']]) 53 | fout_failed.flush() 54 | 55 | fout.close() 56 | fout_failed.close() 57 | -------------------------------------------------------------------------------- /gln/data_process/clean_uspto.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import random 7 | import csv 8 | import os 9 | import sys 10 | import re 11 | from tqdm import tqdm 12 | from rdkit import Chem 13 | import pickle as cp 14 | 15 | 16 | def get_rxn_smiles(prod, reactants): 17 | prod_smi = Chem.MolToSmiles(prod, True) 18 | 19 | # Get rid of reactants when they don't contribute to this prod 20 | prod_maps = set(re.findall('\:([[0-9]+)\]', prod_smi)) 21 | reactants_smi_list = [] 22 | for mol in reactants: 23 | if mol is None: 24 | continue 25 | used = False 26 | for a in mol.GetAtoms(): 27 | if a.HasProp('molAtomMapNumber'): 28 | if a.GetProp('molAtomMapNumber') in prod_maps: 29 | used = True 30 | else: 31 | a.ClearProp('molAtomMapNumber') 32 | if used: 33 | reactants_smi_list.append(Chem.MolToSmiles(mol, True)) 34 | 35 | reactants_smi = '.'.join(reactants_smi_list) 36 | return '{}>>{}'.format(reactants_smi, prod_smi) 37 | 38 | 39 | if __name__ == '__main__': 40 | seed = 19260817 41 | np.random.seed(seed) 42 | random.seed(seed) 43 | fname = sys.argv[1] 44 | split_mode = 'multi' # single or multi 45 | 46 | pt = re.compile(r':(\d+)]') 47 | cnt = 0 48 | clean_list = [] 49 | set_rxn = set() 50 | num_single = 0 51 | num_multi = 0 52 | bad_mapping = 0 53 | bad_prod = 0 54 | missing_map = 0 55 | raw_num = 0 56 | with open(fname, 'r') as f: 57 | reader = csv.reader(f, delimiter='\t') 58 | header = next(reader) 59 | print(header) 60 | pbar = tqdm(reader) 61 | bad_rxn = 0 62 | for row in pbar: 63 | rxn_smiles = row[header.index('ReactionSmiles')] 64 | all_reactants, reagents, prods = rxn_smiles.split('>') 65 | all_reactants = all_reactants.split()[0] # remove ' |f:1...' 66 | prods = prods.split()[0] # remove ' |f:1...' 67 | if '.' in prods: 68 | num_multi += 1 69 | else: 70 | num_single += 1 71 | if split_mode == 'single' and '.' in prods: # multiple prods 72 | continue 73 | rids = ','.join(sorted(re.findall(pt, all_reactants))) 74 | pids = ','.join(sorted(re.findall(pt, prods))) 75 | if rids != pids: # mapping is not 1:1 76 | bad_mapping += 1 77 | continue 78 | reactants = [Chem.MolFromSmiles(smi) for smi in all_reactants.split('.')] 79 | 80 | for sub_prod in prods.split('.'): 81 | mol_prod = Chem.MolFromSmiles(sub_prod) 82 | if mol_prod is None: # rdkit is not able to parse the product 83 | bad_prod += 1 84 | continue 85 | # Make sure all have atom mapping 86 | if not all([a.HasProp('molAtomMapNumber') for a in mol_prod.GetAtoms()]): 87 | missing_map += 1 88 | continue 89 | 90 | raw_num += 1 91 | rxn_smiles = get_rxn_smiles(mol_prod, reactants) 92 | if not rxn_smiles in set_rxn: 93 | clean_list.append((row[header.index('PatentNumber')], rxn_smiles)) 94 | set_rxn.add(rxn_smiles) 95 | pbar.set_description('select: %d, dup: %d' % (len(clean_list), raw_num)) 96 | print('# clean', len(clean_list)) 97 | print('single', num_single, 'multi', num_multi) 98 | print('bad mapping', bad_mapping) 99 | print('bad prod', bad_prod) 100 | print('missing map', missing_map) 101 | print('raw extracted', raw_num) 102 | 103 | random.shuffle(clean_list) 104 | 105 | num_val = num_test = int(len(clean_list) * 0.1) 106 | 107 | out_folder = '.' 108 | for phase in ['val', 'test', 'train']: 109 | fout = os.path.join(out_folder, 'raw_%s.csv' % phase) 110 | with open(fout, 'w') as f: 111 | writer = csv.writer(f) 112 | writer.writerow(['id', 'reactants>reagents>production']) 113 | 114 | if phase == 'val': 115 | r = range(num_val) 116 | elif phase == 'test': 117 | r = range(num_val, num_val + num_test) 118 | else: 119 | r = range(num_val + num_test, len(clean_list)) 120 | for i in r: 121 | rxn_smiles = clean_list[i][1].split('>') 122 | result = [] 123 | for r in rxn_smiles: 124 | if len(r.strip()): 125 | r = r.split()[0] 126 | result.append(r) 127 | rxn_smiles = '>'.join(result) 128 | writer.writerow([clean_list[i][0], rxn_smiles]) 129 | -------------------------------------------------------------------------------- /gln/data_process/data_info.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | from tqdm import tqdm 6 | import csv 7 | import os 8 | import pickle as cp 9 | from collections import defaultdict 10 | import numpy as np 11 | 12 | from gln.common.mol_utils import cano_smarts, cano_smiles 13 | from gln.common.cmd_args import cmd_args 14 | from gln.common.evaluate import canonicalize 15 | from gln.common.mol_utils import smarts_has_useless_parentheses 16 | from gln.mods.mol_gnn.mol_utils import SmilesMols, SmartsMols 17 | 18 | 19 | def load_bin_feats(dropbox, args): 20 | print('loading smiles feature dump') 21 | file_root = os.path.join(dropbox, 'cooked_' + args.data_name, 'tpl-%s' % args.tpl_name) 22 | SmartsMols.set_fp_degree(args.fp_degree) 23 | load_feats = args.subg_enc != 'ecfp' or args.tpl_enc != 'onehot' 24 | load_fp = args.subg_enc == 'ecfp' 25 | SmartsMols.load_dump(os.path.join(file_root, 'graph_smarts'), load_feats=load_feats, load_fp=load_fp) 26 | SmilesMols.set_fp_degree(args.fp_degree) 27 | SmilesMols.load_dump(os.path.join(file_root, '../graph_smiles'), load_feats=args.gm != 'ecfp', load_fp=args.gm == 'ecfp') 28 | 29 | 30 | def load_center_maps(fname): 31 | prod_center_maps = {} 32 | with open(fname, 'r') as f: 33 | reader = csv.reader(f) 34 | header = next(reader) 35 | 36 | for row in tqdm(reader): 37 | smiles, rxn_type, indices = row 38 | indices = [int(t) for t in indices.split()] 39 | prod_center_maps[(rxn_type, smiles)] = indices 40 | avg_sizes = [len(prod_center_maps[key]) for key in prod_center_maps] 41 | print('average # centers per mol:', np.mean(avg_sizes)) 42 | return prod_center_maps 43 | 44 | 45 | def load_train_reactions(args): 46 | train_reactions = [] 47 | raw_data_root = os.path.join(args.dropbox, args.data_name) 48 | with open(os.path.join(raw_data_root, 'raw_train.csv'), 'r') as f: 49 | reader = csv.reader(f) 50 | header = next(reader) 51 | pos = header.index('reactants>reagents>production') if 'reactants>reagents>production' in header else -1 52 | c_idx = header.index('class') 53 | for row in reader: 54 | train_reactions.append((row[c_idx], row[pos])) 55 | print('# raw train loaded', len(train_reactions)) 56 | return train_reactions 57 | 58 | 59 | class DataInfo(object): 60 | 61 | @classmethod 62 | def load_cooked_part(cls, phase, part, load_graphs=True): 63 | args = cls.args 64 | load_feats = args.gm != 'ecfp' 65 | load_fp = not load_feats 66 | if cls.cur_part is not None and cls.cur_part == part: 67 | return 68 | file_root = os.path.join(args.dropbox, 'cooked_' + args.data_name, 'tpl-%s' % args.tpl_name, 'np-%d' % args.num_parts) 69 | assert phase == 'train' 70 | # load neg reactant features 71 | if load_graphs and args.retro_during_train: 72 | if cls.cur_part is not None: 73 | SmilesMols.remove_dump(os.path.join(file_root, 'neg_graphs-part-%d' % cls.cur_part)) 74 | SmilesMols.load_dump(os.path.join(file_root, 'neg_graphs-part-%d' % part), additive=True, load_feats=load_feats, load_fp=load_fp) 75 | 76 | if args.gen_method != 'none': # load pos-tpl map 77 | print('loading positive tpls') 78 | cls.train_pos_maps = defaultdict(list) 79 | fname = 'pos_tpls-part-%d.csv' % part 80 | with open(os.path.join(file_root, fname), 'r') as f: 81 | reader = csv.reader(f) 82 | header = next(reader) 83 | for row in reader: 84 | tpl_idx = int(row[0]) 85 | cls.train_pos_maps[tpl_idx].append((int(row[1]), int(row[2]))) 86 | print('# pos tpls', len(cls.train_pos_maps)) 87 | for key in cls.train_pos_maps: 88 | pos = cls.train_pos_maps[key] 89 | weights = np.array([1.0 / float(x[1]) for x in pos]) 90 | weights /= np.sum(weights) 91 | tpls = [x[0] for x in pos] 92 | cls.train_pos_maps[key] = (tpls, weights) 93 | else: 94 | cls.train_pos_maps = None 95 | 96 | if args.retro_during_train: # load negative reactions 97 | print('loading negative reactions') 98 | cls.neg_reacts_ids = {} 99 | cls.neg_reacts_list = [] 100 | cls.neg_reactions_all = defaultdict(set) 101 | fname = 'neg_reacts.csv' if part is None else 'neg_reacts-part-%d.csv' % part 102 | with open(os.path.join(file_root, fname), 'r') as f: 103 | reader = csv.reader(f) 104 | header = next(reader) 105 | for row in tqdm(reader): 106 | sample_idx, reacts = row 107 | if not reacts in cls.neg_reacts_ids: 108 | idx = len(cls.neg_reacts_ids) 109 | cls.neg_reacts_ids[reacts] = idx 110 | cls.neg_reacts_list.append(reacts) 111 | idx = cls.neg_reacts_ids[reacts] 112 | cls.neg_reactions_all[int(row[0])].add(idx) 113 | for key in cls.neg_reactions_all: 114 | cls.neg_reactions_all[key] = list(cls.neg_reactions_all[key]) 115 | 116 | cls.prod_center_maps = {} 117 | print('loading training prod center maps') 118 | fname = 'train-prod_center_maps-part-%d.csv' % part 119 | fname = os.path.join(file_root, fname) 120 | cls.prod_center_maps = load_center_maps(fname) 121 | 122 | cls.cur_part = part 123 | 124 | @classmethod 125 | def init(cls, dropbox, args): 126 | cls.args = args 127 | cls.args.dropbox = dropbox 128 | file_root = os.path.join(dropbox, 'cooked_' + args.data_name, 'tpl-%s' % args.tpl_name) 129 | print('loading data info from', file_root) 130 | 131 | # load training 132 | tpl_file = os.path.join(file_root, 'templates.csv') 133 | 134 | cls.unique_templates = set() 135 | print('loading templates') 136 | with open(tpl_file, 'r') as f: 137 | reader = csv.reader(f) 138 | header = next(reader) 139 | tpl_idx = header.index('retro_templates') 140 | rt_idx = header.index('class') 141 | for row in tqdm(reader): 142 | tpl = row[tpl_idx] 143 | center, r_a, r_c = tpl.split('>') 144 | if smarts_has_useless_parentheses(center): 145 | center = center[1:-1] 146 | tpl = '>'.join([center, r_a, r_c]) 147 | rxn_type = row[rt_idx] 148 | cls.unique_templates.add((rxn_type, tpl)) 149 | cls.unique_templates = sorted(list(cls.unique_templates)) 150 | cls.idx_of_template = {} 151 | for i, tpl in enumerate(cls.unique_templates): 152 | cls.idx_of_template[tpl] = i 153 | print('# unique templates', len(cls.unique_templates)) 154 | 155 | with open(os.path.join(file_root, '../cano_smiles.pkl'), 'rb') as f: 156 | cls.smiles_cano_map = cp.load(f) 157 | 158 | with open(os.path.join(file_root, 'cano_smarts.pkl'), 'rb') as f: 159 | cls.smarts_cano_map = cp.load(f) 160 | 161 | with open(os.path.join(file_root, 'prod_cano_smarts.txt'), 'r') as f: 162 | cls.prod_cano_smarts = [row.strip() for row in f.readlines()] 163 | 164 | cls.prod_smarts_idx = {} 165 | for i in range(len(cls.prod_cano_smarts)): 166 | cls.prod_smarts_idx[cls.prod_cano_smarts[i]] = i 167 | 168 | cls.unique_tpl_of_prod_center = defaultdict(lambda: defaultdict(list)) 169 | for i, row in enumerate(cls.unique_templates): 170 | rxn_type, tpl = row 171 | center = tpl.split('>')[0] 172 | cano_center = cls.smarts_cano_map[center] 173 | cls.unique_tpl_of_prod_center[cano_center][rxn_type].append(i) 174 | 175 | cls.cur_part = None 176 | 177 | @classmethod 178 | def get_cano_smiles(cls, smiles): 179 | if smiles in cls.smiles_cano_map: 180 | return cls.smiles_cano_map[smiles] 181 | ans = canonicalize(smiles) 182 | cls.smiles_cano_map[smiles] = ans 183 | return ans 184 | 185 | @classmethod 186 | def get_cano_smarts(cls, smarts): 187 | if smarts in cls.smarts_cano_map: 188 | return cls.smarts_cano_map[smarts] 189 | ans = cano_smarts(smarts)[1] 190 | cls.smarts_cano_map[smarts] = ans 191 | return ans 192 | -------------------------------------------------------------------------------- /gln/data_process/dump_graphs.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import rdkit 7 | from rdkit import Chem 8 | import csv 9 | import sys 10 | import os 11 | from tqdm import tqdm 12 | import pickle as cp 13 | from collections import defaultdict 14 | from gln.common.cmd_args import cmd_args 15 | from gln.mods.mol_gnn.mol_utils import SmartsMols, SmilesMols 16 | 17 | 18 | if __name__ == '__main__': 19 | file_root = os.path.join(cmd_args.dropbox, 'cooked_' + cmd_args.data_name, 'tpl-%s' % cmd_args.tpl_name) 20 | if cmd_args.fp_degree > 0: 21 | SmilesMols.set_fp_degree(cmd_args.fp_degree) 22 | SmartsMols.set_fp_degree(cmd_args.fp_degree) 23 | 24 | if cmd_args.retro_during_train: 25 | part_folder = os.path.join(file_root, 'np-%d' % cmd_args.num_parts) 26 | if cmd_args.part_num > 0: 27 | prange = range(cmd_args.part_id, cmd_args.part_id + cmd_args.part_num) 28 | else: 29 | prange = range(cmd_args.num_parts) 30 | for pid in prange: 31 | with open(os.path.join(part_folder, 'neg_reacts-part-%d.csv' % pid), 'r') as f: 32 | reader = csv.reader(f) 33 | header = next(reader) 34 | for row in tqdm(reader): 35 | reacts = row[-1] 36 | for t in reacts.split('.'): 37 | SmilesMols.get_mol_graph(t) 38 | SmilesMols.get_mol_graph(reacts) 39 | SmilesMols.save_dump(os.path.join(part_folder, 'neg_graphs-part-%d' % pid)) 40 | SmilesMols.clear() 41 | sys.exit() 42 | 43 | with open(os.path.join(file_root, '../cano_smiles.pkl'), 'rb') as f: 44 | smiles_cano_map = cp.load(f) 45 | 46 | with open(os.path.join(file_root, 'prod_cano_smarts.txt'), 'r') as f: 47 | prod_cano_smarts = [row.strip() for row in f.readlines()] 48 | 49 | with open(os.path.join(file_root, 'react_cano_smarts.txt'), 'r') as f: 50 | react_cano_smarts = [row.strip() for row in f.readlines()] 51 | 52 | 53 | for mol in tqdm(smiles_cano_map): 54 | SmilesMols.get_mol_graph(smiles_cano_map[mol]) 55 | SmilesMols.save_dump(os.path.join(cmd_args.save_dir, '../graph_smiles')) 56 | 57 | for smarts in tqdm(prod_cano_smarts + react_cano_smarts): 58 | SmartsMols.get_mol_graph(smarts) 59 | SmartsMols.save_dump(cmd_args.save_dir + '/graph_smarts') 60 | -------------------------------------------------------------------------------- /gln/data_process/filter_template.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import os 6 | import sys 7 | import csv 8 | from tqdm import tqdm 9 | from gln.common.cmd_args import cmd_args 10 | from collections import Counter, defaultdict 11 | 12 | 13 | if __name__ == '__main__': 14 | proc_file = os.path.join(cmd_args.save_dir, '../proc_train_singleprod.csv') 15 | 16 | unique_tpls = Counter() 17 | tpl_types = defaultdict(set) 18 | with open(proc_file, 'r') as f: 19 | reader = csv.reader(f) 20 | header = next(reader) 21 | print(header) 22 | for row in tqdm(reader): 23 | tpl = row[header.index('retro_templates')] 24 | rxn_type = row[header.index('class')] 25 | tpl_types[tpl].add(rxn_type) 26 | unique_tpls[tpl] += 1 27 | 28 | print('total # templates', len(unique_tpls)) 29 | 30 | used_tpls = [] 31 | for x in unique_tpls: 32 | if unique_tpls[x] >= cmd_args.tpl_min_cnt: 33 | used_tpls.append(x) 34 | print('num templates after filtering', len(used_tpls)) 35 | 36 | out_file = os.path.join(cmd_args.save_dir, 'templates.csv') 37 | with open(out_file, 'w') as f: 38 | writer = csv.writer(f) 39 | writer.writerow(['class', 'retro_templates']) 40 | for x in used_tpls: 41 | for t in tpl_types[x]: 42 | writer.writerow([t, x]) 43 | -------------------------------------------------------------------------------- /gln/data_process/find_centers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import rdkit 7 | from rdkit import Chem 8 | import csv 9 | import os 10 | from tqdm import tqdm 11 | import pickle as cp 12 | from collections import defaultdict 13 | from gln.common.cmd_args import cmd_args 14 | from gln.common.mol_utils import cano_smarts, cano_smiles, smarts_has_useless_parentheses 15 | import multiprocessing 16 | 17 | 18 | def find_edges(task): 19 | idx, rxn_type, smiles = task 20 | smiles = smiles_cano_map[smiles] 21 | 22 | mol = Chem.MolFromSmiles(smiles) 23 | if mol is None: 24 | return idx, rxn_type, smiles, None 25 | list_centers = [] 26 | for i, (sm_center, center_mol) in enumerate(prod_center_mols): 27 | if center_mol is None: 28 | continue 29 | if not rxn_type in smarts_type_set[sm_center]: 30 | continue 31 | if mol.HasSubstructMatch(center_mol): 32 | list_centers.append(str(i)) 33 | if len(list_centers) == 0: 34 | return idx, rxn_type, smiles, None 35 | centers = ' '.join(list_centers) 36 | return idx, rxn_type, smiles, centers 37 | 38 | 39 | if __name__ == '__main__': 40 | with open(os.path.join(cmd_args.save_dir, '../cano_smiles.pkl'), 'rb') as f: 41 | smiles_cano_map = cp.load(f) 42 | 43 | with open(os.path.join(cmd_args.save_dir, 'cano_smarts.pkl'), 'rb') as f: 44 | smarts_cano_map = cp.load(f) 45 | 46 | with open(os.path.join(cmd_args.save_dir, 'prod_cano_smarts.txt'), 'r') as f: 47 | prod_cano_smarts = [row.strip() for row in f.readlines()] 48 | 49 | prod_center_mols = [] 50 | for sm in tqdm(prod_cano_smarts): 51 | prod_center_mols.append((sm, Chem.MolFromSmarts(sm))) 52 | 53 | print('num of prod centers', len(prod_center_mols)) 54 | print('num of smiles', len(smiles_cano_map)) 55 | 56 | csv_file = os.path.join(cmd_args.save_dir, 'templates.csv') 57 | 58 | smarts_type_set = defaultdict(set) 59 | with open(csv_file, 'r') as f: 60 | reader = csv.reader(f) 61 | header = next(reader) 62 | tpl_idx = header.index('retro_templates') 63 | type_idx = header.index('class') 64 | for row in reader: 65 | rxn_type = row[type_idx] 66 | template = row[tpl_idx] 67 | sm_prod, _, _ = template.split('>') 68 | if smarts_has_useless_parentheses(sm_prod): 69 | sm_prod = sm_prod[1:-1] 70 | sm_prod = smarts_cano_map[sm_prod] 71 | smarts_type_set[sm_prod].add(rxn_type) 72 | 73 | if cmd_args.num_parts <= 0: 74 | num_parts = cmd_args.num_cores 75 | else: 76 | num_parts = cmd_args.num_parts 77 | 78 | pool = multiprocessing.Pool(cmd_args.num_cores) 79 | 80 | raw_data_root = os.path.join(cmd_args.dropbox, cmd_args.data_name) 81 | for out_phase in ['train', 'val', 'test']: 82 | csv_file = os.path.join(raw_data_root, 'raw_%s.csv' % out_phase) 83 | 84 | rxn_smiles = [] 85 | with open(csv_file, 'r') as f: 86 | reader = csv.reader(f) 87 | header = next(reader) 88 | rxn_idx = header.index('reactants>reagents>production') 89 | type_idx = header.index('class') 90 | for row in tqdm(reader): 91 | rxn_smiles.append((row[type_idx], row[rxn_idx])) 92 | 93 | part_size = min(len(rxn_smiles) // num_parts + 1, len(rxn_smiles)) 94 | 95 | for pid in range(num_parts): 96 | idx_range = range(pid * part_size, min((pid + 1) * part_size, len(rxn_smiles))) 97 | 98 | local_results = [None] * len(idx_range) 99 | 100 | tasks = [] 101 | for i, idx in enumerate(idx_range): 102 | rxn_type, rxn = rxn_smiles[idx] 103 | reactants, _, prod = rxn.split('>') 104 | tasks.append((i, rxn_type, prod)) 105 | for result in tqdm(pool.imap_unordered(find_edges, tasks), total=len(tasks)): 106 | i, rxn_type, smiles, centers = result 107 | local_results[i] = (rxn_type, smiles, centers) 108 | out_folder = os.path.join(cmd_args.save_dir, 'np-%d' % num_parts) 109 | if not os.path.isdir(out_folder): 110 | os.makedirs(out_folder) 111 | fout = open(os.path.join(out_folder, '%s-prod_center_maps-part-%d.csv' % (out_phase, pid)), 'w') 112 | writer = csv.writer(fout) 113 | writer.writerow(['smiles', 'class', 'centers']) 114 | 115 | for i in range(len(local_results)): 116 | rxn_type, smiles, centers = local_results[i] 117 | if centers is not None: 118 | writer.writerow([smiles, rxn_type, centers]) 119 | fout.close() 120 | -------------------------------------------------------------------------------- /gln/data_process/get_canonical_smarts.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import rdkit 7 | from rdkit import Chem 8 | import csv 9 | import os 10 | from tqdm import tqdm 11 | import pickle as cp 12 | from collections import defaultdict 13 | from gln.common.cmd_args import cmd_args 14 | from gln.common.mol_utils import cano_smarts, cano_smiles, smarts_has_useless_parentheses 15 | 16 | 17 | def process_centers(): 18 | prod_cano_smarts = set() 19 | react_cano_smarts = set() 20 | 21 | smarts_cano_map = {} 22 | pbar = tqdm(retro_templates) 23 | for template in pbar: 24 | sm_prod, _, sm_react = template.split('>') 25 | if smarts_has_useless_parentheses(sm_prod): 26 | sm_prod = sm_prod[1:-1] 27 | 28 | smarts_cano_map[sm_prod] = cano_smarts(sm_prod)[1] 29 | prod_cano_smarts.add(smarts_cano_map[sm_prod]) 30 | 31 | for r_smarts in sm_react.split('.'): 32 | smarts_cano_map[r_smarts] = cano_smarts(r_smarts)[1] 33 | react_cano_smarts.add(smarts_cano_map[r_smarts]) 34 | pbar.set_description('# prod centers: %d, # react centers: %d' % (len(prod_cano_smarts), len(react_cano_smarts))) 35 | print('# prod centers: %d, # react centers: %d' % (len(prod_cano_smarts), len(react_cano_smarts))) 36 | 37 | with open(os.path.join(cmd_args.save_dir, 'prod_cano_smarts.txt'), 'w') as f: 38 | for s in prod_cano_smarts: 39 | f.write('%s\n' % s) 40 | with open(os.path.join(cmd_args.save_dir, 'react_cano_smarts.txt'), 'w') as f: 41 | for s in react_cano_smarts: 42 | f.write('%s\n' % s) 43 | with open(os.path.join(cmd_args.save_dir, 'cano_smarts.pkl'), 'wb') as f: 44 | cp.dump(smarts_cano_map, f, cp.HIGHEST_PROTOCOL) 45 | 46 | 47 | if __name__ == '__main__': 48 | tpl_file = os.path.join(cmd_args.save_dir, 'templates.csv') 49 | 50 | retro_templates = [] 51 | 52 | with open(tpl_file, 'r') as f: 53 | reader = csv.reader(f) 54 | header = next(reader) 55 | 56 | for row in tqdm(reader): 57 | retro_templates.append(row[header.index('retro_templates')]) 58 | 59 | raw_data_root = os.path.join(cmd_args.dropbox, cmd_args.data_name) 60 | rxn_smiles = [] 61 | for phase in ['train', 'val', 'test']: 62 | csv_file = os.path.join(raw_data_root, 'raw_%s.csv' % phase) 63 | with open(csv_file, 'r') as f: 64 | reader = csv.reader(f) 65 | header = next(reader) 66 | rxn_idx = header.index('reactants>reagents>production') 67 | for row in tqdm(reader): 68 | rxn_smiles.append(row[rxn_idx]) 69 | 70 | process_centers() 71 | -------------------------------------------------------------------------------- /gln/data_process/get_canonical_smiles.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import rdkit 7 | from rdkit import Chem 8 | import csv 9 | import os 10 | from tqdm import tqdm 11 | import pickle as cp 12 | from collections import defaultdict 13 | from gln.common.cmd_args import cmd_args 14 | from gln.common.mol_utils import cano_smarts, cano_smiles, smarts_has_useless_parentheses 15 | 16 | 17 | def process_smiles(): 18 | all_symbols = set() 19 | 20 | smiles_cano_map = {} 21 | for rxn in tqdm(rxn_smiles): 22 | reactants, _, prod = rxn.split('>') 23 | mols = reactants.split('.') + [prod] 24 | for sm in mols: 25 | m, cano_sm = cano_smiles(sm) 26 | if m is not None: 27 | for a in m.GetAtoms(): 28 | all_symbols.add((a.GetAtomicNum(), a.GetSymbol())) 29 | if sm in smiles_cano_map: 30 | assert smiles_cano_map[sm] == cano_sm 31 | else: 32 | smiles_cano_map[sm] = cano_sm 33 | print('num of smiles', len(smiles_cano_map)) 34 | set_mols = set() 35 | for s in smiles_cano_map: 36 | set_mols.add(smiles_cano_map[s]) 37 | print('# unique smiles', len(set_mols)) 38 | with open(os.path.join(cmd_args.save_dir, 'cano_smiles.pkl'), 'wb') as f: 39 | cp.dump(smiles_cano_map, f, cp.HIGHEST_PROTOCOL) 40 | print('# unique atoms:', len(all_symbols)) 41 | all_symbols = sorted(list(all_symbols)) 42 | with open(os.path.join(cmd_args.save_dir, 'atom_list.txt'), 'w') as f: 43 | for a in all_symbols: 44 | f.write('%d\n' % a[0]) 45 | 46 | 47 | if __name__ == '__main__': 48 | 49 | raw_data_root = os.path.join(cmd_args.dropbox, cmd_args.data_name) 50 | rxn_smiles = [] 51 | for phase in ['train', 'val', 'test']: 52 | csv_file = os.path.join(raw_data_root, 'raw_%s.csv' % phase) 53 | with open(csv_file, 'r') as f: 54 | reader = csv.reader(f) 55 | header = next(reader) 56 | rxn_idx = header.index('reactants>reagents>production') 57 | for row in tqdm(reader): 58 | rxn_smiles.append(row[rxn_idx]) 59 | 60 | process_smiles() 61 | 62 | -------------------------------------------------------------------------------- /gln/data_process/step0.0_run_get_cano_smiles.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | dropbox=../../dropbox 5 | data=schneider50k 6 | 7 | save_dir=$dropbox/cooked_$data 8 | 9 | 10 | python get_canonical_smiles.py \ 11 | -dropbox $dropbox \ 12 | -data_name $data \ 13 | -save_dir $save_dir 14 | -------------------------------------------------------------------------------- /gln/data_process/step0.1_run_raw_template_extract.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | dropbox=../../dropbox 5 | data=schneider50k 6 | num_cores=4 7 | 8 | save_dir=$dropbox/cooked_${data} 9 | 10 | if [ ! -e $save_dir ]; 11 | then 12 | mkdir -p $save_dir 13 | fi 14 | 15 | python build_raw_template.py \ 16 | -dropbox $dropbox \ 17 | -data_name $data \ 18 | -save_dir $save_dir \ 19 | -num_cores $num_cores \ 20 | -------------------------------------------------------------------------------- /gln/data_process/step1_filter_template.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | dropbox=../../dropbox 5 | data=schneider50k 6 | version=default 7 | 8 | save_dir=$dropbox/cooked_${data}/tpl-$version 9 | 10 | if [ ! -e $save_dir ]; 11 | then 12 | mkdir -p $save_dir 13 | fi 14 | 15 | python filter_template.py \ 16 | -dropbox $dropbox \ 17 | -data_name $data \ 18 | -tpl_name $version \ 19 | -save_dir $save_dir \ 20 | -------------------------------------------------------------------------------- /gln/data_process/step2_run_get_cano_smarts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | dropbox=../../dropbox 5 | data=schneider50k 6 | tpl=default 7 | 8 | save_dir=$dropbox/cooked_$data/tpl-$tpl 9 | 10 | 11 | python get_canonical_smarts.py \ 12 | -dropbox $dropbox \ 13 | -data_name $data \ 14 | -save_dir $save_dir \ 15 | -tpl_name $tpl \ 16 | -------------------------------------------------------------------------------- /gln/data_process/step3_run_find_centers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dropbox=../../dropbox 4 | data=schneider50k 5 | tpl=default 6 | num_cores=8 7 | num_parts=1 8 | 9 | save_dir=$dropbox/cooked_$data/tpl-$tpl 10 | 11 | python find_centers.py \ 12 | -dropbox $dropbox \ 13 | -data_name $data \ 14 | -tpl_name $tpl \ 15 | -save_dir $save_dir \ 16 | -num_cores $num_cores \ 17 | -num_parts $num_parts \ 18 | $@ 19 | -------------------------------------------------------------------------------- /gln/data_process/step4_run_find_all_reactions.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dropbox=../../dropbox 4 | data_name=schneider50k 5 | tpl_name=default 6 | num_cores=4 7 | num_parts=1 8 | 9 | save_dir=$dropbox/cooked_$data_name/tpl-$tpl_name 10 | 11 | 12 | python build_all_reactions.py \ 13 | -dropbox $dropbox \ 14 | -phase cooking \ 15 | -data_name $data_name \ 16 | -save_dir $save_dir \ 17 | -tpl_name $tpl_name \ 18 | -f_atoms $dropbox/cooked_$data_name/atom_list.txt \ 19 | -num_cores $num_cores \ 20 | -num_parts $num_parts \ 21 | -gpu -1 \ 22 | $@ 23 | 24 | -------------------------------------------------------------------------------- /gln/data_process/step5_prun_dump_neg_graphs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dropbox='../../dropbox' 4 | data=uspto_full 5 | tpl=new 6 | tpl_min_cnt=4 7 | fp_degree=2 8 | num_parts=40 9 | cooked_root=../../cooked_data 10 | 11 | save_dir=$cooked_root/$data/tpl-$tpl-mincnt-$tpl_min_cnt 12 | 13 | N=$num_parts 14 | ( 15 | for ((s=0;s<40;s+=1)) 16 | do 17 | ((i=i%N)); ((i++==0)) && wait 18 | 19 | python dump_graphs.py \ 20 | -dropbox $dropbox \ 21 | -data_name $data \ 22 | -tpl_min_cnt $tpl_min_cnt \ 23 | -cooked_root $cooked_root \ 24 | -tpl_name $tpl \ 25 | -save_dir $save_dir \ 26 | -f_atoms $cooked_root/$data/atom_list.txt \ 27 | -num_parts $num_parts \ 28 | -part_id $s \ 29 | -part_num 1 \ 30 | -fp_degree $fp_degree \ 31 | -retro_during_train True \ 32 | & 33 | 34 | done 35 | ) 36 | -------------------------------------------------------------------------------- /gln/data_process/step5_run_dump_graphs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dropbox=../../dropbox 4 | data=schneider50k 5 | tpl=default 6 | fp_degree=2 7 | num_parts=1 8 | 9 | save_dir=$dropbox/cooked_$data/tpl-$tpl 10 | 11 | for r in False True; do 12 | 13 | python dump_graphs.py \ 14 | -dropbox $dropbox \ 15 | -data_name $data \ 16 | -tpl_name $tpl \ 17 | -save_dir $save_dir \ 18 | -f_atoms $dropbox/cooked_$data/atom_list.txt \ 19 | -num_parts $num_parts \ 20 | -fp_degree $fp_degree \ 21 | -retro_during_train $r \ 22 | $@ 23 | 24 | done 25 | -------------------------------------------------------------------------------- /gln/graph_logic/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | from gln.mods.mol_gnn.gnn_family import EmbedMeanField, GGNN, MPNN, MorganFp, S2vMeanFieldV2 6 | from gln.mods.mol_gnn.mg_clib import NUM_NODE_FEATS, NUM_EDGE_FEATS 7 | 8 | 9 | def get_gnn(args, gm=None): 10 | if gm is None: 11 | gm = args.gm 12 | if gm == 'mean_field': 13 | gnn = EmbedMeanField(latent_dim=args.latent_dim, 14 | output_dim=args.embed_dim, 15 | num_node_feats=NUM_NODE_FEATS, 16 | num_edge_feats=NUM_EDGE_FEATS, 17 | max_lv=args.max_lv, 18 | act_func=args.act_func, 19 | readout_agg=args.readout_agg_type, 20 | act_last=args.act_last, 21 | dropout=args.dropout) 22 | elif gm == 's2v_v2': 23 | gnn = S2vMeanFieldV2(latent_dim=args.latent_dim, 24 | output_dim=args.embed_dim, 25 | num_node_feats=NUM_NODE_FEATS, 26 | num_edge_feats=NUM_EDGE_FEATS, 27 | max_lv=args.max_lv, 28 | act_func=args.act_func, 29 | readout_agg=args.readout_agg_type, 30 | act_last=args.act_last, 31 | out_method=args.gnn_out, 32 | bn=args.bn, 33 | dropout=args.dropout) 34 | elif gm == 'ggnn': 35 | gnn = GGNN(node_state_dim=args.latent_dim, 36 | output_dims=[args.embed_dim], 37 | edge_hidden_sizes=[args.latent_dim], 38 | num_node_feats=NUM_NODE_FEATS, 39 | num_edge_feats=NUM_EDGE_FEATS, 40 | max_lv=args.max_lv, 41 | msg_aggregate_type=args.msg_agg_type, 42 | readout_agg=args.readout_agg_type, 43 | share_params=args.gnn_share_param, 44 | act_func=args.act_func, 45 | dropout=args.dropout) 46 | elif gm == 'mpnn': 47 | gnn = MPNN(latent_dim=args.latent_dim, 48 | output_dim=args.embed_dim, 49 | num_node_feats=NUM_NODE_FEATS, 50 | num_edge_feats=NUM_EDGE_FEATS, 51 | max_lv=args.max_lv, 52 | msg_aggregate_type=args.msg_agg_type, 53 | act_func=args.act_func, 54 | dropout=args.dropout) 55 | elif gm == 'ecfp': 56 | gnn = MorganFp(feat_dim=args.fp_dim, 57 | hidden_size=args.embed_dim, 58 | num_hidden=1, 59 | feat_mode='dense', 60 | act_func=args.act_func, 61 | dropout=args.dropout) 62 | else: 63 | raise NotImplementedError 64 | return gnn 65 | -------------------------------------------------------------------------------- /gln/graph_logic/graph_feat.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import os 7 | import rdkit 8 | from rdkit import Chem 9 | import csv 10 | from gln.mods.mol_gnn.torch_util import MLP 11 | from gln.mods.mol_gnn.gnn_family.utils import get_agg 12 | from gln.mods.mol_gnn.mol_utils import SmartsMols, SmilesMols 13 | from gln.common.consts import DEVICE 14 | from gln.data_process.data_info import DataInfo 15 | from gln.mods.mol_gnn.mg_clib import NUM_NODE_FEATS, NUM_EDGE_FEATS 16 | 17 | from gln.graph_logic import get_gnn 18 | 19 | import torch 20 | from torch.autograd import Variable 21 | from torch.nn.parameter import Parameter 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | import torch.optim as optim 25 | 26 | 27 | class TempFeaturizer(nn.Module): 28 | def __init__(self, args): 29 | super(TempFeaturizer, self).__init__() 30 | 31 | 32 | class DeepsetTempFeaturizer(TempFeaturizer): 33 | def __init__(self, args): 34 | super(DeepsetTempFeaturizer, self).__init__(args) 35 | self.prod_gnn = get_gnn(args) 36 | self.react_gnn = get_gnn(args) 37 | self.reactants_agg = get_agg('sum') 38 | 39 | self.readout = MLP(args.embed_dim * 2, [args.mlp_hidden, args.embed_dim], nonlinearity='relu') 40 | 41 | def forward(self, template_list): 42 | list_prod = [] 43 | list_rxtants = [] 44 | rxtant_indices = [] 45 | for i, temp in enumerate(template_list): 46 | prod, _, reacts = temp.split('>') 47 | prod = DataInfo.get_cano_smarts(prod) 48 | list_prod.append(SmartsMols.get_mol_graph(prod)) 49 | reacts = reacts.split('.') 50 | for r in reacts: 51 | r = DataInfo.get_cano_smarts(r) 52 | list_rxtants.append(SmartsMols.get_mol_graph(r)) 53 | 54 | rxtant_indices += [i] * len(reacts) 55 | 56 | prods, _ = self.prod_gnn(list_prod) 57 | rxtants, _ = self.react_gnn(list_rxtants) 58 | rxtants = F.relu(rxtants) 59 | 60 | rxtant_indices = torch.LongTensor(rxtant_indices).to(DEVICE) 61 | rxtants = self.reactants_agg(rxtants, rxtant_indices.view(-1, 1).expand(-1, rxtants.shape[1]), 62 | dim=0, dim_size=len(template_list)) 63 | 64 | feats = torch.cat((prods, rxtants), dim=1) 65 | out = self.readout(feats) 66 | return out 67 | 68 | 69 | class DeepsetReactFeaturizer(nn.Module): 70 | 71 | def __init__(self, args): 72 | super(DeepsetReactFeaturizer, self).__init__() 73 | 74 | self.react_gnn = get_gnn(args) 75 | self.reactants_agg = get_agg('sum') 76 | 77 | def forward(self, reaction_list): 78 | list_prod = [] 79 | list_cata = [] 80 | list_rxtants = [] 81 | rxtant_indices = [] 82 | for i, react in enumerate(reaction_list): 83 | reactants, cata, prod = react.split('>') 84 | list_prod.append(SmilesMols.get_mol_graph(prod)) 85 | list_cata.append(SmilesMols.get_mol_graph(cata)) 86 | reactants = reactants.split('.') 87 | for r in reactants: 88 | list_rxtants.append(SmilesMols.get_mol_graph(r)) 89 | 90 | rxtant_indices += [i] * len(reactants) 91 | 92 | rxtants, _ = self.react_gnn(list_rxtants) 93 | 94 | rxtant_indices = torch.LongTensor(rxtant_indices).to(DEVICE) 95 | rxtants = self.reactants_agg(rxtants, rxtant_indices.view(-1, 1).expand(-1, rxtants.shape[1]), 96 | dim=0, dim_size=len(reaction_list)) 97 | return rxtants 98 | -------------------------------------------------------------------------------- /gln/graph_logic/logic_net.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import os 7 | import rdkit 8 | from rdkit import Chem 9 | import csv 10 | from gln.mods.mol_gnn.mol_utils import SmartsMols, SmilesMols 11 | from gln.common.consts import DEVICE, t_float 12 | from torch_scatter import scatter_max, scatter_add, scatter_mean 13 | from gln.graph_logic.graph_feat import get_gnn 14 | from gln.graph_logic.soft_logic import OnehotEmbedder, CenterProbCalc, ActiveProbCalc, ReactionProbCalc 15 | 16 | import torch 17 | from gln.mods.torchext import jagged_log_softmax 18 | 19 | from torch.autograd import Variable 20 | from torch.nn.parameter import Parameter 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | import torch.optim as optim 24 | from gln.data_process.data_info import DataInfo 25 | 26 | 27 | class GraphPath(nn.Module): 28 | def __init__(self, args): 29 | super(GraphPath, self).__init__() 30 | 31 | # predict the template 32 | self.tpl_fwd_predicate = ActiveProbCalc(args) 33 | # predict the center 34 | self.prod_center_predicate = CenterProbCalc(args) 35 | # predict the entire reaction 36 | self.reaction_predicate = ReactionProbCalc(args) 37 | self.retro_during_train = args.retro_during_train 38 | 39 | def forward(self, samples): 40 | prods = [] 41 | list_of_list_centers = [] 42 | list_of_list_tpls = [] 43 | list_of_list_reacts = [] 44 | 45 | for sample in samples: 46 | prods.append(SmilesMols.get_mol_graph(sample.prod)) 47 | 48 | list_centers = [sample.center] + sample.neg_centers 49 | list_of_list_centers.append([SmartsMols.get_mol_graph(c) for c in list_centers]) 50 | 51 | list_tpls = [sample.template] + sample.neg_tpls 52 | list_of_list_tpls.append(list_tpls) 53 | if self.retro_during_train: 54 | list_reacts = [sample.reaction] + sample.neg_reactions 55 | list_of_list_reacts.append(list_reacts) 56 | 57 | center_log_prob = self.prod_center_predicate(prods, list_of_list_centers) 58 | tpl_log_prob = self.tpl_fwd_predicate(prods, list_of_list_tpls) 59 | 60 | loss = -torch.mean(center_log_prob) - torch.mean(tpl_log_prob) 61 | if self.retro_during_train: 62 | react_log_prob = self.reaction_predicate(prods, list_of_list_reacts) 63 | loss = loss - torch.mean(react_log_prob) 64 | 65 | return loss 66 | -------------------------------------------------------------------------------- /gln/graph_logic/run_test_graph_feat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dropbox=../../dropbox 4 | data_name=schneider50k 5 | tpl_name=default_tpl 6 | 7 | python graph_feat.py \ 8 | -dropbox $dropbox \ 9 | -data_name $data_name \ 10 | -tpl_name $tpl_name \ 11 | -f_atoms $dropbox/cooked_$data_name/atom_list.txt \ 12 | $@ 13 | 14 | -------------------------------------------------------------------------------- /gln/graph_logic/soft_logic.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import os 7 | import rdkit 8 | from rdkit import Chem 9 | import csv 10 | from gln.mods.mol_gnn.mol_utils import SmartsMols 11 | from gln.mods.mol_gnn.torch_util import MLP, glorot_uniform 12 | 13 | from gln.common.consts import DEVICE 14 | from gln.data_process.data_info import DataInfo 15 | from torch_scatter import scatter_max, scatter_add, scatter_mean 16 | from gln.mods.torchext import jagged_log_softmax 17 | from gln.graph_logic import get_gnn 18 | from gln.graph_logic.graph_feat import DeepsetTempFeaturizer, DeepsetReactFeaturizer 19 | 20 | import torch 21 | from torch.autograd import Variable 22 | from torch.nn.parameter import Parameter 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | import torch.optim as optim 26 | 27 | 28 | def jagged_forward(list_graph, list_of_list_cand, graph_enc, cand_enc, att_func, list_target_pos=None, normalize=True): 29 | graph_embed = graph_enc(list_graph) 30 | 31 | flat_cands = [] 32 | rep_indices = [] 33 | prefix_sum = [] 34 | offset = 0 35 | for i, l in enumerate(list_of_list_cand): 36 | for c in l: 37 | flat_cands.append(c) 38 | rep_indices += [i] * len(l) 39 | offset += len(l) 40 | prefix_sum.append(offset) 41 | 42 | cand_embed = cand_enc(flat_cands) 43 | rep_indices = torch.LongTensor(rep_indices).to(DEVICE) 44 | prefix_sum = torch.LongTensor(prefix_sum).to(DEVICE) 45 | 46 | graph_embed = torch.gather(graph_embed, 0, rep_indices.view(-1, 1).expand(-1, graph_embed.shape[1])) 47 | 48 | logits = att_func(graph_embed, cand_embed) 49 | if normalize: 50 | log_prob = jagged_log_softmax(logits, prefix_sum) 51 | else: 52 | log_prob = logits 53 | 54 | if list_target_pos is None: 55 | return log_prob 56 | 57 | offset = 0 58 | target_pos = [] 59 | for i, l in enumerate(list_of_list_cand): 60 | idx = list_target_pos[i] 61 | target_pos.append(offset + idx) 62 | offset += len(l) 63 | target_pos = torch.LongTensor(target_pos).to(DEVICE) 64 | 65 | return log_prob[target_pos] 66 | 67 | 68 | class ReactionProbCalc(nn.Module): 69 | def __init__(self, args): 70 | super(ReactionProbCalc, self).__init__() 71 | 72 | self.prod_enc = get_gnn(args) 73 | self.react_enc = DeepsetReactFeaturizer(args) 74 | if args.att_type == 'inner_prod': 75 | self.att_func = lambda x, y: torch.sum(x * y, dim=1).view(-1) 76 | elif args.att_type == 'mlp': 77 | self.pred = MLP(2 * args.embed_dim, [args.mlp_hidden, 1], nonlinearity='relu') 78 | self.att_func = lambda x, y: self.pred(torch.cat((x, y), dim=1)).view(-1) 79 | elif args.att_type == 'bilinear': 80 | self.bilin = nn.Bilinear(args.embed_dim, args.embed_dim, 1) 81 | self.att_func = lambda x, y: self.bilin(x, y).view(-1) 82 | else: 83 | raise NotImplementedError 84 | 85 | def forward(self, list_mols, list_of_list_reactions, list_target_pos=None): 86 | if list_target_pos is None: 87 | list_target_pos = [0] * len(list_mols) 88 | 89 | log_prob = jagged_forward(list_mols, list_of_list_reactions, 90 | graph_enc=lambda x: self.prod_enc(x)[0], 91 | cand_enc=lambda x: self.react_enc(x), 92 | att_func=self.att_func, 93 | list_target_pos=list_target_pos) 94 | return log_prob 95 | 96 | def inference(self, list_mols, list_of_list_reactions): 97 | logits = jagged_forward(list_mols, list_of_list_reactions, 98 | graph_enc=lambda x: self.prod_enc(x)[0], 99 | cand_enc=lambda x: self.react_enc(x), 100 | att_func=self.att_func, 101 | list_target_pos=None, 102 | normalize=False) 103 | return logits 104 | 105 | 106 | class OnehotEmbedder(nn.Module): 107 | def __init__(self, list_keys, fn_getkey, embed_size): 108 | super(OnehotEmbedder, self).__init__() 109 | self.key_idx = {} 110 | for i, key in enumerate(list_keys): 111 | self.key_idx[key] = i 112 | self.embed_size = embed_size 113 | self.fn_getkey = fn_getkey 114 | self.embedding = nn.Embedding(len(list_keys) + 1, embed_size) 115 | glorot_uniform(self) 116 | 117 | def forward(self, list_objs): 118 | indices = [] 119 | for obj in list_objs: 120 | key = self.fn_getkey(obj) 121 | if key is None: 122 | indices.append(len(self.key_idx)) 123 | else: 124 | indices.append(self.key_idx[key]) 125 | indices = torch.LongTensor(indices).to(DEVICE) 126 | return self.embedding(indices) 127 | 128 | 129 | class ActiveProbCalc(nn.Module): 130 | def __init__(self, args): 131 | super(ActiveProbCalc, self).__init__() 132 | self.prod_enc = get_gnn(args) 133 | if args.tpl_enc == 'deepset': 134 | self.tpl_enc = DeepsetTempFeaturizer(args) 135 | elif args.tpl_enc == 'onehot': 136 | self.tpl_enc = OnehotEmbedder(list_keys=DataInfo.unique_templates, 137 | fn_getkey=lambda x: x, 138 | embed_size=args.embed_dim) 139 | else: 140 | raise NotImplementedError 141 | if args.att_type == 'inner_prod': 142 | self.att_func = lambda x, y: torch.sum(x * y, dim=1).view(-1) 143 | elif args.att_type == 'mlp': 144 | self.pred = MLP(2 * args.embed_dim, [args.mlp_hidden, 1], nonlinearity='relu') 145 | self.att_func = lambda x, y: self.pred(torch.cat((x, y), dim=1)).view(-1) 146 | elif args.att_type == 'bilinear': 147 | self.bilin = nn.Bilinear(args.embed_dim, args.embed_dim, 1) 148 | self.att_func = lambda x, y: self.bilin(x, y).view(-1) 149 | else: 150 | raise NotImplementedError 151 | 152 | def forward(self, list_mols, list_of_list_templates, list_target_pos=None): 153 | if list_target_pos is None: 154 | list_target_pos = [0] * len(list_mols) 155 | log_prob = jagged_forward(list_mols, list_of_list_templates, 156 | graph_enc=lambda x: self.prod_enc(x)[0], 157 | cand_enc=lambda x: self.tpl_enc(x), 158 | att_func=self.att_func, 159 | list_target_pos=list_target_pos) 160 | 161 | return log_prob 162 | 163 | def inference(self, list_mols, list_of_list_templates): 164 | logits = jagged_forward(list_mols, list_of_list_templates, 165 | graph_enc=lambda x: self.prod_enc(x)[0], 166 | cand_enc=lambda x: self.tpl_enc(x), 167 | att_func=self.att_func, 168 | list_target_pos=None, 169 | normalize=False) 170 | return logits 171 | 172 | 173 | class CenterProbCalc(nn.Module): 174 | def __init__(self, args): 175 | super(CenterProbCalc, self).__init__() 176 | self.prod_enc = get_gnn(args) 177 | if args.subg_enc == 'onehot': 178 | self.prod_center_enc = OnehotEmbedder(list_keys=DataInfo.prod_cano_smarts, 179 | fn_getkey=lambda m: m.name if m is not None else None, 180 | embed_size=args.embed_dim) 181 | self.prod_embed_func = lambda x: self.prod_center_enc(x) 182 | else: 183 | self.prod_center_enc = get_gnn(args, gm=args.subg_enc) 184 | self.prod_embed_func = lambda x: self.prod_center_enc(x)[0] 185 | if args.att_type == 'inner_prod': 186 | self.att_func = lambda x, y: torch.sum(x * y, dim=1).view(-1) 187 | elif args.att_type == 'mlp': 188 | self.pred = MLP(2 * args.embed_dim, [args.mlp_hidden, 1], nonlinearity=args.act_func) 189 | self.att_func = lambda x, y: self.pred(torch.cat((x, y), dim=1)).view(-1) 190 | elif args.att_type == 'bilinear': 191 | self.bilin = nn.Bilinear(args.embed_dim, args.embed_dim, 1) 192 | self.att_func = lambda x, y: self.bilin(x, y).view(-1) 193 | else: 194 | raise NotImplementedError 195 | 196 | def forward(self, list_mols, list_of_list_centers, list_target_pos=None): 197 | if list_target_pos is None: 198 | list_target_pos = [0] * len(list_mols) 199 | log_prob = jagged_forward(list_mols, list_of_list_centers, 200 | graph_enc=lambda x: self.prod_enc(x)[0], 201 | cand_enc=lambda x: self.prod_embed_func(x), 202 | att_func=self.att_func, 203 | list_target_pos=list_target_pos) 204 | 205 | return log_prob 206 | 207 | def inference(self, list_mols, list_of_list_centers): 208 | logits = jagged_forward(list_mols, list_of_list_centers, 209 | graph_enc=lambda x: self.prod_enc(x)[0], 210 | cand_enc=lambda x: self.prod_embed_func(x), 211 | att_func=self.att_func, 212 | list_target_pos=None, 213 | normalize=False) 214 | return logits 215 | -------------------------------------------------------------------------------- /gln/mods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hanjun-Dai/GLN/b5bd7b181a61a8289cc1d1a33825b2c417bed0ef/gln/mods/__init__.py -------------------------------------------------------------------------------- /gln/mods/mol_gnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hanjun-Dai/GLN/b5bd7b181a61a8289cc1d1a33825b2c417bed0ef/gln/mods/mol_gnn/__init__.py -------------------------------------------------------------------------------- /gln/mods/mol_gnn/gnn_family/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | 6 | from gln.mods.mol_gnn.gnn_family.mean_field import EmbedMeanField 7 | from gln.mods.mol_gnn.gnn_family.s2v import S2vMeanFieldV2 8 | from gln.mods.mol_gnn.gnn_family.mpnn import MPNN 9 | from gln.mods.mol_gnn.gnn_family.ggnn import GGNN 10 | from gln.mods.mol_gnn.gnn_family.morganfp import MorganFp 11 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/gnn_family/ggnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from gln.mods.mol_gnn.gnn_family.utils import GNNEmbedding, prepare_gnn, get_agg 10 | from torch_geometric.nn import NNConv, Set2Set 11 | from gln.mods.mol_gnn.torch_util import MLP, NONLINEARITIES 12 | from torch_geometric.nn.conv import MessagePassing 13 | 14 | 15 | class GGNNConv(MessagePassing): 16 | def __init__(self, node_state_dim, num_edge_feats, edge_hidden_sizes, 17 | aggr='add', act_func='elu'): 18 | if aggr == 'sum': 19 | aggr = 'add' 20 | super(GGNNConv, self).__init__(aggr=aggr) 21 | 22 | self.node_state_dim = node_state_dim 23 | self.num_edge_feats = num_edge_feats 24 | self.edge_hidden_sizes = edge_hidden_sizes 25 | self.message_net = MLP(input_dim=self.node_state_dim * 2 + num_edge_feats, 26 | hidden_dims=self.edge_hidden_sizes, 27 | nonlinearity=act_func, 28 | act_last=act_func) 29 | self.cell = nn.GRUCell(self.node_state_dim, 30 | self.node_state_dim) 31 | 32 | def forward(self, x, edge_index, edge_features): 33 | prop_out = self.propagate(edge_index, x=x, edge_features=edge_features) 34 | new_states = self.cell(prop_out, x) 35 | return new_states 36 | 37 | def message(self, x_i, x_j, edge_features): 38 | if edge_features is None: 39 | edge_inputs = torch.cat((x_i, x_j), dim=-1) 40 | else: 41 | edge_inputs = torch.cat((x_i, x_j, edge_features), dim=-1) 42 | return self.message_net(edge_inputs) 43 | 44 | 45 | class GGNN(GNNEmbedding): 46 | def __init__(self, node_state_dim, output_dims, edge_hidden_sizes, 47 | num_node_feats, num_edge_feats, max_lv=3, msg_aggregate_type='sum', 48 | readout_agg='sum', share_params=False, act_func='elu', out_method='last', dropout=None): 49 | if output_dims is None: 50 | embed_dim = node_state_dim 51 | else: 52 | if isinstance(output_dims, str): 53 | embed_dim = int(output_dims.split('-')[-1]) 54 | else: 55 | embed_dim = output_dims[-1] 56 | super(GGNN, self).__init__(embed_dim, dropout) 57 | self.out_method = out_method 58 | self.node_state_dim = node_state_dim 59 | if isinstance(edge_hidden_sizes, str): 60 | edge_hidden_sizes += '-%d' % node_state_dim 61 | else: 62 | edge_hidden_sizes += [node_state_dim] 63 | lm_layer = lambda: GGNNConv(node_state_dim, num_edge_feats, edge_hidden_sizes, 64 | aggr=msg_aggregate_type, act_func=act_func) 65 | if share_params: 66 | self.ggnn_layer = lm_layer() 67 | self.layers = [lambda x: self.ggnn_layer(x)] * max_lv 68 | else: 69 | self.layers = [lm_layer() for _ in range(max_lv)] 70 | self.layers = nn.ModuleList(self.layers) 71 | self.max_lv = max_lv 72 | self.node2hidden = nn.Linear(num_node_feats, node_state_dim) 73 | self.readout_agg = get_agg(readout_agg) 74 | 75 | self.readout_funcs = [] 76 | if output_dims is None: 77 | for i in range(self.max_lv + 1): 78 | self.readout_funcs.append(lambda x: x) 79 | else: 80 | for i in range(self.max_lv + 1): 81 | mlp = MLP(input_dim=node_state_dim, 82 | hidden_dims=output_dims, 83 | nonlinearity=act_func, 84 | act_last=act_func) 85 | self.readout_funcs.append(mlp) 86 | if self.out_method == 'last': 87 | break 88 | self.readout_funcs = nn.ModuleList(self.readout_funcs) 89 | if self.out_method == 'gru': 90 | self.final_cell = nn.GRUCell(self.embed_dim, self.embed_dim) 91 | 92 | def get_feat(self, graph_list): 93 | node_feat, edge_feat, edge_from_idx, edge_to_idx, g_idx = prepare_gnn(graph_list, self.is_cuda()) 94 | edge_index = [edge_from_idx, edge_to_idx] 95 | edge_index = torch.stack(edge_index) 96 | node_states = self.node2hidden(node_feat) 97 | init_embed = self.readout_funcs[-1](node_states) 98 | outs = self.readout_agg(init_embed, g_idx, dim=0, dim_size=len(graph_list)) 99 | for i in range(self.max_lv): 100 | layer = self.layers[i] 101 | new_states = layer(node_states, edge_index, edge_feat) 102 | node_states = new_states 103 | 104 | if self.out_method == 'last': 105 | continue 106 | 107 | out_states = self.readout_funcs[i](node_states) 108 | 109 | graph_embed = self.readout_agg(out_states, g_idx, 110 | dim=0, dim_size=len(graph_list)) 111 | if self.out_method == 'gru': 112 | outs = self.final_cell(graph_embed, outs) 113 | else: 114 | outs += graph_embed 115 | 116 | if self.out_method == 'last': 117 | out_states = self.readout_funcs[0](node_states) 118 | 119 | graph_embed = self.readout_agg(out_states, g_idx, 120 | dim=0, dim_size=len(graph_list)) 121 | return graph_embed, (g_idx, out_states) 122 | else: 123 | if self.out_method == 'mean': 124 | outs /= self.max_lv + 1 125 | return outs, None 126 | 127 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/gnn_family/mean_field.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from gln.mods.mol_gnn.gnn_family.utils import GNNEmbedding, prepare_gnn, get_agg, ReadoutNet 10 | from torch_geometric.nn.conv import MessagePassing 11 | from torch_geometric.nn import NNConv, Set2Set 12 | from gln.mods.mol_gnn.torch_util import MLP, NONLINEARITIES 13 | from torch_scatter import scatter_add 14 | 15 | 16 | class _MeanFieldLayer(MessagePassing): 17 | def __init__(self, latent_dim): 18 | super(_MeanFieldLayer, self).__init__() 19 | self.conv_params = nn.Linear(latent_dim, latent_dim) 20 | 21 | def forward(self, x, edge_index): 22 | x = x.unsqueeze(-1) if x.dim() == 1 else x 23 | return self.propagate(edge_index, x=x) 24 | 25 | def update(self, aggr_out): 26 | return self.conv_params(aggr_out) 27 | 28 | 29 | class EmbedMeanField(GNNEmbedding): 30 | def __init__(self, latent_dim, output_dim, num_node_feats, num_edge_feats, max_lv=3, act_func='tanh', readout_agg='sum', share_params=True, act_last=True, dropout=None): 31 | if output_dim > 0: 32 | embed_dim = output_dim 33 | else: 34 | embed_dim = latent_dim 35 | super(EmbedMeanField, self).__init__(embed_dim, dropout) 36 | self.latent_dim = latent_dim 37 | self.output_dim = output_dim 38 | self.num_node_feats = num_node_feats 39 | self.num_edge_feats = num_edge_feats 40 | 41 | self.max_lv = max_lv 42 | self.act_func = NONLINEARITIES[act_func] 43 | self.w_n2l = nn.Linear(num_node_feats, latent_dim) 44 | if num_edge_feats > 0: 45 | self.w_e2l = nn.Linear(num_edge_feats, latent_dim) 46 | 47 | lm_layer = lambda: _MeanFieldLayer(latent_dim) 48 | if share_params: 49 | self.conv_layer = lm_layer() 50 | self.conv_layers = [lambda x, y: self.conv_layer(x, y) for _ in range(max_lv)] 51 | else: 52 | conv_layers = [lm_layer() for _ in range(max_lv)] 53 | self.conv_layers = nn.ModuleList(conv_layers) 54 | 55 | self.readout_net = ReadoutNet(node_state_dim=latent_dim, 56 | output_dim=output_dim, 57 | max_lv=max_lv, 58 | act_func=act_func, 59 | out_method='last', 60 | readout_agg=readout_agg, 61 | act_last=act_last, 62 | bn=False) 63 | 64 | def get_feat(self, graph_list): 65 | node_feat, edge_feat, edge_from_idx, edge_to_idx, g_idx = prepare_gnn(graph_list, self.is_cuda()) 66 | input_node_linear = self.w_n2l(node_feat) 67 | input_message = input_node_linear 68 | if edge_feat is not None: 69 | input_edge_linear = self.w_e2l(edge_feat) 70 | e2npool_input = scatter_add(input_edge_linear, edge_to_idx, dim=0, dim_size=node_feat.shape[0]) 71 | input_message += e2npool_input 72 | input_potential = self.act_func(input_message) 73 | 74 | cur_message_layer = input_potential 75 | all_embeds = [cur_message_layer] 76 | edge_index = [edge_from_idx, edge_to_idx] 77 | edge_index = torch.stack(edge_index) 78 | for lv in range(self.max_lv): 79 | node_linear = self.conv_layers[lv](cur_message_layer, edge_index) 80 | merged_linear = node_linear + input_message 81 | cur_message_layer = self.act_func(merged_linear) 82 | all_embeds.append(cur_message_layer) 83 | 84 | return self.readout_net(all_embeds, g_idx, len(graph_list)) 85 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/gnn_family/morganfp.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from gln.mods.mol_gnn.gnn_family.utils import GNNEmbedding 10 | from gln.mods.mol_gnn.torch_util import MLP, NONLINEARITIES 11 | from torch_scatter import scatter_add 12 | from torch_sparse import spmm 13 | 14 | 15 | class MorganFp(GNNEmbedding): 16 | def __init__(self, feat_dim, hidden_size, num_hidden, feat_mode='dense', act_func='elu', dropout=0): 17 | super(MorganFp, self).__init__(hidden_size, dropout) 18 | self.feat_mode = feat_mode 19 | self.feat_dim = feat_dim 20 | if self.feat_mode == 'dense': 21 | self.mlp = MLP(input_dim=feat_dim, 22 | hidden_dims=[hidden_size] * num_hidden, 23 | nonlinearity=act_func, 24 | dropout=dropout, 25 | act_last=act_func) 26 | else: 27 | self.input_linear = nn.Linear(feat_dim, hidden_size) 28 | if num_hidden > 1: 29 | self.mlp = MLP(input_dim=hidden_size, 30 | hidden_dims=[hidden_size] * (num_hidden - 1), 31 | nonlinearity=act_func, 32 | dropout=dropout, 33 | act_last=act_func) 34 | else: 35 | self.mlp = lambda x: x 36 | 37 | def get_fp(self, graph_list): 38 | feat_indices = [] 39 | row_indices = [] 40 | for i, mol in enumerate(graph_list): 41 | feat = [t % self.feat_dim for t in mol.fingerprints] 42 | row_indices += [i] * len(feat) 43 | feat_indices += feat 44 | assert len(row_indices) == len(feat_indices) 45 | sp_indices = torch.LongTensor([row_indices, feat_indices]) 46 | vals = torch.ones(len(row_indices), dtype=torch.float32) 47 | 48 | if self.is_cuda(): 49 | sp_indices = sp_indices.cuda() 50 | vals = vals.cuda() 51 | 52 | if self.feat_mode == 'dense': 53 | sp_feat = torch.sparse.FloatTensor(sp_indices, vals, torch.Size([len(graph_list), self.feat_dim])) 54 | sp_feat = sp_feat.to_dense() 55 | return sp_feat 56 | else: 57 | return sp_indices, vals 58 | 59 | def get_feat(self, graph_list): 60 | if self.feat_mode == 'dense': 61 | dense_feat = self.get_fp(graph_list) 62 | else: 63 | sp_indices, vals = self.get_fp(graph_list) 64 | w = self.input_linear.weight 65 | b = self.input_linear.bias 66 | dense_feat = spmm(sp_indices, vals, len(graph_list), w.transpose(0, 1)) + b 67 | 68 | return self.mlp(dense_feat), None 69 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/gnn_family/mpnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from gln.mods.mol_gnn.gnn_family.utils import GNNEmbedding, prepare_gnn 10 | from torch_geometric.nn import NNConv, Set2Set 11 | from gln.mods.mol_gnn.torch_util import MLP, NONLINEARITIES 12 | 13 | 14 | class MPNN(GNNEmbedding): 15 | def __init__(self, latent_dim, output_dim, num_node_feats, num_edge_feats, max_lv=3, 16 | act_func='elu', msg_aggregate_type='mean', dropout=None): 17 | if output_dim > 0: 18 | embed_dim = output_dim 19 | else: 20 | embed_dim = latent_dim 21 | super(MPNN, self).__init__(embed_dim, dropout) 22 | if msg_aggregate_type == 'sum': 23 | msg_aggregate_type = 'add' 24 | self.max_lv = max_lv 25 | self.readout = nn.Linear(2 * latent_dim, self.embed_dim) 26 | self.lin0 = torch.nn.Linear(num_node_feats, latent_dim) 27 | net = MLP(input_dim=num_edge_feats, 28 | hidden_dims=[128, latent_dim * latent_dim], 29 | nonlinearity=act_func) 30 | self.conv = NNConv(latent_dim, latent_dim, net, aggr=msg_aggregate_type, root_weight=False) 31 | 32 | self.act_func = NONLINEARITIES[act_func] 33 | self.gru = nn.GRU(latent_dim, latent_dim) 34 | self.set2set = Set2Set(latent_dim, processing_steps=3) 35 | 36 | def get_feat(self, graph_list): 37 | node_feat, edge_feat, edge_from_idx, edge_to_idx, g_idx = prepare_gnn(graph_list, self.is_cuda()) 38 | out = self.act_func(self.lin0(node_feat)) 39 | h = out.unsqueeze(0) 40 | edge_index = [edge_from_idx, edge_to_idx] 41 | edge_index = torch.stack(edge_index) 42 | for lv in range(self.max_lv): 43 | m = self.act_func(self.conv(out, edge_index, edge_feat)) 44 | out, h = self.gru(m.unsqueeze(0), h) 45 | out = out.squeeze(0) 46 | out = self.set2set(out, g_idx) 47 | out = self.readout(out) 48 | 49 | return out, None 50 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/gnn_family/s2v.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from gln.mods.mol_gnn.gnn_family.utils import GNNEmbedding, prepare_gnn, get_agg, ReadoutNet 10 | from torch_geometric.nn.conv import MessagePassing 11 | from torch_geometric.nn import NNConv, Set2Set 12 | from gln.mods.mol_gnn.torch_util import MLP, NONLINEARITIES 13 | from torch_scatter import scatter_add 14 | 15 | 16 | class _MeanFieldLayer(MessagePassing): 17 | def __init__(self, latent_dim): 18 | super(_MeanFieldLayer, self).__init__() 19 | self.conv_params = nn.Linear(latent_dim, latent_dim) 20 | 21 | def forward(self, x, edge_index): 22 | x = x.unsqueeze(-1) if x.dim() == 1 else x 23 | return self.propagate(edge_index, x=x) 24 | 25 | def update(self, aggr_out): 26 | return self.conv_params(aggr_out) 27 | 28 | 29 | class S2vMeanFieldV2(GNNEmbedding): 30 | def __init__(self, latent_dim, output_dim, num_node_feats, num_edge_feats, max_lv=3, act_func='relu', readout_agg='sum', act_last=True, out_method='last', bn=True, dropout=None): 31 | if output_dim > 0: 32 | embed_dim = output_dim 33 | else: 34 | embed_dim = latent_dim 35 | super(S2vMeanFieldV2, self).__init__(embed_dim, dropout) 36 | self.latent_dim = latent_dim 37 | self.output_dim = output_dim 38 | self.num_node_feats = num_node_feats 39 | self.num_edge_feats = num_edge_feats 40 | self.bn = bn 41 | self.max_lv = max_lv 42 | self.act_func = NONLINEARITIES[act_func] 43 | self.w_n2l = nn.Linear(num_node_feats, latent_dim) 44 | if num_edge_feats > 0: 45 | self.w_e2l = [nn.Linear(num_edge_feats, latent_dim) for _ in range(self.max_lv + 1)] 46 | self.w_e2l = nn.ModuleList(self.w_e2l) 47 | lm_layer = lambda: _MeanFieldLayer(latent_dim) 48 | conv_layers = [lm_layer() for _ in range(max_lv)] 49 | self.conv_layers = nn.ModuleList(conv_layers) 50 | self.conv_l2 = [nn.Linear(latent_dim, latent_dim) for _ in range(self.max_lv)] 51 | self.conv_l2 = nn.ModuleList(self.conv_l2) 52 | 53 | self.readout_net = ReadoutNet(node_state_dim=latent_dim, 54 | output_dim=output_dim, 55 | max_lv=max_lv, 56 | act_func=act_func, 57 | out_method='last', 58 | readout_agg=readout_agg, 59 | act_last=act_last, 60 | bn=bn) 61 | if self.bn: 62 | msg_bn = [nn.BatchNorm1d(latent_dim) for _ in range(self.max_lv + 1)] 63 | hidden_bn = [nn.BatchNorm1d(latent_dim) for _ in range(self.max_lv)] 64 | self.msg_bn = nn.ModuleList(msg_bn) 65 | self.hidden_bn = nn.ModuleList(hidden_bn) 66 | else: 67 | self.msg_bn = [lambda x: x for _ in range(self.max_lv + 1)] 68 | self.hidden_bn = [lambda x: x for _ in range(self.max_lv)] 69 | 70 | def get_feat(self, graph_list): 71 | node_feat, edge_feat, edge_from_idx, edge_to_idx, g_idx = prepare_gnn(graph_list, self.is_cuda()) 72 | input_node_linear = self.w_n2l(node_feat) 73 | input_message = input_node_linear 74 | if edge_feat is not None: 75 | input_edge_linear = self.w_e2l[0](edge_feat) 76 | e2npool_input = scatter_add(input_edge_linear, edge_to_idx, dim=0, dim_size=node_feat.shape[0]) 77 | input_message += e2npool_input 78 | input_potential = self.act_func(input_message) 79 | input_potential = self.msg_bn[0](input_potential) 80 | 81 | cur_message_layer = input_potential 82 | all_embeds = [cur_message_layer] 83 | edge_index = [edge_from_idx, edge_to_idx] 84 | edge_index = torch.stack(edge_index) 85 | for lv in range(self.max_lv): 86 | node_linear = self.conv_layers[lv](cur_message_layer, edge_index) 87 | edge_linear = self.w_e2l[lv + 1](edge_feat) 88 | e2npool_input = scatter_add(edge_linear, edge_to_idx, dim=0, dim_size=node_linear.shape[0]) 89 | merged_hidden = self.act_func(node_linear + e2npool_input) 90 | merged_hidden = self.hidden_bn[lv](merged_hidden) 91 | residual_out = self.conv_l2[lv](merged_hidden) + cur_message_layer 92 | cur_message_layer = self.act_func(residual_out) 93 | cur_message_layer = self.msg_bn[lv + 1](cur_message_layer) 94 | all_embeds.append(cur_message_layer) 95 | return self.readout_net(all_embeds, g_idx, len(graph_list)) 96 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/gnn_family/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import numpy as np 6 | import torch 7 | import random 8 | from functools import partial 9 | from torch.autograd import Variable 10 | from torch.nn.parameter import Parameter 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | from tqdm import tqdm 15 | from torch_geometric.nn.conv import MessagePassing 16 | 17 | from torch_scatter import scatter_add, scatter_mean 18 | from torch_scatter import scatter_max as orig_smax 19 | from torch_scatter import scatter_min as orig_smin 20 | 21 | from gln.mods.mol_gnn.mg_clib.mg_lib import MGLIB 22 | from gln.mods.mol_gnn.torch_util import MLP, NONLINEARITIES 23 | 24 | 25 | class GNNEmbedding(nn.Module): 26 | def __init__(self, embed_dim, dropout=None): 27 | super(GNNEmbedding, self).__init__() 28 | self.embed_dim = embed_dim 29 | if dropout is not None and dropout > 0: 30 | self.dropout = nn.Dropout(p=dropout) 31 | else: 32 | self.dropout = lambda x: x 33 | 34 | def is_cuda(self): 35 | return next(self.parameters()).is_cuda 36 | 37 | def forward(self, graph_list): 38 | selected = [] 39 | sublist = [] 40 | for i, g in enumerate(graph_list): 41 | if g is not None: 42 | selected.append(i) 43 | sublist.append(g) 44 | if len(sublist): 45 | embed, nodes_info = self.get_feat(sublist) 46 | embed = self.dropout(embed) 47 | if nodes_info is not None: 48 | g_idx, node_embed = nodes_info 49 | node_embed = self.dropout(embed) 50 | nodes_info = (g_idx, node_embed) 51 | else: 52 | embed = None 53 | nodes_info = None 54 | if len(sublist) == len(graph_list): 55 | return embed, nodes_info 56 | 57 | full_embed = torch.zeros(len(graph_list), self.embed_dim, dtype=torch.float32) 58 | if self.is_cuda(): 59 | full_embed = full_embed.cuda() 60 | if embed is not None: 61 | full_embed[selected] = embed 62 | return full_embed, None 63 | 64 | def get_feat(self, graph_list): 65 | raise NotImplementedError 66 | 67 | 68 | class ReadoutNet(nn.Module): 69 | def __init__(self, node_state_dim, output_dim, max_lv, act_func, out_method, readout_agg, act_last, bn): 70 | super(ReadoutNet, self).__init__() 71 | 72 | self.out_method = out_method 73 | self.max_lv = max_lv 74 | self.readout_agg = get_agg(readout_agg) 75 | self.act_last = act_last 76 | self.act_func = NONLINEARITIES[act_func] 77 | self.readout_funcs = [] 78 | self.bn = bn 79 | if output_dim is None: 80 | self.embed_dim = node_state_dim 81 | for i in range(self.max_lv + 1): 82 | self.readout_funcs.append(lambda x: x) 83 | else: 84 | self.embed_dim = output_dim 85 | for i in range(self.max_lv + 1): 86 | self.readout_funcs.append(nn.Linear(node_state_dim, output_dim)) 87 | if self.out_method == 'last': 88 | break 89 | self.readout_funcs = nn.ModuleList(self.readout_funcs) 90 | 91 | if self.out_method == 'gru': 92 | self.final_cell = nn.GRUCell(self.embed_dim, self.embed_dim) 93 | if self.bn: 94 | out_bn = [nn.BatchNorm1d(self.embed_dim) for _ in range(self.max_lv + 1)] 95 | self.out_bn = nn.ModuleList(out_bn) 96 | 97 | def forward(self, list_node_states, g_idx, num_graphs): 98 | assert len(list_node_states) == self.max_lv + 1 99 | if self.out_method == 'last': 100 | out_states = self.readout_funcs[0](list_node_states[-1]) 101 | if self.act_last: 102 | out_states = self.act_func(out_states) 103 | graph_embed = self.readout_agg(out_states, g_idx, dim=0, dim_size=num_graphs) 104 | return graph_embed, (g_idx, out_states) 105 | 106 | list_node_embed = [self.readout_funcs[i](list_node_states[i]) for i in range(self.max_lv + 1)] 107 | if self.act_last: 108 | list_node_embed = [self.act_func(e) for e in list_node_embed] 109 | if self.bn: 110 | list_node_embed = [self.out_bn[i](e) for i, e in enumerate(list_node_embed)] 111 | list_graph_embed = [self.readout_agg(e, g_idx, dim=0, dim_size=num_graphs) for e in list_node_embed] 112 | out_embed = list_graph_embed[0] 113 | 114 | for i in range(1, self.max_lv + 1): 115 | if self.out_method == 'gru': 116 | out_embed = self.final_cell(list_graph_embed[i], out_embed) 117 | elif self.out_method == 'sum' or self.out_method == 'mean': 118 | out_embed += list_graph_embed[i] 119 | else: 120 | raise NotImplementedError 121 | 122 | if self.out_method == 'mean': 123 | out_embed /= self.max_lv + 1 124 | 125 | return out_embed, (None, None) 126 | 127 | 128 | def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=0): 129 | return orig_smax(src, index, dim, out, dim_size, fill_value)[0] 130 | 131 | 132 | def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=0): 133 | return orig_smin(src, index, dim, out, dim_size, fill_value)[0] 134 | 135 | 136 | def get_agg(agg_type): 137 | if agg_type == 'sum': 138 | return scatter_add 139 | elif agg_type == 'mean': 140 | return scatter_mean 141 | elif agg_type == 'max': 142 | return scatter_max 143 | elif agg_type == 'min': 144 | return scatter_min 145 | else: 146 | raise NotImplementedError 147 | 148 | 149 | def prepare_gnn(graph_list, is_cuda): 150 | node_feat, edge_feat = MGLIB.PrepareBatchFeature(graph_list) 151 | if is_cuda: 152 | node_feat = node_feat.cuda() 153 | edge_feat = edge_feat.cuda() 154 | edge_to_idx, edge_from_idx, g_idx = MGLIB.PrepareIndices(graph_list) 155 | if is_cuda: 156 | edge_to_idx = edge_to_idx.cuda() 157 | edge_from_idx = edge_from_idx.cuda() 158 | g_idx = g_idx.cuda() 159 | return node_feat, edge_feat, edge_from_idx, edge_to_idx, g_idx 160 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/mg_clib/Makefile: -------------------------------------------------------------------------------- 1 | dir_guard = @mkdir -p $(@D) 2 | FIND := find 3 | CXX := g++ 4 | 5 | CXXFLAGS += -Wall -O3 -std=c++11 6 | LDFLAGS += -lm 7 | 8 | include_dirs = ./include 9 | 10 | CXXFLAGS += $(addprefix -I,$(include_dirs)) -Wno-unused-local-typedef 11 | CXXFLAGS += -fPIC 12 | cpp_files = $(shell $(FIND) src/lib -name "*.cpp" -print | rev | cut -d"/" -f1 | rev) 13 | cxx_obj_files = $(subst .cpp,.o,$(cpp_files)) 14 | 15 | objs = $(addprefix build/lib/,$(cxx_obj_files)) 16 | DEPS = $(objs:.o=.d) 17 | 18 | target = build/dll/libmolgnn.so 19 | target_dep = $(addsuffix .d,$(target)) 20 | 21 | .PRECIOUS: build/lib/%.o 22 | 23 | all: $(target) 24 | 25 | build/dll/libmolgnn.so : src/mg_clib.cpp $(objs) 26 | $(dir_guard) 27 | $(CXX) -shared $(CXXFLAGS) -MMD -o $@ $(filter %.cpp %.o, $^) $(LDFLAGS) 28 | 29 | DEPS += $(target_dep) 30 | 31 | build/lib/%.o: src/lib/%.cpp 32 | $(dir_guard) 33 | $(CXX) $(CXXFLAGS) -MMD -c -o $@ $(filter %.cpp, $^) 34 | 35 | clean: 36 | rm -rf build 37 | 38 | -include $(DEPS) 39 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/mg_clib/__init__.py: -------------------------------------------------------------------------------- 1 | from .mg_lib import MGLIB 2 | 3 | NUM_NODE_FEATS = MGLIB.NUM_NODE_FEATS 4 | NUM_EDGE_FEATS = MGLIB.NUM_EDGE_FEATS 5 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/mg_clib/default_atoms.txt: -------------------------------------------------------------------------------- 1 | 6 2 | 7 3 | 8 4 | 16 5 | 9 6 | 14 7 | 15 8 | 17 9 | 35 10 | 12 11 | 11 12 | 20 13 | 26 14 | 33 15 | 13 16 | 53 17 | 5 18 | 23 19 | 19 20 | 81 21 | 70 22 | 51 23 | 50 24 | 47 25 | 85 26 | 27 27 | 34 28 | 22 29 | 30 30 | 1 31 | 3 32 | 32 33 | 29 34 | 79 35 | 28 36 | 48 37 | 49 38 | 25 39 | 40 40 | 64 41 | 78 42 | 80 43 | 82 -------------------------------------------------------------------------------- /gln/mods/mol_gnn/mg_clib/include/config.h: -------------------------------------------------------------------------------- 1 | #ifndef cfg_H 2 | #define cfg_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | typedef float Dtype; 11 | 12 | struct cfg 13 | { 14 | static int num_atom_types; 15 | static int nodefeat_dim; 16 | static int edgefeat_dim; 17 | 18 | static void LoadParams(const int argc, const char** argv) 19 | { 20 | for (int i = 1; i < argc; i += 2) 21 | { 22 | if (strcmp(argv[i], "-num_atom_types") == 0) 23 | num_atom_types = atoi(argv[i + 1]); 24 | if (strcmp(argv[i], "-nodefeat_dim") == 0) 25 | nodefeat_dim = atoi(argv[i + 1]); 26 | if (strcmp(argv[i], "-edgefeat_dim") == 0) 27 | edgefeat_dim = atoi(argv[i + 1]); 28 | } 29 | std::cerr << "====== begin of gnn_clib configuration ======" << std::endl; 30 | std::cerr << "| num_atom_types = " << num_atom_types << std::endl; 31 | std::cerr << "| nodefeat_dim = " << nodefeat_dim << std::endl; 32 | std::cerr << "| edgefeat_dim = " << edgefeat_dim << std::endl; 33 | std::cerr << "====== end of gnn_clib configuration ======" << std::endl; 34 | } 35 | }; 36 | 37 | #endif -------------------------------------------------------------------------------- /gln/mods/mol_gnn/mg_clib/include/mg_clib.h: -------------------------------------------------------------------------------- 1 | #ifndef MG_CLIB_H 2 | #define MG_CLIB_H 3 | 4 | #include "config.h" 5 | 6 | extern "C" int Init(const int argc, const char **argv); 7 | 8 | extern "C" int PrepareBatchFeature(const int num_graphs, 9 | const int *num_nodes, 10 | const int *num_edges, 11 | void** list_node_feats, 12 | void** list_edge_feats, 13 | Dtype* node_input, 14 | Dtype* edge_input); 15 | 16 | extern "C" int PrepareIndices(const int num_graphs, 17 | const int *num_nodes, 18 | const int *num_edges, 19 | void **list_of_edge_pairs, 20 | long long* edge_to_idx, 21 | long long* edge_from_idx, 22 | long long* g_idx); 23 | 24 | #endif -------------------------------------------------------------------------------- /gln/mods/mol_gnn/mg_clib/include/mol_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef MOL_UTILS_H 2 | #define MOL_UTILS_H 3 | 4 | #include 5 | #include 6 | 7 | #include "config.h" 8 | 9 | struct MolFeat 10 | { 11 | static void InitIdxMap(); 12 | 13 | static void ParseAtomFeat(Dtype* arr, int feat); 14 | 15 | static void ParseEdgeFeat(Dtype* arr, int feat); 16 | 17 | static std::map atom_idx_map; 18 | }; 19 | 20 | 21 | #endif -------------------------------------------------------------------------------- /gln/mods/mol_gnn/mg_clib/mg_lib.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import numpy as np 3 | import os 4 | import sys 5 | try: 6 | import torch 7 | except: 8 | print('no torch loaded') 9 | 10 | class _mg_lib(object): 11 | 12 | def __init__(self, sys_args): 13 | dir_path = os.path.dirname(os.path.realpath(__file__)) 14 | self.lib = ctypes.CDLL('%s/build/dll/libmolgnn.so' % dir_path) 15 | 16 | atom_file = '%s/default_atoms.txt' % dir_path 17 | for i in range(len(sys_args)): 18 | if sys_args[i] == '-f_atoms': 19 | atom_file = sys.argv[i + 1] 20 | atom_nums = [] 21 | with open(atom_file, 'r') as f: 22 | for row in f: 23 | atom_nums.append(int(row.strip())) 24 | 25 | self.lib.PrepareIndices.restype = ctypes.c_int 26 | self.lib.PrepareBatchFeature.restype = ctypes.c_int 27 | 28 | self.NUM_EDGE_FEATS = 7 29 | self.NUM_NODE_FEATS = len(atom_nums) + 23 30 | 31 | self.atom_idx_map = {} 32 | for i in range(len(atom_nums)): 33 | self.atom_idx_map[atom_nums[i]] = i 34 | 35 | args = 'this -num_atom_types %d -nodefeat_dim %d -edgefeat_dim %d' % (len(atom_nums), self.NUM_NODE_FEATS, self.NUM_EDGE_FEATS) 36 | args = args.split() 37 | if sys.version_info[0] > 2: 38 | args = [arg.encode() for arg in args] # str -> bytes for each element in args 39 | 40 | arr = (ctypes.c_char_p * len(args))() 41 | arr[:] = args 42 | self.lib.Init(len(args), arr) 43 | 44 | def PrepareIndices(self, graph_list): 45 | edgepair_list = (ctypes.c_void_p * len(graph_list))() 46 | list_num_nodes = np.zeros((len(graph_list), ), dtype=np.int32) 47 | list_num_edges = np.zeros((len(graph_list), ), dtype=np.int32) 48 | for i in range(len(graph_list)): 49 | if type(graph_list[i].edge_pairs) is ctypes.c_void_p: 50 | edgepair_list[i] = graph_list[i].edge_pairs 51 | elif type(graph_list[i].edge_pairs) is np.ndarray: 52 | edgepair_list[i] = ctypes.c_void_p(graph_list[i].edge_pairs.ctypes.data) 53 | else: 54 | raise NotImplementedError 55 | 56 | list_num_nodes[i] = graph_list[i].num_nodes 57 | list_num_edges[i] = graph_list[i].num_edges 58 | total_num_nodes = np.sum(list_num_nodes) 59 | total_num_edges = np.sum(list_num_edges) 60 | 61 | edge_to_idx = torch.LongTensor(total_num_edges * 2) 62 | edge_from_idx = torch.LongTensor(total_num_edges * 2) 63 | g_idx = torch.LongTensor(total_num_nodes) 64 | self.lib.PrepareIndices(len(graph_list), 65 | ctypes.c_void_p(list_num_nodes.ctypes.data), 66 | ctypes.c_void_p(list_num_edges.ctypes.data), 67 | ctypes.cast(edgepair_list, ctypes.c_void_p), 68 | ctypes.c_void_p(edge_to_idx.numpy().ctypes.data), 69 | ctypes.c_void_p(edge_from_idx.numpy().ctypes.data), 70 | ctypes.c_void_p(g_idx.numpy().ctypes.data)) 71 | return edge_to_idx, edge_from_idx, g_idx 72 | 73 | def PrepareBatchFeature(self, molgraph_list): 74 | n_graphs = len(molgraph_list) 75 | c_node_list = (ctypes.c_void_p * n_graphs)() 76 | c_edge_list = (ctypes.c_void_p * n_graphs)() 77 | list_num_nodes = np.zeros((n_graphs, ), dtype=np.int32) 78 | list_num_edges = np.zeros((n_graphs, ), dtype=np.int32) 79 | 80 | for i in range(n_graphs): 81 | mol = molgraph_list[i] 82 | c_node_list[i] = ctypes.c_void_p(mol.atom_feats.ctypes.data) 83 | c_edge_list[i] = ctypes.c_void_p(mol.bond_feats.ctypes.data) 84 | list_num_nodes[i] = mol.num_nodes 85 | list_num_edges[i] = mol.num_edges 86 | 87 | torch_node_feat = torch.zeros(np.sum(list_num_nodes), self.NUM_NODE_FEATS) 88 | torch_edge_feat = torch.zeros(np.sum(list_num_edges) * 2, self.NUM_EDGE_FEATS) 89 | 90 | node_feat = torch_node_feat.numpy() 91 | edge_feat = torch_edge_feat.numpy() 92 | 93 | self.lib.PrepareBatchFeature(n_graphs, 94 | ctypes.c_void_p(list_num_nodes.ctypes.data), 95 | ctypes.c_void_p(list_num_edges.ctypes.data), 96 | ctypes.cast(c_node_list, ctypes.c_void_p), 97 | ctypes.cast(c_edge_list, ctypes.c_void_p), 98 | ctypes.c_void_p(node_feat.ctypes.data), 99 | ctypes.c_void_p(edge_feat.ctypes.data)) 100 | 101 | return torch_node_feat, torch_edge_feat 102 | 103 | 104 | dll_path = '%s/build/dll/libmolgnn.so' % os.path.dirname(os.path.realpath(__file__)) 105 | if os.path.exists(dll_path): 106 | MGLIB = _mg_lib(sys.argv) 107 | else: 108 | MGLIB = None 109 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/mg_clib/src/lib/config.cpp: -------------------------------------------------------------------------------- 1 | #include "config.h" 2 | 3 | 4 | int cfg::num_atom_types = 0; 5 | int cfg::nodefeat_dim = 0; 6 | int cfg::edgefeat_dim = 0; 7 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/mg_clib/src/lib/mol_utils.cpp: -------------------------------------------------------------------------------- 1 | #include "mol_utils.h" 2 | #include "config.h" 3 | 4 | #include 5 | 6 | void MolFeat::ParseAtomFeat(Dtype* arr, int feat) 7 | { 8 | // atom_idx_map 9 | int t = feat & ((1 << 8) - 1); 10 | arr[t] = 1.0; 11 | feat >>= 8; 12 | int base_idx = cfg::num_atom_types + 1; 13 | 14 | // getDegree 15 | int mask = (1 << 4) - 1; 16 | t = feat & mask; 17 | arr[base_idx + t] = 1.0; 18 | feat >>= 4; 19 | base_idx += 8; 20 | 21 | // getTotalNumHs 22 | t = feat & mask; 23 | arr[base_idx + t] = 1.0; 24 | feat >>= 4; 25 | base_idx += 6; 26 | 27 | // getImplicitValence 28 | t = feat & mask; 29 | arr[base_idx + t] = 1.0; 30 | feat >>= 4; 31 | base_idx += 7; 32 | 33 | // getIsAromatic 34 | if (feat & mask) 35 | arr[base_idx] = 1.0; 36 | assert(base_idx + 1 == cfg::nodefeat_dim); 37 | } 38 | 39 | 40 | void MolFeat::ParseEdgeFeat(Dtype* arr, int feat) 41 | { 42 | int mask = (1 << 8) - 1; 43 | // getBondType 44 | arr[feat & mask] = 1.0; 45 | feat >>= 8; 46 | // getIsConjugated 47 | if (feat & mask) 48 | arr[4] = 1.0; 49 | feat >>= 8; 50 | // is ring 51 | int t = feat & mask; 52 | if (t == 2) 53 | arr[6] = 1.0; 54 | else if (feat & mask) 55 | arr[5] = 1.0; 56 | } 57 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/mg_clib/src/mg_clib.cpp: -------------------------------------------------------------------------------- 1 | #include "mg_clib.h" 2 | #include "config.h" 3 | #include "mol_utils.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int Init(const int argc, const char **argv) 10 | { 11 | cfg::LoadParams(argc, argv); 12 | return 0; 13 | } 14 | 15 | 16 | int PrepareBatchFeature(const int num_graphs, 17 | const int *num_nodes, 18 | const int *num_edges, 19 | void** list_node_feats, 20 | void** list_edge_feats, 21 | Dtype* node_input, 22 | Dtype* edge_input) 23 | { 24 | Dtype* ptr = node_input; 25 | for (int i = 0; i < num_graphs; ++i) 26 | { 27 | int* node_feats = static_cast(list_node_feats[i]); 28 | for (int j = 0; j < num_nodes[i]; ++j) 29 | { 30 | MolFeat::ParseAtomFeat(ptr, node_feats[j]); 31 | ptr += cfg::nodefeat_dim; 32 | } 33 | } 34 | 35 | ptr = edge_input; 36 | for (int i = 0; i < num_graphs; ++i) 37 | { 38 | int* edge_feats = static_cast(list_edge_feats[i]); 39 | for (int j = 0; j < num_edges[i] * 2; j += 2) 40 | { 41 | // two directions have the same feature 42 | MolFeat::ParseEdgeFeat(ptr, edge_feats[j / 2]); 43 | ptr += cfg::edgefeat_dim; 44 | MolFeat::ParseEdgeFeat(ptr, edge_feats[j / 2]); 45 | ptr += cfg::edgefeat_dim; 46 | } 47 | } 48 | 49 | return 0; 50 | } 51 | 52 | 53 | int PrepareIndices(const int num_graphs, 54 | const int *num_nodes, 55 | const int *num_edges, 56 | void **list_of_edge_pairs, 57 | long long* edge_to_idx, 58 | long long* edge_from_idx, 59 | long long* g_idx) 60 | { 61 | int offset = 0; 62 | int cur_edge = 0; 63 | for (int i = 0; i < num_graphs; ++i) 64 | { 65 | int *edge_pairs = static_cast(list_of_edge_pairs[i]); 66 | for (int j = 0; j < num_edges[i] * 2; j += 2) 67 | { 68 | int x = offset + edge_pairs[j]; 69 | int y = offset + edge_pairs[j + 1]; 70 | edge_to_idx[cur_edge] = y; 71 | edge_from_idx[cur_edge] = x; 72 | cur_edge += 1; 73 | edge_to_idx[cur_edge] = x; 74 | edge_from_idx[cur_edge] = y; 75 | cur_edge += 1; 76 | } 77 | for (int j = 0; j < num_nodes[i]; ++j) 78 | g_idx[offset + j] = i; 79 | offset += num_nodes[i]; 80 | } 81 | return 0; 82 | } 83 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/mol_utils.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import numpy as np 3 | import os 4 | import sys 5 | from tqdm import tqdm 6 | import rdkit 7 | from rdkit import Chem 8 | from rdkit.Chem import AllChem 9 | import struct 10 | import pickle as cp 11 | from gln.mods.mol_gnn.mg_clib.mg_lib import MGLIB 12 | 13 | 14 | def get_atom_feat(atom, sanitized): 15 | # getIsAromatic 16 | feat = int(atom.GetIsAromatic()) 17 | feat <<= 4 18 | # getImplicitValence 19 | v = atom.GetImplicitValence() 20 | if 0 <= v <= 5: 21 | feat |= v 22 | else: 23 | feat |= 6 24 | feat <<= 4 25 | # getTotalNumHs 26 | if sanitized: 27 | h = atom.GetTotalNumHs() 28 | if h <= 4: 29 | feat |= h 30 | else: 31 | feat |= 4 32 | else: 33 | feat |= 5 34 | feat <<= 4 35 | # getDegree 36 | feat |= atom.GetDegree() 37 | feat <<= 8 38 | x = atom.GetAtomicNum() 39 | if x in MGLIB.atom_idx_map: 40 | feat |= MGLIB.atom_idx_map[x] 41 | else: 42 | feat |= len(MGLIB.atom_idx_map) 43 | assert feat >= 0 44 | return feat 45 | 46 | 47 | def get_bond_feat(bond, sanitized): 48 | bt = bond.GetBondType() 49 | t = 0 50 | if bt == rdkit.Chem.rdchem.BondType.SINGLE: 51 | t = 0 52 | elif bt == rdkit.Chem.rdchem.BondType.DOUBLE: 53 | t = 1 54 | elif bt == rdkit.Chem.rdchem.BondType.TRIPLE: 55 | t = 2 56 | elif bt == rdkit.Chem.rdchem.BondType.AROMATIC: 57 | t = 3 58 | feat = 2 59 | if sanitized: 60 | feat = 1 if bond.GetOwningMol().GetRingInfo().NumBondRings(bond.GetIdx()) > 0 else 0 61 | feat <<= 8 62 | feat |= int(bond.GetIsConjugated()) 63 | feat <<= 8 64 | feat |= t 65 | assert feat >= 0 66 | return feat 67 | 68 | 69 | class MolGraph(object): 70 | 71 | def __init__(self, name, sanitized, *, mol=None, num_nodes=None, num_edges=None, atom_feats=None, bond_feats=None, edge_pairs=None): 72 | self.name = name 73 | self.sanitized = sanitized 74 | if num_nodes is None: 75 | assert mol is not None 76 | self.num_nodes = mol.GetNumAtoms() 77 | self.num_edges = mol.GetNumBonds() 78 | 79 | self.atom_feats = np.zeros((self.num_nodes, ), dtype=np.int32) 80 | for i, atom in enumerate(mol.GetAtoms()): 81 | self.atom_feats[i] = get_atom_feat(atom, self.sanitized) 82 | 83 | self.bond_feats = np.zeros((self.num_edges, ), dtype=np.int32) 84 | self.edge_pairs = np.zeros((self.num_edges * 2, ), dtype=np.int32) 85 | 86 | for i, bond in enumerate(mol.GetBonds()): 87 | self.bond_feats[i] = get_bond_feat(bond, self.sanitized) 88 | x = bond.GetBeginAtomIdx() 89 | y = bond.GetEndAtomIdx() 90 | self.edge_pairs[i * 2] = x 91 | self.edge_pairs[i * 2 + 1] = y 92 | else: 93 | self.num_nodes = num_nodes 94 | self.num_edges = num_edges 95 | 96 | self.atom_feats = np.array(atom_feats, dtype=np.int32) if atom_feats is not None else None 97 | self.bond_feats = np.array(bond_feats, dtype=np.int32) if bond_feats is not None else None 98 | self.edge_pairs = np.array(edge_pairs, dtype=np.int32) if edge_pairs is not None else None 99 | self.fingerprints = None 100 | self.fp_info = None 101 | 102 | 103 | class _MolHolder(object): 104 | 105 | def __init__(self, sanitized): 106 | self.sanitized = sanitized 107 | self.dict_molgraph = {} 108 | self.fp_degree = 0 109 | self.fp_info = False 110 | self.null_graphs = set() 111 | 112 | def set_fp_degree(self, degree, fp_info=False): 113 | self.fp_degree = degree 114 | self.fp_info = fp_info 115 | 116 | def _get_inv(self, m): 117 | if self.sanitized: 118 | return None 119 | feats = [] 120 | for a in m.GetAtoms(): 121 | f = (a.GetAtomicNum(), a.GetDegree(), a.GetFormalCharge()) 122 | f = ctypes.c_uint32(hash(f)).value 123 | feats.append(f) 124 | return feats 125 | 126 | def new_mol(self, name): 127 | if self.sanitized: 128 | mol = Chem.MolFromSmiles(name) 129 | else: 130 | mol = Chem.MolFromSmarts(name) 131 | if mol is None: 132 | return None 133 | else: 134 | mg = MolGraph(name, self.sanitized, mol=mol) 135 | if self.fp_degree > 0: 136 | bi = {} if self.fp_info else None 137 | feat = AllChem.GetMorganFingerprint(mol, self.fp_degree, bitInfo=bi, invariants=self._get_inv(mol)) 138 | on_bits = list(feat.GetNonzeroElements().keys()) 139 | mg.fingerprints = on_bits 140 | mg.fp_info = bi 141 | return mg 142 | 143 | def get_mol_graph(self, name): 144 | if name is None or len(name.strip()) == 0: 145 | return None 146 | if name in self.null_graphs: 147 | return None 148 | if not name in self.dict_molgraph: 149 | mg = self.new_mol(name) 150 | if mg is None: 151 | self.null_graphs.add(name) 152 | return None 153 | else: 154 | self.dict_molgraph[name] = mg 155 | return self.dict_molgraph[name] 156 | 157 | def clear(self): 158 | self.dict_molgraph = {} 159 | 160 | def save_dump(self, prefix): 161 | with open(prefix + '.names', 'w') as f: 162 | for key in self.dict_molgraph: 163 | f.write('%s\n' % key) 164 | 165 | with open(prefix + '.bin', 'wb') as f: 166 | n_graphs = len(self.dict_molgraph) 167 | # write total number of mols 168 | f.write(struct.pack('=i', n_graphs)) 169 | # save all the size info 170 | list_num_nodes = [None] * n_graphs 171 | list_num_edges = [None] * n_graphs 172 | for i, key in enumerate(self.dict_molgraph): 173 | mol = self.dict_molgraph[key] 174 | list_num_nodes[i] = mol.num_nodes 175 | list_num_edges[i] = mol.num_edges 176 | 177 | f.write(struct.pack('=%di' % n_graphs, *list_num_nodes)) 178 | f.write(struct.pack('=%di' % n_graphs, *list_num_edges)) 179 | 180 | for key in tqdm(self.dict_molgraph): 181 | mol = self.dict_molgraph[key] 182 | f.write(struct.pack('=%di' % mol.num_nodes, *(mol.atom_feats.tolist()))) 183 | f.write(struct.pack('=%di' % mol.num_edges, *(mol.bond_feats.tolist()))) 184 | f.write(struct.pack('=%di' % (mol.num_edges * 2), *(mol.edge_pairs.tolist()))) 185 | 186 | if self.fp_degree > 0: 187 | if self.fp_info: 188 | with open(prefix + '.fp%d_info' % self.fp_degree, 'wb') as f: 189 | for key in self.dict_molgraph: 190 | mol = self.dict_molgraph[key] 191 | assert mol.fp_info is not None 192 | cp.dump(mol.fp_info, f, cp.HIGHEST_PROTOCOL) 193 | else: 194 | with open(prefix + '.fp%d' % self.fp_degree, 'wb') as f: 195 | num_fps = [None] * n_graphs 196 | for i, key in enumerate(self.dict_molgraph): 197 | mol = self.dict_molgraph[key] 198 | num_fps[i] = len(mol.fingerprints) 199 | f.write(struct.pack('=%di' % n_graphs, *num_fps)) 200 | for i, key in enumerate(self.dict_molgraph): 201 | mol = self.dict_molgraph[key] 202 | fp = mol.fingerprints 203 | f.write(struct.pack('=%dI' % len(fp), *fp)) 204 | 205 | print('%d molecules saved' % n_graphs) 206 | print('total # nodes', np.sum(list_num_nodes)) 207 | print('total # edges', np.sum(list_num_edges)) 208 | 209 | def remove_dump(self, prefix): 210 | print('mol_holder unloading', prefix) 211 | names = [] 212 | with open(prefix + '.names', 'r') as f: 213 | for row in f: 214 | names.append(row.strip()) 215 | [self.dict_molgraph.pop(x, None) for x in names] 216 | 217 | def load_dump(self, prefix, additive=False, load_feats=True, load_fp=True): 218 | print('mol_holder loading', prefix) 219 | if not additive: 220 | self.dict_molgraph = {} 221 | names = [] 222 | with open(prefix + '.names', 'r') as f: 223 | for row in f: 224 | names.append(row.strip()) 225 | self.dict_molgraph[names[-1]] = MolGraph(names[-1], self.sanitized, num_nodes=-1, num_edges=-1) 226 | if load_feats: 227 | print('loading binary features') 228 | with open(prefix + '.bin', 'rb') as f: 229 | n_graphs = struct.unpack('=i', f.read(4))[0] 230 | assert n_graphs == len(names) 231 | list_num_nodes = struct.unpack('=%di' % n_graphs, f.read(4 * n_graphs)) 232 | list_num_edges = struct.unpack('=%di' % n_graphs, f.read(4 * n_graphs)) 233 | 234 | for i in tqdm(range(n_graphs)): 235 | mol = self.dict_molgraph[names[i]] 236 | mol.num_nodes = n = list_num_nodes[i] 237 | mol.num_edges = m = list_num_edges[i] 238 | 239 | mol.atom_feats = np.array(struct.unpack('=%di' % n, f.read(4 * n)), dtype=np.int32) 240 | mol.bond_feats = np.array(struct.unpack('=%di' % m, f.read(4 * m)), dtype=np.int32) 241 | mol.edge_pairs = np.array(struct.unpack('=%di' % (2 * m), f.read(4 * 2 * m)), dtype=np.int32) 242 | 243 | if self.fp_degree > 0 and load_fp: 244 | print('loading fingerprints') 245 | if self.fp_info: 246 | with open(prefix + '.fp%d_info' % self.fp_degree, 'rb') as f: 247 | for name in tqdm(names): 248 | d = cp.load(f) 249 | mol = self.dict_molgraph[name] 250 | mol.fp_info = d 251 | mol.fingerprints = list(d.keys()) 252 | else: 253 | n_graphs = len(names) 254 | with open(prefix + '.fp%d' % self.fp_degree, 'rb') as f: 255 | num_fps = struct.unpack('=%di' % n_graphs, f.read(4 * n_graphs)) 256 | for i, key in tqdm(enumerate(names)): 257 | mol = self.dict_molgraph[key] 258 | mol.fingerprints = struct.unpack('=%dI' % num_fps[i], f.read(4 * num_fps[i])) 259 | 260 | print('done with fp loading') 261 | 262 | SmartsMols = _MolHolder(sanitized=False) 263 | SmilesMols = _MolHolder(sanitized=True) 264 | -------------------------------------------------------------------------------- /gln/mods/mol_gnn/torch_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | from torch.nn.parameter import Parameter 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | import numpy as np 13 | 14 | class Lambda(nn.Module): 15 | 16 | def __init__(self, f): 17 | super(Lambda, self).__init__() 18 | self.f = f 19 | 20 | def forward(self, x): 21 | return self.f(x) 22 | 23 | 24 | class Swish(nn.Module): 25 | 26 | def __init__(self): 27 | super(Swish, self).__init__() 28 | self.beta = nn.Parameter(torch.tensor(1.0)) 29 | 30 | def forward(self, x): 31 | return x * torch.sigmoid(self.beta * x) 32 | 33 | 34 | NONLINEARITIES = { 35 | "tanh": nn.Tanh(), 36 | "relu": nn.ReLU(), 37 | "softplus": nn.Softplus(), 38 | "sigmoid": nn.Sigmoid(), 39 | "elu": nn.ELU(), 40 | "swish": Swish(), 41 | "square": Lambda(lambda x: x**2), 42 | "identity": Lambda(lambda x: x), 43 | } 44 | 45 | 46 | class MLP(nn.Module): 47 | def __init__(self, input_dim, hidden_dims, nonlinearity='elu', act_last=None, bn=False, dropout=-1): 48 | super(MLP, self).__init__() 49 | self.act_last = act_last 50 | self.nonlinearity = nonlinearity 51 | self.input_dim = input_dim 52 | self.bn = bn 53 | 54 | if isinstance(hidden_dims, str): 55 | hidden_dims = list(map(int, hidden_dims.split("-"))) 56 | assert len(hidden_dims) 57 | hidden_dims = [input_dim] + hidden_dims 58 | self.output_size = hidden_dims[-1] 59 | 60 | list_layers = [] 61 | 62 | for i in range(1, len(hidden_dims)): 63 | list_layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i])) 64 | if i + 1 < len(hidden_dims): # not the last layer 65 | if self.bn: 66 | bnorm_layer = nn.BatchNorm1d(hidden_dims[i]) 67 | list_layers.append(bnorm_layer) 68 | list_layers.append(NONLINEARITIES[self.nonlinearity]) 69 | if dropout > 0: 70 | list_layers.append(nn.Dropout(dropout)) 71 | else: 72 | if act_last is not None: 73 | list_layers.append(NONLINEARITIES[act_last]) 74 | 75 | self.main = nn.Sequential(*list_layers) 76 | 77 | def forward(self, z): 78 | x = self.main(z) 79 | return x 80 | 81 | 82 | def _glorot_uniform(t): 83 | if len(t.size()) == 2: 84 | fan_in, fan_out = t.size() 85 | elif len(t.size()) == 3: 86 | # out_ch, in_ch, kernel for Conv 1 87 | fan_in = t.size()[1] * t.size()[2] 88 | fan_out = t.size()[0] * t.size()[2] 89 | else: 90 | fan_in = np.prod(t.size()) 91 | fan_out = np.prod(t.size()) 92 | 93 | limit = np.sqrt(6.0 / (fan_in + fan_out)) 94 | t.uniform_(-limit, limit) 95 | 96 | 97 | def _param_init(m): 98 | if isinstance(m, Parameter): 99 | _glorot_uniform(m.data) 100 | elif isinstance(m, nn.Linear): 101 | m.bias.data.zero_() 102 | _glorot_uniform(m.weight.data) 103 | elif isinstance(m, nn.Embedding): 104 | _glorot_uniform(m.weight.data) 105 | 106 | 107 | def glorot_uniform(m): 108 | for p in m.modules(): 109 | if isinstance(p, nn.ParameterList) or isinstance(p, nn.ModuleList): 110 | for pp in p: 111 | _param_init(pp) 112 | else: 113 | _param_init(p) 114 | 115 | for name, p in m.named_parameters(): 116 | if not '.' in name: # top-level parameters 117 | _param_init(p) -------------------------------------------------------------------------------- /gln/mods/rdchiral/README.md: -------------------------------------------------------------------------------- 1 | # rdchiral 2 | Wrapper for RDKit's RunReactants to improve stereochemistry handling 3 | 4 | See ```rdchiral/main.py``` for a brief description of expected behavior and a few basic examples of how to use the wrapper. 5 | 6 | See ```rdchiral/test/test_rdchiral.py``` for a small set of test cases described [here](https://chemrxiv.org/articles/RDChiral_An_RDKit_Wrapper_for_Handling_Stereochemistry_in_Retrosynthetic_Template_Extraction_and_Application/7949024) 7 | -------------------------------------------------------------------------------- /gln/mods/rdchiral/__init__.py: -------------------------------------------------------------------------------- 1 | from rdkit import RDLogger 2 | lg = RDLogger.logger() 3 | lg.setLevel(RDLogger.CRITICAL) -------------------------------------------------------------------------------- /gln/mods/rdchiral/bonds.py: -------------------------------------------------------------------------------- 1 | import rdkit.Chem as Chem 2 | import rdkit.Chem.AllChem as AllChem 3 | from rdkit.Chem.rdchem import ChiralType, BondType, BondDir, BondStereo 4 | from gln.mods.rdchiral.utils import vprint, PLEVEL 5 | 6 | BondDirOpposite = {AllChem.BondDir.ENDUPRIGHT: AllChem.BondDir.ENDDOWNRIGHT, 7 | AllChem.BondDir.ENDDOWNRIGHT: AllChem.BondDir.ENDUPRIGHT, 8 | AllChem.BondDir.NONE: AllChem.BondDir.NONE} 9 | BondDirLabel = {AllChem.BondDir.ENDUPRIGHT: '\\', 10 | AllChem.BondDir.ENDDOWNRIGHT: '/'} 11 | 12 | def bond_dirs_by_mapnum(mol): 13 | bond_dirs_by_mapnum = {} 14 | for b in mol.GetBonds(): 15 | i = None; j = None 16 | if b.GetBeginAtom().GetAtomMapNum(): 17 | i = b.GetBeginAtom().GetAtomMapNum() 18 | if b.GetEndAtom().GetAtomMapNum(): 19 | j = b.GetEndAtom().GetAtomMapNum() 20 | if i is None or j is None or b.GetBondDir() == BondDir.NONE: 21 | continue 22 | bond_dirs_by_mapnum[(i, j)] = b.GetBondDir() 23 | bond_dirs_by_mapnum[(j, i)] = BondDirOpposite[b.GetBondDir()] 24 | return bond_dirs_by_mapnum 25 | 26 | def enumerate_possible_cistrans_defs(template_r, \ 27 | labeling_func=lambda a: a.GetAtomMapNum()): 28 | ''' 29 | This function is meant to take a reactant template and fully enumerate 30 | all the ways in which different double-bonds can have their cis/trans 31 | chirality specified (based on labeling_func). This is necessary because 32 | double-bond chirality cannot be specified using cis/trans (global properties) 33 | but must be done using ENDUPRIGHT and ENDDOWNRIGHT for the attached single 34 | bonds (local properties). Now, the next issue is that on each side of the 35 | double bond, only one of the single bond directions must be specified, and 36 | that direction can be using either atom order. e.g., 37 | 38 | A1 B1 39 | \ / 40 | C = C 41 | / \ 42 | A2 B2 43 | 44 | Can be specified by: 45 | A1-C is an ENDDOWNRIGHT, C-B1 is an ENDUPRIGHT 46 | A1-C is an ENDDOWNRIGHT, C-B2 is an ENDDOWNRIGHT 47 | A1-C is an ENDDOWNRIGHT, B1-C is an ENDDOWNRIGHT 48 | A1-C is an ENDDOWNRIGHT, B2-C is an ENDUPRIGHT 49 | ...and twelve more definitions using different A1/A2 specs. 50 | 51 | ALSO - we can think about horizontally reflecting this bond entirely, 52 | which gets us even more definitions. 53 | 54 | So, the point of this function is to fully enumerate *all* of the ways 55 | in which chirality could have been specified. That way, we can take a 56 | reactant atom and check if its chirality is within the list of acceptable 57 | definitions to determine if a match was made. 58 | 59 | Gross. 60 | 61 | The way we do this is by first defining the *local* chirality of a double 62 | bond, which weights side chains based purely on the unique mapnum numbering. 63 | Once we have a local cis/trans definition for a double bond, we can enumerate 64 | the sixteen possible ways that a reactant could match it. 65 | 66 | ''' 67 | 68 | required_bond_defs = {} 69 | required_bond_defs_coreatoms = set() 70 | 71 | if PLEVEL >= 10: print('Looking at initializing template frag') 72 | for b in template_r.GetBonds(): 73 | if b.GetBondType() != BondType.DOUBLE: 74 | continue 75 | 76 | # Define begin and end atoms of the double bond 77 | ba = b.GetBeginAtom() 78 | bb = b.GetEndAtom() 79 | 80 | # Now check if it is even possible to specify 81 | if ba.GetDegree() == 1 or bb.GetDegree() == 1: 82 | continue 83 | 84 | ba_label = labeling_func(ba) 85 | bb_label = labeling_func(bb) 86 | 87 | if PLEVEL >= 10: print('Found a double bond with potential cis/trans (based on degree)') 88 | if PLEVEL >= 10: print('{} {} {}'.format(ba_label, 89 | b.GetSmarts(), 90 | bb_label)) 91 | 92 | # Save core atoms so we know that cis/trans was POSSIBLE to specify 93 | required_bond_defs_coreatoms.add((ba_label, bb_label)) 94 | required_bond_defs_coreatoms.add((bb_label, ba_label)) 95 | 96 | # Define heaviest mapnum neighbor for each atom, excluding the other side of the double bond 97 | ba_neighbor_labels = [labeling_func(a) for a in ba.GetNeighbors()] 98 | ba_neighbor_labels.remove(bb_label) # remove other side of = 99 | ba_neighbor_labels_max = max(ba_neighbor_labels) 100 | bb_neighbor_labels = [labeling_func(a) for a in bb.GetNeighbors()] 101 | bb_neighbor_labels.remove(ba_label) # remove other side of = 102 | bb_neighbor_labels_max = max(bb_neighbor_labels) 103 | 104 | # The direction of the bond being observed might need to be flipped, 105 | # based on 106 | # (a) if it is the heaviest atom on this side, and 107 | # (b) if the begin/end atoms for the directional bond are 108 | # in the wrong order (i.e., if the double-bonded atom 109 | # is the begin atom) 110 | front_spec = None 111 | back_spec = None 112 | for bab in ba.GetBonds(): 113 | if bab.GetBondDir() != BondDir.NONE: 114 | if labeling_func(bab.GetBeginAtom()) == ba_label: 115 | # Bond is in wrong order - flip 116 | if labeling_func(bab.GetEndAtom()) != ba_neighbor_labels_max: 117 | # Defined atom is not the heaviest - flip 118 | front_spec = bab.GetBondDir() 119 | break 120 | front_spec = BondDirOpposite[bab.GetBondDir()] 121 | break 122 | if labeling_func(bab.GetBeginAtom()) != ba_neighbor_labels_max: 123 | # Defined atom is not heaviest 124 | front_spec = BondDirOpposite[bab.GetBondDir()] 125 | break 126 | front_spec = bab.GetBondDir() 127 | break 128 | if front_spec is None: 129 | if PLEVEL >= 10: print('Chirality not specified at front end of the bond!') 130 | else: 131 | if PLEVEL >= 10: print('Front specification: {}'.format(front_spec)) 132 | 133 | for bbb in bb.GetBonds(): 134 | if bbb.GetBondDir() != BondDir.NONE: 135 | # For the "back" specification, the double-bonded atom *should* be the BeginAtom 136 | if labeling_func(bbb.GetEndAtom()) == bb_label: 137 | # Bond is in wrong order - flip 138 | if labeling_func(bbb.GetBeginAtom()) != bb_neighbor_labels_max: 139 | # Defined atom is not the heaviest - flip 140 | back_spec = bbb.GetBondDir() 141 | break 142 | back_spec = BondDirOpposite[bbb.GetBondDir()] 143 | break 144 | if labeling_func(bbb.GetEndAtom()) != bb_neighbor_labels_max: 145 | # Defined atom is not heaviest - flip 146 | back_spec = BondDirOpposite[bbb.GetBondDir()] 147 | break 148 | back_spec = bbb.GetBondDir() 149 | break 150 | if back_spec is None: 151 | if PLEVEL >= 10: print('Chirality not specified at back end of the bond!') 152 | else: 153 | if PLEVEL >= 10: print('Back specification: {}'.format(back_spec)) 154 | 155 | # Is this an overall unspecified bond? Put it in the dictionary anyway, 156 | # so there is something to match 157 | if front_spec is None or back_spec is None: 158 | # Create a definition over this bond so that reactant MUST be unspecified, too 159 | for start_atom in ba_neighbor_labels: 160 | for end_atom in bb_neighbor_labels: 161 | required_bond_defs[(start_atom, ba_label, bb_label, end_atom)] = (BondDir.NONE, BondDir.NONE) 162 | required_bond_defs[(ba_label, start_atom, bb_label, end_atom)] = (BondDir.NONE, BondDir.NONE) 163 | required_bond_defs[(start_atom, ba_label, end_atom, bb_label)] = (BondDir.NONE, BondDir.NONE) 164 | required_bond_defs[(ba_label, start_atom, end_atom, bb_label)] = (BondDir.NONE, BondDir.NONE) 165 | required_bond_defs[(bb_label, end_atom, start_atom, ba_label)] = (BondDir.NONE, BondDir.NONE) 166 | required_bond_defs[(end_atom, bb_label, start_atom, ba_label)] = (BondDir.NONE, BondDir.NONE) 167 | required_bond_defs[(bb_label, end_atom, ba_label, start_atom)] = (BondDir.NONE, BondDir.NONE) 168 | required_bond_defs[(end_atom, bb_label, ba_label, start_atom)] = (BondDir.NONE, BondDir.NONE) 169 | continue 170 | 171 | if front_spec == back_spec: 172 | if PLEVEL >= 10: print('-> locally TRANS') 173 | b.SetProp('localChirality', 'trans') 174 | else: 175 | if PLEVEL >= 10: print('--> locally CIS') 176 | b.SetProp('localChirality', 'cis') 177 | 178 | possible_defs = {} 179 | for start_atom in ba_neighbor_labels: 180 | for end_atom in bb_neighbor_labels: 181 | needs_inversion = (start_atom != ba_neighbor_labels_max) != \ 182 | (end_atom != bb_neighbor_labels_max) 183 | for start_atom_dir in [BondDir.ENDUPRIGHT, BondDir.ENDDOWNRIGHT]: 184 | # When locally trans, BondDir of start shold be same as end, 185 | # unless we need inversion 186 | if (front_spec != back_spec) != needs_inversion: 187 | # locally cis and does not need inversion (True, False) 188 | # or locally trans and does need inversion (False, True) 189 | end_atom_dir = BondDirOpposite[start_atom_dir] 190 | else: 191 | # locally cis and does need inversion (True, True) 192 | # or locally trans and does not need inversion (False, False) 193 | end_atom_dir = start_atom_dir 194 | 195 | # Enumerate all combinations of atom orders... 196 | possible_defs[(start_atom, ba_label, bb_label, end_atom)] = (start_atom_dir, end_atom_dir) 197 | possible_defs[(ba_label, start_atom, bb_label, end_atom)] = (BondDirOpposite[start_atom_dir], end_atom_dir) 198 | possible_defs[(start_atom, ba_label, end_atom, bb_label)] = (start_atom_dir, BondDirOpposite[end_atom_dir]) 199 | possible_defs[(ba_label, start_atom, end_atom, bb_label)] = (BondDirOpposite[start_atom_dir], BondDirOpposite[end_atom_dir]) 200 | 201 | possible_defs[(bb_label, end_atom, start_atom, ba_label)] = (end_atom_dir, start_atom_dir) 202 | possible_defs[(bb_label, end_atom, ba_label, start_atom)] = (end_atom_dir, BondDirOpposite[start_atom_dir]) 203 | possible_defs[(end_atom, bb_label, start_atom, ba_label)] = (BondDirOpposite[end_atom_dir], start_atom_dir) 204 | possible_defs[(end_atom, bb_label, ba_label, start_atom)] = (BondDirOpposite[end_atom_dir], BondDirOpposite[start_atom_dir]) 205 | 206 | # Save to the definition of this bond (in either direction) 207 | required_bond_defs.update(possible_defs) 208 | 209 | if PLEVEL >= 10: print('All bond specs for this template:' ) 210 | if PLEVEL >= 10: print(str([(k, v) for (k, v) in required_bond_defs.items()])) 211 | return required_bond_defs, required_bond_defs_coreatoms 212 | 213 | def get_atoms_across_double_bonds(mol, labeling_func=lambda a:a.GetAtomMapNum()): 214 | ''' 215 | This function takes a molecule and returns a list of cis/trans specifications 216 | according to the following: 217 | 218 | (mapnums, dirs) 219 | 220 | where atoms = (a1, a2, a3, a4) and dirs = (d1, d2) 221 | and (a1, a2) defines the ENDUPRIGHT/ENDDOWNRIGHT direction of the "front" 222 | of the bond using d1, and (a3, a4) defines the direction of the "back" of 223 | the bond using d2. 224 | 225 | This is used to initialize reactants with a SINGLE definition constraining 226 | the chirality. Templates have their chirality fully enumerated, so we can 227 | match this specific definition to the full set of possible definitions 228 | when determining if a match should be made. 229 | 230 | NOTE: the atom mapnums are returned. This is so we can later use them 231 | to get the old_mapno property from the corresponding product atom, which is 232 | an outcome-specific assignment 233 | 234 | We also include implicit chirality here based on ring membership, but keep 235 | track of that separately 236 | ''' 237 | atoms_across_double_bonds = [] 238 | atomrings = None 239 | 240 | for b in mol.GetBonds(): 241 | if b.GetBondType() != BondType.DOUBLE: 242 | continue 243 | 244 | # Define begin and end atoms of the double bond 245 | ba = b.GetBeginAtom() 246 | bb = b.GetEndAtom() 247 | 248 | # Now check if it is even possible to specify 249 | if ba.GetDegree() == 1 or bb.GetDegree() == 1: 250 | continue 251 | 252 | ba_label = labeling_func(ba) 253 | bb_label = labeling_func(bb) 254 | 255 | if PLEVEL >= 5: print('Found a double bond with potential cis/trans (based on degree)') 256 | if PLEVEL >= 5: print('{} {} {}'.format(ba_label, 257 | b.GetSmarts(), 258 | bb_label)) 259 | 260 | # Try to specify front and back direction separately 261 | front_mapnums = None 262 | front_dir = None 263 | back_mapnums = None 264 | back_dir = None 265 | is_implicit = False 266 | bab = None; bbb = None; 267 | for bab in (z for z in ba.GetBonds() if z.GetBondType() != BondType.DOUBLE): 268 | if bab.GetBondDir() != BondDir.NONE: 269 | front_mapnums = (labeling_func(bab.GetBeginAtom()), labeling_func(bab.GetEndAtom())) 270 | front_dir = bab.GetBondDir() 271 | break 272 | for bbb in (z for z in bb.GetBonds() if z.GetBondType() != BondType.DOUBLE): 273 | if bbb.GetBondDir() != BondDir.NONE: 274 | back_mapnums = (labeling_func(bbb.GetBeginAtom()), labeling_func(bbb.GetEndAtom())) 275 | back_dir = bbb.GetBondDir() 276 | break 277 | 278 | # If impossible to spec, just continue 279 | if (bab is None or bbb is None): 280 | continue 281 | 282 | # Did we actually get a specification out? 283 | if (front_dir is None or back_dir is None): 284 | 285 | if b.IsInRing(): 286 | # Implicit cis! Now to figure out right definitions... 287 | if atomrings is None: 288 | atomrings = mol.GetRingInfo().AtomRings() # tuple of tuples of atomIdx 289 | for atomring in atomrings: 290 | if ba.GetIdx() in atomring and bb.GetIdx() in atomring: 291 | front_mapnums = (labeling_func(bab.GetOtherAtom(ba)), ba_label) 292 | back_mapnums = (bb_label, labeling_func(bbb.GetOtherAtom(bb))) 293 | if (bab.GetOtherAtomIdx(ba.GetIdx()) in atomring) != \ 294 | (bbb.GetOtherAtomIdx(bb.GetIdx()) in atomring): 295 | # one of these atoms are in the ring, one is outside -> trans 296 | if PLEVEL >= 10: print('Implicit trans found') 297 | front_dir = BondDir.ENDUPRIGHT 298 | back_dir = BondDir.ENDUPRIGHT 299 | else: 300 | if PLEVEL >= 10: print('Implicit cis found') 301 | front_dir = BondDir.ENDUPRIGHT 302 | back_dir = BondDir.ENDDOWNRIGHT 303 | is_implicit = True 304 | break 305 | 306 | else: 307 | # Okay no, actually unspecified 308 | # Specify direction as BondDir.NONE using whatever bab and bbb were at the end fo the loop 309 | # note: this is why we use "for bab in ___generator___", so that we know the current 310 | # value of bab and bbb correspond to a single bond we can def. by 311 | front_mapnums = (labeling_func(bab.GetBeginAtom()), labeling_func(bab.GetEndAtom())) 312 | front_dir = BondDir.NONE 313 | back_mapnums = (labeling_func(bbb.GetBeginAtom()), labeling_func(bbb.GetEndAtom())) 314 | back_dir = BondDir.NONE 315 | 316 | # Save this (a1, a2, a3, a4) -> (d1, d2) spec 317 | atoms_across_double_bonds.append( 318 | ( 319 | front_mapnums + back_mapnums, 320 | (front_dir, back_dir), 321 | is_implicit, 322 | ) 323 | ) 324 | 325 | return atoms_across_double_bonds 326 | 327 | def restore_bond_stereo_to_sp2_atom(a, bond_dirs_by_mapnum): 328 | '''Copy over single-bond directions (ENDUPRIGHT, ENDDOWNRIGHT) to 329 | the single bonds attached to some double-bonded atom, a 330 | 331 | a - atom with a double bond 332 | bond_dirs_by_mapnum - dictionary of (begin_mapnum, end_mapnum): bond_dir 333 | that defines if a bond should be ENDUPRIGHT or ENDDOWNRIGHT. The reverse 334 | key is also included with the reverse bond direction. If the source 335 | molecule did not have a specified chirality at this double bond, then 336 | the mapnum tuples will be missing from the dict 337 | 338 | In some cases, like C=C/O>>C=C/Br, we should assume that stereochem was 339 | preserved, even though mapnums won't match. There might be some reactions 340 | where the chirality is inverted (like C=C/O >> C=C\Br), but let's not 341 | worry about those for now... 342 | 343 | Returns True if a bond direction was copied''' 344 | 345 | for bond_to_spec in a.GetBonds(): 346 | if (bond_to_spec.GetOtherAtom(a).GetAtomMapNum(), a.GetAtomMapNum()) in bond_dirs_by_mapnum: 347 | bond_to_spec.SetBondDir( 348 | bond_dirs_by_mapnum[ 349 | (bond_to_spec.GetBeginAtom().GetAtomMapNum(), 350 | bond_to_spec.GetEndAtom().GetAtomMapNum()) 351 | ] 352 | ) 353 | if PLEVEL >= 2: print('Tried to copy bond direction b/w {} and {}'.format( 354 | bond_to_spec.GetBeginAtom().GetAtomMapNum(), 355 | bond_to_spec.GetEndAtom().GetAtomMapNum() 356 | )) 357 | return True 358 | 359 | # Weird case, like C=C/O >> C=C/Br 360 | if PLEVEL >= 2: print('Bond stereo could not be restored to sp2 atom, missing the branch that was used to define before...') 361 | 362 | if a.GetDegree() == 2: 363 | # Either the branch used to define was replaced with H (deg 3 -> deg 2) 364 | # or the branch used to define was reacted (deg 2 -> deg 2) 365 | for bond_to_spec in a.GetBonds(): 366 | if bond_to_spec.GetBondType() == BondType.DOUBLE: 367 | continue 368 | if not bond_to_spec.GetOtherAtom(a).HasProp('old_mapno'): 369 | # new atom, deg2->deg2, assume direction preserved 370 | if PLEVEL >= 5: print('Only single-bond attachment to atom {} is new, try to reproduce chirality'.format(a.GetAtomMapNum())) 371 | needs_inversion = False 372 | else: 373 | # old atom, just was not used in chirality definition - set opposite 374 | if PLEVEL >= 5: print('Only single-bond attachment to atom {} is old, try to reproduce chirality'.format(a.GetAtomMapNum())) 375 | needs_inversion = True 376 | 377 | for (i, j), bond_dir in bond_dirs_by_mapnum.items(): 378 | if bond_dir != BondDir.NONE: 379 | if i == bond_to_spec.GetBeginAtom().GetAtomMapNum(): 380 | if needs_inversion: 381 | bond_to_spec.SetBondDir(BondDirOpposite[bond_dir]) 382 | else: 383 | bond_to_spec.SetBondDir(bond_dir) 384 | return True 385 | 386 | if a.GetDegree() == 3: 387 | # If we lost the branch defining stereochem, it must have been replaced 388 | for bond_to_spec in a.GetBonds(): 389 | if bond_to_spec.GetBondType() == BondType.DOUBLE: 390 | continue 391 | oa = bond_to_spec.GetOtherAtom(a) 392 | if oa.HasProp('old_mapno') or oa.HasProp('react_atom_idx'): 393 | # looking at an old atom, which should have opposite direction as removed atom 394 | needs_inversion = True 395 | else: 396 | # looking at a new atom, assume same as removed atom 397 | needs_inversion = False 398 | 399 | for (i, j), bond_dir in bond_dirs_by_mapnum.items(): 400 | if bond_dir != BondDir.NONE: 401 | if i == bond_to_spec.GetBeginAtom().GetAtomMapNum(): 402 | if needs_inversion: 403 | bond_to_spec.SetBondDir(BondDirOpposite[bond_dir]) 404 | else: 405 | bond_to_spec.SetBondDir(bond_dir) 406 | return True 407 | 408 | return False -------------------------------------------------------------------------------- /gln/mods/rdchiral/chiral.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from rdkit.Chem.rdchem import ChiralType, BondType, BondDir 3 | 4 | from gln.mods.rdchiral.utils import vprint, parity4, PLEVEL 5 | 6 | def template_atom_could_have_been_tetra(a, strip_if_spec=False, cache=True): 7 | ''' 8 | Could this atom have been a tetrahedral center? 9 | If yes, template atom is considered achiral and will not match a chiral rct 10 | If no, the tempalte atom is auxilliary and we should not use it to remove 11 | a matched reaction. For example, a fully-generalized terminal [C:1] 12 | ''' 13 | 14 | if a.HasProp('tetra_possible'): 15 | return a.GetBoolProp('tetra_possible') 16 | if a.GetDegree() < 3 or (a.GetDegree() == 3 and 'H' not in a.GetSmarts()): 17 | if cache: 18 | a.SetBoolProp('tetra_possible', False) 19 | if strip_if_spec: # Clear chiral tag in case improperly set 20 | a.SetChiralTag(ChiralType.CHI_UNSPECIFIED) 21 | return False 22 | if cache: 23 | a.SetBoolProp('tetra_possible', True) 24 | return True 25 | 26 | 27 | 28 | def copy_chirality(a_src, a_new): 29 | 30 | # Not possible to be a tetrahedral center anymore? 31 | if a_new.GetDegree() < 3: 32 | return 33 | if a_new.GetDegree() == 3 and \ 34 | any(b.GetBondType() != BondType.SINGLE for b in a_new.GetBonds()): 35 | return 36 | 37 | if PLEVEL >= 3: print('For mapnum {}, copying src {} chirality tag to new'.format( 38 | a_src.GetAtomMapNum(), a_src.GetChiralTag())) 39 | a_new.SetChiralTag(a_src.GetChiralTag()) 40 | 41 | if atom_chirality_matches(a_src, a_new) == -1: 42 | if PLEVEL >= 3: print('For mapnum {}, inverting chirality'.format(a_new.GetAtomMapNum())) 43 | a_new.InvertChirality() 44 | 45 | def atom_chirality_matches(a_tmp, a_mol): 46 | ''' 47 | Checks for consistency in chirality between a template atom and a molecule atom. 48 | 49 | Also checks to see if chirality needs to be inverted in copy_chirality 50 | 51 | Returns +1 if it is a match and there is no need for inversion (or ambiguous) 52 | Returns -1 if it is a match but they are the opposite 53 | Returns 0 if an explicit NOT match 54 | Returns 2 if ambiguous or achiral-achiral 55 | ''' 56 | if a_mol.GetChiralTag() == ChiralType.CHI_UNSPECIFIED: 57 | if a_tmp.GetChiralTag() == ChiralType.CHI_UNSPECIFIED: 58 | if PLEVEL >= 3: print('atom {} is achiral & achiral -> match'.format(a_mol.GetAtomMapNum())) 59 | return 2 # achiral template, achiral molecule -> match 60 | # What if the template was chiral, but the reactant isn't just due to symmetry? 61 | if not a_mol.HasProp('_ChiralityPossible'): 62 | # It's okay to make a match, as long as the product is achiral (even 63 | # though the product template will try to impose chirality) 64 | if PLEVEL >= 3: print('atom {} is specified in template, but cant possibly be chiral in mol'.format(a_mol.GetAtomMapNum())) 65 | return 2 66 | 67 | # Discussion: figure out if we want this behavior - should a chiral template 68 | # be applied to an achiral molecule? For the retro case, if we have 69 | # a retro reaction that requires a specific stereochem, return False; 70 | # however, there will be many cases where the reaction would probably work 71 | if PLEVEL >= 3: print('atom {} is achiral in mol, but specified in template'.format(a_mol.GetAtomMapNum())) 72 | return 0 73 | if a_tmp.GetChiralTag() == ChiralType.CHI_UNSPECIFIED: 74 | if PLEVEL >= 3: print('Reactant {} atom chiral, rtemplate achiral...'.format(a_tmp.GetAtomMapNum())) 75 | if template_atom_could_have_been_tetra(a_tmp): 76 | if PLEVEL >= 3: print('...and that atom could have had its chirality specified! no_match') 77 | return 0 78 | if PLEVEL >= 3: print('...but the rtemplate atom could not have had chirality specified, match anyway') 79 | return 2 80 | 81 | mapnums_tmp = [a.GetAtomMapNum() for a in a_tmp.GetNeighbors()] 82 | mapnums_mol = [a.GetAtomMapNum() for a in a_mol.GetNeighbors()] 83 | 84 | # When there are fewer than 3 heavy neighbors, chirality is ambiguous... 85 | if len(mapnums_tmp) < 3 or len(mapnums_mol) < 3: 86 | return 2 87 | 88 | # Degree of 3 -> remaining atom is a hydrogen, add to list 89 | if len(mapnums_tmp) < 4: 90 | mapnums_tmp.append(-1) # H 91 | if len(mapnums_mol) < 4: 92 | mapnums_mol.append(-1) # H 93 | 94 | try: 95 | if PLEVEL >= 10: print(str(mapnums_tmp)) 96 | if PLEVEL >= 10: print(str(mapnums_mol)) 97 | if PLEVEL >= 10: print(str(a_tmp.GetChiralTag())) 98 | if PLEVEL >= 10: print(str(a_mol.GetChiralTag())) 99 | only_in_src = [i for i in mapnums_tmp if i not in mapnums_mol][::-1] # reverse for popping 100 | only_in_mol = [i for i in mapnums_mol if i not in mapnums_tmp] 101 | if len(only_in_src) <= 1 and len(only_in_mol) <= 1: 102 | tmp_parity = parity4(mapnums_tmp) 103 | mol_parity = parity4([i if i in mapnums_tmp else only_in_src.pop() for i in mapnums_mol]) 104 | if PLEVEL >= 10: print(str(tmp_parity)) 105 | if PLEVEL >= 10: print(str(mol_parity)) 106 | parity_matches = tmp_parity == mol_parity 107 | tag_matches = a_tmp.GetChiralTag() == a_mol.GetChiralTag() 108 | chirality_matches = parity_matches == tag_matches 109 | if PLEVEL >= 2: print('mapnum {} chiral match? {}'.format(a_tmp.GetAtomMapNum(), chirality_matches)) 110 | return 1 if chirality_matches else -1 111 | else: 112 | if PLEVEL >= 2: print('mapnum {} chiral match? Based on mapnum lists, ambiguous -> True'.format(a_tmp.GetAtomMapNum())) 113 | return 2 # ambiguous case, just return for now 114 | 115 | except IndexError as e: 116 | print(a_tmp.GetPropsAsDict()) 117 | print(a_mol.GetPropsAsDict()) 118 | print(a_tmp.GetChiralTag()) 119 | print(a_mol.GetChiralTag()) 120 | print(str(e)) 121 | print(str(mapnums_tmp)) 122 | print(str(mapnums_mol)) 123 | raise KeyError('Pop from empty set - this should not happen!') -------------------------------------------------------------------------------- /gln/mods/rdchiral/clean.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import rdkit.Chem as Chem 3 | import re 4 | from itertools import chain 5 | 6 | from gln.mods.rdchiral.utils import vprint, PLEVEL 7 | 8 | 9 | def canonicalize_outcome_smiles(smiles, ensure=True): 10 | # Uniquify via SMILES string - a little sloppy 11 | # Need a full SMILES->MOL->SMILES cycle to get a true canonical string 12 | # also, split by '.' and sort when outcome contains multiple molecules 13 | if ensure: 14 | outcome = Chem.MolFromSmiles(smiles) 15 | if outcome is None: 16 | if PLEVEL >= 1: print('~~ could not parse self?') 17 | if PLEVEL >= 1: print('Attempted SMILES: {}', smiles) 18 | return None 19 | 20 | smiles = Chem.MolToSmiles(outcome, True) 21 | 22 | return '.'.join(sorted(smiles.split('.'))) 23 | 24 | def combine_enantiomers_into_racemic(final_outcomes): 25 | ''' 26 | If two products are identical except for an inverted CW/CCW or an 27 | opposite cis/trans, then just strip that from the product. Return 28 | the achiral one instead. 29 | 30 | This is not very sophisticated, since the chirality could affect the bond 31 | order and thus the canonical SMILES. But, whatever. It also does not look 32 | to invert multiple stereocenters at once 33 | ''' 34 | 35 | for smiles in list(final_outcomes)[:]: 36 | 37 | # Look for @@ tetrahedral center 38 | for match in re.finditer(r'@@', smiles): 39 | smiles_inv = '%s@%s' % (smiles[:match.start()], smiles[match.end():]) 40 | if smiles_inv in final_outcomes: 41 | if smiles in final_outcomes: 42 | final_outcomes.remove(smiles) 43 | final_outcomes.remove(smiles_inv) 44 | # Re-parse smiles so that hydrogens can become implicit 45 | smiles = smiles[:match.start()] + smiles[match.end():] 46 | outcome = Chem.MolFromSmiles(smiles) 47 | if outcome is None: 48 | raise ValueError('Horrible mistake when fixing duplicate!') 49 | smiles = '.'.join(sorted(Chem.MolToSmiles(outcome, True).split('.'))) 50 | final_outcomes.add(smiles) 51 | 52 | # Look for // or \\ trans bond 53 | # where [^=\.] is any non-double bond or period or slash 54 | for match in chain(re.finditer(r'(\/)([^=\.\\\/]+=[^=\.\\\/]+)(\/)', smiles), 55 | re.finditer(r'(\\)([^=\.\\\/]+=[^=\.\\\/]+)(\\)', smiles)): 56 | # See if cis version is present in list of outcomes 57 | opposite = {'\\': '/', '/': '\\'} 58 | smiles_cis1 = '%s%s%s%s%s' % (smiles[:match.start()], 59 | match.group(1), match.group(2), opposite[match.group(3)], 60 | smiles[match.end():]) 61 | smiles_cis2 = '%s%s%s%s%s' % (smiles[:match.start()], 62 | opposite[match.group(1)], match.group(2), match.group(3), 63 | smiles[match.end():]) 64 | # Also look for equivalent trans 65 | smiles_trans2 = '%s%s%s%s%s' % (smiles[:match.start()], 66 | opposite[match.group(1)], match.group(2), 67 | opposite[match.group(3)], smiles[match.end():]) 68 | # Kind of weird remove conditionals... 69 | remove = False 70 | if smiles_cis1 in final_outcomes: 71 | final_outcomes.remove(smiles_cis1) 72 | remove = True 73 | if smiles_cis2 in final_outcomes: 74 | final_outcomes.remove(smiles_cis2) 75 | remove = True 76 | if smiles_trans2 in final_outcomes and smiles in final_outcomes: 77 | final_outcomes.remove(smiles_trans2) 78 | if remove: 79 | final_outcomes.remove(smiles) 80 | smiles = smiles[:match.start()] + match.group(2) + smiles[match.end():] 81 | outcome = Chem.MolFromSmiles(smiles) 82 | if outcome is None: 83 | raise ValueError('Horrible mistake when fixing duplicate!') 84 | smiles = '.'.join(sorted(Chem.MolToSmiles(outcome, True).split('.'))) 85 | final_outcomes.add(smiles) 86 | return final_outcomes 87 | -------------------------------------------------------------------------------- /gln/mods/rdchiral/initialization.py: -------------------------------------------------------------------------------- 1 | import rdkit.Chem as Chem 2 | import rdkit.Chem.AllChem as AllChem 3 | from rdkit.Chem.rdchem import ChiralType, BondType, BondDir, BondStereo 4 | 5 | from gln.mods.rdchiral.chiral import template_atom_could_have_been_tetra 6 | from gln.mods.rdchiral.utils import vprint, PLEVEL 7 | from gln.mods.rdchiral.bonds import enumerate_possible_cistrans_defs, bond_dirs_by_mapnum, \ 8 | get_atoms_across_double_bonds 9 | 10 | BondDirOpposite = {AllChem.BondDir.ENDUPRIGHT: AllChem.BondDir.ENDDOWNRIGHT, 11 | AllChem.BondDir.ENDDOWNRIGHT: AllChem.BondDir.ENDUPRIGHT} 12 | 13 | class rdchiralReaction(): 14 | ''' 15 | Class to store everything that should be pre-computed for a reaction. This 16 | makes library application much faster, since we can pre-do a lot of work 17 | instead of doing it for every mol-template pair 18 | ''' 19 | def __init__(self, reaction_smarts): 20 | # Keep smarts, useful for reporting 21 | self.reaction_smarts = reaction_smarts 22 | 23 | # Initialize - assigns stereochemistry and fills in missing rct map numbers 24 | self.rxn = initialize_rxn_from_smarts(reaction_smarts) 25 | 26 | # Combine template fragments so we can play around with mapnums 27 | self.template_r, self.template_p = get_template_frags_from_rxn(self.rxn) 28 | 29 | # Define molAtomMapNumber->atom dictionary for template rct and prd 30 | self.atoms_rt_map = {a.GetAtomMapNum(): a \ 31 | for a in self.template_r.GetAtoms() if a.GetAtomMapNum()} 32 | self.atoms_pt_map = {a.GetAtomMapNum(): a \ 33 | for a in self.template_p.GetAtoms() if a.GetAtomMapNum()} 34 | 35 | # Back-up the mapping for the reaction 36 | self.atoms_rt_idx_to_map = {a.GetIdx(): a.GetAtomMapNum() 37 | for a in self.template_r.GetAtoms()} 38 | self.atoms_pt_idx_to_map = {a.GetIdx(): a.GetAtomMapNum() 39 | for a in self.template_p.GetAtoms()} 40 | 41 | # Check consistency (this should not be necessary...) 42 | if any(self.atoms_rt_map[i].GetAtomicNum() != self.atoms_pt_map[i].GetAtomicNum() \ 43 | for i in self.atoms_rt_map if i in self.atoms_pt_map): 44 | raise ValueError('Atomic identity should not change in a reaction!') 45 | 46 | # Call template_atom_could_have_been_tetra to pre-assign value to atom 47 | [template_atom_could_have_been_tetra(a) for a in self.template_r.GetAtoms()] 48 | [template_atom_could_have_been_tetra(a) for a in self.template_p.GetAtoms()] 49 | 50 | # Pre-list chiral double bonds (for copying back into outcomes/matching) 51 | self.rt_bond_dirs_by_mapnum = bond_dirs_by_mapnum(self.template_r) 52 | self.pt_bond_dirs_by_mapnum = bond_dirs_by_mapnum(self.template_p) 53 | 54 | # Enumerate possible cis/trans... 55 | self.required_rt_bond_defs, self.required_bond_defs_coreatoms = \ 56 | enumerate_possible_cistrans_defs(self.template_r) 57 | 58 | def reset(self): 59 | for (idx, mapnum) in self.atoms_rt_idx_to_map.items(): 60 | self.template_r.GetAtomWithIdx(idx).SetAtomMapNum(mapnum) 61 | for (idx, mapnum) in self.atoms_pt_idx_to_map.items(): 62 | self.template_p.GetAtomWithIdx(idx).SetAtomMapNum(mapnum) 63 | 64 | class rdchiralReactants(): 65 | ''' 66 | Class to store everything that should be pre-computed for a reactant mol 67 | so that library application is faster 68 | ''' 69 | def __init__(self, reactant_smiles): 70 | # Keep original smiles, useful for reporting 71 | self.reactant_smiles = reactant_smiles 72 | 73 | # Initialize into RDKit mol 74 | self.reactants = initialize_reactants_from_smiles(reactant_smiles) 75 | 76 | # Set mapnum->atom dictionary 77 | # all reactant atoms must be mapped after initialization, so this is safe 78 | self.atoms_r = {a.GetAtomMapNum(): a for a in self.reactants.GetAtoms()} 79 | self.idx_to_mapnum = lambda idx: self.reactants.GetAtomWithIdx(idx).GetAtomMapNum() 80 | 81 | # Create copy of molecule without chiral information, used with 82 | # RDKit's naive runReactants 83 | self.reactants_achiral = initialize_reactants_from_smiles(reactant_smiles) 84 | [a.SetChiralTag(ChiralType.CHI_UNSPECIFIED) for a in self.reactants_achiral.GetAtoms()] 85 | [(b.SetStereo(BondStereo.STEREONONE), b.SetBondDir(BondDir.NONE)) \ 86 | for b in self.reactants_achiral.GetBonds()] 87 | 88 | # Pre-list reactant bonds (for stitching broken products) 89 | self.bonds_by_mapnum = [ 90 | (b.GetBeginAtom().GetAtomMapNum(), b.GetEndAtom().GetAtomMapNum(), b) \ 91 | for b in self.reactants.GetBonds() 92 | ] 93 | 94 | # Pre-list chiral double bonds (for copying back into outcomes/matching) 95 | self.bond_dirs_by_mapnum = {} 96 | for (i, j, b) in self.bonds_by_mapnum: 97 | if b.GetBondDir() != BondDir.NONE: 98 | self.bond_dirs_by_mapnum[(i, j)] = b.GetBondDir() 99 | self.bond_dirs_by_mapnum[(j, i)] = BondDirOpposite[b.GetBondDir()] 100 | 101 | # Get atoms across double bonds defined by mapnum 102 | self.atoms_across_double_bonds = get_atoms_across_double_bonds(self.reactants) 103 | 104 | 105 | def initialize_rxn_from_smarts(reaction_smarts): 106 | # Initialize reaction 107 | rxn = AllChem.ReactionFromSmarts(reaction_smarts) 108 | rxn.Initialize() 109 | if rxn.Validate()[1] != 0: 110 | raise ValueError('validation failed') 111 | if PLEVEL >= 2: print('Validated rxn without errors') 112 | 113 | 114 | # Figure out if there are unnecessary atom map numbers (that are not balanced) 115 | # e.g., leaving groups for retrosynthetic templates. This is because additional 116 | # atom map numbers in the input SMARTS template may conflict with the atom map 117 | # numbers of the molecules themselves 118 | prd_maps = [a.GetAtomMapNum() for prd in rxn.GetProducts() for a in prd.GetAtoms() if a.GetAtomMapNum()] 119 | 120 | unmapped = 700 121 | for rct in rxn.GetReactants(): 122 | rct.UpdatePropertyCache() 123 | Chem.AssignStereochemistry(rct) 124 | # Fill in atom map numbers 125 | for a in rct.GetAtoms(): 126 | if not a.GetAtomMapNum() or a.GetAtomMapNum() not in prd_maps: 127 | a.SetAtomMapNum(unmapped) 128 | unmapped += 1 129 | if PLEVEL >= 2: print('Added {} map nums to unmapped reactants'.format(unmapped-700)) 130 | if unmapped > 800: 131 | raise ValueError('Why do you have so many unmapped atoms in the template reactants?') 132 | 133 | return rxn 134 | 135 | def initialize_reactants_from_smiles(reactant_smiles): 136 | # Initialize reactants 137 | reactants = Chem.MolFromSmiles(reactant_smiles) 138 | Chem.AssignStereochemistry(reactants, flagPossibleStereoCenters=True) 139 | reactants.UpdatePropertyCache() 140 | # To have the product atoms match reactant atoms, we 141 | # need to populate the map number field, since this field 142 | # gets copied over during the reaction via reactant_atom_idx. 143 | [a.SetAtomMapNum(i+1) for (i, a) in enumerate(reactants.GetAtoms())] 144 | if PLEVEL >= 2: print('Initialized reactants, assigned map numbers, stereochem, flagpossiblestereocenters') 145 | return reactants 146 | 147 | def get_template_frags_from_rxn(rxn): 148 | # Copy reaction template so we can play around with map numbers 149 | for i, rct in enumerate(rxn.GetReactants()): 150 | if i == 0: 151 | template_r = rct 152 | else: 153 | template_r = AllChem.CombineMols(template_r, rct) 154 | for i, prd in enumerate(rxn.GetProducts()): 155 | if i == 0: 156 | template_p = prd 157 | else: 158 | template_p = AllChem.CombineMols(template_p, prd) 159 | return template_r, template_p -------------------------------------------------------------------------------- /gln/mods/rdchiral/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | PLEVEL = 0 4 | def vprint(level, txt, *args): 5 | if PLEVEL >= level: 6 | print(txt.format(*args)) 7 | 8 | def parity4(data): 9 | ''' 10 | Thanks to http://www.dalkescientific.com/writings/diary/archive/2016/08/15/fragment_parity_calculation.html 11 | ''' 12 | if data[0] < data[1]: 13 | if data[2] < data[3]: 14 | if data[0] < data[2]: 15 | if data[1] < data[2]: 16 | return 0 # (0, 1, 2, 3) 17 | else: 18 | if data[1] < data[3]: 19 | return 1 # (0, 2, 1, 3) 20 | else: 21 | return 0 # (0, 3, 1, 2) 22 | else: 23 | if data[0] < data[3]: 24 | if data[1] < data[3]: 25 | return 0 # (1, 2, 0, 3) 26 | else: 27 | return 1 # (1, 3, 0, 2) 28 | else: 29 | return 0 # (2, 3, 0, 1) 30 | else: 31 | if data[0] < data[3]: 32 | if data[1] < data[2]: 33 | if data[1] < data[3]: 34 | return 1 # (0, 1, 3, 2) 35 | else: 36 | return 0 # (0, 2, 3, 1) 37 | else: 38 | return 1 # (0, 3, 2, 1) 39 | else: 40 | if data[0] < data[2]: 41 | if data[1] < data[2]: 42 | return 1 # (1, 2, 3, 0) 43 | else: 44 | return 0 # (1, 3, 2, 0) 45 | else: 46 | return 1 # (2, 3, 1, 0) 47 | else: 48 | if data[2] < data[3]: 49 | if data[0] < data[3]: 50 | if data[0] < data[2]: 51 | return 1 # (1, 0, 2, 3) 52 | else: 53 | if data[1] < data[2]: 54 | return 0 # (2, 0, 1, 3) 55 | else: 56 | return 1 # (2, 1, 0, 3) 57 | else: 58 | if data[1] < data[2]: 59 | return 1 # (3, 0, 1, 2) 60 | else: 61 | if data[1] < data[3]: 62 | return 0 # (3, 1, 0, 2) 63 | else: 64 | return 1 # (3, 2, 0, 1) 65 | else: 66 | if data[0] < data[2]: 67 | if data[0] < data[3]: 68 | return 0 # (1, 0, 3, 2) 69 | else: 70 | if data[1] < data[3]: 71 | return 1 # (2, 0, 3, 1) 72 | else: 73 | return 0 # (2, 1, 3, 0) 74 | else: 75 | if data[1] < data[2]: 76 | if data[1] < data[3]: 77 | return 0 # (3, 0, 2, 1) 78 | else: 79 | return 1 # (3, 1, 2, 0) 80 | else: 81 | return 0 # (3, 2, 1, 0) 82 | 83 | def bond_to_label(bond): 84 | '''This function takes an RDKit bond and creates a label describing 85 | the most important attributes''' 86 | 87 | a1_label = str(bond.GetBeginAtom().GetAtomicNum()) 88 | a2_label = str(bond.GetEndAtom().GetAtomicNum()) 89 | if bond.GetBeginAtom().GetAtomMapNum(): 90 | a1_label += str(bond.GetBeginAtom().GetAtomMapNum()) 91 | if bond.GetEndAtom().GetAtomMapNum(): 92 | a2_label += str(bond.GetEndAtom().GetAtomMapNum()) 93 | atoms = sorted([a1_label, a2_label]) 94 | 95 | return '{}{}{}'.format(atoms[0], bond.GetSmarts(), atoms[1]) 96 | 97 | 98 | def atoms_are_different(atom1, atom2): 99 | '''Compares two RDKit atoms based on basic properties''' 100 | 101 | if atom1.GetSmarts() != atom2.GetSmarts(): return True # should be very general 102 | if atom1.GetAtomicNum() != atom2.GetAtomicNum(): return True # must be true for atom mapping 103 | if atom1.GetTotalNumHs() != atom2.GetTotalNumHs(): return True 104 | if atom1.GetFormalCharge() != atom2.GetFormalCharge(): return True 105 | if atom1.GetDegree() != atom2.GetDegree(): return True 106 | if atom1.GetNumRadicalElectrons() != atom2.GetNumRadicalElectrons(): return True 107 | if atom1.GetIsAromatic() != atom2.GetIsAromatic(): return True 108 | 109 | # Check bonds and nearest neighbor identity 110 | bonds1 = sorted([bond_to_label(bond) for bond in atom1.GetBonds()]) 111 | bonds2 = sorted([bond_to_label(bond) for bond in atom2.GetBonds()]) 112 | if bonds1 != bonds2: return True 113 | 114 | return False -------------------------------------------------------------------------------- /gln/mods/torchext/__init__.py: -------------------------------------------------------------------------------- 1 | from .jagged_ops import jagged_log_softmax -------------------------------------------------------------------------------- /gln/mods/torchext/jagged_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import extlib 3 | try: 4 | import extlib_cuda 5 | except: 6 | print('not loading cuda jagged ops') 7 | from torch.autograd import Function 8 | from torch.nn import Module 9 | import numpy as np 10 | 11 | #---------------------- 12 | # jagged_log_softmax 13 | #---------------------- 14 | class JaggedLogSoftmaxFunc(Function): 15 | @staticmethod 16 | def forward(ctx, logits, prefix_sum): 17 | assert len(prefix_sum.size()) == 1 18 | if not logits.is_cuda: 19 | output = extlib.jagged_log_softmax_forward(logits, prefix_sum) 20 | else: 21 | output = extlib_cuda.jagged_log_softmax_forward_cuda(logits, prefix_sum) 22 | 23 | ctx.save_for_backward(prefix_sum, output) 24 | return output 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | prefix_sum, output = ctx.saved_variables 29 | if not grad_output.is_cuda: 30 | grad_input = extlib.jagged_log_softmax_backward(output.data, grad_output, prefix_sum.data) 31 | else: 32 | grad_input = extlib_cuda.jagged_log_softmax_backward_cuda(output.data, grad_output, prefix_sum.data) 33 | return grad_input, None 34 | 35 | 36 | class JaggedLogSoftmax(Module): 37 | def forward(self, logits, prefix_sum): 38 | return JaggedLogSoftmaxFunc.apply(logits, prefix_sum) 39 | 40 | jagged_log_softmax = JaggedLogSoftmax() 41 | 42 | -------------------------------------------------------------------------------- /gln/mods/torchext/src/extlib.cpp: -------------------------------------------------------------------------------- 1 | #include "extlib.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | torch::Tensor make_contiguous(torch::Tensor& t) 11 | { 12 | if (t.is_contiguous()) 13 | return t; 14 | return t.contiguous(); 15 | } 16 | 17 | template 18 | void impl_jagged_log_softmax_forward(scalar_t *input_data_base, scalar_t *output_data_base, torch::Tensor prefix_sum) 19 | { 20 | int64_t *ps = prefix_sum.data(); 21 | int64_t bsize = prefix_sum.sizes()[0]; 22 | int64_t i, d; 23 | 24 | #pragma omp parallel for private(i, d) 25 | for (i = 0; i < bsize; i++) 26 | { 27 | int64_t offset = (i == 0) ? 0 : ps[i - 1]; 28 | 29 | scalar_t* input_data = input_data_base + offset; 30 | scalar_t* output_data = output_data_base + offset; 31 | 32 | int64_t n_ele = ps[i] - offset; 33 | scalar_t max_input = -FLT_MAX; 34 | 35 | for (d = 0; d < n_ele; d++) 36 | max_input = std::max(max_input, input_data[d]); 37 | 38 | double logsum = 0; 39 | for (d = 0; d < n_ele; d++) 40 | logsum += exp(input_data[d] - max_input); 41 | logsum = max_input + log(logsum); 42 | for (d = 0; d < n_ele; d++) 43 | output_data[d] = input_data[d] - logsum; 44 | } 45 | } 46 | 47 | 48 | template void impl_jagged_log_softmax_forward(float *input_data_base, float *output_data_base, torch::Tensor prefix_sum); 49 | template void impl_jagged_log_softmax_forward(double *input_data_base, double *output_data_base, torch::Tensor prefix_sum); 50 | 51 | 52 | torch::Tensor jagged_log_softmax_forward(torch::Tensor logits, torch::Tensor prefix_sum) 53 | { 54 | logits = make_contiguous(logits); 55 | prefix_sum = make_contiguous(prefix_sum); 56 | auto output = torch::zeros_like(logits); 57 | AT_DISPATCH_FLOATING_TYPES(logits.type(), "jagged_log_softmax_forward", ([&] { 58 | impl_jagged_log_softmax_forward(logits.data(), 59 | output.data(), 60 | prefix_sum); 61 | })); 62 | return output; 63 | } 64 | 65 | template 66 | void impl_jagged_log_softmax_backward(scalar_t *output_data_base, scalar_t *gradOutput_data_base, torch::Tensor prefix_sum, scalar_t *gradInput_data_base) 67 | { 68 | int64_t *ps = prefix_sum.data(); 69 | int64_t bsize = prefix_sum.sizes()[0]; 70 | int64_t i, d; 71 | 72 | #pragma omp parallel for private(i, d) 73 | for (i = 0; i < bsize; i++) 74 | { 75 | int64_t offset = (i == 0) ? 0 : ps[i - 1]; 76 | scalar_t *gradInput_data = gradInput_data_base + offset; 77 | scalar_t *output_data = output_data_base + offset; 78 | scalar_t *gradOutput_data = gradOutput_data_base + offset; 79 | 80 | double sum = 0; 81 | int64_t n_ele = ps[i] - offset; 82 | for (d = 0; d < n_ele; d++) 83 | sum += gradOutput_data[d]; 84 | 85 | for (d = 0; d < n_ele; d++) 86 | gradInput_data[d] = gradOutput_data[d] - exp(output_data[d]) * sum; 87 | } 88 | } 89 | 90 | template void impl_jagged_log_softmax_backward(float *output_data_base, float *gradOutput_data_base, torch::Tensor prefix_sum, float *gradInput_data_base); 91 | template void impl_jagged_log_softmax_backward(double *output_data_base, double *gradOutput_data_base, torch::Tensor prefix_sum, double *gradInput_data_base); 92 | 93 | torch::Tensor jagged_log_softmax_backward(torch::Tensor output, torch::Tensor grad_output, torch::Tensor prefix_sum) 94 | { 95 | output = make_contiguous(output); 96 | grad_output = make_contiguous(grad_output); 97 | prefix_sum = make_contiguous(prefix_sum); 98 | 99 | auto grad_input = torch::zeros_like(output); 100 | 101 | AT_DISPATCH_FLOATING_TYPES(output.type(), "jagged_log_softmax_backward", ([&] { 102 | impl_jagged_log_softmax_backward(output.data(), 103 | grad_output.data(), 104 | prefix_sum, 105 | grad_input.data()); 106 | })); 107 | 108 | return grad_input; 109 | } 110 | 111 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 112 | m.def("jagged_log_softmax_forward", &jagged_log_softmax_forward, "Jagged Log Softmax Forward"); 113 | m.def("jagged_log_softmax_backward", &jagged_log_softmax_backward, "Jagged Log Softmax Backward"); 114 | } 115 | -------------------------------------------------------------------------------- /gln/mods/torchext/src/extlib.h: -------------------------------------------------------------------------------- 1 | #ifndef EXTLIB_H 2 | #define EXTLIB_H 3 | 4 | #include 5 | #include 6 | 7 | torch::Tensor jagged_log_softmax_forward(torch::Tensor logits, torch::Tensor prefix_sum); 8 | 9 | torch::Tensor jagged_log_softmax_backward(torch::Tensor output, torch::Tensor grad_output, torch::Tensor prefix_sum); 10 | 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /gln/mods/torchext/src/extlib_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include "extlib_cuda.h" 2 | #include "extlib_cuda_kernels.h" 3 | #include 4 | #include 5 | 6 | torch::Tensor make_contiguous(torch::Tensor& t) 7 | { 8 | if (t.is_contiguous()) 9 | return t; 10 | return t.contiguous(); 11 | } 12 | 13 | torch::Tensor jagged_log_softmax_forward_cuda(torch::Tensor logits, torch::Tensor prefix_sum) 14 | { 15 | logits = make_contiguous(logits); 16 | prefix_sum = make_contiguous(prefix_sum); 17 | auto output = torch::zeros_like(logits); 18 | int64_t bsize = prefix_sum.sizes()[0]; 19 | int64_t* ps = prefix_sum.data(); 20 | 21 | AT_DISPATCH_FLOATING_TYPES(logits.type(), "jagged_log_softmax_forward_cuda", ([&] { 22 | HostLogSoftmaxForward(logits.data(), 23 | output.data(), 24 | ps, bsize); 25 | })); 26 | return output; 27 | } 28 | 29 | torch::Tensor jagged_log_softmax_backward_cuda(torch::Tensor output, torch::Tensor grad_output, torch::Tensor prefix_sum) 30 | { 31 | output = make_contiguous(output); 32 | grad_output = make_contiguous(grad_output); 33 | prefix_sum = make_contiguous(prefix_sum); 34 | 35 | auto grad_input = torch::zeros_like(output); 36 | 37 | int64_t bsize = prefix_sum.sizes()[0]; 38 | int64_t* ps = prefix_sum.data(); 39 | AT_DISPATCH_FLOATING_TYPES(output.type(), "jagged_log_softmax_backward_cuda", ([&] { 40 | HostLogSoftmaxBackward(grad_output.data(), 41 | grad_input.data(), 42 | output.data(), 43 | ps, bsize); 44 | })); 45 | return grad_input; 46 | } 47 | 48 | 49 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 50 | m.def("jagged_log_softmax_forward_cuda", &jagged_log_softmax_forward_cuda, "Jagged Log Softmax Forward (CUDA)"); 51 | m.def("jagged_log_softmax_backward_cuda", &jagged_log_softmax_backward_cuda, "Jagged Log Softmax Backward (CUDA)"); 52 | } -------------------------------------------------------------------------------- /gln/mods/torchext/src/extlib_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef EXTLIB_CUDA_H 2 | #define EXTLIB_CUDA_H 3 | 4 | #include 5 | 6 | torch::Tensor jagged_log_softmax_forward_cuda(torch::Tensor logits, torch::Tensor prefix_sum); 7 | 8 | torch::Tensor jagged_log_softmax_backward_cuda(torch::Tensor output, torch::Tensor grad_output, torch::Tensor prefix_sum); 9 | 10 | 11 | #endif -------------------------------------------------------------------------------- /gln/mods/torchext/src/extlib_cuda_kernels.cu: -------------------------------------------------------------------------------- 1 | #include "extlib_cuda_kernels.h" 2 | #include 3 | #include 4 | 5 | struct SharedMem 6 | { 7 | __device__ double *getPointer() { 8 | extern __shared__ double s_double[]; 9 | return s_double; 10 | } 11 | }; 12 | 13 | 14 | struct Max 15 | { 16 | template 17 | __device__ __forceinline__ double operator()(double x, scalar_t y) const { 18 | return x > static_cast(y) ? x : static_cast(y); 19 | } 20 | }; 21 | 22 | struct Add 23 | { 24 | template 25 | __device__ __forceinline__ double operator()(double x, scalar_t y) const { 26 | return x + y; 27 | } 28 | }; 29 | 30 | 31 | struct SumExp 32 | { 33 | __device__ __forceinline__ SumExp(double v) : max_k(v) {} 34 | 35 | template 36 | __device__ __forceinline__ double operator()(double sum, scalar_t v) const { 37 | return sum + static_cast(exp((double)v - max_k)); 38 | } 39 | 40 | const double max_k; 41 | }; 42 | 43 | 44 | template 45 | __device__ __forceinline__ double 46 | blockReduce(double* smem, double val, 47 | const Reduction& r, 48 | double defaultVal) 49 | { 50 | // To avoid RaW races from chaining blockReduce calls together, we need a sync here 51 | __syncthreads(); 52 | 53 | smem[threadIdx.x] = val; 54 | 55 | __syncthreads(); 56 | 57 | double warpVal = defaultVal; 58 | 59 | // First warp will perform per-warp reductions for the remaining warps 60 | if (threadIdx.x < 32) { 61 | int lane = threadIdx.x % 32; 62 | if (lane < blockDim.x / 32) { 63 | #pragma unroll 64 | for (int i = 0; i < 32; ++i) { 65 | warpVal = r(warpVal, smem[lane * 32 + i]); 66 | } 67 | smem[lane] = warpVal; 68 | } 69 | } 70 | 71 | __syncthreads(); 72 | 73 | // First thread will perform a reduction of the above per-warp reductions 74 | double blockVal = defaultVal; 75 | 76 | if (threadIdx.x == 0) { 77 | for (int i = 0; i < blockDim.x / 32; ++i) { 78 | blockVal = r(blockVal, smem[i]); 79 | } 80 | smem[0] = blockVal; 81 | } 82 | 83 | // Sync and broadcast 84 | __syncthreads(); 85 | return smem[0]; 86 | } 87 | 88 | 89 | template 90 | __device__ __forceinline__ double 91 | ilpReduce(scalar_t* data, 92 | int size, 93 | const Reduction& r, 94 | double defaultVal) 95 | { 96 | double threadVal = defaultVal; 97 | int offset = threadIdx.x; 98 | 99 | int last = size % (ILP * blockDim.x); 100 | 101 | // Body (unroll by ILP times) 102 | for (; offset < size - last; offset += blockDim.x * ILP) { 103 | scalar_t tmp[ILP]; 104 | 105 | #pragma unroll 106 | for (int j = 0; j < ILP; ++j) 107 | tmp[j] = data[offset + j * blockDim.x]; 108 | 109 | #pragma unroll 110 | for (int j = 0; j < ILP; ++j) 111 | threadVal = r(threadVal, tmp[j]); 112 | } 113 | 114 | // Epilogue 115 | for (; offset < size; offset += blockDim.x) 116 | threadVal = r(threadVal, data[offset]); 117 | 118 | return threadVal; 119 | } 120 | 121 | 122 | template 123 | __global__ void cunn_SoftMaxForward(scalar_t *output, scalar_t *input, int64_t* ps) 124 | { 125 | SharedMem smem; 126 | double *buffer = smem.getPointer(); 127 | // forward pointers to batch[blockIdx.x] 128 | // each block handles a sample in the mini-batch 129 | int64_t ofs = (blockIdx.x == 0) ? 0 : ps[blockIdx.x - 1]; 130 | int64_t n_ele = ps[blockIdx.x] - ofs; 131 | input += ofs; 132 | output += ofs; 133 | 134 | // find the max 135 | double threadMax = ilpReduce(input, n_ele, Max(), -DBL_MAX); 136 | double max_k = blockReduce(buffer, threadMax, Max(), -DBL_MAX); 137 | 138 | // reduce all values 139 | double threadExp = ilpReduce(input, n_ele, SumExp(max_k), static_cast(0)); 140 | 141 | double sumAll = blockReduce(buffer, threadExp, Add(), static_cast(0)); 142 | double logsum = max_k + log(sumAll); 143 | 144 | int offset = threadIdx.x; 145 | int last = n_ele % (ILP * blockDim.x); 146 | for (; offset < n_ele - last; offset += blockDim.x * ILP) { 147 | scalar_t tmp[ILP]; 148 | 149 | #pragma unroll 150 | for (int j = 0; j < ILP; ++j) 151 | tmp[j] = input[offset + j * blockDim.x]; 152 | 153 | #pragma unroll 154 | for (int j = 0; j < ILP; ++j) 155 | output[offset + j * blockDim.x] = (double)tmp[j] - logsum; 156 | } 157 | 158 | for (; offset < n_ele; offset += blockDim.x) 159 | output[offset] = (double)input[offset] - logsum; 160 | } 161 | 162 | 163 | template 164 | void HostLogSoftmaxForward(scalar_t* input, scalar_t *output, int64_t* ps, int64_t bsize) 165 | { 166 | dim3 grid(bsize); 167 | dim3 block(1024); 168 | 169 | cunn_SoftMaxForward<2> 170 | <<>>( 171 | output, input, ps 172 | ); 173 | } 174 | 175 | template void HostLogSoftmaxForward(float* input, float* output, int64_t* ps, int64_t bsize); 176 | template void HostLogSoftmaxForward(double* input, double* output, int64_t* ps, int64_t bsize); 177 | 178 | template 179 | __global__ void cunn_SoftMaxBackward(scalar_t *gradInput, scalar_t *output, scalar_t *gradOutput, int64_t* ps) 180 | { 181 | SharedMem smem; 182 | double *buffer = smem.getPointer(); 183 | int64_t ofs = (blockIdx.x == 0) ? 0 : ps[blockIdx.x - 1]; 184 | int64_t n_ele = ps[blockIdx.x] - ofs; 185 | 186 | gradInput += ofs; 187 | output += ofs; 188 | gradOutput += ofs; 189 | 190 | double threadSum = ilpReduce(gradOutput, n_ele, Add(), double(0)); 191 | double sum_k = blockReduce(buffer, threadSum, Add(), double(0)); 192 | 193 | int offset = threadIdx.x; 194 | int last = n_ele % (ILP * blockDim.x); 195 | for (; offset < n_ele - last; offset += blockDim.x * ILP) { 196 | scalar_t tmpGradOutput[ILP]; 197 | scalar_t tmpOutput[ILP]; 198 | 199 | #pragma unroll 200 | for (int j = 0; j < ILP; ++j) { 201 | tmpGradOutput[j] = gradOutput[offset + j * blockDim.x]; 202 | tmpOutput[j] = output[offset + j * blockDim.x]; 203 | } 204 | 205 | #pragma unroll 206 | for (int j = 0; j < ILP; ++j) 207 | gradInput[offset + j * blockDim.x] = tmpGradOutput[j] - exp((double)tmpOutput[j]) * sum_k; 208 | } 209 | 210 | for (; offset < n_ele; offset += blockDim.x) 211 | gradInput[offset] = gradOutput[offset] - exp((double)output[offset]) * sum_k; 212 | } 213 | 214 | 215 | template 216 | void HostLogSoftmaxBackward(scalar_t *gradOutput, scalar_t *gradInput, scalar_t *output, int64_t* ps, int64_t bsize) 217 | { 218 | dim3 grid(bsize); 219 | dim3 block(1024); 220 | 221 | cunn_SoftMaxBackward<2> 222 | <<>>( 223 | gradInput, output, gradOutput, ps 224 | ); 225 | } 226 | 227 | template void HostLogSoftmaxBackward(float *gradOutput, float *gradInput, float *output, int64_t* ps, int64_t bsize); 228 | template void HostLogSoftmaxBackward(double *gradOutput, double *gradInput, double *output, int64_t* ps, int64_t bsize); 229 | -------------------------------------------------------------------------------- /gln/mods/torchext/src/extlib_cuda_kernels.h: -------------------------------------------------------------------------------- 1 | #ifndef EXTLIB_CUDA_KERNELS_H 2 | #define EXTLIB_CUDA_KERNELS_H 3 | 4 | #include 5 | 6 | template 7 | void HostLogSoftmaxForward(scalar_t* input, scalar_t *output, int64_t* ps, int64_t bsize); 8 | 9 | template 10 | void HostLogSoftmaxBackward(scalar_t *gradOutput, scalar_t *gradInput, scalar_t *output, int64_t* ps, int64_t bsize); 11 | 12 | 13 | #endif -------------------------------------------------------------------------------- /gln/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hanjun-Dai/GLN/b5bd7b181a61a8289cc1d1a33825b2c417bed0ef/gln/test/__init__.py -------------------------------------------------------------------------------- /gln/test/main_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import os 7 | import sys 8 | import rdkit 9 | from rdkit import Chem 10 | import random 11 | import csv 12 | from gln.common.cmd_args import cmd_args 13 | from gln.data_process.data_info import DataInfo, load_center_maps 14 | from gln.test.model_inference import RetroGLN 15 | from gln.common.evaluate import get_score, canonicalize 16 | 17 | from tqdm import tqdm 18 | import torch 19 | 20 | from rdkit import rdBase 21 | rdBase.DisableLog('rdApp.error') 22 | rdBase.DisableLog('rdApp.warning') 23 | 24 | import argparse 25 | cmd_opt = argparse.ArgumentParser(description='Argparser for test only') 26 | cmd_opt.add_argument('-model_for_test', default=None, help='model for test') 27 | local_args, _ = cmd_opt.parse_known_args() 28 | 29 | 30 | def load_raw_reacts(name): 31 | print('loading raw', name) 32 | args = cmd_args 33 | csv_file = os.path.join(args.dropbox, args.data_name, 'raw_%s.csv' % name) 34 | reactions = [] 35 | print('loading templates') 36 | with open(csv_file, 'r') as f: 37 | reader = csv.reader(f) 38 | header = next(reader) 39 | for row in tqdm(reader): 40 | reactions.append((row[1], row[2])) 41 | print('num %s:' % name, len(reactions)) 42 | return reactions 43 | 44 | 45 | def rxn_data_gen(phase, model): 46 | list_reactions = load_raw_reacts(phase) 47 | 48 | eval_cnt = 0 49 | for pid in range(cmd_args.num_parts): 50 | fname = os.path.join(cmd_args.dropbox, 'cooked_' + cmd_args.data_name, 'tpl-%s' % cmd_args.tpl_name, 'np-%d' % cmd_args.num_parts, '%s-prod_center_maps-part-%d.csv' % (phase, pid)) 51 | model.prod_center_maps = load_center_maps(fname) 52 | 53 | tot_num = len(list_reactions) 54 | if cmd_args.num_parts > 1: 55 | part_size = tot_num // cmd_args.num_parts + 1 56 | else: 57 | part_size = tot_num 58 | indices = range(pid * part_size, min((pid + 1) * part_size, tot_num)) 59 | 60 | for idx in indices: 61 | rxn_type, rxn = list_reactions[idx] 62 | _, _, raw_prod = rxn.split('>') 63 | eval_cnt += 1 64 | yield rxn_type, rxn, raw_prod 65 | assert eval_cnt == len(list_reactions) 66 | 67 | 68 | def eval_model(phase, model, fname_pred): 69 | case_gen = rxn_data_gen(phase, model) 70 | 71 | cnt = 0 72 | topk_scores = [0.0] * cmd_args.topk 73 | 74 | pbar = tqdm(case_gen) 75 | 76 | fpred = open(fname_pred, 'w') 77 | for rxn_type, rxn, raw_prod in pbar: 78 | pred_struct = model.run(raw_prod, cmd_args.beam_size, cmd_args.topk, rxn_type=rxn_type) 79 | reactants, _, prod = rxn.split('>') 80 | if pred_struct is not None and len(pred_struct['reactants']): 81 | predictions = pred_struct['reactants'] 82 | else: 83 | predictions = [prod] 84 | s = 0.0 85 | reactants = canonicalize(reactants) 86 | for i in range(cmd_args.topk): 87 | if i < len(predictions): 88 | pred = predictions[i] 89 | pred = canonicalize(pred) 90 | predictions[i] = pred 91 | cur_s = (pred == reactants) 92 | else: 93 | cur_s = s 94 | s = max(cur_s, s) 95 | topk_scores[i] += s 96 | cnt += 1 97 | if pred_struct is None or len(pred_struct['reactants']) == 0: 98 | predictions = [] 99 | fpred.write('%s %s %d\n' % (rxn_type, rxn, len(predictions))) 100 | for i in range(len(predictions)): 101 | fpred.write('%s %s\n' % (pred_struct['template'][i], predictions[i])) 102 | msg = 'average score' 103 | for k in range(0, min(cmd_args.topk, 10), 3): 104 | msg += ', t%d: %.4f' % (k + 1, topk_scores[k] / cnt) 105 | pbar.set_description(msg) 106 | fpred.close() 107 | h = '========%s results========' % phase 108 | print(h) 109 | for k in range(cmd_args.topk): 110 | print('top %d: %.4f' % (k + 1, topk_scores[k] / cnt)) 111 | print('=' * len(h)) 112 | 113 | f_summary = '.'.join(fname_pred.split('.')[:-1]) + '.summary' 114 | with open(f_summary, 'w') as f: 115 | f.write('type overall\n') 116 | for k in range(cmd_args.topk): 117 | f.write('top %d: %.4f\n' % (k + 1, topk_scores[k] / cnt)) 118 | 119 | 120 | if __name__ == '__main__': 121 | random.seed(cmd_args.seed) 122 | np.random.seed(cmd_args.seed) 123 | torch.manual_seed(cmd_args.seed) 124 | 125 | if local_args.model_for_test is None: # test all 126 | i = 0 127 | while True: 128 | model_dump = os.path.join(cmd_args.save_dir, 'model-%d.dump' % i) 129 | if not os.path.isdir(model_dump): 130 | break 131 | local_args.model_for_test = model_dump 132 | model = RetroGLN(cmd_args.dropbox, local_args.model_for_test) 133 | print('testing', local_args.model_for_test) 134 | for phase in ['val', 'test']: 135 | fname_pred = os.path.join(cmd_args.save_dir, '%s-%d.pred' % (phase, i)) 136 | eval_model(phase, model, fname_pred) 137 | i += 1 138 | else: 139 | model = RetroGLN(cmd_args.dropbox, local_args.model_for_test) 140 | print('testing', local_args.model_for_test) 141 | fname_pred = os.path.join(cmd_args.save_dir, 'test.pred') 142 | eval_model('test', model, fname_pred) 143 | -------------------------------------------------------------------------------- /gln/test/model_inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | 6 | import rdkit 7 | from rdkit import Chem 8 | import os 9 | import numpy as np 10 | import torch 11 | import pickle as cp 12 | import math 13 | from scipy.special import softmax 14 | from gln.data_process.data_info import DataInfo, load_bin_feats 15 | from gln.mods.mol_gnn.mol_utils import SmartsMols, SmilesMols 16 | from gln.common.reactor import Reactor 17 | from gln.graph_logic.logic_net import GraphPath 18 | 19 | 20 | class RetroGLN(object): 21 | def __init__(self, dropbox, model_dump): 22 | """ 23 | Args: 24 | dropbox: the dropbox folder 25 | model_dump: the ckpt folder, which contains model dump and model args 26 | """ 27 | assert os.path.isdir(model_dump) 28 | 29 | arg_file = os.path.join(model_dump, 'args.pkl') 30 | with open(arg_file, 'rb') as f: 31 | self.args = cp.load(f) 32 | self.args.dropbox = dropbox 33 | 34 | DataInfo.init(dropbox, self.args) 35 | load_bin_feats(dropbox, self.args) 36 | 37 | model_file = os.path.join(model_dump, 'model.dump') 38 | self.gln = GraphPath(self.args) 39 | self.gln.load_state_dict(torch.load(model_file)) 40 | self.gln.cuda() 41 | self.gln.eval() 42 | 43 | self.prod_center_maps = {} 44 | self.cached_smarts = None 45 | 46 | def _ordered_tpls(self, cano_prod, beam_size, rxn_type): 47 | if (rxn_type, cano_prod) not in self.prod_center_maps: 48 | mol = Chem.MolFromSmiles(cano_prod) 49 | if mol is None: 50 | return None 51 | if self.cached_smarts is None: 52 | self.cached_smarts = [] 53 | print('caching smarts centers') 54 | for sm in DataInfo.prod_cano_smarts: 55 | self.cached_smarts.append(Chem.MolFromSmarts(sm)) 56 | 57 | prod_center_cand_idx = [] 58 | for i, sm in enumerate(self.cached_smarts): 59 | if sm is not None and mol.HasSubstructMatch(sm): 60 | prod_center_cand_idx.append(i) 61 | self.prod_center_maps[(rxn_type, cano_prod)] = prod_center_cand_idx 62 | prod_center_cand_idx = self.prod_center_maps[(rxn_type, cano_prod)] 63 | 64 | # infer the reaction center 65 | if not len(prod_center_cand_idx): 66 | return None 67 | prod_center_mols = [SmartsMols.get_mol_graph(DataInfo.prod_cano_smarts[m]) for m in prod_center_cand_idx] 68 | prod_mol = SmilesMols.get_mol_graph(cano_prod) 69 | prod_center_scores = self.gln.prod_center_predicate.inference([prod_mol], [prod_center_mols]) 70 | prod_center_scores = prod_center_scores.view(-1).data.cpu().numpy() 71 | top_centers = np.argsort(-1 * prod_center_scores)[:beam_size] 72 | top_center_scores = [prod_center_scores[i] for i in top_centers] 73 | top_center_mols = [prod_center_mols[i] for i in top_centers] 74 | top_center_smarts = [DataInfo.prod_cano_smarts[prod_center_cand_idx[i]] for i in top_centers] 75 | 76 | # infer the template 77 | list_of_list_tpls = [] 78 | for i, c in enumerate(top_center_smarts): 79 | assert c in DataInfo.unique_tpl_of_prod_center 80 | if not rxn_type in DataInfo.unique_tpl_of_prod_center[c]: 81 | continue 82 | tpl_indices = DataInfo.unique_tpl_of_prod_center[c][rxn_type] 83 | tpls = [DataInfo.unique_templates[t][1] for t in tpl_indices] 84 | list_of_list_tpls.append(tpls) 85 | if not len(list_of_list_tpls): 86 | return None 87 | tpl_scores = self.gln.tpl_fwd_predicate.inference([prod_mol] * len(top_center_mols), list_of_list_tpls) 88 | tpl_scores = tpl_scores.view(-1).data.cpu().numpy() 89 | 90 | idx = 0 91 | tpl_with_scores = [] 92 | for i, c in enumerate(top_center_scores): 93 | for tpl in list_of_list_tpls[i]: 94 | t_score = tpl_scores[idx] 95 | tot_score = c + t_score 96 | tpl_with_scores.append((tot_score, tpl)) 97 | idx += 1 98 | tpl_with_scores = sorted(tpl_with_scores, key=lambda x: -1 * x[0]) 99 | 100 | return tpl_with_scores 101 | 102 | def run(self, raw_prod, beam_size, topk, rxn_type='UNK'): 103 | """ 104 | Args: 105 | raw_prod: the single product smiles 106 | beam_size: the size for beam search 107 | topk: top-k prediction of reactants 108 | rxn_type: (optional) reaction type 109 | Return: 110 | a dictionary with the following keys: 111 | { 112 | 'reactants': the top-k prediction of reactants 113 | 'template': the list of corresponding reaction templates used 114 | 'scores': the scores for the corresponding predictions, in descending order 115 | } 116 | if no valid reactions are found, None will be returned 117 | """ 118 | cano_prod = DataInfo.get_cano_smiles(raw_prod) 119 | prod_mol = SmilesMols.get_mol_graph(cano_prod) 120 | tpl_with_scores = self._ordered_tpls(cano_prod, beam_size, rxn_type) 121 | 122 | if tpl_with_scores is None: 123 | return None 124 | # filter out invalid tpls 125 | list_of_list_reacts = [] 126 | list_reacts = [] 127 | list_tpls = [] 128 | num_tpls = 0 129 | num_reacts = 0 130 | for prod_tpl_score, tpl in tpl_with_scores: 131 | pred_mols = Reactor.run_reaction(raw_prod, tpl) 132 | if pred_mols is not None and len(pred_mols): 133 | num_tpls += 1 134 | list_of_list_reacts.append(pred_mols) 135 | num_reacts += len(pred_mols) 136 | list_tpls.append((prod_tpl_score, tpl)) 137 | if num_tpls >= beam_size: 138 | break 139 | 140 | list_rxns = [] 141 | for i in range(len(list_of_list_reacts)): 142 | list_rxns.append([DataInfo.get_cano_smiles(r) + '>>' + cano_prod for r in list_of_list_reacts[i]]) 143 | if len(list_rxns) and len(list_tpls): 144 | react_scores = self.gln.reaction_predicate.inference([prod_mol] * len(list_tpls), list_rxns) 145 | react_scores = react_scores.view(-1).data.cpu().numpy() 146 | 147 | idx = 0 148 | final_joint = [] 149 | for i, (prod_tpl_score, tpl) in enumerate(list_tpls): 150 | for reacts in list_of_list_reacts[i]: 151 | r_score = react_scores[idx] 152 | tot_score = prod_tpl_score + r_score 153 | final_joint.append((tot_score, tpl, reacts)) 154 | idx += 1 155 | final_joint = sorted(final_joint, key=lambda x: -1 * x[0])[:topk] 156 | scores = [t[0] for t in final_joint] 157 | scores = softmax([each_score for each_score in scores]) 158 | list_reacts = [t[2] for t in final_joint] 159 | ret_tpls = [t[1] for t in final_joint] 160 | result = {'template': ret_tpls, 161 | 'reactants': list_reacts, 162 | 'scores': scores} 163 | else: 164 | result = {'template': [], 165 | 'reactants': [], 166 | 'scores': []} 167 | return result 168 | -------------------------------------------------------------------------------- /gln/test/report_test_stats.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | 6 | import os 7 | import sys 8 | 9 | if __name__ == '__main__': 10 | files = os.listdir(sys.argv[1]) 11 | 12 | best_val = 0.0 13 | best_test = None 14 | for fname in files: 15 | if 'val-' in fname and 'summary' in fname: 16 | f_test = os.path.join(sys.argv[1], 'test' + fname[3:]) 17 | if not os.path.isfile(f_test): 18 | continue 19 | with open(os.path.join(sys.argv[1], fname), 'r') as f: 20 | f.readline() 21 | top1 = float(f.readline().strip().split()[-1].strip()) 22 | if top1 > best_val: 23 | best_val = top1 24 | best_test = f_test 25 | assert best_test is not None 26 | with open(best_test, 'r') as f: 27 | for row in f: 28 | print(row.strip()) 29 | print(best_test) 30 | -------------------------------------------------------------------------------- /gln/test/test_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dropbox=../../dropbox 4 | data_name=$1 5 | tpl_name=default 6 | 7 | save_dir=$2 8 | 9 | export CUDA_VISIBLE_DEVICES=0 10 | 11 | python main_test.py \ 12 | -dropbox $dropbox \ 13 | -data_name $data_name \ 14 | -save_dir $save_dir \ 15 | -tpl_name $tpl_name \ 16 | -f_atoms $dropbox/cooked_$data_name/atom_list.txt \ 17 | -topk 50 \ 18 | -beam_size 50 \ 19 | -gpu 0 \ 20 | 21 | -------------------------------------------------------------------------------- /gln/test/test_single.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dropbox=../../dropbox 4 | data_name=$1 5 | tpl_name=default 6 | 7 | save_dir=$data_name-results 8 | 9 | if [ ! -e $save_dir ]; 10 | then 11 | mkdir -p $save_dir 12 | fi 13 | 14 | export CUDA_VISIBLE_DEVICES=0 15 | 16 | python main_test.py \ 17 | -dropbox $dropbox \ 18 | -data_name $data_name \ 19 | -save_dir $save_dir \ 20 | -model_for_test $dropbox/$data_name.ckpt \ 21 | -tpl_name $tpl_name \ 22 | -f_atoms $dropbox/cooked_$data_name/atom_list.txt \ 23 | -topk 50 \ 24 | -beam_size 50 \ 25 | -gpu 0 \ 26 | 27 | -------------------------------------------------------------------------------- /gln/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hanjun-Dai/GLN/b5bd7b181a61a8289cc1d1a33825b2c417bed0ef/gln/training/__init__.py -------------------------------------------------------------------------------- /gln/training/data_gen.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import random 7 | import os 8 | import rdkit 9 | from rdkit import Chem 10 | import csv 11 | from gln.mods.mol_gnn.mol_utils import SmartsMols, SmilesMols 12 | from multiprocessing import Process, Queue 13 | import time 14 | 15 | from gln.data_process.data_info import DataInfo, load_train_reactions 16 | from gln.common.reactor import Reactor 17 | 18 | 19 | class DataSample(object): 20 | def __init__(self, prod, center, template, label=None, neg_centers=None, neg_tpls=None, 21 | reaction=None, neg_reactions=None): 22 | self.prod = prod 23 | self.center = center 24 | self.template = template 25 | self.label = label 26 | self.neg_centers = neg_centers 27 | self.neg_tpls = neg_tpls 28 | self.reaction = reaction 29 | self.neg_reactions = neg_reactions 30 | 31 | 32 | def _rand_sample_except(candidates, exclude, k=None): 33 | assert len(candidates) 34 | if k is None: 35 | if len(candidates) == 1: 36 | assert exclude is None or candidates[0] == exclude 37 | return candidates[0] 38 | else: 39 | while True: 40 | c = np.random.choice(candidates) 41 | if exclude is None or c != exclude: 42 | break 43 | return c 44 | else: 45 | if k <= 0 or len(candidates) <= k: 46 | return [c for c in candidates if exclude is None or c != exclude] 47 | cand_indices = np.random.permutation(len(candidates))[:k] 48 | selected = [] 49 | for i in cand_indices: 50 | c = candidates[i] 51 | if exclude is None or c != exclude: 52 | selected.append(c) 53 | if k <= 0: 54 | continue 55 | if len(selected) >= k: 56 | break 57 | return selected 58 | 59 | 60 | def worker_softmax(worker_id, seed, args): 61 | np.random.seed(seed) 62 | random.seed(seed) 63 | num_epochs = 0 64 | part_id = 0 65 | train_reactions = load_train_reactions(args) 66 | while True: 67 | if num_epochs % args.epochs_per_part == 0: 68 | DataInfo.load_cooked_part('train', part_id) 69 | tot_num = len(train_reactions) 70 | part_size = tot_num // args.num_parts + 1 71 | indices = range(part_id * part_size, min((part_id + 1) * part_size, tot_num)) 72 | indices = list(indices) 73 | part_id = (part_id + 1) % args.num_parts 74 | random.shuffle(indices) 75 | for sample_idx in indices: 76 | rxn_type, rxn_smiles = train_reactions[sample_idx] 77 | 78 | if sample_idx in DataInfo.train_pos_maps: 79 | pos_tpls, weights = DataInfo.train_pos_maps[sample_idx] 80 | pos_tpl_idx = pos_tpls[np.argmax(np.random.multinomial(1, weights))] 81 | rxn_type, rxn_template = DataInfo.unique_templates[pos_tpl_idx] 82 | else: 83 | continue 84 | 85 | reactants, _, prod = rxn_smiles.split('>') 86 | cano_prod = DataInfo.smiles_cano_map[prod] 87 | sm_prod, _, _ = rxn_template.split('>') 88 | cano_sm_prod = DataInfo.smarts_cano_map[sm_prod] 89 | 90 | # negative samples of prod centers 91 | assert (rxn_type, cano_prod) in DataInfo.prod_center_maps 92 | prod_center_cand_idx = DataInfo.prod_center_maps[(rxn_type, cano_prod)] 93 | 94 | neg_center_idxes = _rand_sample_except(prod_center_cand_idx, DataInfo.prod_smarts_idx[cano_sm_prod], args.neg_num) 95 | neg_centers = [DataInfo.prod_cano_smarts[c] for c in neg_center_idxes] 96 | 97 | # negative samples of templates 98 | assert cano_sm_prod in DataInfo.unique_tpl_of_prod_center 99 | assert rxn_type in DataInfo.unique_tpl_of_prod_center[cano_sm_prod] 100 | neg_tpl_idxes = _rand_sample_except(DataInfo.unique_tpl_of_prod_center[cano_sm_prod][rxn_type], pos_tpl_idx, args.neg_num) 101 | tpl_cand_idx = [] 102 | for c in neg_centers: 103 | tpl_cand_idx += DataInfo.unique_tpl_of_prod_center[c][rxn_type] 104 | if len(tpl_cand_idx): 105 | neg_tpl_idxes += _rand_sample_except(tpl_cand_idx, pos_tpl_idx, args.neg_num) 106 | neg_tpls = [DataInfo.unique_templates[i][1] for i in neg_tpl_idxes] 107 | 108 | sample = DataSample(prod=cano_prod, center=cano_sm_prod, template=rxn_template, 109 | neg_centers=neg_centers, neg_tpls=neg_tpls) 110 | 111 | if args.retro_during_train: 112 | sample.reaction = DataInfo.get_cano_smiles(reactants) + '>>' + cano_prod 113 | sample.neg_reactions = [] 114 | if len(DataInfo.neg_reactions_all[sample_idx]): 115 | neg_reacts = DataInfo.neg_reactions_all[sample_idx] 116 | if len(neg_reacts): 117 | neg_reactants = _rand_sample_except(neg_reacts, None, args.neg_num) 118 | sample.neg_reactions = [DataInfo.neg_reacts_list[r] + '>>' + cano_prod for r in neg_reactants] 119 | if len(sample.neg_tpls) or len(sample.neg_reactions): 120 | yield (worker_id, sample) 121 | num_epochs += 1 122 | 123 | 124 | def worker_process(worker_func, worker_id, seed, data_q, *args): 125 | worker_gen = worker_func(worker_id, seed, *args) 126 | for t in worker_gen: 127 | data_q.put(t) 128 | 129 | 130 | def data_gen(num_workers, worker_func, worker_args, max_qsize=16384, max_gen=-1, timeout=60): 131 | cnt = 0 132 | data_q = Queue(max_qsize) 133 | 134 | if num_workers == 0: # single process generator 135 | worker_gen = worker_func(-1, np.random.randint(10000), *worker_args) 136 | while True: 137 | worker_id, data_sample = next(worker_gen) 138 | yield data_sample 139 | cnt += 1 140 | if max_gen > 0 and cnt >= max_gen: 141 | break 142 | return 143 | 144 | worker_procs = [Process(target=worker_process, args=[worker_func, i, np.random.randint(10000), data_q] + worker_args) for i in range(num_workers)] 145 | for p in worker_procs: 146 | p.start() 147 | last_update = [time.time()] * num_workers 148 | while True: 149 | if data_q.empty(): 150 | time.sleep(0.1) 151 | if not data_q.full(): 152 | for i in range(num_workers): 153 | if time.time() - last_update[i] > timeout: 154 | print('worker', i, 'is dead') 155 | worker_procs[i].terminate() 156 | while worker_procs[i].is_alive(): # busy waiting for the stop of the process 157 | time.sleep(0.01) 158 | worker_procs[i] = Process(target=worker_process, args=[worker_func, i, np.random.randint(10000), data_q] + worker_args) 159 | print('worker', i, 'restarts') 160 | worker_procs[i].start() 161 | last_update[i] = time.time() 162 | try: 163 | sample = data_q.get_nowait() 164 | except: 165 | continue 166 | cnt += 1 167 | worker_id, data_sample = sample 168 | last_update[worker_id] = time.time() 169 | yield data_sample 170 | if max_gen > 0 and cnt >= max_gen: 171 | break 172 | 173 | print('stopping') 174 | for p in worker_procs: 175 | p.terminate() 176 | for p in worker_procs: 177 | p.join() 178 | -------------------------------------------------------------------------------- /gln/training/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import os 7 | import sys 8 | import rdkit 9 | from rdkit import Chem 10 | import random 11 | import pickle as cp 12 | import csv 13 | from gln.common.cmd_args import cmd_args 14 | from gln.common.consts import t_float, DEVICE 15 | from gln.data_process.data_info import load_bin_feats, DataInfo 16 | from gln.graph_logic.logic_net import GraphPath 17 | from gln.training.data_gen import data_gen, worker_softmax 18 | from tqdm import tqdm 19 | import torch 20 | import torch.optim as optim 21 | 22 | from gln.common.reactor import Reactor 23 | from gln.common.evaluate import get_score, canonicalize 24 | from rdkit import rdBase 25 | rdBase.DisableLog('rdApp.error') 26 | rdBase.DisableLog('rdApp.warning') 27 | 28 | 29 | def main_train(): 30 | data_root = os.path.join(cmd_args.dropbox, cmd_args.data_name) 31 | train_sample_gen = data_gen(cmd_args.num_data_proc, worker_softmax, [cmd_args], max_gen=-1) 32 | 33 | if cmd_args.init_model_dump is not None: 34 | graph_path.load_state_dict(torch.load(cmd_args.init_model_dump)) 35 | 36 | optimizer = optim.Adam(graph_path.parameters(), lr=cmd_args.learning_rate) 37 | 38 | for epoch in range(cmd_args.num_epochs): 39 | 40 | pbar = tqdm(range(1, 1 + cmd_args.iters_per_val)) 41 | 42 | for it in pbar: 43 | samples = [next(train_sample_gen) for _ in range(cmd_args.batch_size)] 44 | optimizer.zero_grad() 45 | loss = graph_path(samples) 46 | loss.backward() 47 | 48 | if cmd_args.grad_clip > 0: 49 | torch.nn.utils.clip_grad_norm_(graph_path.parameters(), max_norm=cmd_args.grad_clip) 50 | 51 | optimizer.step() 52 | pbar.set_description('epoch %.2f, loss %.4f' % (epoch + it / cmd_args.iters_per_val, loss.item())) 53 | 54 | if epoch % cmd_args.epochs2save == 0: 55 | out_folder = os.path.join(cmd_args.save_dir, 'model-%d.dump' % epoch) 56 | if not os.path.isdir(out_folder): 57 | os.makedirs(out_folder) 58 | torch.save(graph_path.state_dict(), os.path.join(out_folder, 'model.dump')) 59 | with open(os.path.join(out_folder, 'args.pkl'), 'wb') as f: 60 | cp.dump(cmd_args, f, cp.HIGHEST_PROTOCOL) 61 | 62 | 63 | if __name__ == '__main__': 64 | random.seed(cmd_args.seed) 65 | np.random.seed(cmd_args.seed) 66 | torch.manual_seed(cmd_args.seed) 67 | 68 | DataInfo.init(cmd_args.dropbox, cmd_args) 69 | load_bin_feats(cmd_args.dropbox, cmd_args) 70 | graph_path = GraphPath(cmd_args).to(DEVICE) 71 | main_train() 72 | -------------------------------------------------------------------------------- /gln/training/scripts/run_mf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dropbox=../../../dropbox 4 | data_name=$1 5 | tpl_name=default 6 | gm=mean_field 7 | act=relu 8 | msg_dim=128 9 | embed_dim=256 10 | neg_size=64 11 | lv=3 12 | tpl_enc=deepset 13 | subg_enc=mean_field 14 | graph_agg=max 15 | retro=True 16 | bn=True 17 | gen=weighted 18 | gnn_out=last 19 | neg_sample=all 20 | att_type=bilinear 21 | 22 | save_dir=$HOME/scratch/results/gln/$data_name/tpl-$tpl_name/${gm}-${act}-lv-${lv}-l-${msg_dim}-e-${embed_dim}-gagg-${graph_agg}-retro-${retro}-gen-${gen}-ng-${neg_size}-bn-${bn}-te-${tpl_enc}-se-${subg_enc}-go-${gnn_out}-ns-${neg_sample}-att-${att_type} 23 | 24 | if [ ! -e $save_dir ]; 25 | then 26 | mkdir -p $save_dir 27 | fi 28 | 29 | export CUDA_VISIBLE_DEVICES=0 30 | export OMP_NUM_THREADS=2 31 | 32 | python ../main.py \ 33 | -gm $gm \ 34 | -fp_degree 2 \ 35 | -neg_sample $neg_sample \ 36 | -att_type $att_type \ 37 | -gnn_out $gnn_out \ 38 | -tpl_enc $tpl_enc \ 39 | -subg_enc $subg_enc \ 40 | -latent_dim $msg_dim \ 41 | -bn $bn \ 42 | -gen_method $gen \ 43 | -retro_during_train $retro \ 44 | -neg_num $neg_size \ 45 | -embed_dim $embed_dim \ 46 | -readout_agg_type $graph_agg \ 47 | -act_func $act \ 48 | -act_last True \ 49 | -max_lv $lv \ 50 | -dropbox $dropbox \ 51 | -data_name $data_name \ 52 | -save_dir $save_dir \ 53 | -tpl_name $tpl_name \ 54 | -f_atoms $dropbox/cooked_$data_name/atom_list.txt \ 55 | -iters_per_val 3000 \ 56 | -gpu 0 \ 57 | -topk 50 \ 58 | -beam_size 50 \ 59 | -num_parts 1 \ 60 | 61 | -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- 1 | ignored-modules=numpy,numpy.linalg,numpy.random,scipy,cv2,tensorflow,mpi4py.MPI,torch,pandas 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CppExtension, BuildExtension, CUDAExtension 3 | 4 | from distutils.command.build import build 5 | from setuptools.command.install import install 6 | 7 | from setuptools.command.develop import develop 8 | 9 | import os 10 | import subprocess 11 | import platform 12 | BASEPATH = os.path.dirname(os.path.abspath(__file__)) 13 | 14 | compile_args = [] 15 | link_args = [] 16 | 17 | if platform.system() != 'Darwin': # add openmp 18 | compile_args.append('-fopenmp') 19 | link_args.append('-lgomp') 20 | 21 | ext_modules=[CppExtension('extlib', 22 | ['gln/mods/torchext/src/extlib.cpp'], 23 | extra_compile_args=compile_args, 24 | extra_link_args=link_args)] 25 | 26 | # build cuda lib 27 | import torch 28 | if torch.cuda.is_available(): 29 | ext_modules.append(CUDAExtension('extlib_cuda', 30 | ['gln/mods/torchext/src/extlib_cuda.cpp', 'gln/mods/torchext/src/extlib_cuda_kernels.cu'])) 31 | 32 | class custom_develop(develop): 33 | def run(self): 34 | original_cwd = os.getcwd() 35 | 36 | folders = [ 37 | os.path.join(BASEPATH, 'gln/mods/mol_gnn/mg_clib'), 38 | ] 39 | for folder in folders: 40 | os.chdir(folder) 41 | subprocess.check_call(['make']) 42 | 43 | os.chdir(original_cwd) 44 | 45 | super().run() 46 | 47 | setup(name='gln', 48 | py_modules=['gln'], 49 | ext_modules=ext_modules, 50 | install_requires=[ 51 | 'torch', 52 | ], 53 | cmdclass={ 54 | 'develop': custom_develop, 55 | 'build_ext': BuildExtension, 56 | } 57 | ) 58 | --------------------------------------------------------------------------------