├── .gitignore ├── LICENSE.md ├── README.md ├── assets └── graphretro.png ├── configs ├── lg_ind │ └── defaults.yaml └── single_edit │ └── defaults.yaml ├── data_process ├── canonicalize_prod.py ├── core_edits │ ├── bond_edits.py │ └── bond_edits_seq.py ├── lg_edits │ ├── lg_classifier.py │ └── lg_tensors.py └── parse_info.py ├── datasets └── uspto-50k │ ├── canonicalized_eval.csv │ ├── canonicalized_test.csv │ ├── canonicalized_train.csv │ ├── eval.csv │ ├── test.csv │ └── train.csv ├── environment.yml ├── eval.sh ├── models ├── LGIndEmbedClassifier_18-04-2021--11-59-29 │ └── checkpoints │ │ └── step_110701.pt ├── LGIndEmbed_18-02-2021--12-23-26 │ └── checkpoints │ │ └── step_101951.pt ├── SingleEdit_10-02-2021--08-44-37 │ └── checkpoints │ │ └── epoch_156.pt └── SingleEdit_14-02-2021--19-26-20 │ └── checkpoints │ └── step_144228.pt ├── scripts ├── benchmarks │ └── run_model.py └── eval │ ├── edit_models.py │ ├── lg_models.py │ └── single_edit_lg.py ├── seq_graph_retro ├── __init__.py ├── __pycache__ │ └── __init__.cpython-37.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── collate_fns.cpython-37.pyc │ │ ├── dataset.cpython-37.pyc │ │ ├── edits_datasets.cpython-37.pyc │ │ ├── lg_datasets.cpython-37.pyc │ │ ├── pretrain_datasets.cpython-37.pyc │ │ └── shared_retro_datasets.cpython-37.pyc │ ├── collate_fns.py │ ├── dataset.py │ ├── edits_datasets.py │ ├── lg_datasets.py │ ├── pretrain_datasets.py │ └── shared_retro_datasets.py ├── layers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── encoder.cpython-37.pyc │ │ ├── graph_transformer.cpython-37.pyc │ │ ├── reaction.cpython-37.pyc │ │ └── rnn.cpython-37.pyc │ ├── encoder.py │ ├── graph_transformer.py │ ├── reaction.py │ └── rnn.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── focal_loss.cpython-37.pyc │ │ ├── model_builder.cpython-37.pyc │ │ └── trainer.cpython-37.pyc │ ├── core_edits │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── multi_edit.cpython-37.pyc │ │ │ └── single_edit.cpython-37.pyc │ │ ├── multi_edit.py │ │ └── single_edit.py │ ├── lg_edits │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── lg_ind_embed.cpython-37.pyc │ │ │ └── lg_shared_embed.cpython-37.pyc │ │ ├── lg_ind_embed.py │ │ └── lg_shared_embed.py │ ├── model_builder.py │ ├── retro │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── separate_edits_lg.cpython-37.pyc │ │ │ └── shared_edits_lg.cpython-37.pyc │ │ ├── separate_edits_lg.py │ │ └── shared_edits_lg.py │ └── trainer.py ├── molgraph │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── mol_features.cpython-37.pyc │ │ ├── rxn_graphs.cpython-37.pyc │ │ └── vocab.cpython-37.pyc │ ├── mol_features.py │ ├── rxn_graphs.py │ └── vocab.py ├── search │ ├── __init__.py │ └── __pycache__ │ │ └── __init__.cpython-37.pyc └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── chem.cpython-37.pyc │ ├── edit_mol.cpython-37.pyc │ ├── metrics.cpython-37.pyc │ ├── parse.cpython-37.pyc │ └── torch.cpython-37.pyc │ ├── chem.py │ ├── edit_mol.py │ ├── metrics.py │ ├── parse.py │ └── torch.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | datasets/uspto-* 4 | datasets/raw/* 5 | datasets/*.ipynb* 6 | 7 | tests/ 8 | images/ 9 | notebooks/ 10 | misc/ 11 | scripts/dummy/* 12 | 13 | *.egg-info* 14 | *~ 15 | 16 | local_experiments/ 17 | experiments/ 18 | test_* 19 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Vignesh Ram Somnath 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Graph Models for Retrosynthesis Prediction 2 | 3 | (Under Construction and Subject to Change) 4 | 5 | This is the official [PyTorch](https://pytorch.org/) implementation for _GraphRetro_ ([Somnath et al. 2021](https://openreview.net/pdf?id=SnONpXZ_uQ_)), a graph based model for one-step retrosynthesis prediction. Our model achieves the transformation from products to reactants using a two stage decomposition: 6 | 7 | ![graph-retro-overview](./assets/graphretro.png) 8 | 9 | a) __Edit Prediction__: Identifies edits given a product molecule, which upon application give intermediate molecules called _synthons_\ 10 | b) __Synthon Completion__: Completes _synthons_ into reactants by adding subgraphs called _leaving groups_ from a precomputed vocabulary. 11 | 12 | ## Setup 13 | 14 | This assumes conda is installed on your system \ 15 | If conda is not installed, download the [Miniconda installer](https://docs.conda.io/en/latest/miniconda.html#). 16 | If conda is installed, run the following commands: 17 | 18 | ``` 19 | echo 'export SEQ_GRAPH_RETRO=/path/to/dir/' >> ~/.bashrc 20 | source ~/.bashrc 21 | 22 | conda env create -f environment.yml 23 | source activate seq_gr 24 | python setup.py develop(or install) 25 | ``` 26 | 27 | ## Datasets 28 | The original and canonicalized files are provided under `datasets/uspto-50k/`. Please make sure to move them to `$SEQ_GRAPH_RETRO/` before use. 29 | 30 | ## Input Preparation 31 | 32 | Before preparing inputs, we canonicalize the products. This can be done by running, 33 | 34 | ``` 35 | python data_process/canonicalize_prod.py --filename train.csv 36 | python data_process/canonicalize_prod.py --filename eval.csv 37 | python data_process/canonicalize_prod.py --filename test.csv 38 | ``` 39 | This step can also be skipped if the canonicalized files are already present. 40 | The preprocessing steps now directly work with the canonicalized files. 41 | 42 | #### 1. Reaction Info preparation 43 | ``` 44 | python data_process/parse_info.py --mode train 45 | python data_process/parse_info.py --mode eval 46 | python data_process/parse_info.py --mode test 47 | ``` 48 | 49 | #### 2. Prepare batches for Edit Prediction 50 | ``` 51 | python data_process/core_edits/bond_edits.py 52 | ``` 53 | 54 | #### 3. Prepare batches for Synthon Completion 55 | ``` 56 | python data_process/lg_edits/lg_classifier.py 57 | python data_process/lg_edits/lg_tensors.py 58 | ``` 59 | 60 | ## Run a Model 61 | Trained models are stored in `experiments/`. You can override this by adjusting `--exp_dir` before training. 62 | Model configurations are stored in `config/MODEL_NAME` 63 | where `MODEL_NAME` is one of `{single_edit, lg_ind}`. 64 | 65 | To run a model, 66 | ``` 67 | python scripts/benchmarks/run_model.py --config_file configs/MODEL_NAME/defaults.yaml 68 | ``` 69 | NOTE: We recently updated the code to use wandb for experiment tracking. You would need to setup [wandb](https://docs.wandb.ai/quickstart) before being able to train a model. 70 | 71 | ## Evaluate using a Trained Model 72 | 73 | To evaluate the trained model, run 74 | ``` 75 | python scripts/eval/single_edit_lg.py --edits_exp EDITS_EXP --edits_step EDITS_STEP \ 76 | --lg_exp LG_EXP --lg_step LG_STEP 77 | ``` 78 | This will setup a model with the edit prediction module loaded from experiment `EDITS_EXP` and checkpoint `EDITS_STEP` \ 79 | and the synthon completion module loaded from experiment `LG_EXP` and checkpoint `LG_STEP`. 80 | 81 | ## Reproducing our results 82 | To reproduce our results, please run the command, 83 | ``` 84 | ./eval.sh 85 | ``` 86 | This will display the results for reaction class unknown and known setting. 87 | 88 | ## License 89 | This project is licensed under the MIT-License. Please see [LICENSE.md](https://github.com/vsomnath/graphretro/blob/main/LICENSE.md) for more details. 90 | 91 | ## Reference 92 | If you find our code useful for your work, please cite our paper: 93 | ``` 94 | @inproceedings{ 95 | somnath2021learning, 96 | title={Learning Graph Models for Retrosynthesis Prediction}, 97 | author={Vignesh Ram Somnath and Charlotte Bunne and Connor W. Coley and Andreas Krause and Regina Barzilay}, 98 | booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, 99 | year={2021}, 100 | url={https://openreview.net/forum?id=SnONpXZ_uQ_} 101 | } 102 | ``` 103 | 104 | ## Contact 105 | If you have any questions about the code, or want to report a bug, please raise a GitHub issue. 106 | 107 | -------------------------------------------------------------------------------- /assets/graphretro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/assets/graphretro.png -------------------------------------------------------------------------------- /configs/lg_ind/defaults.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | value: lg_ind 3 | mpnn: 4 | value: graph_feat 5 | rnn_type: 6 | value: gru 7 | mpn_size: 8 | value: 300 9 | depth: 10 | value: 10 11 | dropout_mpn: 12 | value: 0.15 13 | embed_size: 14 | value: 200 15 | mlp_size: 16 | value: 300 17 | dropout_mlp: 18 | value: 0.3 19 | clip_norm: 20 | value: 10.0 21 | scheduler_type: 22 | value: plateau 23 | metric_thresh: 24 | value: 0.01 25 | patience: 26 | value: 5 27 | anneal_rate: 28 | value: 0.9 29 | lr: 30 | value: 0.001 31 | use_rxn_class: 32 | value: False 33 | use_h_labels: 34 | value: True 35 | use_grad_noise: 36 | value: True 37 | use_prev_pred: 38 | value: True 39 | print_every: 40 | value: 200 41 | eval_every: 42 | value: 700 43 | epochs: 44 | value: 100 45 | -------------------------------------------------------------------------------- /configs/single_edit/defaults.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | value: single_edit 3 | mpnn: 4 | value: graph_feat 5 | mpn_size: 6 | value: 256 7 | depth: 8 | value: 10 9 | dropout_mpn: 10 | value: 0.15 11 | mlp_size: 12 | value: 512 13 | dropout_mlp: 14 | value: 0.3 15 | rnn_type: 16 | value: gru 17 | clip_norm: 18 | value: 10.0 19 | loss_type: 20 | value: softmax 21 | edits_type: 22 | value: bond_edits 23 | lr: 24 | value: 0.001 25 | pos_weight: 26 | value: 5.0 27 | propagate_logits: 28 | value: true 29 | print_every: 30 | value: 200 31 | eval_every: 32 | value: 600 33 | scheduler_type: 34 | value: plateau 35 | anneal_rate: 36 | value: 0.9 37 | metric_thresh: 38 | value: 0.01 39 | patience: 40 | value: 10 41 | epochs: 42 | value: 200 43 | use_prod_edits: 44 | value: true 45 | -------------------------------------------------------------------------------- /data_process/canonicalize_prod.py: -------------------------------------------------------------------------------- 1 | """ 2 | Canonicalize the product SMILES, and then use substructure matching to infer 3 | the correspondence to the original atom-mapped order. This correspondence is then 4 | used to renumber the reactant atoms. 5 | """ 6 | 7 | from rdkit import Chem 8 | import os 9 | import argparse 10 | import pandas as pd 11 | 12 | DATA_DIR = f"{os.environ['SEQ_GRAPH_RETRO']}/datasets/uspto-50k/" 13 | 14 | def canonicalize_prod(p): 15 | import copy 16 | p = copy.deepcopy(p) 17 | p = canonicalize(p) 18 | p_mol = Chem.MolFromSmiles(p) 19 | for atom in p_mol.GetAtoms(): 20 | atom.SetAtomMapNum(atom.GetIdx() + 1) 21 | p = Chem.MolToSmiles(p_mol) 22 | return p 23 | 24 | def remove_amap_not_in_product(rxn_smi): 25 | """ 26 | Corrects the atom map numbers of atoms only in reactants. 27 | This correction helps avoid the issue of duplicate atom mapping 28 | after the canonicalization step. 29 | """ 30 | r, p = rxn_smi.split(">>") 31 | 32 | pmol = Chem.MolFromSmiles(p) 33 | pmol_amaps = set([atom.GetAtomMapNum() for atom in pmol.GetAtoms()]) 34 | max_amap = max(pmol_amaps) #Atoms only in reactants are labelled starting with max_amap 35 | 36 | rmol = Chem.MolFromSmiles(r) 37 | 38 | for atom in rmol.GetAtoms(): 39 | amap_num = atom.GetAtomMapNum() 40 | if amap_num not in pmol_amaps: 41 | atom.SetAtomMapNum(max_amap+1) 42 | max_amap += 1 43 | 44 | r_updated = Chem.MolToSmiles(rmol) 45 | rxn_smi_updated = r_updated + ">>" + p 46 | return rxn_smi_updated 47 | 48 | def canonicalize(smiles): 49 | try: 50 | tmp = Chem.MolFromSmiles(smiles) 51 | except: 52 | print('no mol', flush=True) 53 | return smiles 54 | if tmp is None: 55 | return smiles 56 | tmp = Chem.RemoveHs(tmp) 57 | [a.ClearProp('molAtomMapNumber') for a in tmp.GetAtoms()] 58 | return Chem.MolToSmiles(tmp) 59 | 60 | def infer_correspondence(p): 61 | orig_mol = Chem.MolFromSmiles(p) 62 | canon_mol = Chem.MolFromSmiles(canonicalize_prod(p)) 63 | matches = list(canon_mol.GetSubstructMatches(orig_mol))[0] 64 | idx_amap = {atom.GetIdx(): atom.GetAtomMapNum() for atom in orig_mol.GetAtoms()} 65 | 66 | correspondence = {} 67 | for idx, match_idx in enumerate(matches): 68 | match_anum = canon_mol.GetAtomWithIdx(match_idx).GetAtomMapNum() 69 | old_anum = idx_amap[idx] 70 | correspondence[old_anum] = match_anum 71 | return correspondence 72 | 73 | def remap_rxn_smi(rxn_smi): 74 | r, p = rxn_smi.split(">>") 75 | canon_mol = Chem.MolFromSmiles(canonicalize_prod(p)) 76 | correspondence = infer_correspondence(p) 77 | 78 | rmol = Chem.MolFromSmiles(r) 79 | for atom in rmol.GetAtoms(): 80 | atomnum = atom.GetAtomMapNum() 81 | if atomnum in correspondence: 82 | newatomnum = correspondence[atomnum] 83 | atom.SetAtomMapNum(newatomnum) 84 | 85 | rmol = Chem.MolFromSmiles(Chem.MolToSmiles(rmol)) 86 | rxn_smi_new = Chem.MolToSmiles(rmol) + ">>" + Chem.MolToSmiles(canon_mol) 87 | return rxn_smi_new, correspondence 88 | 89 | 90 | def main(): 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("--data_dir", default=DATA_DIR, help="Directory where data is located.") 93 | parser.add_argument("--filename", required=True, help="File with reactions to canonicalize") 94 | args = parser.parse_args() 95 | 96 | new_file = f"canonicalized_{args.filename}" 97 | df = pd.read_csv(f"{args.data_dir}/{args.filename}") 98 | print(f"Processing file of size: {len(df)}") 99 | 100 | new_dict = {'id': [], 'class': [], 'reactants>reagents>production': []} 101 | for idx in range(len(df)): 102 | element = df.loc[idx] 103 | uspto_id, class_id, rxn_smi = element['id'], element['class'], element['reactants>reagents>production'] 104 | 105 | rxn_smi_new = remove_amap_not_in_product(rxn_smi) 106 | rxn_smi_new, _ = remap_rxn_smi(rxn_smi_new) 107 | new_dict['id'].append(uspto_id) 108 | new_dict['class'].append(class_id) 109 | new_dict['reactants>reagents>production'].append(rxn_smi_new) 110 | 111 | new_df = pd.DataFrame.from_dict(new_dict) 112 | new_df.to_csv(f"{args.data_dir}/{new_file}", index=False) 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /data_process/core_edits/bond_edits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rdkit import Chem 3 | import argparse 4 | import joblib 5 | import os 6 | import sys 7 | import copy 8 | from typing import Any 9 | 10 | from seq_graph_retro.utils.parse import extract_leaving_groups 11 | from seq_graph_retro.utils.chem import apply_edits_to_mol, get_mol 12 | from seq_graph_retro.molgraph import BondEditsRxn, RxnElement, MultiElement 13 | from seq_graph_retro.molgraph.vocab import Vocab 14 | from seq_graph_retro.data.collate_fns import (pack_graph_feats, prepare_lg_labels, 15 | tensorize_bond_graphs) 16 | from seq_graph_retro.utils import str2bool 17 | 18 | DATA_DIR = "./datasets/uspto-50k" 19 | INFO_FILE = "uspto_50k.info.kekulized" 20 | NUM_SHARDS = 5 21 | 22 | def process_batch(edit_graphs, mol_list, args): 23 | assert len(edit_graphs) == len(mol_list) 24 | 25 | if args.mode == 'dummy': 26 | lg_vocab_file = os.path.join(args.data_dir, 'dummy') 27 | else: 28 | lg_vocab_file = os.path.join(args.data_dir, 'train') 29 | 30 | if args.use_h_labels: 31 | lg_vocab_file += "/h_labels/lg_vocab.txt" 32 | else: 33 | lg_vocab_file += "/without_h_labels/lg_vocab.txt" 34 | 35 | lg_vocab = Vocab(joblib.load(lg_vocab_file)) 36 | _, lg_groups, _ = extract_leaving_groups(mol_list) 37 | 38 | mol_attrs = ['prod_mol', 'frag_mol'] 39 | if args.use_h_labels: 40 | label_attrs = ['edit_label', 'h_label'] 41 | else: 42 | label_attrs = ['edit_label', 'done_label'] 43 | 44 | attributes = [graph.get_attributes(mol_attrs=mol_attrs, label_attrs=label_attrs) for graph in edit_graphs] 45 | prod_batch, frag_batch, edit_labels = list(zip(*attributes)) 46 | 47 | if len(edit_labels[0]) == 1: 48 | edit_labels = torch.tensor(edit_labels, dtype=torch.long) 49 | else: 50 | edit_labels = [torch.tensor(edit_labels[i], dtype=torch.float) for i in range(len(edit_labels))] 51 | 52 | if args.mpnn == 'graph_feat': 53 | directed = True 54 | elif args.mpnn == 'wln': 55 | directed = False 56 | 57 | prod_inputs = pack_graph_feats(prod_batch, directed=directed, use_rxn_class=args.use_rxn_class) 58 | frag_inputs = pack_graph_feats(frag_batch, directed=directed, use_rxn_class=args.use_rxn_class) 59 | lg_labels, lengths = prepare_lg_labels(lg_vocab, lg_groups) 60 | 61 | if args.parse_bond_graph: 62 | bond_graph_inputs = tensorize_bond_graphs(prod_batch, directed=directed, use_rxn_class=args.use_rxn_class) 63 | return prod_inputs, edit_labels, frag_inputs, lg_labels, lengths, bond_graph_inputs 64 | return prod_inputs, edit_labels, frag_inputs, lg_labels, lengths, None 65 | 66 | def parse_bond_edits_forward(args: Any, mode: str = 'train') -> None: 67 | """Parse reactions. 68 | 69 | Parameters 70 | ---------- 71 | rxns: List 72 | List of reaction SMILES 73 | args: Namespace object 74 | Args supplied via command line 75 | mode: str, default train 76 | Type of dataset being parsed. 77 | """ 78 | if args.use_h_labels: 79 | base_file = os.path.join(args.data_dir, f"{mode}", "h_labels", args.info_file) 80 | else: 81 | base_file = os.path.join(args.data_dir, f"{mode}", "without_h_labels", args.info_file) 82 | 83 | info_all = [] 84 | for shard_num in range(5): 85 | shard_file = base_file + f"-shard-{shard_num}" 86 | info_all.extend(joblib.load(shard_file)) 87 | 88 | bond_edits_graphs = [] 89 | mol_list = [] 90 | 91 | if args.augment: 92 | save_dir = os.path.join(args.data_dir, f"{mode}_aug") 93 | else: 94 | save_dir = os.path.join(args.data_dir, f"{mode}") 95 | 96 | if args.use_h_labels: 97 | save_dir = os.path.join(save_dir, "h_labels") 98 | else: 99 | save_dir = os.path.join(save_dir, "without_h_labels") 100 | 101 | if args.use_rxn_class: 102 | save_dir = os.path.join(save_dir, "with_rxn", "bond_edits") 103 | else: 104 | save_dir = os.path.join(save_dir, "without_rxn", "bond_edits") 105 | 106 | save_dir = os.path.join(save_dir, args.mpnn) 107 | os.makedirs(save_dir, exist_ok=True) 108 | 109 | num_batches = 0 110 | total_examples = 0 111 | 112 | for idx, reaction_info in enumerate(info_all): 113 | rxn_smi = reaction_info.rxn_smi 114 | r, p = rxn_smi.split(">>") 115 | products = get_mol(p) 116 | 117 | assert len(bond_edits_graphs) == len(mol_list) 118 | if (len(mol_list) % args.batch_size == 0) and len(mol_list): 119 | print(f"Saving after {total_examples}") 120 | sys.stdout.flush() 121 | batch_tensors = process_batch(bond_edits_graphs, mol_list, args) 122 | torch.save(batch_tensors, os.path.join(save_dir, f"batch-{num_batches}.pt")) 123 | 124 | num_batches += 1 125 | mol_list = [] 126 | bond_edits_graphs = [] 127 | 128 | if (products is None) or (products.GetNumAtoms() <= 1): 129 | print(f"Product has 0 or 1 atoms, Skipping reaction {idx}") 130 | print() 131 | sys.stdout.flush() 132 | continue 133 | 134 | reactants = get_mol(r) 135 | 136 | if (reactants is None) or (reactants.GetNumAtoms() <= 1): 137 | print(f"Reactant has 0 or 1 atoms, Skipping reaction {idx}") 138 | print() 139 | sys.stdout.flush() 140 | continue 141 | 142 | fragments = apply_edits_to_mol(Chem.Mol(products), reaction_info.core_edits) 143 | 144 | if (fragments is None) or (fragments.GetNumAtoms() <=1): 145 | print(f"Fragments are invalid. Skipping reaction {idx}") 146 | print() 147 | sys.stdout.flush() 148 | continue 149 | 150 | if len(Chem.rdmolops.GetMolFrags(fragments)) != len(Chem.rdmolops.GetMolFrags(reactants)): 151 | print(f"Number of fragments don't match reactants. Skipping reaction {idx}") 152 | print() 153 | sys.stdout.flush() 154 | continue 155 | 156 | tmp_frag = MultiElement(mol=Chem.Mol(fragments)).mols 157 | fragments = Chem.Mol() 158 | for mol in tmp_frag: 159 | fragments = Chem.CombineMols(fragments, mol) 160 | 161 | if len(reaction_info.core_edits) == 1: 162 | edit = reaction_info.core_edits[0] 163 | a1, a2, b1, b2 = edit.split(":") 164 | 165 | if float(b1) and float(b2) >= 0: 166 | bond_edits_graph = BondEditsRxn(prod_mol=Chem.Mol(products), 167 | frag_mol=Chem.Mol(fragments), 168 | reac_mol=Chem.Mol(reactants), 169 | edits_to_apply=[edit], 170 | rxn_class=reaction_info.rxn_class) 171 | frag_graph = MultiElement(mol=Chem.Mol(fragments)) 172 | 173 | frag_mols = copy.deepcopy(frag_graph.mols) 174 | reac_mols = copy.deepcopy(MultiElement(mol=Chem.Mol(reactants)).mols) 175 | 176 | bond_edits_graphs.append(bond_edits_graph) 177 | mol_list.append((products, copy.deepcopy(reac_mols), copy.deepcopy(frag_mols))) 178 | total_examples += 1 179 | 180 | if (idx % args.print_every == 0) and idx: 181 | print(f"{idx}/{len(info_all)} {mode} reactions processed.") 182 | sys.stdout.flush() 183 | 184 | print(f"All {mode} reactions complete.") 185 | sys.stdout.flush() 186 | 187 | assert len(bond_edits_graphs) == len(mol_list) 188 | batch_tensors = process_batch(bond_edits_graphs, mol_list, args) 189 | torch.save(batch_tensors, os.path.join(save_dir, f"batch-{num_batches}.pt")) 190 | 191 | num_batches += 1 192 | mol_list = [] 193 | bond_edits_graphs = [] 194 | 195 | return num_batches 196 | 197 | def parse_bond_edits_reverse(args: Any, mode: str = 'train', num_batches: int = None) -> None: 198 | """Parse reactions. 199 | 200 | Parameters 201 | ---------- 202 | rxns: List 203 | List of reaction SMILES 204 | args: Namespace object 205 | Args supplied via command line 206 | mode: str, default train 207 | Type of dataset being parsed. 208 | """ 209 | if args.use_h_labels: 210 | base_file = os.path.join(args.data_dir, f"{mode}", "h_labels", args.info_file) 211 | else: 212 | base_file = os.path.join(args.data_dir, f"{mode}", "without_h_labels", args.info_file) 213 | 214 | info_all = [] 215 | for shard_num in range(5): 216 | shard_file = base_file + f"-shard-{shard_num}" 217 | info_all.extend(joblib.load(shard_file)) 218 | 219 | bond_edits_graphs = [] 220 | mol_list = [] 221 | 222 | if args.augment: 223 | save_dir = os.path.join(args.data_dir, f"{mode}_aug") 224 | else: 225 | save_dir = os.path.join(args.data_dir, f"{mode}") 226 | 227 | if args.use_h_labels: 228 | save_dir = os.path.join(save_dir, "h_labels") 229 | else: 230 | save_dir = os.path.join(save_dir, "without_h_labels") 231 | 232 | if args.use_rxn_class: 233 | save_dir = os.path.join(save_dir, "with_rxn", "bond_edits") 234 | else: 235 | save_dir = os.path.join(save_dir, "without_rxn", "bond_edits") 236 | os.makedirs(save_dir, exist_ok=True) 237 | 238 | for idx, reaction_info in enumerate(info_all): 239 | rxn_smi = reaction_info.rxn_smi 240 | r, p = rxn_smi.split(">>") 241 | products = get_mol(p) 242 | 243 | if (products is None) or (products.GetNumAtoms() <= 1): 244 | print(f"Product has 0 or 1 atoms, Skipping reaction {idx}") 245 | print() 246 | sys.stdout.flush() 247 | continue 248 | 249 | reactants = get_mol(r) 250 | 251 | if (reactants is None) or (reactants.GetNumAtoms() <= 1): 252 | print(f"Reactant has 0 or 1 atoms, Skipping reaction {idx}") 253 | print() 254 | sys.stdout.flush() 255 | continue 256 | 257 | fragments = apply_edits_to_mol(Chem.Mol(products), reaction_info.core_edits) 258 | 259 | if (fragments is None) or (fragments.GetNumAtoms() <=1): 260 | print(f"Fragments are invalid. Skipping reaction {idx}") 261 | print() 262 | sys.stdout.flush() 263 | continue 264 | 265 | if len(Chem.rdmolops.GetMolFrags(fragments)) != len(Chem.rdmolops.GetMolFrags(reactants)): 266 | print(f"Number of fragments don't match reactants. Skipping reaction {idx}") 267 | print() 268 | sys.stdout.flush() 269 | continue 270 | 271 | if len(Chem.rdmolops.GetMolFrags(fragments)) == 1: 272 | continue 273 | 274 | tmp_frag = MultiElement(mol=Chem.Mol(fragments)).mols 275 | fragments = Chem.Mol() 276 | for mol in tmp_frag: 277 | fragments = Chem.CombineMols(fragments, mol) 278 | 279 | if len(reaction_info.core_edits) == 1: 280 | edit = reaction_info.core_edits[0] 281 | a1, a2, b1, b2 = edit.split(":") 282 | 283 | if float(b1) and float(b2) >= 0 and int(a2) != 0: 284 | 285 | frag_mols = MultiElement(mol=fragments).mols 286 | reac_mols = MultiElement(mol=reactants).mols 287 | 288 | reac_mols, frag_mols = map_reac_and_frag(reac_mols, frag_mols) 289 | reac_mols_rev = copy.deepcopy(reac_mols[::-1]) 290 | frag_mols_rev = copy.deepcopy(frag_mols[::-1]) 291 | 292 | reactants_rev = Chem.Mol() 293 | for mol in reac_mols_rev: 294 | reactants_rev = Chem.CombineMols(reactants_rev, Chem.Mol(mol)) 295 | 296 | fragments_rev = Chem.Mol() 297 | for mol in frag_mols_rev: 298 | fragments_rev = Chem.CombineMols(fragments_rev, Chem.Mol(mol)) 299 | 300 | bond_edits_graph = BondEditsRxn(prod_mol=Chem.Mol(products), 301 | frag_mol=Chem.Mol(fragments_rev), 302 | reac_mol=Chem.Mol(reactants_rev), 303 | edits_to_apply=[edit], 304 | rxn_class=reaction_info.rxn_class) 305 | bond_edits_graphs.append(bond_edits_graph) 306 | mol_list.append((products, copy.deepcopy(reac_mols_rev), copy.deepcopy(frag_mols_rev))) 307 | 308 | if (idx % args.print_every == 0) and idx: 309 | print(f"{idx}/{len(info_all)} {mode} reactions processed.") 310 | sys.stdout.flush() 311 | 312 | assert len(bond_edits_graphs) == len(mol_list) 313 | if (len(mol_list) % args.batch_size == 0) and len(mol_list): 314 | batch_tensors = process_batch(bond_edits_graphs, mol_list, args) 315 | torch.save(batch_tensors, os.path.join(save_dir, f"batch-{num_batches}.pt")) 316 | 317 | num_batches += 1 318 | mol_list = [] 319 | bond_edits_graphs = [] 320 | 321 | print(f"All {mode} reactions complete.") 322 | sys.stdout.flush() 323 | 324 | assert len(bond_edits_graphs) == len(mol_list) 325 | batch_tensors = process_batch(bond_edits_graphs, mol_list, args) 326 | torch.save(batch_tensors, os.path.join(save_dir, f"batch-{num_batches}.pt")) 327 | 328 | num_batches += 1 329 | mol_list = [] 330 | bond_edits_graphs = [] 331 | 332 | def main() -> None: 333 | parser = argparse.ArgumentParser() 334 | 335 | parser.add_argument("--data_dir", default=DATA_DIR, help="Directory to parse from.") 336 | parser.add_argument("--info_file", default=INFO_FILE, help='File with the information.') 337 | parser.add_argument("--print_every", default=1000, type=int, help="Print during parsing.") 338 | parser.add_argument('--mode', default='train') 339 | parser.add_argument("--mpnn", default='graph_feat') 340 | parser.add_argument("--use_h_labels", type=str2bool, default=True, help='Whether to use h-labels') 341 | parser.add_argument("--use_rxn_class", type=str2bool, default=False, help='Whether to use rxn_class') 342 | parser.add_argument("--parse_bond_graph", type=str2bool, default=True) 343 | parser.add_argument("--batch_size", type=int, default=32, help='Batch size to use.') 344 | parser.add_argument("--augment", type=str2bool, default=False, help="Whether to augment") 345 | 346 | args = parser.parse_args() 347 | 348 | if args.augment: 349 | num_batches = parse_bond_edits_forward(args=args, mode=args.mode) 350 | parse_bond_edits_reverse(args=args, num_batches=num_batches, mode=args.mode) 351 | else: 352 | num_batches = parse_bond_edits_forward(args=args, mode=args.mode) 353 | 354 | if __name__ == "__main__": 355 | main() 356 | -------------------------------------------------------------------------------- /data_process/core_edits/bond_edits_seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rdkit import Chem 3 | import argparse 4 | import joblib 5 | import os 6 | import sys 7 | import copy 8 | from typing import Any 9 | 10 | from seq_graph_retro.utils.parse import extract_leaving_groups 11 | from seq_graph_retro.utils.chem import apply_edits_to_mol, get_mol 12 | from seq_graph_retro.molgraph import BondEditsRxn, RxnElement, MultiElement 13 | from seq_graph_retro.molgraph.vocab import Vocab 14 | from seq_graph_retro.data.collate_fns import pack_graph_feats, prepare_lg_labels 15 | from seq_graph_retro.utils import str2bool 16 | 17 | DATA_DIR = "./datasets/uspto-50k" 18 | INFO_FILE = "uspto_50k.info.kekulized" 19 | NUM_SHARDS = 5 20 | 21 | def check_edits(edits): 22 | for edit in edits: 23 | a1, a2, b1, b2 = edit.split(":") 24 | b1 = float(b1) 25 | if b1 == 0.0: 26 | return False 27 | 28 | return True 29 | 30 | def process_batch_seq(edit_graphs, frag_batch, mol_list, args): 31 | assert len(edit_graphs) == len(frag_batch) == len(mol_list) 32 | lengths = torch.tensor([len(graph_seq) for graph_seq in edit_graphs], dtype=torch.long) 33 | max_seq_len = max([len(graph_seq) for graph_seq in edit_graphs]) 34 | 35 | if args.mode == 'dummy': 36 | lg_vocab_file = os.path.join(args.data_dir, 'dummy') 37 | else: 38 | lg_vocab_file = os.path.join(args.data_dir, 'train') 39 | 40 | if args.use_h_labels: 41 | lg_vocab_file += "/h_labels/lg_vocab.txt" 42 | else: 43 | lg_vocab_file += "/without_h_labels/lg_vocab.txt" 44 | 45 | lg_vocab = Vocab(joblib.load(lg_vocab_file)) 46 | 47 | seq_tensors = [] 48 | seq_labels = [] 49 | 50 | _, lg_groups, _ = extract_leaving_groups(mol_list) 51 | 52 | mol_attrs = ['prod_mol'] 53 | if args.use_h_labels: 54 | label_attrs = ['edit_label', 'h_label', 'done_label'] 55 | else: 56 | label_attrs = ['edit_label', 'done_label'] 57 | 58 | seq_mask = [] 59 | if args.mpnn == 'graph_feat': 60 | directed = True 61 | elif args.mpnn == 'wln': 62 | directed = False 63 | 64 | for idx in range(max_seq_len): 65 | graphs_idx = [copy.deepcopy(edit_graphs[i][min(idx, length-1)]).get_attributes(mol_attrs=mol_attrs, 66 | label_attrs=label_attrs) 67 | for i, length in enumerate(lengths)] 68 | mask = (idx < lengths).long() 69 | prod_graphs, edit_labels = list(zip(*graphs_idx)) 70 | assert all([isinstance(graph, RxnElement) for graph in prod_graphs]) 71 | 72 | if len(edit_labels[0]) == 1: 73 | edit_labels = torch.tensor(edit_labels, dtype=torch.long) 74 | else: 75 | edit_labels = [torch.tensor(edit_labels[i], dtype=torch.float) for i in range(len(edit_labels))] 76 | 77 | prod_tensors = pack_graph_feats(prod_graphs, directed=directed, use_rxn_class=args.use_rxn_class) 78 | seq_tensors.append(prod_tensors) 79 | seq_labels.append(edit_labels) 80 | seq_mask.append(mask) 81 | 82 | frag_tensors = pack_graph_feats(frag_batch, directed=directed, use_rxn_class=args.use_rxn_class) 83 | lg_labels, lengths = prepare_lg_labels(lg_vocab, lg_groups) 84 | seq_mask = torch.stack(seq_mask).long() 85 | assert seq_mask.shape[0] == max_seq_len 86 | assert seq_mask.shape[1] == len(mol_list) 87 | 88 | return seq_tensors, seq_labels, seq_mask, frag_tensors, lg_labels, lengths 89 | 90 | def parse_bond_edits_seq(args: Any, mode: str = 'train') -> None: 91 | """Parse reactions. 92 | 93 | Parameters 94 | ---------- 95 | rxns: List 96 | List of reaction SMILES 97 | args: Namespace object 98 | Args supplied via command line 99 | mode: str, default train 100 | Type of dataset being parsed. 101 | """ 102 | if args.use_h_labels: 103 | base_file = os.path.join(args.data_dir, f"{mode}", "h_labels", args.info_file) 104 | else: 105 | base_file = os.path.join(args.data_dir, f"{mode}", "without_h_labels", args.info_file) 106 | 107 | info_shards = 5 108 | info_all = [] 109 | for shard_num in range(info_shards): 110 | shard_file = base_file + f"-shard-{shard_num}" 111 | info_all.extend(joblib.load(shard_file)) 112 | 113 | bond_edits_graphs = [] 114 | bond_edits_frags = [] 115 | mol_list = [] 116 | 117 | save_dir = os.path.join(args.data_dir, f"{mode}") 118 | 119 | if args.use_h_labels: 120 | save_dir = os.path.join(save_dir, "h_labels") 121 | else: 122 | save_dir = os.path.join(save_dir, "without_h_labels") 123 | 124 | if args.use_rxn_class: 125 | save_dir = os.path.join(save_dir, "with_rxn", "bond_edits_seq") 126 | else: 127 | save_dir = os.path.join(save_dir, "without_rxn", "bond_edits_seq") 128 | save_dir = os.path.join(save_dir, args.mpnn) 129 | os.makedirs(save_dir, exist_ok=True) 130 | 131 | num_batches = 0 132 | 133 | for idx, reaction_info in enumerate(info_all): 134 | graph_seq = [] 135 | rxn_smi = reaction_info.rxn_smi 136 | r, p = rxn_smi.split(">>") 137 | products = get_mol(p) 138 | 139 | if (products is None) or (products.GetNumAtoms() <= 1): 140 | print(f"Product has 0 or 1 atoms, Skipping reaction {idx}") 141 | print() 142 | sys.stdout.flush() 143 | continue 144 | 145 | reactants = get_mol(r) 146 | 147 | if (reactants is None) or (reactants.GetNumAtoms() <= 1): 148 | print(f"Reactant has 0 or 1 atoms, Skipping reaction {idx}") 149 | print() 150 | sys.stdout.flush() 151 | continue 152 | 153 | fragments = apply_edits_to_mol(Chem.Mol(products), reaction_info.core_edits) 154 | 155 | if len(Chem.rdmolops.GetMolFrags(fragments)) != len(Chem.rdmolops.GetMolFrags(reactants)): 156 | print(f"Number of fragments don't match reactants. Skipping reaction {idx}") 157 | print() 158 | sys.stdout.flush() 159 | continue 160 | 161 | tmp_frag = MultiElement(mol=Chem.Mol(fragments)).mols 162 | fragments = Chem.Mol() 163 | for mol in tmp_frag: 164 | fragments = Chem.CombineMols(fragments, mol) 165 | 166 | edits_accepted = check_edits(reaction_info.core_edits) 167 | if not edits_accepted: 168 | print(f"New addition edit. Skipping reaction {idx}") 169 | print() 170 | sys.stdout.flush() 171 | continue 172 | 173 | edits_applied = [] 174 | for _, edit in enumerate(reaction_info.core_edits): 175 | interim_mol = apply_edits_to_mol(Chem.Mol(products), edits_applied) 176 | if interim_mol is None: 177 | print("Interim mol is None") 178 | break 179 | graph = BondEditsRxn(prod_mol=Chem.Mol(interim_mol), 180 | frag_mol=Chem.Mol(fragments), 181 | reac_mol=Chem.Mol(reactants), 182 | edits_to_apply=[edit], 183 | rxn_class=reaction_info.rxn_class) 184 | edits_applied.append(edit) 185 | graph_seq.append(graph) 186 | 187 | interim_mol = apply_edits_to_mol(Chem.Mol(products), edits_applied) 188 | if interim_mol is not None: 189 | graph = BondEditsRxn(prod_mol=Chem.Mol(interim_mol), 190 | frag_mol=Chem.Mol(fragments), 191 | reac_mol=Chem.Mol(reactants), 192 | edits_to_apply=[], 193 | rxn_class=reaction_info.rxn_class) 194 | 195 | frag_graph = MultiElement(mol=Chem.Mol(fragments), 196 | rxn_class=reaction_info.rxn_class) 197 | 198 | frag_mols = copy.deepcopy(frag_graph.mols) 199 | reac_mols = copy.deepcopy(MultiElement(mol=Chem.Mol(reactants)).mols) 200 | 201 | graph_seq.append(graph) 202 | else: 203 | continue 204 | 205 | if len(graph_seq) == 0: 206 | print(f"No valid fragment states found. Skipping reaction {idx}") 207 | print() 208 | sys.stdout.flush() 209 | continue 210 | 211 | bond_edits_graphs.append(graph_seq) 212 | bond_edits_frags.append(frag_graph) 213 | mol_list.append((products, copy.deepcopy(reac_mols), copy.deepcopy(frag_mols))) 214 | 215 | if (idx % args.print_every == 0) and idx: 216 | print(f"{idx}/{len(info_all)} {mode} reactions processed.") 217 | sys.stdout.flush() 218 | 219 | assert len(bond_edits_graphs) == len(bond_edits_frags) == len(mol_list) 220 | if (len(mol_list) % args.batch_size == 0) and len(mol_list): 221 | batch_tensors = process_batch_seq(bond_edits_graphs, bond_edits_frags, mol_list, args) 222 | torch.save(batch_tensors, os.path.join(save_dir, f"batch-{num_batches}.pt")) 223 | 224 | num_batches += 1 225 | bond_edits_frags = [] 226 | bond_edits_graphs = [] 227 | mol_list = [] 228 | 229 | print(f"All {mode} reactions complete.") 230 | sys.stdout.flush() 231 | 232 | batch_tensors = process_batch_seq(bond_edits_graphs, bond_edits_frags, mol_list, args) 233 | print("Saving..") 234 | torch.save(batch_tensors, os.path.join(save_dir, f"batch-{num_batches}.pt")) 235 | 236 | num_batches += 1 237 | bond_edits_frags = [] 238 | bond_edits_graphs = [] 239 | mol_list = [] 240 | 241 | def main() -> None: 242 | parser = argparse.ArgumentParser() 243 | 244 | parser.add_argument("--data_dir", default=DATA_DIR, help="Directory to parse from.") 245 | parser.add_argument("--info_file", default=INFO_FILE, help='File with the information.') 246 | parser.add_argument("--print_every", default=1000, type=int, help="Print during parsing.") 247 | parser.add_argument('--mode', default='train') 248 | parser.add_argument("--mpnn", default='graph_feat') 249 | parser.add_argument("--use_h_labels", type=str2bool, default=True, help='Whether to use h-labels') 250 | parser.add_argument("--use_rxn_class", type=str2bool, default=False, help='Whether to use rxn_class') 251 | parser.add_argument("--batch_size", default=32, type=int, help="Number of shards") 252 | 253 | args = parser.parse_args() 254 | parse_bond_edits_seq(args=args, mode=args.mode) 255 | 256 | if __name__ == "__main__": 257 | main() 258 | -------------------------------------------------------------------------------- /data_process/lg_edits/lg_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rdkit import Chem 3 | import argparse 4 | import joblib 5 | import os 6 | import sys 7 | import copy 8 | import random 9 | from typing import List, Any, Tuple 10 | 11 | from seq_graph_retro.utils.parse import extract_leaving_groups, map_reac_and_frag 12 | from seq_graph_retro.utils.chem import apply_edits_to_mol, get_mol 13 | from seq_graph_retro.molgraph import RxnElement, MultiElement 14 | from seq_graph_retro.molgraph.vocab import Vocab 15 | from seq_graph_retro.data.collate_fns import pack_graph_feats, prepare_lg_labels 16 | from seq_graph_retro.utils import str2bool 17 | 18 | DATA_DIR = "./datasets/uspto-50k" 19 | INFO_FILE = "uspto_50k.info.kekulized" 20 | NUM_SHARDS = 5 21 | 22 | def process_batch(prod_batch: List[RxnElement], 23 | frag_batch: List[MultiElement], 24 | mol_list: Tuple[Chem.Mol], args: Any) -> Tuple[Tuple[torch.Tensor]]: 25 | """Process batch of input graphs. 26 | 27 | Parameters 28 | ---------- 29 | prod_batch: List[RxnElement], 30 | Product graphs in batch 31 | frag_batch: List[MultiElement], 32 | Fragment graphs in batch 33 | mol_List: Tuple[Chem.Mol], 34 | A tuple of (product, reactant, fragment) mols to extract leaving groups 35 | args: Any, 36 | Command line arguments 37 | """ 38 | assert len(frag_batch) == len(mol_list) 39 | 40 | if args.mode == 'dummy': 41 | lg_vocab_file = os.path.join(args.data_dir, 'dummy') 42 | else: 43 | lg_vocab_file = os.path.join(args.data_dir, 'train') 44 | 45 | if args.use_h_labels: 46 | lg_vocab_file += "/h_labels/lg_vocab.txt" 47 | else: 48 | lg_vocab_file += "/without_h_labels/lg_vocab.txt" 49 | 50 | lg_vocab = Vocab(joblib.load(lg_vocab_file)) 51 | _, lg_groups, _ = extract_leaving_groups(mol_list) 52 | 53 | if args.mpnn == 'graph_feat': 54 | directed = True 55 | elif args.mpnn == 'wln': 56 | directed = False 57 | 58 | prod_inputs = pack_graph_feats(prod_batch, directed=directed, use_rxn_class=args.use_rxn_class) 59 | frag_inputs = pack_graph_feats(frag_batch, directed=directed, use_rxn_class=args.use_rxn_class) 60 | lg_labels, lengths = prepare_lg_labels(lg_vocab, lg_groups) 61 | 62 | return prod_inputs, frag_inputs, lg_labels, lengths 63 | 64 | def parse_frags_forward(args: Any, mode: str = 'train') -> None: 65 | """Parse Fragments using same order as reactants. 66 | 67 | Parameters 68 | ---------- 69 | args: Namespace object 70 | Args supplied via command line 71 | mode: str, default train 72 | Type of dataset being parsed. 73 | """ 74 | if args.use_h_labels: 75 | base_file = os.path.join(args.data_dir, f"{mode}", "h_labels", args.info_file) 76 | else: 77 | base_file = os.path.join(args.data_dir, f"{mode}", "without_h_labels", args.info_file) 78 | 79 | info_all = [] 80 | for shard_num in range(5): 81 | shard_file = base_file + f"-shard-{shard_num}" 82 | info_all.extend(joblib.load(shard_file)) 83 | 84 | frag_graphs = [] 85 | prod_graphs = [] 86 | mol_list = [] 87 | 88 | if args.augment: 89 | save_dir = os.path.join(args.data_dir, f"{mode}_aug") 90 | else: 91 | save_dir = os.path.join(args.data_dir, f"{mode}") 92 | 93 | if args.use_h_labels: 94 | save_dir = os.path.join(save_dir, "h_labels") 95 | else: 96 | save_dir = os.path.join(save_dir, "without_h_labels") 97 | 98 | if args.use_rxn_class: 99 | save_dir = os.path.join(save_dir, "with_rxn", "lg_classifier") 100 | else: 101 | save_dir = os.path.join(save_dir, "without_rxn", "lg_classifier") 102 | 103 | save_dir = os.path.join(save_dir, args.mpnn) 104 | os.makedirs(save_dir, exist_ok=True) 105 | num_batches = 0 106 | 107 | for idx, reaction_info in enumerate(info_all): 108 | rxn_smi = reaction_info.rxn_smi 109 | r, p = rxn_smi.split(">>") 110 | products = get_mol(p) 111 | 112 | if (products is None) or (products.GetNumAtoms() <= 1): 113 | print(f"Product has 0 or 1 atoms, Skipping reaction {idx}") 114 | print() 115 | sys.stdout.flush() 116 | continue 117 | 118 | reactants = get_mol(r) 119 | 120 | if (reactants is None) or (reactants.GetNumAtoms() <= 1): 121 | print(f"Reactant has 0 or 1 atoms, Skipping reaction {idx}") 122 | print() 123 | sys.stdout.flush() 124 | continue 125 | 126 | fragments = apply_edits_to_mol(Chem.Mol(products), reaction_info.core_edits) 127 | 128 | if (fragments is None) or (fragments.GetNumAtoms() <=1): 129 | print(f"Fragments are invalid. Skipping reaction {idx}") 130 | print() 131 | sys.stdout.flush() 132 | continue 133 | 134 | if len(Chem.rdmolops.GetMolFrags(fragments)) != len(Chem.rdmolops.GetMolFrags(reactants)): 135 | print(f"Number of fragments don't match reactants. Skipping reaction {idx}") 136 | print() 137 | sys.stdout.flush() 138 | continue 139 | 140 | prod_graph = RxnElement(mol=Chem.Mol(products), rxn_class=reaction_info.rxn_class) 141 | tmp_frags = MultiElement(mol=Chem.Mol(fragments)) 142 | tmp_reac = MultiElement(mol=Chem.Mol(reactants)) 143 | 144 | frag_mols = copy.deepcopy(tmp_frags.mols) 145 | reac_mols = copy.deepcopy(tmp_reac.mols) 146 | 147 | reac_mols, frag_mols = map_reac_and_frag(reac_mols, frag_mols) 148 | 149 | # Shuffling is introduced here to negate the effects that the 150 | # atom-mapping might bring on the order in which synthons are processed. 151 | shuffled_order = list(range(len(reac_mols))) 152 | #random.shuffle(shuffled_order) 153 | reac_mols = [reac_mols[idx] for idx in shuffled_order] 154 | frag_mols = [frag_mols[idx] for idx in shuffled_order] 155 | 156 | reac_aligned = Chem.Mol() 157 | frag_aligned = Chem.Mol() 158 | 159 | # Combine the shuffled mols into a single mol 160 | for reac_mol, frag_mol in zip(*(reac_mols, frag_mols)): 161 | reac_aligned = Chem.CombineMols(reac_aligned, reac_mol) 162 | frag_aligned = Chem.CombineMols(frag_aligned, frag_mol) 163 | 164 | frag_graph = MultiElement(mol=Chem.Mol(frag_aligned), rxn_class=reaction_info.rxn_class) 165 | 166 | prod_graphs.append(prod_graph) 167 | frag_graphs.append(frag_graph) 168 | mol_list.append((products, copy.deepcopy(reac_mols), copy.deepcopy(frag_mols))) 169 | 170 | if (idx % args.print_every == 0) and idx: 171 | print(f"{idx}/{len(info_all)} {mode} reactions processed.") 172 | sys.stdout.flush() 173 | 174 | assert len(frag_graphs) == len(mol_list) == len(prod_graphs) 175 | if (len(mol_list) % args.batch_size == 0) and len(mol_list): 176 | batch_tensors = process_batch(prod_graphs, frag_graphs, mol_list, args) 177 | torch.save(batch_tensors, os.path.join(save_dir, f"batch-{num_batches}.pt")) 178 | 179 | num_batches += 1 180 | mol_list = [] 181 | prod_graphs = [] 182 | frag_graphs = [] 183 | 184 | 185 | print(f"All {mode} reactions complete.") 186 | sys.stdout.flush() 187 | 188 | if len(frag_graphs) != 0: 189 | assert len(frag_graphs) == len(mol_list) == len(prod_graphs) 190 | batch_tensors = process_batch(prod_graphs, frag_graphs, mol_list, args) 191 | torch.save(batch_tensors, os.path.join(save_dir, f"batch-{num_batches}.pt")) 192 | 193 | num_batches += 1 194 | mol_list = [] 195 | prod_graphs = [] 196 | frag_graphs = [] 197 | 198 | return num_batches 199 | 200 | def parse_frags_reverse(args: Any, mode: str = 'train', num_batches: int = None) -> None: 201 | """Parse Fragments using reverse order as reactants. 202 | 203 | Parameters 204 | ---------- 205 | args: Namespace object 206 | Args supplied via command line 207 | mode: str, default train 208 | Type of dataset being parsed. 209 | """ 210 | if args.use_h_labels: 211 | base_file = os.path.join(args.data_dir, f"{mode}", "h_labels", args.info_file) 212 | else: 213 | base_file = os.path.join(args.data_dir, f"{mode}", "without_h_labels", args.info_file) 214 | 215 | info_all = [] 216 | for shard_num in range(5): 217 | shard_file = base_file + f"-shard-{shard_num}" 218 | info_all.extend(joblib.load(shard_file)) 219 | 220 | frag_graphs = [] 221 | 222 | prod_graphs = [] 223 | mol_list = [] 224 | 225 | if args.augment: 226 | save_dir = os.path.join(args.data_dir, f"{mode}_aug") 227 | else: 228 | save_dir = os.path.join(args.data_dir, f"{mode}") 229 | 230 | if args.use_h_labels: 231 | save_dir = os.path.join(save_dir, "h_labels") 232 | else: 233 | save_dir = os.path.join(save_dir, "without_h_labels") 234 | 235 | if args.use_rxn_class: 236 | save_dir = os.path.join(save_dir, "with_rxn", "lg_classifier") 237 | else: 238 | save_dir = os.path.join(save_dir, "without_rxn", "lg_classifier") 239 | 240 | os.makedirs(save_dir, exist_ok=True) 241 | 242 | for idx, reaction_info in enumerate(info_all): 243 | rxn_smi = reaction_info.rxn_smi 244 | r, p = rxn_smi.split(">>") 245 | products = get_mol(p) 246 | 247 | if (products is None) or (products.GetNumAtoms() <= 1): 248 | print(f"Product has 0 or 1 atoms, Skipping reaction {idx}") 249 | print() 250 | sys.stdout.flush() 251 | continue 252 | 253 | reactants = get_mol(r) 254 | 255 | if (reactants is None) or (reactants.GetNumAtoms() <= 1): 256 | print(f"Reactant has 0 or 1 atoms, Skipping reaction {idx}") 257 | print() 258 | sys.stdout.flush() 259 | continue 260 | 261 | fragments = apply_edits_to_mol(Chem.Mol(products), reaction_info.core_edits) 262 | 263 | if (fragments is None) or (fragments.GetNumAtoms() <=1): 264 | print(f"Fragments are invalid. Skipping reaction {idx}") 265 | print() 266 | sys.stdout.flush() 267 | continue 268 | 269 | if len(Chem.rdmolops.GetMolFrags(fragments)) != len(Chem.rdmolops.GetMolFrags(reactants)): 270 | print(f"Number of fragments don't match reactants. Skipping reaction {idx}") 271 | print() 272 | sys.stdout.flush() 273 | continue 274 | 275 | if len(Chem.rdmolops.GetMolFrags(fragments)) == 1: 276 | continue 277 | 278 | prod_graph = RxnElement(mol=Chem.Mol(products), rxn_class=reaction_info.rxn_class) 279 | tmp_frags = MultiElement(mol=Chem.Mol(fragments)) 280 | tmp_reac = MultiElement(mol=Chem.Mol(reactants)) 281 | 282 | frag_mols = tmp_frags.mols 283 | reac_mols = tmp_reac.mols 284 | 285 | reac_mols, frag_mols = map_reac_and_frag(reac_mols, frag_mols) 286 | reac_mols_rev = copy.deepcopy(reac_mols[::-1]) 287 | frag_mols_rev = copy.deepcopy(frag_mols[::-1]) 288 | 289 | fragments_rev = Chem.Mol() 290 | reactants_rev = Chem.Mol() 291 | 292 | for reac_mol, frag_mol in zip(*(reac_mols_rev, frag_mols_rev)): 293 | fragments_rev = Chem.CombineMols(fragments_rev, Chem.Mol(frag_mol)) 294 | reactants_rev = Chem.CombineMols(reactants_rev, Chem.Mol(reac_mol)) 295 | 296 | frag_graph = MultiElement(mol=Chem.Mol(fragments_rev), rxn_class=reaction_info.rxn_class) 297 | reac_graph = MultiElement(mol=Chem.Mol(reactants_rev), rxn_class=reaction_info.rxn_class) 298 | 299 | prod_graphs.append(prod_graph) 300 | frag_graphs.append(frag_graph) 301 | mol_list.append((products, copy.deepcopy(reac_mols_rev), copy.deepcopy(frag_mols_rev))) 302 | 303 | if (idx % args.print_every == 0) and idx: 304 | print(f"{idx}/{len(info_all)} {mode} reactions processed.") 305 | sys.stdout.flush() 306 | 307 | assert len(frag_graphs) == len(mol_list) == len(prod_graphs) 308 | if (len(mol_list) % args.batch_size == 0) and len(mol_list): 309 | batch_tensors = process_batch(prod_graphs, frag_graphs, mol_list, args) 310 | torch.save(batch_tensors, os.path.join(save_dir, f"batch-{num_batches}.pt")) 311 | 312 | num_batches += 1 313 | mol_list = [] 314 | prod_graphs = [] 315 | frag_graphs = [] 316 | 317 | 318 | print(f"All {mode} reactions complete.") 319 | sys.stdout.flush() 320 | 321 | if len(frag_graphs) != 0: 322 | assert len(frag_graphs) == len(mol_list) == len(prod_graphs) 323 | batch_tensors = process_batch(prod_graphs, frag_graphs, mol_list, args) 324 | torch.save(batch_tensors, os.path.join(save_dir, f"batch-{num_batches}.pt")) 325 | 326 | num_batches += 1 327 | mol_list = [] 328 | prod_graphs = [] 329 | frag_graphs = [] 330 | 331 | 332 | def main() -> None: 333 | parser = argparse.ArgumentParser() 334 | 335 | parser.add_argument("--data_dir", default=DATA_DIR, help="Directory to parse from.") 336 | parser.add_argument("--info_file", default=INFO_FILE, help='File with the information.') 337 | parser.add_argument("--print_every", default=1000, type=int, help="Print during parsing.") 338 | parser.add_argument('--mode', default='train') 339 | parser.add_argument("--mpnn", default='graph_feat') 340 | parser.add_argument("--use_h_labels", type=str2bool, default=True, help='Whether to use h-labels') 341 | parser.add_argument("--use_rxn_class", type=str2bool, default=False, help='Whether to use reaction-class') 342 | parser.add_argument("--batch_size", type=int, default=32, help='Batch size to use.') 343 | parser.add_argument("--augment", action='store_true', help="Whether to augment") 344 | 345 | args = parser.parse_args() 346 | 347 | if args.augment: 348 | num_batches = parse_frags_forward(args=args, mode=args.mode) 349 | parse_frags_reverse(args=args, num_batches=num_batches, mode=args.mode) 350 | else: 351 | num_batches = parse_frags_forward(args=args, mode=args.mode) 352 | 353 | if __name__ == "__main__": 354 | main() 355 | -------------------------------------------------------------------------------- /data_process/lg_edits/lg_tensors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rdkit import Chem 3 | import argparse 4 | import joblib 5 | import os 6 | import sys 7 | 8 | from seq_graph_retro.molgraph import RxnElement 9 | from seq_graph_retro.data.collate_fns import pack_graph_feats 10 | from seq_graph_retro.utils import str2bool 11 | 12 | DATA_DIR = "./datasets/uspto-50k" 13 | INFO_FILE = "uspto_50k.info.kekulized" 14 | NUM_SHARDS = 5 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument("--data_dir", default=DATA_DIR, help="Directory to parse from.") 20 | parser.add_argument('--mode', default='train') 21 | parser.add_argument("--mpnn", default='graph_feat') 22 | parser.add_argument("--use_h_labels", type=str2bool, default=True, help='Whether to use h-labels') 23 | parser.add_argument("--use_rxn_class", type=str2bool, default=False, help='Whether to use reaction class.') 24 | 25 | args = parser.parse_args() 26 | 27 | if args.mode == 'dummy': 28 | lg_mols_file = os.path.join(args.data_dir, 'dummy') 29 | else: 30 | lg_mols_file = os.path.join(args.data_dir, 'train') 31 | 32 | if args.use_h_labels: 33 | lg_mols_file += "/h_labels/lg_mols.file" 34 | else: 35 | lg_mols_file += "/without_h_labels/lg_mols.file" 36 | lg_mols = joblib.load(lg_mols_file) 37 | 38 | graphs = [] 39 | for idx, mol in enumerate(lg_mols): 40 | graphs.append(RxnElement(mol=Chem.Mol(mol), rxn_class=0)) 41 | 42 | print(len(graphs)) 43 | sys.stdout.flush() 44 | 45 | if args.mpnn == 'graph_feat': 46 | directed = True 47 | elif args.mpnn == 'wln': 48 | directed = False 49 | 50 | lg_inputs = pack_graph_feats(graphs, directed=directed, use_rxn_class=args.use_rxn_class) 51 | if args.use_h_labels: 52 | save_dir = os.path.join(args.data_dir, f"{args.mode}", "h_labels") 53 | else: 54 | save_dir = os.path.join(args.data_dir, f"{args.mode}", "without_h_labels") 55 | 56 | if args.use_rxn_class: 57 | save_dir = os.path.join(save_dir, "with_rxn") 58 | else: 59 | save_dir = os.path.join(save_dir, "without_rxn") 60 | 61 | os.makedirs(save_dir, exist_ok=True) 62 | torch.save(lg_inputs, os.path.join(save_dir, f"lg_inputs.pt")) 63 | print("Save complete.") 64 | sys.stdout.flush() 65 | -------------------------------------------------------------------------------- /data_process/parse_info.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | import pandas as pd 4 | import argparse 5 | import joblib 6 | import os 7 | import sys 8 | import copy 9 | from typing import List, Any 10 | 11 | from seq_graph_retro.utils.parse import get_reaction_info, extract_leaving_groups 12 | from seq_graph_retro.molgraph import MultiElement 13 | from seq_graph_retro.utils.chem import apply_edits_to_mol, get_mol 14 | from seq_graph_retro.utils import str2bool 15 | 16 | DATA_DIR = "./datasets/uspto-50k" 17 | 18 | 19 | def parse_info(rxns: List, rxn_classes: List, args: Any, mode: str = 'train') -> None: 20 | """Parse reactions. 21 | 22 | Parameters 23 | ---------- 24 | rxns: List 25 | List of reaction SMILES 26 | args: Namespace object 27 | Args supplied via command line 28 | mode: str, default train 29 | Type of dataset being parsed. 30 | """ 31 | info_all = [] 32 | mol_list = [] 33 | counter = [] 34 | if args.use_h_labels: 35 | save_dir = os.path.join(args.data_dir, f"{mode}", "h_labels") 36 | else: 37 | save_dir = os.path.join(args.data_dir, f"{mode}", "without_h_labels") 38 | 39 | os.makedirs(save_dir, exist_ok=True) 40 | 41 | for idx, rxn_smi in enumerate(rxns): 42 | try: 43 | reaction_info = get_reaction_info(rxn_smi, kekulize=args.kekulize, 44 | use_h_labels=args.use_h_labels, 45 | rxn_class=int(rxn_classes[idx])) 46 | except: 47 | print(f"Failed to extract reaction info. Skipping reaction {idx}") 48 | print() 49 | sys.stdout.flush() 50 | continue 51 | 52 | r, p = rxn_smi.split(">>") 53 | products = get_mol(p) 54 | 55 | if (products is None) or (products.GetNumAtoms() <= 1): 56 | print(f"Product has 0 or 1 atoms, Skipping reaction {idx}") 57 | print() 58 | sys.stdout.flush() 59 | continue 60 | 61 | reactants = get_mol(r) 62 | 63 | if (reactants is None) or (reactants.GetNumAtoms() <= 1): 64 | print(f"Reactant has 0 or 1 atoms, Skipping reaction {idx}") 65 | print() 66 | sys.stdout.flush() 67 | continue 68 | 69 | fragments = apply_edits_to_mol(Chem.Mol(products), reaction_info.core_edits) 70 | counter.append(len(reaction_info.core_edits)) 71 | 72 | if len(Chem.rdmolops.GetMolFrags(fragments)) != len(Chem.rdmolops.GetMolFrags(reactants)): 73 | print(f"Number of fragments don't match reactants. Skipping reaction {idx}") 74 | print() 75 | sys.stdout.flush() 76 | continue 77 | 78 | frag_mols = copy.deepcopy(MultiElement(mol=Chem.Mol(fragments)).mols) 79 | reac_mols = copy.deepcopy(MultiElement(mol=Chem.Mol(reactants)).mols) 80 | mol_list.append((products, copy.deepcopy(reac_mols), copy.deepcopy(frag_mols))) 81 | info_all.append(reaction_info) 82 | 83 | if (idx % args.print_every == 0) and idx: 84 | print(f"{idx}/{len(rxns)} {mode} reactions processed.") 85 | sys.stdout.flush() 86 | 87 | print(f"All {mode} reactions complete.") 88 | sys.stdout.flush() 89 | 90 | info_file = os.path.join(save_dir, args.save_file) 91 | if args.kekulize: 92 | info_file += ".kekulized" 93 | 94 | n_shards = 5 95 | indices_shards = np.array_split(np.arange(len(info_all)), n_shards) 96 | 97 | for shard_num, indices_per_shard in enumerate(indices_shards): 98 | info_shard = [] 99 | frag_shard = [] 100 | 101 | for index in indices_per_shard: 102 | info_shard.append(info_all[index]) 103 | 104 | info_file_shard = info_file + f"-shard-{shard_num}" 105 | joblib.dump(info_shard, info_file_shard, compress=3) 106 | 107 | print("Extracting leaving groups.") 108 | lg_dict, lg_groups, lg_mols = extract_leaving_groups(mol_list) 109 | 110 | print("Leaving groups extracted...") 111 | print(f"{mode}: {len(lg_groups)}, {len(info_all)}") 112 | sys.stdout.flush() 113 | 114 | if mode == 'train' or mode == 'dummy': 115 | joblib.dump(lg_dict, os.path.join(save_dir, "lg_vocab.txt")) 116 | joblib.dump(lg_mols, os.path.join(save_dir, "lg_mols.file")) 117 | print(lg_dict) 118 | 119 | from collections import Counter 120 | print(Counter(counter)) 121 | joblib.dump(lg_groups, os.path.join(save_dir, "lg_groups.txt")) 122 | joblib.dump(lg_mols, os.path.join(save_dir, "lg_mols.file")) 123 | 124 | def main() -> None: 125 | parser = argparse.ArgumentParser() 126 | 127 | parser.add_argument("--data_dir", default=DATA_DIR, help="Directory to parse from.") 128 | parser.add_argument("--save_file", default="uspto_50k.info", help='Base filename to save') 129 | parser.add_argument('--mode', required=True, help="Type of dataset being prepared.") 130 | parser.add_argument("--print_every", default=1000, type=int, help="Print during parsing.") 131 | parser.add_argument("--kekulize", type=str2bool, default=True, help='Whether to kekulize mols during training') 132 | parser.add_argument("--use_h_labels", type=str2bool, default=True, help='Whether to use h-labels') 133 | args = parser.parse_args() 134 | 135 | rxn_key = "reactants>reagents>production" 136 | filename = f"canonicalized_{args.mode}.csv" 137 | df = pd.read_csv(os.path.join(args.data_dir, filename)) 138 | parse_info(rxns=df[rxn_key], rxn_classes=df['class'], args=args, mode=args.mode) 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: seq_gr 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - rdkit 8 | - python=3.7.3 9 | - pytorch=1.7.0 10 | - cudatoolkit=10.1.243 11 | - pip 12 | - pip: 13 | - networkx 14 | - wandb 15 | - joblib 16 | - tqdm 17 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate seq_gr 4 | EDITS_EXP="SingleEdit_10-02-2021--08-44-37" 5 | EDITS_STEP="epoch_156" 6 | LG_EXP="LGIndEmbed_18-02-2021--12-23-26" 7 | LG_STEP="step_101951" 8 | 9 | echo "Performance evaluation for reaction class unknown setting" 10 | python scripts/eval/single_edit_lg.py \ 11 | --edits_exp $EDITS_EXP \ 12 | --lg_exp $LG_EXP \ 13 | --edits_step $EDITS_STEP \ 14 | --lg_step $LG_STEP \ 15 | --exp_dir models 16 | 17 | EDITS_EXP="SingleEdit_14-02-2021--19-26-20" 18 | EDITS_STEP="step_144228" 19 | LG_EXP="LGIndEmbedClassifier_18-04-2021--11-59-29" 20 | LG_STEP="step_110701" 21 | 22 | echo "Performance evaluation for reaction class known setting" 23 | python scripts/eval/single_edit_lg.py \ 24 | --edits_exp $EDITS_EXP \ 25 | --lg_exp $LG_EXP \ 26 | --edits_step $EDITS_STEP \ 27 | --lg_step $LG_STEP \ 28 | --exp_dir models/ 29 | 30 | -------------------------------------------------------------------------------- /models/LGIndEmbedClassifier_18-04-2021--11-59-29/checkpoints/step_110701.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/models/LGIndEmbedClassifier_18-04-2021--11-59-29/checkpoints/step_110701.pt -------------------------------------------------------------------------------- /models/LGIndEmbed_18-02-2021--12-23-26/checkpoints/step_101951.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/models/LGIndEmbed_18-02-2021--12-23-26/checkpoints/step_101951.pt -------------------------------------------------------------------------------- /models/SingleEdit_10-02-2021--08-44-37/checkpoints/epoch_156.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/models/SingleEdit_10-02-2021--08-44-37/checkpoints/epoch_156.pt -------------------------------------------------------------------------------- /models/SingleEdit_14-02-2021--19-26-20/checkpoints/step_144228.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/models/SingleEdit_14-02-2021--19-26-20/checkpoints/step_144228.pt -------------------------------------------------------------------------------- /scripts/benchmarks/run_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | import json 5 | from datetime import datetime as dt 6 | import sys 7 | from rdkit import RDLogger 8 | 9 | from seq_graph_retro.molgraph.mol_features import ATOM_FDIM 10 | from seq_graph_retro.molgraph.mol_features import BOND_FDIM, BINARY_FDIM 11 | from seq_graph_retro.models.model_builder import build_model, MODEL_ATTRS 12 | from seq_graph_retro.models import Trainer 13 | from seq_graph_retro.utils import str2bool 14 | import wandb 15 | import yaml 16 | lg = RDLogger.logger() 17 | lg.setLevel(RDLogger.CRITICAL) 18 | 19 | try: 20 | ROOT_DIR = os.environ["SEQ_GRAPH_RETRO"] 21 | DATA_DIR = os.path.join(ROOT_DIR, "datasets", "uspto-50k") 22 | out_dir = os.path.join(ROOT_DIR, "experiments") 23 | 24 | except KeyError: 25 | ROOT_DIR = "./" 26 | DATA_DIR = os.path.join(ROOT_DIR, "datasets", "uspto-50k") 27 | out_dir = os.path.join(ROOT_DIR, "local_experiments") 28 | 29 | INFO_FILE = "uspto_50k.info.kekulized" 30 | LABELS_FILE = "lg_groups.txt" 31 | VOCAB_FILE = "lg_vocab.txt" 32 | 33 | NUM_SHARDS = 5 34 | 35 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 36 | 37 | 38 | def get_model_dir(model_name): 39 | MODEL_DIRS = { 40 | 'single_edit': 'bond_edits', 41 | 'multi_edit': 'bond_edits_seq', 42 | 'single_shared': 'bond_edits', 43 | 'lg_classifier': 'lg_classifier', 44 | 'lg_ind': 'lg_classifier' 45 | } 46 | 47 | return MODEL_DIRS.get(model_name) 48 | 49 | 50 | def run_model(config): 51 | print(config) 52 | if config.get('use_doubles', False): 53 | torch.set_default_dtype(torch.float64) 54 | else: 55 | torch.set_default_dtype(torch.float32) 56 | 57 | if config.get('use_augment', False): 58 | aug_suffix = "_aug" 59 | else: 60 | aug_suffix = "" 61 | 62 | model = build_model(config, device=DEVICE) 63 | print(f"Converting model to device: {DEVICE}") 64 | sys.stdout.flush() 65 | model.to(DEVICE) 66 | 67 | if config.get('restore_enc_from', None): 68 | ckpt_id = config.get("enc_ckpt", "best_model") 69 | loaded = torch.load(os.path.join(config['out_dir'], "wandb", config['restore_enc_from'], "files", 70 | f"{ckpt_id}.pt"), map_location=DEVICE) 71 | if 'saveables' in loaded: 72 | loaded_enc_name = loaded['saveables']['encoder_name'] 73 | msg = "Encoder name of pretrained encoder and current encoder must match" 74 | assert loaded_enc_name == model.encoder_name, msg 75 | 76 | state = loaded['state'] 77 | enc_keys = [key for key in state.keys() if 'encoder' in key] 78 | enc_dict = {key: state[key] for key in enc_keys} 79 | 80 | model_dict = model.state_dict() 81 | model_dict.update(enc_dict) 82 | for key in enc_keys: 83 | assert torch.sum(torch.eq(model_dict[key], enc_dict[key])) 84 | 85 | print("Loading from pretrained encoder.") 86 | sys.stdout.flush() 87 | model.load_state_dict(model_dict) 88 | 89 | print("Param Count: ", sum([x.nelement() for x in model.parameters()]) / 10**6, "M") 90 | print() 91 | sys.stdout.flush() 92 | 93 | print(f"Device used: {DEVICE}") 94 | sys.stdout.flush() 95 | 96 | _, train_dataset_class, eval_dataset_class, use_labels = MODEL_ATTRS.get(config['model']) 97 | model_dir_name = get_model_dir(config['model']) 98 | 99 | if config.get('use_h_labels', True): 100 | train_dir = os.path.join(config['data_dir'], "train" + aug_suffix, "h_labels") 101 | eval_dir = os.path.join(config['data_dir'], "eval", "h_labels") 102 | else: 103 | train_dir = os.path.join(config['data_dir'], "train" + aug_suffix, "without_h_labels") 104 | eval_dir = os.path.join(config['data_dir'], "eval", "without_h_labels") 105 | 106 | if config.get('use_rxn_class', False): 107 | train_dir = os.path.join(train_dir, "with_rxn", model_dir_name) 108 | else: 109 | train_dir = os.path.join(train_dir, "without_rxn", model_dir_name) 110 | 111 | train_dataset = train_dataset_class(data_dir=train_dir, 112 | mpnn=config['mpnn']) 113 | 114 | if eval_dataset_class is not None: 115 | eval_dataset = eval_dataset_class(data_dir=eval_dir, 116 | data_file=config['info_file'], 117 | labels_file=config['labels_file'] if use_labels else None, 118 | use_rxn_class=config.get('use_rxn_class', False), 119 | num_shards=config['num_shards']) 120 | 121 | train_data = train_dataset.create_loader(batch_size=1, shuffle=True, 122 | num_workers=config['num_workers']) 123 | 124 | if eval_dataset_class is None: 125 | eval_data = None 126 | else: 127 | eval_data = eval_dataset.create_loader(batch_size=1, 128 | num_workers=config['num_workers']) 129 | 130 | date_and_time = dt.now().strftime("%d-%m-%Y--%H-%M-%S") 131 | 132 | trainer = Trainer(model=model, print_every=config['print_every'], 133 | eval_every=config['eval_every']) 134 | trainer.build_optimizer(learning_rate=config['lr'], finetune_encoder=False) 135 | trainer.build_scheduler(type=config['scheduler_type'], anneal_rate=config['anneal_rate'], 136 | patience=config['patience'], thresh=config['metric_thresh']) 137 | trainer.train_epochs(train_data, eval_data, config['epochs'], 138 | **{"accum_every": config.get('accum_every', None), 139 | "clip_norm": config['clip_norm']}) 140 | 141 | def main(args): 142 | # initialize wandb 143 | wandb.init(project='seq_graph_retro', dir=args.out_dir, 144 | config=args.config_file) 145 | config = wandb.config 146 | tmp_dict = vars(args) 147 | for key, value in tmp_dict.items(): 148 | config[key] = value 149 | 150 | run_model(config) 151 | 152 | 153 | def sweep(args): 154 | # load config 155 | with open(args.config_file) as file: 156 | default_config = yaml.load(file, Loader=yaml.FullLoader) 157 | 158 | loaded_config = {} 159 | for key in default_config: 160 | loaded_config[key] = default_config[key]['value'] 161 | 162 | tmp_dict = vars(args) 163 | for key, value in tmp_dict.items(): 164 | loaded_config[key] = value 165 | 166 | # init wandb 167 | wandb.init(allow_val_change=True, dir=args.out_dir) 168 | 169 | # update wandb config 170 | wandb.config.update(loaded_config) 171 | config = wandb.config 172 | 173 | # start run 174 | run_model(config) 175 | 176 | 177 | def get_args(): 178 | parser = argparse.ArgumentParser() 179 | 180 | parser.add_argument("--data_dir", default=DATA_DIR, help="Data directory") 181 | parser.add_argument("--out_dir", default=out_dir, help="Experiments directory") 182 | parser.add_argument("--info_file", default=INFO_FILE, 183 | help="File containing info. Used only for validation") 184 | parser.add_argument("--labels_file", default=LABELS_FILE, 185 | help='File containing leaving groups. Used only for validation') 186 | parser.add_argument("--vocab_file", default=VOCAB_FILE, 187 | help='File containing the vocabulary of leaving groups.') 188 | parser.add_argument("--num_shards", default=NUM_SHARDS, help="Number of shards") 189 | parser.add_argument("--num_workers", default=6, help="Number of workers") 190 | parser.add_argument("--config_file", required=True, help='File containing the configuration.') 191 | parser.add_argument("--sweep", action='store_true') 192 | 193 | args = parser.parse_args() 194 | return args 195 | 196 | 197 | if __name__ == "__main__": 198 | args = get_args() 199 | 200 | if not os.path.exists(args.out_dir): 201 | os.mkdir(args.out_dir) 202 | 203 | if args.sweep: 204 | sweep(args) 205 | else: 206 | main(args) 207 | -------------------------------------------------------------------------------- /scripts/eval/edit_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | import os 5 | import argparse 6 | import tqdm 7 | import wandb 8 | import yaml 9 | 10 | from seq_graph_retro.utils.parse import get_reaction_info 11 | from seq_graph_retro.models import SingleEdit, MultiEdit 12 | from seq_graph_retro.search import EditSearch 13 | 14 | try: 15 | ROOT_DIR = os.environ["SEQ_GRAPH_RETRO"] 16 | DATA_DIR = os.path.join(ROOT_DIR, "datasets", "uspto-50k") 17 | EXP_DIR = os.path.join(ROOT_DIR, "experiments") 18 | 19 | except KeyError: 20 | ROOT_DIR = "./" 21 | DATA_DIR = os.path.join(ROOT_DIR, "datasets", "uspto-50k") 22 | EXP_DIR = os.path.join(ROOT_DIR, "local_experiments") 23 | 24 | 25 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 26 | DEFAULT_TEST_FILE = f"{DATA_DIR}/canonicalized_test.csv" 27 | MODELS = {'SingleEdit': SingleEdit, "single_edit": SingleEdit, 28 | 'MultiEdit': MultiEdit, "multi_edit": MultiEdit} 29 | 30 | 31 | def main(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--data_dir", default=DATA_DIR, help="Data directory") 34 | parser.add_argument("--exp_dir", default=EXP_DIR, help="Experiments directory.") 35 | parser.add_argument("--test_file", default=DEFAULT_TEST_FILE, help="Test file") 36 | parser.add_argument("--edits_exp", default="SingleEdit_21-03-2020--20-33-05", 37 | help="Name of edit prediction experiment.") 38 | parser.add_argument("--edits_step", default=None, 39 | help="Checkpoint to load for the edits experiment.") 40 | parser.add_argument("--beam_width", type=int, default=5, help="Beam width") 41 | 42 | args = parser.parse_args() 43 | 44 | test_df = pd.read_csv(args.test_file) 45 | 46 | edits_step = args.edits_step 47 | if edits_step is None: 48 | edits_step = "best_model" 49 | 50 | if "run" in args.edits_exp: 51 | # This addition because some of the new experiments were run using wandb 52 | edits_loaded = torch.load(os.path.join(args.exp_dir, "wandb", args.edits_exp, "files", edits_step + ".pt"), map_location=DEVICE) 53 | with open(f"{args.exp_dir}/wandb/{args.edits_exp}/files/config.yaml", "r") as f: 54 | tmp_loaded = yaml.load(f, Loader=yaml.FullLoader) 55 | 56 | model_name = tmp_loaded['model']['value'] 57 | 58 | else: 59 | edits_loaded = torch.load(os.path.join(args.exp_dir, args.edits_exp, 60 | "checkpoints", edits_step + ".pt"), 61 | map_location=DEVICE) 62 | model_name = args.edits_exp.split("_")[0] 63 | 64 | model_class = MODELS.get(model_name) 65 | config = edits_loaded["saveables"] 66 | 67 | em = model_class(**config, device=DEVICE) 68 | em.load_state_dict(edits_loaded['state']) 69 | em.to(DEVICE) 70 | em.eval() 71 | 72 | toggles = config['toggles'] 73 | 74 | if model_name == 'single_edit' or model_name == "SingleEdit": 75 | beam_model = EditSearch(model=em, beam_width=args.beam_width, max_edits=1) 76 | else: 77 | beam_model = EditSearch(model=em, beam_width=args.beam_width, max_edits=6) 78 | 79 | pbar = tqdm.tqdm(list(range(len(test_df)))) 80 | n_matched = np.zeros(args.beam_width) 81 | 82 | for idx in pbar: 83 | rxn_smi = test_df.loc[idx, 'reactants>reagents>production'] 84 | r, p = rxn_smi.split(">>") 85 | rxn_class = test_df.loc[idx, 'class'] 86 | 87 | if rxn_class != 'UNK': 88 | rxn_class = int(rxn_class) 89 | try: 90 | info = get_reaction_info(rxn_smi=rxn_smi, 91 | kekulize=True, use_h_labels=True, 92 | rxn_class=rxn_class) 93 | true_edit = info.core_edits 94 | 95 | if toggles.get("use_rxn_class", False): 96 | top_k_nodes = beam_model.run_edit_step(p, max_steps=6, rxn_class=rxn_class) 97 | else: 98 | top_k_nodes = beam_model.run_edit_step(p, max_steps=6) 99 | 100 | beam_matched = False 101 | for beam_idx, node in enumerate(top_k_nodes): 102 | edit = node.edit 103 | if not isinstance(edit, list): 104 | edit = [edit] 105 | 106 | if set(edit) == set(true_edit) and not beam_matched: 107 | n_matched[beam_idx] += 1 108 | beam_matched = True 109 | 110 | msg = 'average score' 111 | for beam_idx in [1, 2, 3, 5]: 112 | match_perc = np.sum(n_matched[:beam_idx]) / (idx + 1) 113 | msg += ', t%d: %.4f' % (beam_idx, match_perc) 114 | pbar.set_description(msg) 115 | except Exception as e: 116 | print(e) 117 | continue 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /scripts/eval/lg_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | import pandas as pd 4 | import torch 5 | import os 6 | import argparse 7 | import tqdm 8 | import yaml 9 | 10 | from seq_graph_retro.utils.parse import get_reaction_info, extract_leaving_groups 11 | from seq_graph_retro.utils.chem import apply_edits_to_mol 12 | from seq_graph_retro.molgraph import MultiElement 13 | from seq_graph_retro.models import LGClassifier, LGIndEmbed 14 | from seq_graph_retro.search import LGSearch 15 | 16 | try: 17 | ROOT_DIR = os.environ["SEQ_GRAPH_RETRO"] 18 | DATA_DIR = os.path.join(ROOT_DIR, "datasets", "uspto-50k") 19 | EXP_DIR = os.path.join(ROOT_DIR, "experiments") 20 | 21 | except KeyError: 22 | ROOT_DIR = "./" 23 | DATA_DIR = os.path.join(ROOT_DIR, "datasets", "uspto-50k") 24 | EXP_DIR = os.path.join(ROOT_DIR, "local_experiments") 25 | 26 | 27 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 28 | DEFAULT_TEST_FILE = f"{DATA_DIR}/canonicalized_test.csv" 29 | MODELS = {'LGClassifier': LGClassifier, "lg_classifier": LGClassifier, 30 | "LGIndEmbed": LGIndEmbed, "lg_ind": LGIndEmbed, "LGIndEmbedClassifier": LGIndEmbed} 31 | 32 | def main(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--data_dir", default=DATA_DIR, help="Data directory") 35 | parser.add_argument("--exp_dir", default=EXP_DIR, help="Experiments directory.") 36 | parser.add_argument("--test_file", default=DEFAULT_TEST_FILE, help="Test file") 37 | parser.add_argument("--lg_exp", default="LGClassifier_02-04-2020--02-06-17", 38 | help="Name of synthon completion experiment") 39 | parser.add_argument("--lg_step", default=None, 40 | help="Checkpoint from synthon completion experiment") 41 | parser.add_argument("--beam_width", type=int, default=5, help="Beam width") 42 | args = parser.parse_args() 43 | 44 | test_df = pd.read_csv(args.test_file) 45 | 46 | lg_step = args.lg_step 47 | if lg_step is None: 48 | lg_step = "best_model" 49 | 50 | if "run" in args.lg_exp: 51 | # This addition because some of the new experiments were run using wandb 52 | lg_loaded = torch.load(os.path.join(args.exp_dir, "wandb", args.lg_exp, "files", lg_step + ".pt"), map_location=DEVICE) 53 | with open(f"{args.exp_dir}/wandb/{args.lg_exp}/files/config.yaml", "r") as f: 54 | tmp_loaded = yaml.load(f, Loader=yaml.FullLoader) 55 | 56 | model_name = tmp_loaded['model']['value'] 57 | 58 | else: 59 | lg_loaded = torch.load(os.path.join(args.exp_dir, args.lg_exp, 60 | "checkpoints", lg_step + ".pt"), 61 | map_location=DEVICE) 62 | model_name = args.lg_exp.split("_")[0] 63 | 64 | model_class = MODELS.get(model_name) 65 | config = lg_loaded["saveables"] 66 | toggles = config['toggles'] 67 | 68 | if 'tensor_file' in config: 69 | if not os.path.isfile(config['tensor_file']): 70 | if not toggles.get("use_rxn_class", False): 71 | tensor_file = os.path.join(args.data_dir, "train/h_labels/without_rxn/lg_inputs.pt") 72 | else: 73 | tensor_file = os.path.join(args.data_dir, "train/h_labels/with_rxn/lg_inputs.pt") 74 | config['tensor_file'] = tensor_file 75 | 76 | lg_model = model_class(**config, device=DEVICE) 77 | lg_model.load_state_dict(lg_loaded['state']) 78 | lg_model.to(DEVICE) 79 | lg_model.eval() 80 | 81 | n_matched = np.zeros(args.beam_width) 82 | beam_model = LGSearch(model=lg_model, beam_width=args.beam_width, max_edits=1) 83 | pbar = tqdm.tqdm(list(range(len(test_df)))) 84 | 85 | for idx in pbar: 86 | rxn_smi = test_df.loc[idx, 'reactants>reagents>production'] 87 | r, p = rxn_smi.split(">>") 88 | rxn_class = test_df.loc[idx, 'class'] 89 | 90 | if rxn_class != 'UNK': 91 | rxn_class = int(rxn_class) 92 | 93 | info = get_reaction_info(rxn_smi=rxn_smi, 94 | kekulize=True, use_h_labels=True, 95 | rxn_class=rxn_class) 96 | 97 | reactants = Chem.MolFromSmiles(r) 98 | products = Chem.MolFromSmiles(p) 99 | fragments = apply_edits_to_mol(Chem.Mol(products), info.core_edits) 100 | 101 | frag_mols = MultiElement(Chem.Mol(fragments)).mols 102 | reac_mols = MultiElement(Chem.Mol(reactants)).mols 103 | 104 | _, labels, _ = extract_leaving_groups([(products, reac_mols, frag_mols)]) 105 | assert len(labels) == 1 106 | lg_group = labels[0] 107 | 108 | if toggles.get("use_rxn_class", False): 109 | top_k_nodes = beam_model.run_search(p, edits=info.core_edits, max_steps=6, rxn_class=rxn_class) 110 | else: 111 | top_k_nodes = beam_model.run_search(p, edits=info.core_edits, max_steps=6) 112 | 113 | beam_matched = False 114 | for beam_idx, node in enumerate(top_k_nodes): 115 | pred_labels = node.lg_groups 116 | if pred_labels == lg_group and not beam_matched: 117 | n_matched[beam_idx] += 1 118 | beam_matched = True 119 | 120 | msg = 'average score' 121 | for beam_idx in [1, 2, 3, 5]: 122 | match_perc = np.sum(n_matched[:beam_idx]) / (idx + 1) 123 | msg += ', t%d: %.4f' % (beam_idx, match_perc) 124 | 125 | pbar.set_description(msg) 126 | 127 | if __name__ == "__main__": 128 | main() 129 | -------------------------------------------------------------------------------- /scripts/eval/single_edit_lg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | import os 5 | import argparse 6 | from tqdm import tqdm 7 | from rdkit import RDLogger, Chem 8 | import yaml 9 | 10 | from seq_graph_retro.utils.parse import get_reaction_info, extract_leaving_groups 11 | from seq_graph_retro.utils.chem import apply_edits_to_mol 12 | from seq_graph_retro.utils.edit_mol import canonicalize, generate_reac_set 13 | from seq_graph_retro.models import EditLGSeparate 14 | from seq_graph_retro.search import BeamSearch 15 | from seq_graph_retro.molgraph import MultiElement 16 | lg = RDLogger.logger() 17 | lg.setLevel(4) 18 | 19 | try: 20 | ROOT_DIR = os.environ["SEQ_GRAPH_RETRO"] 21 | DATA_DIR = os.path.join(ROOT_DIR, "datasets", "uspto-50k") 22 | EXP_DIR = os.path.join(ROOT_DIR, "experiments") 23 | 24 | except KeyError: 25 | ROOT_DIR = "./" 26 | DATA_DIR = os.path.join(ROOT_DIR, "datasets", "uspto-50k") 27 | EXP_DIR = os.path.join(ROOT_DIR, "local_experiments") 28 | 29 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 30 | DEFAULT_TEST_FILE = f"{DATA_DIR}/canonicalized_test.csv" 31 | 32 | def canonicalize_prod(p): 33 | pcanon = canonicalize(p) 34 | pmol = Chem.MolFromSmiles(pcanon) 35 | [atom.SetAtomMapNum(atom.GetIdx()+1) for atom in pmol.GetAtoms()] 36 | p = Chem.MolToSmiles(pmol) 37 | return p 38 | 39 | 40 | def load_edits_model(args): 41 | edits_step = args.edits_step 42 | if edits_step is None: 43 | edits_step = "best_model" 44 | 45 | if "run" in args.edits_exp: 46 | # This addition because some of the new experiments were run using wandb 47 | edits_loaded = torch.load(os.path.join(args.exp_dir, "wandb", args.edits_exp, "files", edits_step + ".pt"), map_location=DEVICE) 48 | with open(f"{args.exp_dir}/wandb/{args.edits_exp}/files/config.yaml", "r") as f: 49 | tmp_loaded = yaml.load(f, Loader=yaml.FullLoader) 50 | 51 | model_name = tmp_loaded['model']['value'] 52 | 53 | else: 54 | edits_loaded = torch.load(os.path.join(args.exp_dir, args.edits_exp, 55 | "checkpoints", edits_step + ".pt"), 56 | map_location=DEVICE) 57 | model_name = args.edits_exp.split("_")[0] 58 | 59 | return edits_loaded, model_name 60 | 61 | 62 | def load_lg_model(args): 63 | lg_step = args.lg_step 64 | if lg_step is None: 65 | lg_step = "best_model" 66 | 67 | if "run" in args.lg_exp: 68 | # This addition because some of the new experiments were run using wandb 69 | lg_loaded = torch.load(os.path.join(args.exp_dir, "wandb", args.lg_exp, "files", lg_step + ".pt"), map_location=DEVICE) 70 | with open(f"{args.exp_dir}/wandb/{args.lg_exp}/files/config.yaml", "r") as f: 71 | tmp_loaded = yaml.load(f, Loader=yaml.FullLoader) 72 | 73 | model_name = tmp_loaded['model']['value'] 74 | 75 | else: 76 | lg_loaded = torch.load(os.path.join(args.exp_dir, args.lg_exp, 77 | "checkpoints", lg_step + ".pt"), 78 | map_location=DEVICE) 79 | model_name = args.lg_exp.split("_")[0] 80 | 81 | return lg_loaded, model_name 82 | 83 | def main(): 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument("--data_dir", default=DATA_DIR, help="Data directory") 86 | parser.add_argument("--exp_dir", default=EXP_DIR, help="Experiments directory.") 87 | parser.add_argument("--test_file", default=DEFAULT_TEST_FILE, help="Test file.") 88 | parser.add_argument("--edits_exp", default="SingleEdit_21-03-2020--20-33-05", 89 | help="Name of edit prediction experiment.") 90 | parser.add_argument("--edits_step", default=None, 91 | help="Checkpoint for edit prediction experiment.") 92 | parser.add_argument("--lg_exp", default="LGClassifier_02-04-2020--02-06-17", 93 | help="Name of synthon completion experiment.") 94 | parser.add_argument("--lg_step", default=None, 95 | help="Checkpoint for synthon completion experiment.") 96 | parser.add_argument("--beam_width", default=10, type=int, help="Beam width") 97 | parser.add_argument("--use_rxn_class", action='store_true', help="Whether to use reaction class.") 98 | parser.add_argument("--rxn_class_acc", action="store_true", 99 | help="Whether to print reaction class accuracy.") 100 | args = parser.parse_args() 101 | 102 | test_df = pd.read_csv(args.test_file) 103 | 104 | edits_loaded, edit_net_name = load_edits_model(args) 105 | lg_loaded, lg_net_name = load_lg_model(args) 106 | 107 | edits_config = edits_loaded["saveables"] 108 | lg_config = lg_loaded['saveables'] 109 | lg_toggles = lg_config['toggles'] 110 | 111 | if 'tensor_file' in lg_config: 112 | if not os.path.isfile(lg_config['tensor_file']): 113 | if not lg_toggles.get("use_rxn_class", False): 114 | tensor_file = os.path.join(args.data_dir, "train/h_labels/without_rxn/lg_inputs.pt") 115 | else: 116 | tensor_file = os.path.join(args.data_dir, "train/h_labels/with_rxn/lg_inputs.pt") 117 | lg_config['tensor_file'] = tensor_file 118 | 119 | rm = EditLGSeparate(edits_config=edits_config, lg_config=lg_config, edit_net_name=edit_net_name, 120 | lg_net_name=lg_net_name, device=DEVICE) 121 | rm.load_state_dict(edits_loaded['state'], lg_loaded['state']) 122 | rm.to(DEVICE) 123 | rm.eval() 124 | 125 | n_matched = np.zeros(args.beam_width) 126 | 127 | beam_model = BeamSearch(model=rm, beam_width=args.beam_width, max_edits=1) 128 | pbar = tqdm(list(range(len(test_df)))) 129 | 130 | for idx in pbar: 131 | rxn_smi = test_df.loc[idx, 'reactants>reagents>production'] 132 | r, p = rxn_smi.split(">>") 133 | 134 | rxn_class = test_df.loc[idx, 'class'] 135 | 136 | if rxn_class != 'UNK': 137 | rxn_class = int(rxn_class) 138 | 139 | # Canonicalize the product and reactant sets, just for security. 140 | # The product is already canonicalized since the dataset we use is the canonicalized one. 141 | p = canonicalize_prod(p) 142 | r_can = canonicalize(r) 143 | rset = set(r_can.split(".")) 144 | 145 | try: 146 | if lg_toggles.get("use_rxn_class", False): 147 | top_k_nodes = beam_model.run_search(p, max_steps=6, rxn_class=rxn_class) 148 | else: 149 | top_k_nodes = beam_model.run_search(p, max_steps=6) 150 | 151 | beam_matched = False 152 | for beam_idx, node in enumerate(top_k_nodes): 153 | pred_edit = node.edit 154 | pred_label = node.lg_groups 155 | 156 | if isinstance(pred_edit, list): 157 | pred_edit = pred_edit[0] 158 | try: 159 | pred_set = generate_reac_set(p, pred_edit, pred_label, verbose=False) 160 | except BaseException as e: 161 | print(e, flush=True) 162 | pred_set = None 163 | 164 | if pred_set == rset and not beam_matched: 165 | n_matched[beam_idx] += 1 166 | beam_matched = True 167 | 168 | except Exception as e: 169 | print(e) 170 | continue 171 | 172 | msg = 'average score' 173 | for beam_idx in [1, 3, 5, 10, 20, 50]: 174 | match_perc = np.sum(n_matched[:beam_idx]) / (idx + 1) 175 | msg += ', t%d: %.4f' % (beam_idx, match_perc) 176 | pbar.set_description(msg) 177 | 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /seq_graph_retro/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/__init__.py -------------------------------------------------------------------------------- /seq_graph_retro/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/data/__init__.py: -------------------------------------------------------------------------------- 1 | from seq_graph_retro.data.dataset import BaseDataset, EvalDataset 2 | from seq_graph_retro.data.edits_datasets import SingleEditDataset, EditsEvalDataset, MultiEditDataset 3 | from seq_graph_retro.data.lg_datasets import LGClassifierDataset, LGEvalDataset 4 | from seq_graph_retro.data.shared_retro_datasets import SingleEditSharedDataset, SharedEvalDataset 5 | from seq_graph_retro.data.pretrain_datasets import EncoderDataset, EncoderEvalDataset, ContextPredDataset 6 | -------------------------------------------------------------------------------- /seq_graph_retro/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/data/__pycache__/collate_fns.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/data/__pycache__/collate_fns.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/data/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/data/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/data/__pycache__/edits_datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/data/__pycache__/edits_datasets.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/data/__pycache__/lg_datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/data/__pycache__/lg_datasets.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/data/__pycache__/pretrain_datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/data/__pycache__/pretrain_datasets.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/data/__pycache__/shared_retro_datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/data/__pycache__/shared_retro_datasets.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/data/collate_fns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rdkit import Chem 3 | import networkx as nx 4 | 5 | from seq_graph_retro.molgraph.mol_features import get_atom_features, get_bond_features 6 | from seq_graph_retro.molgraph.mol_features import BOND_FDIM, ATOM_FDIM, BOND_TYPES 7 | from seq_graph_retro.utils.torch import create_pad_tensor 8 | 9 | from typing import Any, List, Dict, Tuple 10 | 11 | def prepare_lg_labels(lg_dict: Dict, lg_data: List) -> torch.Tensor: 12 | """Prepare leaving group tensors. 13 | 14 | Parameters 15 | ---------- 16 | lg_dict: Dict 17 | Dictionary containing leaving groups to indices map 18 | lg_data: List 19 | List of lists containing the leaving groups 20 | """ 21 | pad_idx, unk_idx = lg_dict[""], lg_dict[""] 22 | lg_labels = [[lg_dict.get(lg_group, unk_idx) for lg_group in labels] for labels in lg_data] 23 | 24 | lengths = [len(lg) for lg in lg_labels] 25 | labels = torch.full(size=(len(lg_labels), max(lengths)), fill_value=pad_idx, dtype=torch.long) 26 | for i, lgs in enumerate(lg_labels): 27 | labels[i, :len(lgs)] = torch.tensor(lgs) 28 | return labels, lengths 29 | 30 | def pack_graph_feats(graph_batch: List[Any], directed: bool, use_rxn_class: bool = False, 31 | return_graphs: bool = False) -> Tuple[torch.Tensor, List[Tuple[int]]]: 32 | """Prepare graph tensors. 33 | 34 | Parameters 35 | ---------- 36 | graph_batch: List[Any], 37 | Batch of graph objects. Should have attributes G_dir, G_undir 38 | directed: bool, 39 | Whether to prepare tensors for directed message passing 40 | use_rxn_class: bool, default False, 41 | Whether to use reaction class as additional input 42 | return_graphs: bool, default False, 43 | Whether to return the graphs 44 | """ 45 | if directed: 46 | fnode = [get_atom_features(Chem.Atom("*"), use_rxn_class=use_rxn_class, rxn_class=0)] 47 | fmess = [[0,0] + [0] * BOND_FDIM] 48 | agraph, bgraph = [[]], [[]] 49 | atoms_in_bonds = [[]] 50 | 51 | atom_scope, bond_scope = [], [] 52 | edge_dict = {} 53 | all_G = [] 54 | 55 | for bid, graph in enumerate(graph_batch): 56 | mol = graph.mol 57 | assert mol.GetNumAtoms() == len(graph.G_dir) 58 | atom_offset = len(fnode) 59 | bond_offset = len(atoms_in_bonds) 60 | 61 | bond_to_tuple = {bond.GetIdx(): tuple(sorted((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()))) 62 | for bond in mol.GetBonds()} 63 | tuple_to_bond = {val: key for key, val in bond_to_tuple.items()} 64 | 65 | atom_scope.append(graph.update_atom_scope(atom_offset)) 66 | bond_scope.append(graph.update_bond_scope(bond_offset)) 67 | 68 | G = nx.convert_node_labels_to_integers(graph.G_dir, first_label=atom_offset) 69 | all_G.append(G) 70 | fnode.extend( [None for v in G.nodes] ) 71 | 72 | for v, attr in G.nodes(data='label'): 73 | G.nodes[v]['batch_id'] = bid 74 | fnode[v] = get_atom_features(mol.GetAtomWithIdx(v-atom_offset), 75 | use_rxn_class=use_rxn_class, 76 | rxn_class=graph.rxn_class) 77 | agraph.append([]) 78 | 79 | bond_comp = [None for _ in range(mol.GetNumBonds())] 80 | for u, v, attr in G.edges(data='label'): 81 | bond_feat = get_bond_features(mol.GetBondBetweenAtoms(u-atom_offset, v-atom_offset)).tolist() 82 | 83 | bond = sorted([u, v]) 84 | mess_vec = [u, v] + bond_feat 85 | if [v, u] not in bond_comp: 86 | idx_to_add = tuple_to_bond[(u-atom_offset, v-atom_offset)] 87 | bond_comp[idx_to_add] = [u, v] 88 | 89 | fmess.append(mess_vec) 90 | edge_dict[(u, v)] = eid = len(edge_dict) + 1 91 | G[u][v]['mess_idx'] = eid 92 | agraph[v].append(eid) 93 | bgraph.append([]) 94 | atoms_in_bonds.extend(bond_comp) 95 | 96 | for u, v in G.edges: 97 | eid = edge_dict[(u, v)] 98 | for w in G.predecessors(u): 99 | if w == v: continue 100 | bgraph[eid].append( edge_dict[(w, u)] ) 101 | 102 | fnode = torch.tensor(fnode, dtype=torch.float) 103 | fmess = torch.tensor(fmess, dtype=torch.float) 104 | atoms_in_bonds = create_pad_tensor(atoms_in_bonds).long() 105 | agraph = create_pad_tensor(agraph) 106 | bgraph = create_pad_tensor(bgraph) 107 | 108 | graph_tensors = (fnode, fmess, agraph, bgraph, atoms_in_bonds) 109 | scopes = (atom_scope, bond_scope) 110 | 111 | if return_graphs: 112 | return graph_tensors, scopes, nx.union_all(all_G) 113 | else: 114 | return graph_tensors, scopes 115 | 116 | else: 117 | afeat = [get_atom_features(Chem.Atom("*"), use_rxn_class=use_rxn_class, rxn_class=0)] 118 | bfeat = [[0] * BOND_FDIM] 119 | atoms_in_bonds = [[]] 120 | agraph, bgraph = [[]], [[]] 121 | atom_scope = [] 122 | bond_scope = [] 123 | edge_dict = {} 124 | all_G = [] 125 | 126 | for bid, graph in enumerate(graph_batch): 127 | mol = graph.mol 128 | assert mol.GetNumAtoms() == len(graph.G_undir) 129 | atom_offset = len(afeat) 130 | bond_offset = len(bfeat) 131 | 132 | atom_scope.append(graph.update_atom_scope(atom_offset)) 133 | bond_scope.append(graph.update_bond_scope(bond_offset)) 134 | 135 | G = nx.convert_node_labels_to_integers(graph.G_undir, first_label=atom_offset) 136 | all_G.append(G) 137 | afeat.extend( [None for v in G.nodes] ) 138 | 139 | for v, attr in G.nodes(data='label'): 140 | G.nodes[v]['batch_id'] = bid 141 | afeat[v] = get_atom_features(mol.GetAtomWithIdx(v-atom_offset), 142 | use_rxn_class=use_rxn_class, 143 | rxn_class=graph.rxn_class) 144 | agraph.append([]) 145 | bgraph.append([]) 146 | 147 | for u, v, attr in G.edges(data='label'): 148 | bond_feat = get_bond_features(mol.GetBondBetweenAtoms(u-atom_offset, v-atom_offset)).tolist() 149 | bfeat.append(bond_feat) 150 | atoms_in_bonds.append([u, v]) 151 | 152 | edge_dict[(u, v)] = eid = len(edge_dict) + 1 153 | G[u][v]['mess_idx'] = eid 154 | 155 | agraph[v].append(u) 156 | agraph[u].append(v) 157 | 158 | bgraph[u].append(eid) 159 | bgraph[v].append(eid) 160 | 161 | afeat = torch.tensor(afeat, dtype=torch.float) 162 | bfeat = torch.tensor(bfeat, dtype=torch.float) 163 | atoms_in_bonds = create_pad_tensor(atoms_in_bonds).long() 164 | agraph = create_pad_tensor(agraph) 165 | bgraph = create_pad_tensor(bgraph) 166 | 167 | graph_tensors = (afeat, bfeat, agraph, bgraph, atoms_in_bonds) 168 | scopes = (atom_scope, bond_scope) 169 | 170 | if return_graphs: 171 | return graph_tensors, scopes, nx.union_all(all_G) 172 | else: 173 | return graph_tensors, scopes 174 | 175 | def tensorize_bond_graphs(graph_batch, directed: bool, use_rxn_class: False, 176 | return_graphs: bool = False): 177 | if directed: 178 | edge_dict = {} 179 | fnode = [[0] * BOND_FDIM] 180 | if use_rxn_class: 181 | fmess = [[0, 0] + [0] * (ATOM_FDIM + 10) + [0] + [0] * 2 * (BOND_FDIM - 1)] 182 | else: 183 | fmess = [[0, 0] + [0] * ATOM_FDIM + [0] + [0] * 2 * (BOND_FDIM - 1)] 184 | agraph, bgraph = [[]], [[]] 185 | scope = [] 186 | 187 | for bid, graph in enumerate(graph_batch): 188 | mol = graph.mol 189 | assert mol.GetNumAtoms() == len(graph.G_undir) 190 | offset = len(fnode) 191 | bond_graph = nx.line_graph(graph.G_undir) 192 | bond_graph = nx.to_directed(bond_graph) 193 | fnode.extend([None for v in bond_graph.nodes]) 194 | 195 | scope.append((offset, mol.GetNumBonds())) 196 | ri = mol.GetRingInfo() 197 | 198 | bond_rings = ri.BondRings() 199 | bond_to_tuple = {bond.GetIdx(): tuple(sorted((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()))) 200 | for bond in mol.GetBonds()} 201 | tuple_to_bond = {val: key for key, val in bond_to_tuple.items()} 202 | 203 | for u in bond_graph.nodes(): 204 | agraph.append([]) 205 | atom_idx_a, atom_idx_b = u 206 | bond_idx = tuple_to_bond[u] + offset 207 | fnode[bond_idx] = get_bond_features(mol.GetBondBetweenAtoms(atom_idx_a, atom_idx_b)).tolist() 208 | 209 | for u, v in bond_graph.edges(): 210 | edge_dict[(u, v)] = eid = len(edge_dict) + 1 211 | bond_idx_u = tuple_to_bond[tuple(sorted(u))] + offset 212 | bond_idx_v = tuple_to_bond[tuple(sorted(v))] + offset 213 | 214 | common_atom_idx = set(u).intersection(set(v)) 215 | incommon_ring = 0 216 | for ring in bond_rings: 217 | if (bond_idx_u-offset) in ring and (bond_idx_v-offset) in ring: 218 | incommon_ring = 1 219 | break 220 | 221 | common_atom = mol.GetAtomWithIdx(list(common_atom_idx)[0]) 222 | edge_feats = get_atom_features(common_atom, 223 | use_rxn_class=use_rxn_class, 224 | rxn_class=graph.rxn_class) + [incommon_ring] 225 | atom_idx_a, atom_idx_b = u 226 | atom_idx_c, atom_idx_d = v 227 | 228 | bond_u = mol.GetBondBetweenAtoms(atom_idx_a, atom_idx_b) 229 | bond_v = mol.GetBondBetweenAtoms(atom_idx_c, atom_idx_d) 230 | 231 | bt_u, bt_v = bond_u.GetBondType(), bond_v.GetBondType() 232 | conj_u, conj_v = bond_u.GetIsConjugated(), bond_v.GetIsConjugated() 233 | sorted_u, sorted_v = sorted([bt_u, bt_v]) 234 | 235 | feats_u = [float(sorted_u == bond_type) for bond_type in BOND_TYPES[1:]] 236 | feats_v = [float(sorted_v == bond_type) for bond_type in BOND_TYPES[1:]] 237 | 238 | edge_feats.extend(feats_u) 239 | edge_feats.extend(feats_v) 240 | edge_feats.extend(sorted([conj_u, conj_v])) 241 | 242 | mess_vec = [bond_idx_u, bond_idx_v] + edge_feats 243 | fmess.append(mess_vec) 244 | agraph[bond_idx_v].append(eid) 245 | bgraph.append([]) 246 | 247 | for u, v in bond_graph.edges(): 248 | eid = edge_dict[(u, v)] 249 | for w in bond_graph.predecessors(u): 250 | if w == v: continue 251 | bgraph[eid].append(edge_dict[(w, u)]) 252 | 253 | fnode = torch.tensor(fnode, dtype=torch.float) 254 | fmess = torch.tensor(fmess, dtype=torch.float) 255 | agraph = create_pad_tensor(agraph) 256 | bgraph = create_pad_tensor(bgraph) 257 | 258 | graph_tensors = (fnode, fmess, agraph, bgraph, None) 259 | return graph_tensors, scope 260 | -------------------------------------------------------------------------------- /seq_graph_retro/data/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any, Tuple, List 3 | import os 4 | import joblib 5 | 6 | from seq_graph_retro.utils.parse import ReactionInfo 7 | 8 | class BaseDataset(torch.utils.data.Dataset): 9 | """BaseDataset is an abstract class that loads the saved tensor batches and 10 | passes them to the model for training.""" 11 | 12 | def __init__(self, data_dir: str, mpnn: str = 'graph_feat', **kwargs): 13 | """ 14 | Parameters 15 | ---------- 16 | data_dir: str, 17 | Data directory to load batches from 18 | mpnn: str, default graph_feat 19 | MPNN to load batches for 20 | num_batches: int, default None, 21 | Number of batches to load in the directory 22 | """ 23 | if mpnn == 'gtrans': 24 | mpnn = 'graph_feat' 25 | self.data_dir = data_dir 26 | self.data_files = [ 27 | os.path.join(self.data_dir, mpnn, file) 28 | for file in os.listdir(os.path.join(self.data_dir, mpnn)) 29 | if "batch-" in file 30 | ] 31 | self.__dict__.update(**kwargs) 32 | 33 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor]: 34 | """Retrieves a particular batch of tensors. 35 | 36 | Parameters 37 | ---------- 38 | idx: int, 39 | Batch index 40 | """ 41 | batch_tensors = torch.load(self.data_files[idx], map_location='cpu') 42 | return batch_tensors 43 | 44 | def __len__(self) -> int: 45 | """Returns length of the Dataset.""" 46 | return len(self.data_files) 47 | 48 | def create_loader(self, batch_size: int, num_workers: int = 6, shuffle: bool = False) -> torch.utils.data.DataLoader: 49 | """Creates a DataLoader from given batches. 50 | 51 | Parameters 52 | ---------- 53 | batch_size: int, 54 | Batch size of outputs 55 | num_workers: int, default 6 56 | Number of workers to use 57 | shuffle: bool, default True 58 | Whether to shuffle batches 59 | """ 60 | return torch.utils.data.DataLoader(dataset=self, batch_size=batch_size, 61 | shuffle=shuffle, num_workers=num_workers, 62 | collate_fn=self.collater) 63 | 64 | def collater(self, attributes: List[Any]): 65 | """Processes the batch of tensors to yield corresponding inputs.""" 66 | raise NotImplementedError("Subclasses must implement for themselves") 67 | 68 | class EvalDataset(torch.utils.data.Dataset): 69 | 70 | """EvalDataset is an abstract class that handles evaluation during training.""" 71 | 72 | def __init__(self, data_dir: str, data_file: str, labels_file: str = None, 73 | num_shards: int = None, use_rxn_class: bool = False) -> None: 74 | """ 75 | Parameters 76 | ---------- 77 | data_dir: str, 78 | Data directory to load batches from 79 | data_file: str, 80 | Info file to load 81 | labels_file: str, default None, 82 | Labels file. If None, nothing to load 83 | num_shards: int, default None, 84 | Number of info file shards present 85 | use_rxn_class: bool, default False, 86 | Whether to use reaction class as additional feature. 87 | """ 88 | self.data_dir = data_dir 89 | self.data_file = os.path.join(data_dir, data_file) 90 | self.use_rxn_class = use_rxn_class 91 | 92 | if num_shards is not None: 93 | self.dataset = [] 94 | for shard_num in range(num_shards): 95 | shard_file = self.data_file + f"-shard-{shard_num}" 96 | self.dataset.extend(joblib.load(shard_file)) 97 | 98 | else: 99 | self.dataset = joblib.load(self.data_file) 100 | 101 | self.labels = None 102 | if labels_file is not None: 103 | self.labels = joblib.load(os.path.join(data_dir, labels_file)) 104 | assert len(self.labels) == len(self.dataset) 105 | 106 | def __len__(self) -> int: 107 | """Returns length of the Dataset.""" 108 | return len(self.dataset) 109 | 110 | def __getitem__(self, idx: int) -> ReactionInfo: 111 | """Retrieves the corresponding ReactionInfo 112 | 113 | Parameters 114 | ---------- 115 | idx: int, 116 | Index of particular element 117 | """ 118 | if self.labels is not None: 119 | return self.dataset[idx], self.labels[idx] 120 | return self.dataset[idx] 121 | 122 | def create_loader(self, batch_size: int, num_workers: int = 6, shuffle: bool = False) -> torch.utils.data.DataLoader: 123 | """Creates a DataLoader from given batches. 124 | 125 | Parameters 126 | ---------- 127 | batch_size: int, 128 | Batch size of outputs 129 | num_workers: int, default 6 130 | Number of workers to use 131 | shuffle: bool, default True 132 | Whether to shuffle batches 133 | """ 134 | return torch.utils.data.DataLoader(dataset=self, batch_size=batch_size, 135 | shuffle=shuffle, num_workers=num_workers, 136 | collate_fn=self.collater) 137 | 138 | def collater(self, attributes: List[Any]): 139 | """Processes the batch of tensors to yield corresponding inputs.""" 140 | raise NotImplementedError("Subclasses must implement for themselves") 141 | -------------------------------------------------------------------------------- /seq_graph_retro/data/edits_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple, List, Optional 3 | 4 | from seq_graph_retro.utils.parse import ReactionInfo 5 | from seq_graph_retro.data import BaseDataset, EvalDataset 6 | 7 | class SingleEditDataset(BaseDataset): 8 | 9 | def collater(self, attributes: List[Tuple[torch.tensor]]) -> Tuple[torch.Tensor]: 10 | assert isinstance(attributes, list) 11 | assert len(attributes) == 1 12 | 13 | attributes = attributes[0] 14 | prod_inputs, edit_labels, frag_inputs, lg_labels, lengths, bg_inputs = attributes 15 | prod_tensors, prod_scopes = prod_inputs 16 | return prod_tensors, prod_scopes, bg_inputs, edit_labels 17 | 18 | class MultiEditDataset(BaseDataset): 19 | 20 | def collater(self, attributes: List[Tuple[torch.tensor]]) -> Tuple[torch.Tensor]: 21 | assert isinstance(attributes, list) 22 | assert len(attributes) == 1 23 | 24 | attributes = attributes[0] 25 | prod_seq_inputs, edit_labels, seq_masks, frag_inputs, lg_labels, lengths = attributes 26 | return prod_seq_inputs, edit_labels, seq_masks 27 | 28 | 29 | class EditsEvalDataset(EvalDataset): 30 | 31 | def collater(self, attributes: List[ReactionInfo]) -> Tuple[str, List[str], Optional[List[int]]]: 32 | info_batch = attributes 33 | prod_smi = [info.rxn_smi.split(">>")[-1] for info in info_batch] 34 | core_edits = [set(info.core_edits) for info in info_batch] 35 | 36 | if self.use_rxn_class: 37 | rxn_classes = [info.rxn_class for info in info_batch] 38 | return prod_smi, core_edits, rxn_classes 39 | else: 40 | return prod_smi, core_edits, None 41 | -------------------------------------------------------------------------------- /seq_graph_retro/data/lg_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple, List, Optional 3 | 4 | from seq_graph_retro.utils.parse import ReactionInfo 5 | from seq_graph_retro.data import BaseDataset, EvalDataset 6 | 7 | class LGClassifierDataset(BaseDataset): 8 | 9 | def collater(self, attributes: List[Tuple[torch.tensor]]) -> Tuple[torch.Tensor]: 10 | assert isinstance(attributes, list) 11 | assert len(attributes) == 1 12 | 13 | attributes = attributes[0] 14 | prod_inputs, frag_inputs, lg_labels, lengths = attributes 15 | return prod_inputs, frag_inputs, lg_labels, lengths 16 | 17 | class LGEvalDataset(EvalDataset): 18 | 19 | def collater(self, attributes: List[ReactionInfo]) -> Tuple[str, List[str], Optional[List[int]]]: 20 | info_batch, label_batch = list(zip(*attributes)) 21 | prod_smi = [info.rxn_smi.split(">>")[-1] for info in info_batch] 22 | core_edits = [info.core_edits for info in info_batch] 23 | 24 | if self.use_rxn_class: 25 | rxn_classes = [info.rxn_class for info in info_batch] 26 | return prod_smi, core_edits, label_batch, rxn_classes 27 | else: 28 | return prod_smi, core_edits, label_batch, None 29 | -------------------------------------------------------------------------------- /seq_graph_retro/data/pretrain_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rdkit import Chem 3 | from typing import Tuple, List 4 | 5 | from seq_graph_retro.data.collate_fns import pack_graph_feats 6 | from seq_graph_retro.data import BaseDataset, EvalDataset 7 | from seq_graph_retro.utils.parse import ReactionInfo 8 | from seq_graph_retro.molgraph import MultiElement, RxnElement 9 | 10 | def prep_graphs(info: ReactionInfo) -> Tuple[RxnElement, MultiElement]: 11 | """Prepares reaction graphs for the collater. 12 | 13 | Parameters 14 | ---------- 15 | info: ReactionInfo, 16 | ReactionInfo for the particular reaction. 17 | """ 18 | r, p = info.rxn_smi.split(">>") 19 | return (RxnElement(Chem.MolFromSmiles(p), rxn_class=info.rxn_class), 20 | MultiElement(Chem.MolFromSmiles(r), rxn_class=info.rxn_class)) 21 | 22 | def prep_dgi_eval(info: ReactionInfo) -> List[RxnElement]: 23 | """Prepares a list of RxnElements for the DGI collater. 24 | 25 | Parameters 26 | ---------- 27 | info: ReactionInfo, 28 | ReactionInfo for a particular reaction. 29 | """ 30 | rxn_elements = [] 31 | r, p = info.rxn_smi.split(">>") 32 | rxn_elements.append(RxnElement(Chem.MolFromSmiles(p))) 33 | rxn_elements.extend([RxnElement(Chem.MolFromSmiles(smi)) for smi in r.split(".")]) 34 | return rxn_elements 35 | 36 | class EncoderDataset(BaseDataset): 37 | 38 | def collater(self, attributes: List[Tuple[torch.tensor]]) -> Tuple[torch.Tensor]: 39 | assert isinstance(attributes, list) 40 | assert len(attributes) == 1 41 | 42 | attributes = attributes[0] 43 | prod_inputs, reac_inputs, frag_inputs = attributes 44 | return prod_inputs, reac_inputs, frag_inputs 45 | 46 | class EncoderEvalDataset(EvalDataset): 47 | 48 | def collater(self, attributes: List[ReactionInfo]) -> Tuple[torch.Tensor]: 49 | rxn_smi_batch = [prep_graphs(info) for info in attributes] 50 | prod_batch, reac_batch = list(zip(*rxn_smi_batch)) 51 | prod_inputs = pack_graph_feats(prod_batch, directed=True, use_rxn_class=self.use_rxn_class) 52 | reac_inputs = pack_graph_feats(reac_batch, directed=True, use_rxn_class=self.use_rxn_class) 53 | return prod_inputs, reac_inputs 54 | 55 | class ContextPredDataset(BaseDataset): 56 | 57 | def collater(self, attributes: List[Tuple[torch.Tensor]]) -> Tuple[torch.Tensor]: 58 | assert isinstance(attributes, list) 59 | assert len(attributes) == 1 60 | 61 | attributes = attributes[0] 62 | substruct_inputs, context_inputs, root_idxs, overlaps, overlap_scopes = attributes 63 | return substruct_inputs, context_inputs, root_idxs, overlaps, overlap_scopes 64 | -------------------------------------------------------------------------------- /seq_graph_retro/data/shared_retro_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple, List, Optional 3 | 4 | from seq_graph_retro.utils.parse import ReactionInfo 5 | from seq_graph_retro.data import BaseDataset, EvalDataset 6 | 7 | class SingleEditSharedDataset(BaseDataset): 8 | 9 | def collater(self, attributes: List[Tuple[torch.tensor]]) -> Tuple[torch.Tensor]: 10 | assert isinstance(attributes, list) 11 | assert len(attributes) == 1 12 | attributes = attributes[0] 13 | prod_inputs, edit_labels, frag_inputs, lg_labels, lengths, bg_inputs = attributes 14 | return prod_inputs, bg_inputs, frag_inputs, edit_labels, lg_labels, lengths 15 | 16 | class SharedEvalDataset(EvalDataset): 17 | 18 | def collater(self, attributes: List[ReactionInfo]) -> Tuple[str, List[str], List[str], Optional[List[int]]]: 19 | info_batch, label_batch = list(zip(*attributes)) 20 | prod_smi = [info.rxn_smi.split(">>")[-1] for info in info_batch] 21 | core_edits = [set(info.core_edits) for info in info_batch] 22 | if self.use_rxn_class: 23 | rxn_classes = [info.rxn_class for info in info_batch] 24 | return prod_smi, core_edits, label_batch, rxn_classes 25 | else: 26 | return prod_smi, core_edits, label_batch, None 27 | -------------------------------------------------------------------------------- /seq_graph_retro/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from seq_graph_retro.layers.reaction import AtomAttention, PairFeat 2 | from seq_graph_retro.layers.rnn import GRU, LSTM, MPNLayer 3 | from seq_graph_retro.layers.graph_transformer import (SublayerConnection, MultiHeadBlock, 4 | MultiHeadAttention, PositionwiseFeedForward) 5 | from seq_graph_retro.layers.encoder import GraphFeatEncoder, WLNEncoder, LogitEncoder, GTransEncoder 6 | -------------------------------------------------------------------------------- /seq_graph_retro/layers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/layers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/layers/__pycache__/encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/layers/__pycache__/encoder.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/layers/__pycache__/graph_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/layers/__pycache__/graph_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/layers/__pycache__/reaction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/layers/__pycache__/reaction.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/layers/__pycache__/rnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/layers/__pycache__/rnn.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/layers/graph_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from seq_graph_retro.layers.rnn import MPNLayer 7 | 8 | class SublayerConnection(nn.Module): 9 | """ 10 | A residual connection followed by a layer norm. 11 | Note for code simplicity the norm is first as opposed to last. 12 | """ 13 | 14 | def __init__(self, hsize: int, dropout_p: float = 0.15): 15 | """Initialization. 16 | :param size: the input dimension. 17 | :param dropout: the dropout ratio. 18 | """ 19 | super(SublayerConnection, self).__init__() 20 | self.norm = nn.LayerNorm(hsize, elementwise_affine=True) 21 | self.dropout_layer = nn.Dropout(dropout_p) 22 | 23 | def forward(self, inputs, outputs): 24 | """Apply residual connection to any sublayer with the same size.""" 25 | if inputs is None: 26 | return self.dropout_layer(self.norm(outputs)) 27 | return inputs + self.dropout_layer(self.norm(outputs)) 28 | 29 | 30 | class PositionwiseFeedForward(nn.Module): 31 | """Implements FFN equation.""" 32 | 33 | def __init__(self, in_dim, h_dim, out_dim=None, dropout_p=0.3, **kwargs): 34 | super(PositionwiseFeedForward, self).__init__(**kwargs) 35 | if out_dim is None: 36 | out_dim = in_dim 37 | self.in_dim = in_dim 38 | self.out_dim = out_dim 39 | self.h_dim = h_dim 40 | self.dropout_p = dropout_p 41 | self._build_components() 42 | 43 | def _build_components(self): 44 | self.W_1 = nn.Linear(self.in_dim, self.h_dim) 45 | self.W_2 = nn.Linear(self.h_dim, self.out_dim) 46 | self.dropout_layer = nn.Dropout(self.dropout_p) 47 | 48 | def forward(self, x): 49 | return self.W_2(self.dropout_layer(F.relu(self.W_1(x)))) 50 | 51 | 52 | class Head(nn.Module): 53 | 54 | def __init__(self, 55 | rnn_type: str, 56 | edge_fdim: int, 57 | node_fdim: int, 58 | hsize: int, 59 | depth: int, 60 | dropout_p: float = 0.15, 61 | **kwargs): 62 | super(Head, self).__init__(**kwargs) 63 | self.rnn_type = rnn_type 64 | self.edge_fdim = edge_fdim 65 | self.node_fdim = node_fdim 66 | self.hsize = hsize 67 | self.depth = depth 68 | self.dropout_p = dropout_p 69 | self._build_components() 70 | 71 | def _build_components(self): 72 | self.mpn_q = MPNLayer(rnn_type=self.rnn_type, edge_fdim=self.edge_fdim, 73 | node_fdim=self.node_fdim, hsize=self.hsize, 74 | depth=self.depth, dropout_p=self.dropout_p) 75 | self.mpn_k = MPNLayer(rnn_type=self.rnn_type, edge_fdim=self.edge_fdim, 76 | node_fdim=self.node_fdim, hsize=self.hsize, 77 | depth=self.depth, dropout_p=self.dropout_p) 78 | self.mpn_v = MPNLayer(rnn_type=self.rnn_type, edge_fdim=self.edge_fdim, 79 | node_fdim=self.node_fdim, hsize=self.hsize, 80 | depth=self.depth, dropout_p=self.dropout_p) 81 | 82 | def embed_graph(self, graph_tensors): 83 | """Replaces input graph tensors with corresponding feature vectors. 84 | 85 | Parameters 86 | ---------- 87 | graph_tensors: Tuple[torch.Tensor], 88 | Tuple of graph tensors - Contains atom features, message vector details, 89 | atom graph and bond graph for encoding neighborhood connectivity. 90 | """ 91 | fnode, fmess, agraph, bgraph, _ = graph_tensors 92 | hnode = fnode.clone() 93 | fmess1 = hnode.index_select(index=fmess[:, 0].long(), dim=0) 94 | fmess2 = fmess[:, 2:].clone() 95 | hmess = torch.cat([fmess1, fmess2], dim=-1) 96 | return hnode, hmess, agraph, bgraph 97 | 98 | def forward(self, graph_tensors, mask=None): 99 | graph_tensors = self.embed_graph(graph_tensors) 100 | q, _ = self.mpn_q(*graph_tensors, mask=mask) 101 | k, _ = self.mpn_k(*graph_tensors, mask=mask) 102 | v, _ = self.mpn_v(*graph_tensors, mask=mask) 103 | return q, k, v 104 | 105 | 106 | class Attention(nn.Module): 107 | 108 | def forward(self, query, key, value, mask=None, dropout=None): 109 | # query: n_heads x n_atoms x dk 110 | # key: n_heads x n_atoms x dk 111 | # value: n_heads x n_atoms x dk 112 | scores = torch.bmm(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) 113 | if mask is not None: 114 | scores = scores.masked_fill(mask.unsqueeze(0).expand(scores.shape) == 0, -1e9) 115 | p_attn = F.softmax(scores, dim=-1) # n_heads x n_atoms x n_atoms 116 | if dropout is not None: 117 | p_attn = dropout(p_attn) 118 | scaled_vals = torch.bmm(p_attn, value).transpose(0, 1) 119 | return scaled_vals, p_attn 120 | 121 | class MultiHeadAttention(nn.Module): 122 | 123 | def __init__(self, n_heads: int, hsize: int, dropout: float = 0.1, bias: bool = False): 124 | super().__init__() 125 | assert hsize % n_heads == 0 126 | 127 | # We assume d_v always equals d_k 128 | self.hsize = hsize 129 | self.d_k = hsize // n_heads 130 | self.n_heads = n_heads # number of heads 131 | self.bias = bias 132 | self.dropout = dropout 133 | self._build_components() 134 | 135 | def _build_components(self): 136 | self.linear_layers = nn.ModuleList([nn.Linear(self.hsize, self.hsize) for _ in range(3)]) # why 3: query, key, value 137 | self.output_linear = nn.Linear(self.hsize, self.hsize, self.bias) 138 | self.attention = Attention() 139 | self.dropout = nn.Dropout(p=self.dropout) 140 | 141 | def forward(self, query, key, value, mask=None): 142 | n_atoms = query.size(0) 143 | 144 | # 1) Do all the linear projections in batch from d_model => h x d_k 145 | query, key, value = [l(x).view(n_atoms, self.n_heads, self.d_k).transpose(1, 0) 146 | for l, x in zip(self.linear_layers, (query, key, value))] 147 | 148 | # 2) Apply attention on all the projected vectors in batch. 149 | x, _ = self.attention(query, key, value, mask=mask, dropout=self.dropout) 150 | 151 | # 3) "Concat" using a view and apply a final linear. 152 | x = x.contiguous().view(n_atoms, self.n_heads * self.d_k) 153 | return self.output_linear(x) 154 | 155 | 156 | class MultiHeadBlock(nn.Module): 157 | 158 | def __init__(self, 159 | rnn_type: str, 160 | hsize: int, 161 | depth: int, 162 | n_heads: int, 163 | node_fdim: int, 164 | edge_fdim: int, 165 | bias: bool = False, 166 | dropout_p: float = 0.15, 167 | res_connection: bool = False, 168 | **kwargs): 169 | super(MultiHeadBlock, self).__init__(**kwargs) 170 | self.hsize = hsize 171 | self.rnn_type = rnn_type 172 | self.n_heads = n_heads 173 | self.depth = depth 174 | self.node_fdim = node_fdim 175 | self.edge_fdim = edge_fdim 176 | self.bias = bias 177 | self.dropout_p = dropout_p 178 | self.res_connection = res_connection 179 | self._build_layers() 180 | 181 | def _build_layers(self): 182 | self.W_i = nn.Linear(self.node_fdim, self.hsize, bias=False) 183 | self.W_o = nn.Linear(self.hsize, self.hsize, bias=self.bias) 184 | 185 | self.layernorm = nn.LayerNorm(self.hsize, elementwise_affine=True) 186 | self.heads = [Head(rnn_type=self.rnn_type, depth=self.depth, 187 | hsize=self.hsize // self.n_heads, node_fdim=self.hsize, 188 | edge_fdim=self.edge_fdim, dropout_p=self.dropout_p) 189 | for _ in range(self.n_heads)] 190 | self.heads = nn.ModuleList(self.heads) 191 | self.attention = MultiHeadAttention(n_heads=self.n_heads, hsize=self.hsize, 192 | dropout=self.dropout_p, bias=self.bias) 193 | self.sub_layer = SublayerConnection(hsize=self.hsize, dropout_p=self.dropout_p) 194 | 195 | def forward(self, graph_tensors, scopes): 196 | fnode, fmess, agraph, mess_graph, _ = graph_tensors 197 | queries, keys, values = [], [], [] 198 | 199 | if fnode.size(1) != self.hsize: 200 | fnode = self.W_i(fnode) 201 | 202 | tensors = (fnode,) + tuple(graph_tensors[1:]) 203 | for head in self.heads: 204 | q, k, v = head(tensors) 205 | queries.append(q.unsqueeze(1)) 206 | keys.append(k.unsqueeze(1)) 207 | values.append(v.unsqueeze(1)) 208 | 209 | n_atoms = q.size(0) 210 | dk, dv = q.size(1), v.size(1) 211 | queries = torch.cat(queries, dim=1).view(n_atoms, -1) # n_atoms x hsize 212 | keys = torch.cat(keys, dim=0).view(n_atoms, -1) # n_atoms x hsize 213 | values = torch.cat(values, dim=0).view(n_atoms, -1) # n_atoms x hsize 214 | 215 | assert queries.shape == (n_atoms, self.hsize) 216 | assert keys.shape == (n_atoms, self.hsize) 217 | assert values.shape == (n_atoms, self.hsize) 218 | 219 | # This boolean mask is for making sure attention only happens over 220 | # atoms of the same molecule 221 | mask = queries.new_zeros(n_atoms, n_atoms) 222 | a_scope = scopes[0] 223 | for a_start, a_len in a_scope: 224 | mask[a_start: a_start + a_len, a_start: a_start + a_len] = 1 225 | mask[0, 0] = 1 226 | 227 | x_out = self.attention(queries, keys, values, mask=mask) 228 | x_out = self.W_o(x_out) 229 | 230 | x_in = None 231 | if self.res_connection: 232 | x_in = fnode 233 | 234 | h_atom = self.sub_layer(x_in, x_out) 235 | next_tensors = (h_atom,) + graph_tensors[1:] 236 | return next_tensors, scopes 237 | -------------------------------------------------------------------------------- /seq_graph_retro/layers/reaction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import List, Tuple 5 | import math 6 | 7 | from seq_graph_retro.molgraph.mol_features import BINARY_FDIM 8 | 9 | class AtomAttention(nn.Module): 10 | """Pairwise atom attention layer.""" 11 | # Update this attention layer 12 | 13 | def __init__(self, 14 | n_bin_feat: int = BINARY_FDIM, 15 | hsize: int = 64, 16 | n_heads: int = 4, 17 | bias: bool = False, 18 | **kwargs) -> None: 19 | """ 20 | Parameters 21 | ---------- 22 | n_bin_feat: int, default BINARY_FDIM(11): 23 | Number of binary features used 24 | hsize: int, default 64 25 | Size of the embeddings 26 | n_heads: int, default 4 27 | Number of attention heads 28 | device: str, default cpu 29 | Device on which the programme is running 30 | bias: bool, default False 31 | Whether to use a bias term in the linear layers 32 | """ 33 | super(AtomAttention, self).__init__(**kwargs) 34 | self.n_bin_feat = n_bin_feat 35 | self.hsize = hsize 36 | self.n_heads = n_heads 37 | self.bias = bias 38 | self._build_layer_components() 39 | 40 | def _build_layer_components(self) -> None: 41 | """Builds the different layers associated.""" 42 | self.Wa_pair = nn.Linear(self.hsize, self.hsize * self.n_heads, self.bias) 43 | # self.Wa_bin = nn.Linear(self.n_bin_feat, self.hsize * self.n_heads, bias=True) 44 | self.Wa_score = nn.Parameter(torch.FloatTensor(self.hsize, 1, self.n_heads)) 45 | self.W_proj = nn.Linear(self.hsize * self.n_heads, self.hsize, self.bias) 46 | nn.init.kaiming_uniform_(self.Wa_score, a=math.sqrt(5)) 47 | 48 | def forward(self, inputs: torch.Tensor, scopes: List[torch.Tensor]) -> Tuple[torch.Tensor]: 49 | """Forward propagation step. 50 | 51 | Parameters 52 | ---------- 53 | inputs: torch.Tensor 54 | Atom embeddings from MPNN-Encoder 55 | scopes: Tuple[List] 56 | Scopes is composed of atom and bond scopes, which keep track of 57 | atom and bond indices for each molecule in the 2D feature list 58 | """ 59 | c_atom = inputs 60 | atom_scope, bond_scope = scopes 61 | scope_tensor, scope_rev_tensor = create_scope_tensor(atom_scope, device=c_atom.device) 62 | c_atom_batch = flat_to_batch(c_atom, scope_tensor) 63 | atom_pair = get_pair(c_atom_batch) 64 | 65 | bs, max_atoms = atom_pair.size(0), atom_pair.size(1) 66 | target_shape = [bs, max_atoms, max_atoms, self.hsize, self.n_heads] 67 | 68 | pair_att = self.Wa_pair(atom_pair) 69 | total_att = F.relu(pair_att).view(target_shape) 70 | assert list(total_att.shape) == target_shape 71 | 72 | eq = '...hn,hjn->...jn' 73 | att_score = torch.sigmoid(torch.einsum(eq, [total_att, self.Wa_score])) 74 | #att_score = att_score * attn_mask.unsqueeze(-1).unsqueeze(-1) # Mask deals with dummy atoms 75 | 76 | c_atom_exp = c_atom_batch.unsqueeze(1).unsqueeze(-1) 77 | c_atom_att = att_score * c_atom_exp 78 | c_atom_att = self.W_proj(torch.sum(c_atom_att, dim=2).view(bs, max_atoms, -1)) 79 | assert list(c_atom_att.shape) == [bs, max_atoms, self.hsize] 80 | c_mol_att = c_atom_att.sum(dim=1) 81 | c_atom_att = batch_to_flat(c_atom_att, scope_rev_tensor) 82 | 83 | c_atom_att = torch.cat([c_atom_att.new_zeros(1, self.hsize), c_atom_att], dim=0) 84 | return c_mol_att, c_atom_att 85 | 86 | class PairFeat(nn.Module): 87 | """Computes embeddings for pairs of atoms. Precursor to predicting bond formation.""" 88 | 89 | def __init__(self, 90 | n_bin_feat: int = BINARY_FDIM, 91 | hsize: int = 64, 92 | n_heads: int = 4, 93 | bias: bool = False, 94 | **kwargs) -> None: 95 | """ 96 | Parameters 97 | ---------- 98 | n_bin_feat: int, default BINARY_FDIM(11): 99 | Number of binary features used 100 | hsize: int, default 64 101 | Size of the embeddings 102 | n_heads: int, default 4, 103 | Number of attention heads 104 | bias: bool, default False 105 | Whether to use bias in linear layers 106 | """ 107 | super(PairFeat, self).__init__(**kwargs) 108 | self.n_bin_feat = n_bin_feat 109 | self.hsize = hsize 110 | self.bias = bias 111 | self._build_layer_components() 112 | 113 | def _build_layer_components(self) -> None: 114 | """Builds layer components.""" 115 | self.Wp_a_pair = nn.Linear(self.hsize, self.hsize, self.bias) 116 | self.Wp_att_pair = nn.Linear(self.hsize, self.hsize, self.bias) 117 | self.Wp_bin = nn.Linear(self.n_bin_feat, self.hsize, self.bias) 118 | 119 | def forward(self, inputs: Tuple[torch.Tensor]) -> torch.Tensor: 120 | """Forward pass. 121 | 122 | Parameters 123 | ---------- 124 | inputs: Tuple[torch.Tensor] 125 | Inputs for pair feat computation 126 | """ 127 | atom_pair, c_atom_att, bin_feat = inputs 128 | atom_att_pair = get_pair(c_atom_att) 129 | pair_hidden = self.Wp_a_pair(atom_pair) + self.Wp_att_pair(atom_att_pair) + \ 130 | self.Wp_bin(bin_feat) 131 | pair_hidden = F.relu(pair_hidden) 132 | return pair_hidden 133 | -------------------------------------------------------------------------------- /seq_graph_retro/layers/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Tuple 4 | 5 | from seq_graph_retro.utils.torch import index_select_ND, index_scatter 6 | 7 | class MPNLayer(nn.Module): 8 | """MessagePassing Network based encoder. Messages are updated using an RNN 9 | and the final message is used to update atom embeddings.""" 10 | 11 | def __init__(self, 12 | rnn_type: str, 13 | node_fdim: int, 14 | edge_fdim: int, 15 | hsize: int, 16 | depth: int, 17 | dropout_p: float = 0.15, 18 | **kwargs) -> None: 19 | """ 20 | Parameters 21 | ---------- 22 | rnn_type: str, 23 | Type of RNN used (gru/lstm) 24 | input_size: int, 25 | Input size 26 | node_fdim: int, 27 | Number of node features 28 | hsize: int, 29 | Hidden state size 30 | depth: int, 31 | Number of timesteps in the RNN 32 | """ 33 | super(MPNLayer, self).__init__(**kwargs) 34 | self.hsize = hsize 35 | self.edge_fdim = edge_fdim 36 | self.rnn_type = rnn_type 37 | self.depth = depth 38 | self.node_fdim = node_fdim 39 | self.dropout_p = dropout_p 40 | self._build_layers() 41 | 42 | def _build_layers(self) -> None: 43 | """Build layers associated with the MPNLayer.""" 44 | self.W_o = nn.Sequential(nn.Linear(self.node_fdim + self.hsize, self.hsize), nn.ReLU()) 45 | if self.rnn_type == 'gru': 46 | self.rnn = GRU(input_size=self.node_fdim + self.edge_fdim, 47 | hsize=self.hsize, 48 | depth=self.depth, 49 | dropout_p=self.dropout_p) 50 | 51 | elif self.rnn_type == 'lstm': 52 | self.rnn = LSTM(input_size=self.node_fdim + self.edge_fdim, 53 | hsize=self.hsize, 54 | depth=self.depth, 55 | dropout_p=self.dropout_p) 56 | else: 57 | raise ValueError('unsupported rnn cell type ' + self.rnn_type) 58 | 59 | def forward(self, fnode: torch.Tensor, fmess: torch.Tensor, 60 | agraph: torch.Tensor, bgraph: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor]: 61 | """Forward pass of the MPNLayer. 62 | 63 | Parameters 64 | ---------- 65 | fnode: torch.Tensor, 66 | Node feature tensor 67 | fmess: torch.Tensor, 68 | Message features 69 | agraph: torch.Tensor, 70 | Neighborhood of an atom 71 | bgraph: torch.Tensor, 72 | Neighborhood of a bond, except the directed bond from the destination 73 | node to the source node 74 | mask: torch.Tensor, 75 | Masks on nodes 76 | """ 77 | h = self.rnn(fmess, bgraph) 78 | h = self.rnn.get_hidden_state(h) 79 | nei_message = index_select_ND(h, 0, agraph) 80 | nei_message = nei_message.sum(dim=1) 81 | node_hiddens = torch.cat([fnode, nei_message], dim=1) 82 | node_hiddens = self.W_o(node_hiddens) 83 | 84 | if mask is None: 85 | mask = torch.ones(node_hiddens.size(0), 1, device=fnode.device) 86 | mask[0, 0] = 0 #first node is padding 87 | 88 | return node_hiddens * mask, h 89 | 90 | 91 | class GRU(nn.Module): 92 | """GRU Message Passing layer.""" 93 | 94 | def __init__(self, 95 | input_size: int, 96 | hsize: int, 97 | depth: int, 98 | dropout_p: float = 0.15, 99 | **kwargs) -> None: 100 | """ 101 | Parameters 102 | ---------- 103 | input_size: int, 104 | Size of the input 105 | hsize: int, 106 | Hidden state size 107 | depth: int, 108 | Number of time steps of message passing 109 | device: str, default cpu 110 | Device used for training 111 | """ 112 | super(GRU, self).__init__(**kwargs) 113 | self.hsize = hsize 114 | self.input_size = input_size 115 | self.depth = depth 116 | self.dropout_p = dropout_p 117 | self._build_layer_components() 118 | 119 | def _build_layer_components(self) -> None: 120 | """Build layer components.""" 121 | self.W_z = nn.Linear(self.input_size + self.hsize, self.hsize) 122 | self.W_r = nn.Linear(self.input_size, self.hsize, bias=False) 123 | self.U_r = nn.Linear(self.hsize, self.hsize) 124 | self.W_h = nn.Linear(self.input_size + self.hsize, self.hsize) 125 | 126 | self.dropouts = [] 127 | for i in range(self.depth): 128 | self.dropouts.append(nn.Dropout(p=self.dropout_p)) 129 | self.dropouts = nn.ModuleList(self.dropouts) 130 | 131 | def get_init_state(self, fmess: torch.Tensor, init_state: torch.Tensor = None) -> torch.Tensor: 132 | """Get the initial hidden state of the RNN. 133 | 134 | Parameters 135 | ---------- 136 | fmess: torch.Tensor, 137 | Contains the initial features passed as messages 138 | init_state: torch.Tensor, default None 139 | Custom initial state supplied. 140 | """ 141 | h = torch.zeros(len(fmess), self.hsize, device=fmess.device) 142 | return h if init_state is None else torch.cat( (h, init_state), dim=0) 143 | 144 | def get_hidden_state(self, h: torch.Tensor) -> torch.Tensor: 145 | """Gets the hidden state. 146 | 147 | Parameters 148 | ---------- 149 | h: torch.Tensor, 150 | Hidden state of the GRU 151 | """ 152 | return h 153 | 154 | def GRU(self, x: torch.Tensor, h_nei: torch.Tensor) -> torch.Tensor: 155 | """Implements the GRU gating equations. 156 | 157 | Parameters 158 | ---------- 159 | x: torch.Tensor, 160 | Input tensor 161 | h_nei: torch.Tensor, 162 | Hidden states of the neighbors 163 | """ 164 | sum_h = h_nei.sum(dim=1) 165 | z_input = torch.cat([x,sum_h], dim=1) 166 | z = torch.sigmoid(self.W_z(z_input)) 167 | 168 | r_1 = self.W_r(x).view(-1, 1, self.hsize) 169 | r_2 = self.U_r(h_nei) 170 | r = torch.sigmoid(r_1 + r_2) 171 | 172 | gated_h = r * h_nei 173 | sum_gated_h = gated_h.sum(dim=1) 174 | h_input = torch.cat([x,sum_gated_h], dim=1) 175 | pre_h = torch.tanh(self.W_h(h_input)) 176 | new_h = (1.0 - z) * sum_h + z * pre_h 177 | return new_h 178 | 179 | def forward(self, fmess: torch.Tensor, bgraph: torch.Tensor) -> torch.Tensor: 180 | """Forward pass of the RNN 181 | 182 | Parameters 183 | ---------- 184 | fmess: torch.Tensor, 185 | Contains the initial features passed as messages 186 | bgraph: torch.Tensor, 187 | Bond graph tensor. Contains who passes messages to whom. 188 | """ 189 | h = torch.zeros(fmess.size(0), self.hsize, device=fmess.device) 190 | mask = torch.ones(h.size(0), 1, device=h.device) 191 | mask[0, 0] = 0 #first message is padding 192 | 193 | for i in range(self.depth): 194 | h_nei = index_select_ND(h, 0, bgraph) 195 | h = self.GRU(fmess, h_nei) 196 | h = h * mask 197 | h = self.dropouts[i](h) 198 | return h 199 | 200 | def sparse_forward(self, h: torch.Tensor, fmess: torch.Tensor, 201 | submess: torch.Tensor, bgraph: torch.Tensor) -> torch.Tensor: 202 | """Unknown use. 203 | 204 | Parameters 205 | ---------- 206 | h: torch.Tensor, 207 | Hidden state tensor 208 | fmess: torch.Tensor, 209 | Contains the initial features passed as messages 210 | submess: torch.Tensor, 211 | bgraph: torch.Tensor, 212 | Bond graph tensor. Contains who passes messages to whom. 213 | """ 214 | mask = h.new_ones(h.size(0)).scatter_(0, submess, 0) 215 | h = h * mask.unsqueeze(1) 216 | for i in range(self.depth): 217 | h_nei = index_select_ND(h, 0, bgraph) 218 | sub_h = self.GRU(fmess, h_nei) 219 | h = index_scatter(sub_h, h, submess) 220 | return h 221 | 222 | class LSTM(nn.Module): 223 | 224 | def __init__(self, 225 | input_size: int, 226 | hsize: int, 227 | depth: int, 228 | dropout_p: float = 0.15, 229 | **kwargs): 230 | """ 231 | Parameters 232 | ---------- 233 | input_size: int, 234 | Size of the input 235 | hsize: int, 236 | Hidden state size 237 | depth: int, 238 | Number of time steps of message passing 239 | device: str, default cpu 240 | Device used for training 241 | """ 242 | super(LSTM, self).__init__(**kwargs) 243 | self.hsize = hsize 244 | self.input_size = input_size 245 | self.depth = depth 246 | self.dropout_p = dropout_p 247 | self._build_layer_components() 248 | 249 | def _build_layer_components(self): 250 | """Build layer components.""" 251 | self.W_i = nn.Sequential(nn.Linear(self.input_size + self.hsize, self.hsize), nn.Sigmoid()) 252 | self.W_o = nn.Sequential(nn.Linear(self.input_size + self.hsize, self.hsize), nn.Sigmoid()) 253 | self.W_f = nn.Sequential(nn.Linear(self.input_size + self.hsize, self.hsize), nn.Sigmoid()) 254 | self.W = nn.Sequential(nn.Linear(self.input_size + self.hsize, self.hsize), nn.Tanh()) 255 | 256 | self.dropouts = [] 257 | for i in range(self.depth): 258 | self.dropouts.append(nn.Dropout(p=self.dropout_p)) 259 | self.dropouts = nn.ModuleList(self.dropouts) 260 | 261 | def get_init_state(self, fmess: torch.Tensor, 262 | init_state: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: 263 | """Get the initial hidden state of the RNN. 264 | 265 | Parameters 266 | ---------- 267 | fmess: torch.Tensor, 268 | Contains the initial features passed as messages 269 | init_state: torch.Tensor, default None 270 | Custom initial state supplied. 271 | """ 272 | h = torch.zeros(len(fmess), self.hsize, device=fmess.device) 273 | c = torch.zeros(len(fmess), self.hsize, device=fmess.device) 274 | if init_state is not None: 275 | h = torch.cat((h, init_state), dim=0) 276 | c = torch.cat((c, torch.zeros_like(init_state)), dim=0) 277 | return h,c 278 | 279 | def get_hidden_state(self, h: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: 280 | """Gets the hidden state. 281 | 282 | Parameters 283 | ---------- 284 | h: Tuple[torch.Tensor, torch.Tensor], 285 | Hidden state tuple of the LSTM 286 | """ 287 | return h[0] 288 | 289 | def LSTM(self, x: torch.Tensor, h_nei: torch.Tensor, c_nei: torch.Tensor) -> torch.Tensor: 290 | """Implements the LSTM gating equations. 291 | 292 | Parameters 293 | ---------- 294 | x: torch.Tensor, 295 | Input tensor 296 | h_nei: torch.Tensor, 297 | Hidden states of the neighbors 298 | c_nei: torch.Tensor, 299 | Memory state of the neighbors 300 | """ 301 | h_sum_nei = h_nei.sum(dim=1) 302 | x_expand = x.unsqueeze(1).expand(-1, h_nei.size(1), -1) 303 | i = self.W_i( torch.cat([x, h_sum_nei], dim=-1) ) 304 | o = self.W_o( torch.cat([x, h_sum_nei], dim=-1) ) 305 | f = self.W_f( torch.cat([x_expand, h_nei], dim=-1) ) 306 | u = self.W( torch.cat([x, h_sum_nei], dim=-1) ) 307 | c = i * u + (f * c_nei).sum(dim=1) 308 | h = o * torch.tanh(c) 309 | return h, c 310 | 311 | def forward(self, fmess: torch.Tensor, bgraph: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 312 | """Forward pass of the RNN. 313 | 314 | Parameters 315 | ---------- 316 | fmess: torch.Tensor, 317 | Contains the initial features passed as messages 318 | bgraph: torch.Tensor, 319 | Bond graph tensor. Contains who passes messages to whom. 320 | """ 321 | h = torch.zeros(fmess.size(0), self.hsize, device=fmess.device) 322 | c = torch.zeros(fmess.size(0), self.hsize, device=fmess.device) 323 | mask = torch.ones(h.size(0), 1, device=h.device) 324 | mask[0, 0] = 0 #first message is padding 325 | 326 | for i in range(self.depth): 327 | h_nei = index_select_ND(h, 0, bgraph) 328 | c_nei = index_select_ND(c, 0, bgraph) 329 | h,c = self.LSTM(fmess, h_nei, c_nei) 330 | h = h * mask 331 | c = c * mask 332 | h = self.dropouts[i](h) 333 | c = self.dropouts[i](c) 334 | return h,c 335 | 336 | def sparse_forward(self, h: torch.Tensor, fmess: torch.Tensor, 337 | submess: torch.Tensor, bgraph: torch.Tensor) -> torch.Tensor: 338 | """Unknown use. 339 | 340 | Parameters 341 | ---------- 342 | h: torch.Tensor, 343 | Hidden state tensor 344 | fmess: torch.Tensor, 345 | Contains the initial features passed as messages 346 | submess: torch.Tensor, 347 | bgraph: torch.Tensor, 348 | Bond graph tensor. Contains who passes messages to whom. 349 | """ 350 | h,c = h 351 | mask = h.new_ones(h.size(0)).scatter_(0, submess, 0) 352 | h = h * mask.unsqueeze(1) 353 | c = c * mask.unsqueeze(1) 354 | for i in range(self.depth): 355 | h_nei = index_select_ND(h, 0, bgraph) 356 | c_nei = index_select_ND(c, 0, bgraph) 357 | sub_h, sub_c = self.LSTM(fmess, h_nei, c_nei) 358 | h = index_scatter(sub_h, h, submess) 359 | c = index_scatter(sub_c, c, submess) 360 | return h,c 361 | -------------------------------------------------------------------------------- /seq_graph_retro/models/__init__.py: -------------------------------------------------------------------------------- 1 | from seq_graph_retro.models.core_edits.single_edit import SingleEdit 2 | from seq_graph_retro.models.core_edits.multi_edit import MultiEdit 3 | from seq_graph_retro.models.lg_edits.lg_shared_embed import LGClassifier 4 | from seq_graph_retro.models.lg_edits.lg_ind_embed import LGIndEmbed 5 | from seq_graph_retro.models.retro.shared_edits_lg import SingleEditShared 6 | from seq_graph_retro.models.retro.separate_edits_lg import EditLGSeparate 7 | from seq_graph_retro.models.trainer import Trainer 8 | -------------------------------------------------------------------------------- /seq_graph_retro/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/models/__pycache__/focal_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/__pycache__/focal_loss.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/models/__pycache__/model_builder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/__pycache__/model_builder.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/models/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/models/core_edits/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/core_edits/__init__.py -------------------------------------------------------------------------------- /seq_graph_retro/models/core_edits/__pycache__/multi_edit.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/core_edits/__pycache__/multi_edit.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/models/core_edits/__pycache__/single_edit.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/core_edits/__pycache__/single_edit.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/models/lg_edits/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/lg_edits/__init__.py -------------------------------------------------------------------------------- /seq_graph_retro/models/lg_edits/__pycache__/lg_ind_embed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/lg_edits/__pycache__/lg_ind_embed.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/models/lg_edits/__pycache__/lg_shared_embed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/lg_edits/__pycache__/lg_shared_embed.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/models/lg_edits/lg_ind_embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import List, Dict, Tuple, Union 4 | from rdkit import Chem 5 | 6 | from seq_graph_retro.molgraph.vocab import Vocab 7 | from seq_graph_retro.utils.torch import build_mlp 8 | from seq_graph_retro.utils.metrics import get_accuracy_lg 9 | from seq_graph_retro.layers import AtomAttention, GraphFeatEncoder, WLNEncoder 10 | 11 | from seq_graph_retro.utils.parse import apply_edits_to_mol 12 | from seq_graph_retro.data.collate_fns import pack_graph_feats 13 | from seq_graph_retro.molgraph.rxn_graphs import MultiElement, RxnElement 14 | 15 | 16 | class LGIndEmbed(nn.Module): 17 | """LGIndEmbed is a classifier for predicting leaving groups on fragments.""" 18 | 19 | def __init__(self, 20 | config: Dict, 21 | lg_vocab: Vocab, 22 | encoder_name: str, 23 | toggles: Dict = None, 24 | device: str = 'cpu', 25 | **kwargs): 26 | """ 27 | Parameters 28 | ---------- 29 | config: Dict, 30 | Config for all sub-modules and self 31 | lg_vocab: Vocab 32 | Vocabulary of leaving groups 33 | encoder_name: str, 34 | Name of the encoder network 35 | use_prev_pred: bool, default True 36 | Whether to use previous leaving group prediction 37 | device: str 38 | Device on which program runs 39 | """ 40 | super(LGIndEmbed, self).__init__(**kwargs) 41 | self.config = config 42 | self.lg_vocab = lg_vocab 43 | self.encoder_name = encoder_name 44 | self.toggles = toggles if toggles is not None else {} 45 | self.device = device 46 | self.E_lg = torch.eye(len(lg_vocab)).to(device) 47 | 48 | self._build_layers() 49 | 50 | def _build_layers(self) -> None: 51 | """Builds the layers in the classifier.""" 52 | config = self.config 53 | if self.encoder_name == 'GraphFeatEncoder': 54 | self.encoder = GraphFeatEncoder(node_fdim=config['n_atom_feat'], 55 | edge_fdim=config['n_bond_feat'], 56 | rnn_type=config['rnn_type'], 57 | hsize=config['mpn_size'], 58 | depth=config['depth'], 59 | dropout_p=config['dropout_mpn']) 60 | 61 | elif self.encoder_name == 'WLNEncoder': 62 | self.encoder = WLNEncoder(node_fdim=config['n_atom_feat'], 63 | edge_fdim=config['n_bond_feat'], 64 | hsize=config['mpn_size'], 65 | depth=config['depth'], 66 | bias=config['bias'], 67 | dropout_p=config['dropout_mpn']) 68 | else: 69 | raise ValueError() 70 | 71 | if self.toggles.get('use_attn', False): 72 | self.attn_layer = AtomAttention(n_bin_feat=config['n_bin_feat'], 73 | hsize=config['mpn_size'], 74 | n_heads=config['n_heads'], 75 | bias=config['bias']) 76 | 77 | lg_score_in_dim = 2 * config['mpn_size'] 78 | if self.toggles.get('use_prev_pred', False): 79 | lg_score_in_dim += config['embed_size'] 80 | 81 | self.lg_embedding = nn.Linear(in_features=len(self.lg_vocab), 82 | out_features=config['embed_size'], 83 | bias=config['embed_bias']) 84 | 85 | self.lg_score = build_mlp(in_dim=lg_score_in_dim, 86 | h_dim=config['mlp_size'], 87 | out_dim=len(self.lg_vocab), 88 | dropout_p=config['dropout_mlp']) 89 | 90 | self.lg_loss = nn.CrossEntropyLoss(ignore_index=self.lg_vocab[""]) 91 | 92 | def _compute_lg_step(self, graph_vecs, prod_vecs, prev_embed=None): 93 | if self.toggles.get('use_prev_pred', False): 94 | if prev_embed is None: 95 | init_state = torch.zeros(graph_vecs.size(0), len(self.lg_vocab), device=self.device) 96 | init_state[:, self.lg_vocab.get("")] = 1 97 | prev_lg_emb = self.lg_embedding(init_state) 98 | else: 99 | prev_lg_emb = prev_embed 100 | 101 | if self.toggles.get('use_prev_pred', False): 102 | scores_lg = self.lg_score(torch.cat([prev_lg_emb, prod_vecs, graph_vecs], dim=-1)) 103 | else: 104 | scores_lg = self.lg_score(torch.cat([prod_vecs, graph_vecs], dim=-1)) 105 | return scores_lg, None 106 | 107 | def _compute_lg_logits(self, graph_vecs_pad, prod_vecs, lg_labels=None): 108 | scores = torch.tensor([], device=self.device) 109 | prev_lg_emb = None 110 | 111 | if lg_labels is None: 112 | for idx in range(graph_vecs_pad.size(1)): 113 | scores_lg, _ = self._compute_lg_step(graph_vecs_pad[:, idx], prod_vecs, prev_embed=prev_lg_emb) 114 | prev_lg_emb = self.lg_embedding(self.E_lg.index_select(index=torch.argmax(scores_lg, dim=-1), dim=0)) 115 | scores = torch.cat([scores, scores_lg.unsqueeze(1)], dim=1) 116 | 117 | else: 118 | for idx in range(graph_vecs_pad.size(1)): 119 | scores_lg, _ = self._compute_lg_step(graph_vecs_pad[:, idx], prod_vecs, prev_embed=prev_lg_emb) 120 | prev_lg_emb = self.lg_embedding(self.E_lg.index_select(index=lg_labels[:, idx], dim=0)) 121 | scores = torch.cat([scores, scores_lg.unsqueeze(1)], dim=1) 122 | 123 | return scores 124 | 125 | def forward(self, prod_inputs, frag_inputs): 126 | prod_tensors, prod_scopes = prod_inputs 127 | frag_tensors, frag_scopes = frag_inputs 128 | 129 | prod_tensors = self.to_device(prod_tensors) 130 | frag_tensors = self.to_device(frag_tensors) 131 | 132 | prod_vecs, _ = self.encoder(prod_tensors, prod_scopes) 133 | frag_vecs, c_atom = self.encoder(frag_tensors, frag_scopes) 134 | frag_vecs_pad = torch.nn.utils.rnn.pad_sequence(frag_vecs, batch_first=True) 135 | 136 | return prod_vecs, frag_vecs_pad 137 | 138 | def get_saveables(self): 139 | saveables = {} 140 | saveables['config'] = self.config 141 | saveables['lg_vocab'] = self.lg_vocab 142 | saveables['encoder_name'] = self.encoder_name 143 | saveables['toggles'] = None if self.toggles == {} else self.toggles 144 | return saveables 145 | 146 | def to_device(self, tensors): 147 | """Converts all inputs to the device used.""" 148 | if isinstance(tensors, list) or isinstance(tensors, tuple): 149 | tensors = [tensor.to(self.device) for tensor in tensors] 150 | return tensors 151 | elif isinstance(tensors, torch.Tensor): 152 | return tensors.to(self.device) 153 | else: 154 | raise ValueError(f"Tensors of type {type(tensors)} unsupported") 155 | 156 | def _compute_lg_stats(self, lg_logits, lg_labels, lengths): 157 | loss = self.lg_loss(lg_logits.view(-1, len(self.lg_vocab)), lg_labels.reshape(-1)) 158 | acc_lg = get_accuracy_lg(lg_logits, lg_labels, lengths, device=self.device) 159 | return loss, acc_lg 160 | 161 | def train_step(self, prod_inputs, frag_inputs, lg_labels, lengths, **kwargs): 162 | prod_vecs, frag_vecs_pad = self(prod_inputs, frag_inputs) 163 | lg_labels = self.to_device(lg_labels) 164 | lg_logits = self._compute_lg_logits(frag_vecs_pad, prod_vecs=prod_vecs, lg_labels=lg_labels) 165 | 166 | lg_loss, lg_acc = self._compute_lg_stats(lg_logits, lg_labels, lengths) 167 | metrics = {'loss': lg_loss.item(), "accuracy": lg_acc.item()} 168 | return lg_loss, metrics 169 | 170 | def eval_step(self, prod_smi_batch: List[str], 171 | core_edits_batch: List[List], 172 | lg_label_batch: List[List], 173 | rxn_classes: List[int] = None, 174 | **kwargs) -> Tuple[torch.Tensor, Dict]: 175 | """Eval step of the model. 176 | 177 | Parameters 178 | ---------- 179 | prod_smi_batch: List[str], 180 | List of product smiles 181 | core_edits_batch: List[List], 182 | List of edits for each element in batch. 183 | lg_label_batch: List[List], 184 | Leaving groups for each element in the batch 185 | """ 186 | acc_lg = 0.0 187 | 188 | for idx, prod_smi in enumerate(prod_smi_batch): 189 | if rxn_classes is None: 190 | labels = self.predict(prod_smi, core_edits_batch[idx]) 191 | else: 192 | labels = self.predict(prod_smi, core_edits_batch[idx], rxn_class=rxn_classes[idx]) 193 | if labels == lg_label_batch[idx]: 194 | acc_lg += 1.0 195 | 196 | metrics = {'loss': None, 'accuracy': acc_lg} 197 | return None, metrics 198 | 199 | def predict(self, prod_smi: str, core_edits: List, rxn_class: int = None): 200 | """Make predictions for given product smiles string. 201 | 202 | Parameters 203 | ---------- 204 | prod_smi: str, 205 | Product SMILES string 206 | core_edits: List, 207 | Edits associated with product molecule 208 | """ 209 | if self.encoder_name == 'WLNEncoder': 210 | directed = False 211 | elif self.encoder_name == 'GraphFeatEncoder': 212 | directed = True 213 | 214 | use_rxn_class = False 215 | if rxn_class is not None: 216 | use_rxn_class = True 217 | 218 | with torch.no_grad(): 219 | mol = Chem.MolFromSmiles(prod_smi) 220 | prod_graph = RxnElement(mol=Chem.Mol(mol), rxn_class=rxn_class) 221 | 222 | prod_inputs = pack_graph_feats([prod_graph], directed=directed, 223 | return_graphs=False, use_rxn_class=use_rxn_class) 224 | fragments = apply_edits_to_mol(Chem.Mol(mol), core_edits) 225 | tmp_frags = MultiElement(Chem.Mol(fragments)).mols 226 | 227 | if fragments is None: 228 | return [] 229 | 230 | else: 231 | fragments = Chem.Mol() 232 | for mol in tmp_frags: 233 | fragments = Chem.CombineMols(fragments, mol) 234 | 235 | frag_graph = MultiElement(mol=Chem.Mol(fragments), rxn_class=rxn_class) 236 | frag_inputs = pack_graph_feats([frag_graph], directed=directed, 237 | return_graphs=False, use_rxn_class=use_rxn_class) 238 | 239 | prod_vecs, frag_vecs_pad = self(prod_inputs, frag_inputs) 240 | 241 | lg_logits = self._compute_lg_logits(frag_vecs_pad, prod_vecs, lg_labels=None) 242 | 243 | _, preds = torch.max(lg_logits, dim=-1) 244 | preds = preds.squeeze(0) 245 | pred_labels = [self.lg_vocab.get_elem(pred.item()) for pred in preds] 246 | 247 | return pred_labels 248 | -------------------------------------------------------------------------------- /seq_graph_retro/models/model_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import joblib 3 | 4 | from seq_graph_retro.models import (SingleEdit, MultiEdit, SingleEditShared, 5 | LGClassifier, LGIndEmbed) 6 | from seq_graph_retro.molgraph.vocab import Vocab 7 | 8 | from seq_graph_retro.data import (SingleEditDataset, MultiEditDataset, 9 | SingleEditSharedDataset, LGClassifierDataset, LGEvalDataset, EditsEvalDataset, 10 | SharedEvalDataset) 11 | 12 | from seq_graph_retro.molgraph.mol_features import (ATOM_FDIM, BOND_FDIM, BOND_TYPES, 13 | RXN_CLASSES, BOND_FLOATS, PATTERN_DIM, BINARY_FDIM) 14 | from seq_graph_retro.molgraph.vocab import ATOM_LIST 15 | 16 | MODEL_ATTRS = { 17 | 'single_edit': (SingleEdit, SingleEditDataset, EditsEvalDataset, False), 18 | 'multi_edit': (MultiEdit, MultiEditDataset, EditsEvalDataset, False), 19 | 'single_shared': (SingleEditShared, SingleEditSharedDataset, SharedEvalDataset, True), 20 | 'lg_classifier': (LGClassifier, LGClassifierDataset, LGEvalDataset, True), 21 | 'lg_ind': (LGIndEmbed, LGClassifierDataset, LGEvalDataset, True) 22 | } 23 | 24 | def build_edits_config(loaded_config): 25 | model_config = {} 26 | config = {} 27 | if loaded_config.get('use_rxn_class', False): 28 | config['n_atom_feat'] = ATOM_FDIM + len(RXN_CLASSES) 29 | else: 30 | config['n_atom_feat'] = ATOM_FDIM 31 | config['n_bond_feat'] = BOND_FDIM 32 | config['n_bin_feat'] = BINARY_FDIM 33 | config['rnn_type'] = loaded_config['rnn_type'] 34 | config['mpn_size'] = loaded_config['mpn_size'] 35 | config['mlp_size'] = loaded_config['mlp_size'] 36 | config['depth'] = loaded_config['depth'] 37 | config['bias'] = False 38 | config['edit_loss'] = loaded_config['loss_type'] 39 | if 'n_mt_blocks' in loaded_config: 40 | config['n_mt_blocks'] = loaded_config['n_mt_blocks'] 41 | 42 | if loaded_config['edits_type'] == 'bond_edits': 43 | bs_outdim = len(BOND_FLOATS) 44 | elif loaded_config['edits_type'] == 'bond_disconn': 45 | bs_outdim = 1 46 | else: 47 | raise ValueError() 48 | 49 | config['bs_outdim'] = bs_outdim 50 | if loaded_config.get("propagate_logits", False): 51 | if loaded_config.get('use_rxn_class', False): 52 | config['bond_label_feat'] = ATOM_FDIM + 1 + 2 * (BOND_FDIM-1) + len(RXN_CLASSES) 53 | else: 54 | config['bond_label_feat'] = ATOM_FDIM + 1 + 2 * (BOND_FDIM-1) 55 | config['dropout_mlp'] = loaded_config['dropout_mlp'] 56 | config['dropout_mpn'] = loaded_config['dropout_mpn'] 57 | config['pos_weight'] = loaded_config['pos_weight'] 58 | 59 | toggles = {} 60 | toggles['use_attn'] = loaded_config.get('use_attn', False) 61 | toggles['use_rxn_class'] = loaded_config.get('use_rxn_class', False) 62 | toggles['use_h_labels'] = loaded_config.get('use_h_labels', True) 63 | toggles['use_prod'] = loaded_config.get('use_prod_edits', False) 64 | toggles['propagate_logits'] = loaded_config.get('propagate_logits', False) 65 | toggles['use_res'] = loaded_config.get('use_res', False) 66 | 67 | if 'n_heads' in loaded_config: 68 | config['n_heads'] = loaded_config['n_heads'] 69 | 70 | model_config['config'] = config 71 | model_config['toggles'] = toggles 72 | return model_config 73 | 74 | def build_shared_edit_config(loaded_config): 75 | model_config = {} 76 | config = {} 77 | if loaded_config.get('use_rxn_class', False): 78 | config['n_atom_feat'] = ATOM_FDIM + len(RXN_CLASSES) 79 | else: 80 | config['n_atom_feat'] = ATOM_FDIM 81 | config['n_bond_feat'] = BOND_FDIM 82 | config['n_bin_feat'] = loaded_config['n_bin_feat'] 83 | config['rnn_type'] = loaded_config['rnn_type'] 84 | config['mpn_size'] = loaded_config['mpn_size'] 85 | config['mlp_size'] = loaded_config['mlp_size'] 86 | config['depth'] = loaded_config['depth'] 87 | config['bias'] = False 88 | config['embed_size'] = loaded_config['embed_size'] 89 | config['edit_loss'] = loaded_config['loss_type'] 90 | 91 | if loaded_config['edits_type'] == 'bond_edits': 92 | bs_outdim = len(BOND_FLOATS) 93 | elif loaded_config['edits_type'] == 'bond_disconn': 94 | bs_outdim = 1 95 | else: 96 | raise ValueError() 97 | config['bs_outdim'] = bs_outdim 98 | config['dropout_mlp'] = loaded_config['dropout_mlp'] 99 | config['dropout_mpn'] = loaded_config['dropout_mpn'] 100 | 101 | toggles = {} 102 | toggles['use_attn'] = loaded_config.get('use_attn', False) 103 | toggles['use_prev_pred'] = loaded_config.get('use_prev_pred', True) 104 | toggles['use_rxn_class'] = loaded_config.get('use_rxn_class', False) 105 | toggles['use_h_labels'] = loaded_config.get('use_h_labels', True) 106 | toggles['use_prod'] = loaded_config.get('use_prod_edits', False) 107 | toggles['propagate_logits'] = loaded_config.get('propagate_logits', False) 108 | toggles['use_res'] = loaded_config.get('use_res', False) 109 | 110 | if 'n_heads' in loaded_config: 111 | config['n_heads'] = loaded_config['n_heads'] 112 | 113 | config['embed_bias'] = False 114 | config['lam_edits'] = loaded_config['lam_edits'] 115 | config['lam_lg'] = loaded_config['lam_lg'] 116 | 117 | if loaded_config.get('use_h_labels', True): 118 | lg_dict = joblib.load(os.path.join(loaded_config['data_dir'], "train", "h_labels", loaded_config['vocab_file'])) 119 | lg_tensor_dir = os.path.join(loaded_config['data_dir'], "train", "h_labels") 120 | else: 121 | lg_dict = joblib.load(os.path.join(loaded_config['data_dir'], "train", "without_h_labels", loaded_config['vocab_file'])) 122 | lg_tensor_dir = os.path.join(loaded_config['data_dir'], "train", "without_h_labels") 123 | 124 | if loaded_config.get('use_rxn_class', False): 125 | lg_tensor_dir = os.path.join(lg_tensor_dir, "with_rxn") 126 | else: 127 | lg_tensor_dir = os.path.join(lg_tensor_dir, "without_rxn") 128 | 129 | lg_tensor_file = os.path.join(lg_tensor_dir, "lg_inputs.pt") 130 | lg_vocab = Vocab(lg_dict) 131 | 132 | model_config['config'] = config 133 | model_config['toggles'] = toggles 134 | model_config['lg_vocab'] = lg_vocab 135 | model_config['tensor_file'] = lg_tensor_file 136 | return model_config 137 | 138 | def build_lg_classifier_config(loaded_config): 139 | model_config = {} 140 | config = {} 141 | config['rnn_type'] = loaded_config['rnn_type'] 142 | config['mpn_size'] = loaded_config['mpn_size'] 143 | config['mlp_size'] = loaded_config['mlp_size'] 144 | config['depth'] = loaded_config['depth'] 145 | config['bias'] = False 146 | config['embed_size'] = loaded_config['embed_size'] 147 | config['dropout_mlp'] = loaded_config['dropout_mlp'] 148 | config['dropout_mpn'] = loaded_config['dropout_mpn'] 149 | if 'n_mt_blocks' in loaded_config: 150 | config['n_mt_blocks'] = loaded_config['n_mt_blocks'] 151 | 152 | if loaded_config.get('use_rxn_class', False): 153 | config['n_atom_feat'] = ATOM_FDIM + len(RXN_CLASSES) 154 | else: 155 | config['n_atom_feat'] = ATOM_FDIM 156 | config['n_bond_feat'] = BOND_FDIM 157 | 158 | toggles = {} 159 | toggles['use_attn'] = loaded_config.get('use_attn', False) 160 | toggles['use_prev_pred'] = loaded_config.get('use_prev_pred', True) 161 | toggles['use_rxn_class'] = loaded_config.get('use_rxn_class', False) 162 | 163 | if 'n_heads' in loaded_config: 164 | config['n_heads'] = loaded_config['n_heads'] 165 | config['embed_bias'] = False 166 | 167 | if loaded_config.get('use_h_labels', True): 168 | lg_dict = joblib.load(os.path.join(loaded_config['data_dir'], "train", "h_labels", loaded_config['vocab_file'])) 169 | lg_tensor_dir = os.path.join(loaded_config['data_dir'], "train", "h_labels") 170 | else: 171 | lg_dict = joblib.load(os.path.join(loaded_config['data_dir'], "train", "without_h_labels", loaded_config['vocab_file'])) 172 | lg_tensor_dir = os.path.join(loaded_config['data_dir'], "train", "without_h_labels") 173 | 174 | if loaded_config.get('use_rxn_class', False): 175 | lg_tensor_dir = os.path.join(lg_tensor_dir, "with_rxn") 176 | else: 177 | lg_tensor_dir = os.path.join(lg_tensor_dir, "without_rxn") 178 | 179 | lg_tensor_file = os.path.join(lg_tensor_dir, "lg_inputs.pt") 180 | lg_vocab = Vocab(lg_dict) 181 | 182 | model_config['config'] = config 183 | model_config['toggles'] = toggles 184 | model_config['lg_vocab'] = lg_vocab 185 | model_config['tensor_file'] = lg_tensor_file 186 | return model_config 187 | 188 | def build_lg_ind_config(loaded_config): 189 | model_config = {} 190 | config = {} 191 | config['rnn_type'] = loaded_config['rnn_type'] 192 | config['mpn_size'] = loaded_config['mpn_size'] 193 | config['mlp_size'] = loaded_config['mlp_size'] 194 | config['depth'] = loaded_config['depth'] 195 | config['bias'] = False 196 | config['embed_size'] = loaded_config['embed_size'] 197 | config['dropout_mlp'] = loaded_config['dropout_mlp'] 198 | config['dropout_mpn'] = loaded_config['dropout_mpn'] 199 | if 'n_mt_blocks' in loaded_config: 200 | config['n_mt_blocks'] = loaded_config['n_mt_blocks'] 201 | 202 | if loaded_config.get('use_rxn_class', False): 203 | config['n_atom_feat'] = ATOM_FDIM + len(RXN_CLASSES) 204 | else: 205 | config['n_atom_feat'] = ATOM_FDIM 206 | config['n_bond_feat'] = BOND_FDIM 207 | 208 | toggles = {} 209 | toggles['use_attn'] = loaded_config.get('use_attn', False) 210 | toggles['use_prev_pred'] = loaded_config.get('use_prev_pred', True) 211 | toggles['use_rxn_class'] = loaded_config.get('use_rxn_class', False) 212 | 213 | if 'n_heads' in loaded_config: 214 | config['n_heads'] = loaded_config['n_heads'] 215 | config['embed_bias'] = False 216 | 217 | if loaded_config.get('use_h_labels', True): 218 | lg_dict = joblib.load(os.path.join(loaded_config['data_dir'], "train", "h_labels", loaded_config['vocab_file'])) 219 | lg_tensor_dir = os.path.join(loaded_config['data_dir'], "train", "h_labels") 220 | else: 221 | lg_dict = joblib.load(os.path.join(loaded_config['data_dir'], "train", "without_h_labels", loaded_config['vocab_file'])) 222 | lg_tensor_dir = os.path.join(loaded_config['data_dir'], "train", "without_h_labels") 223 | 224 | if loaded_config.get('use_rxn_class', False): 225 | lg_tensor_dir = os.path.join(lg_tensor_dir, "with_rxn") 226 | else: 227 | lg_tensor_dir = os.path.join(lg_tensor_dir, "without_rxn") 228 | 229 | lg_vocab = Vocab(lg_dict) 230 | 231 | model_config['config'] = config 232 | model_config['toggles'] = toggles 233 | model_config['lg_vocab'] = lg_vocab 234 | return model_config 235 | 236 | CONFIG_FNS = { 237 | 'single_edit': build_edits_config, 238 | 'multi_edit': build_edits_config, 239 | 'single_shared': build_shared_edit_config, 240 | 'lg_classifier': build_lg_classifier_config, 241 | 'lg_ind': build_lg_ind_config 242 | } 243 | 244 | def build_model(loaded_config, device='cpu'): 245 | config_fn = CONFIG_FNS.get(loaded_config['model']) 246 | model_config = config_fn(loaded_config) 247 | 248 | if loaded_config['mpnn'] == 'graph_feat': 249 | encoder_name = 'GraphFeatEncoder' 250 | elif loaded_config['mpnn'] == 'wln': 251 | encoder_name = 'WLNEncoder' 252 | elif loaded_config['mpnn'] == 'gtrans': 253 | encoder_name = 'GTransEncoder' 254 | 255 | model_class = MODEL_ATTRS.get(loaded_config['model'])[0] 256 | model = model_class(**model_config, encoder_name=encoder_name, device=device) 257 | return model 258 | -------------------------------------------------------------------------------- /seq_graph_retro/models/retro/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/retro/__init__.py -------------------------------------------------------------------------------- /seq_graph_retro/models/retro/__pycache__/separate_edits_lg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/retro/__pycache__/separate_edits_lg.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/models/retro/__pycache__/shared_edits_lg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/models/retro/__pycache__/shared_edits_lg.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/models/retro/separate_edits_lg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Dict, Tuple, Union 3 | 4 | from seq_graph_retro.models import SingleEdit, MultiEdit, LGClassifier, LGIndEmbed 5 | 6 | EDIT_NET_DICT = {'SingleEdit': SingleEdit, "single_edit": SingleEdit, 'MultiEdit': MultiEdit} 7 | LG_NET_DICT = {'LGClassifier': LGClassifier, "lg_classifier": LGClassifier, 8 | "lg_ind": LGIndEmbed, 'LGIndEmbed': LGIndEmbed, 'LGIndEmbedClassifier': LGIndEmbed} 9 | 10 | class EditLGSeparate: 11 | 12 | def __init__(self, 13 | edits_config: Dict, 14 | lg_config: Dict, 15 | edit_net_name: str = 'SingleEdit', 16 | lg_net_name: str = 'LGClassifier', 17 | device: str = 'cpu', 18 | **kwargs): 19 | """ 20 | Parameters 21 | ---------- 22 | edits_config: Dict, 23 | Config for the edit prediction model 24 | lg_config: Dict, 25 | Config for the leaving group prediction model 26 | edit_net_name: str, default BondEdits, 27 | Name of the edit prediction network 28 | lg_net_name: str, default LGClassifier, 29 | Name of LGClassifier network 30 | """ 31 | edit_model_class = EDIT_NET_DICT.get(edit_net_name) 32 | lg_model_class = LG_NET_DICT.get(lg_net_name) 33 | self.edit_net = edit_model_class(**edits_config, device=device) 34 | self.lg_net = lg_model_class(**lg_config, device=device) 35 | self.device = device 36 | 37 | def to_device(self, tensors: Union[List, torch.Tensor]) -> Union[List, torch.Tensor]: 38 | """Converts all inputs to the device used. 39 | 40 | Parameters 41 | ---------- 42 | tensors: Union[List, torch.Tensor], 43 | Tensors to convert to model device. The tensors can be either a 44 | single tensor or an iterable of tensors. 45 | """ 46 | if isinstance(tensors, list) or isinstance(tensors, tuple): 47 | tensors = [tensor.to(self.device, non_blocking=True) if tensor is not None else None for tensor in tensors] 48 | return tensors 49 | elif isinstance(tensors, torch.Tensor): 50 | return tensors.to(self.device, non_blocking=True) 51 | else: 52 | raise ValueError(f"Tensors of type {type(tensors)} unsupported") 53 | 54 | def _compute_edit_logits(self, graph_tensors: Union[Tuple[torch.Tensor], List[Tuple[torch.Tensor]]], 55 | scopes: Tuple[List], ha=None, bg_inputs=None) -> Tuple[torch.Tensor]: 56 | """Computes the edit logits for given tensors. 57 | 58 | Parameters 59 | ---------- 60 | graph_tensors: Union[Tuple[torch.Tensor], List[Tuple[torch.Tensor]]], 61 | Graph tensors used. Could be a List of these tensors, or just an 62 | individual one 63 | scopes: Tuple[List], 64 | 65 | """ 66 | return self.edit_net._compute_edit_logits(graph_tensors, scopes, ha=ha, bg_inputs=bg_inputs) 67 | 68 | def _compute_lg_logits(self, graph_vecs_pad, prod_vecs, lg_labels=None): 69 | return self.lg_net._compute_lg_logits(graph_vecs_pad=graph_vecs_pad, 70 | prod_vecs=prod_vecs, 71 | lg_labels=lg_labels) 72 | 73 | def _compute_lg_step(self, graph_vecs, prod_vecs, prev_embed=None): 74 | return self.lg_net._compute_lg_step(graph_vecs=graph_vecs, 75 | prod_vecs=prod_vecs, 76 | prev_embed=prev_embed) 77 | 78 | def to(self, device: str) -> None: 79 | """Convert to device. 80 | 81 | Parameters 82 | ---------- 83 | device: str, 84 | Device used 85 | """ 86 | self.edit_net.to(device) 87 | self.lg_net.to(device) 88 | 89 | def eval(self) -> None: 90 | """Turn the network into eval mode.""" 91 | self.edit_net.eval() 92 | self.lg_net.eval() 93 | 94 | def load_state_dict(self, edit_state: Dict, lg_state: Dict) -> None: 95 | """Loads state dict. 96 | 97 | Parameters 98 | ---------- 99 | edit_state: Dict, 100 | State dict for the edit prediction network 101 | lg_state: Dict, 102 | State dict for the leaving group network. 103 | """ 104 | self.edit_net.load_state_dict(edit_state) 105 | self.lg_net.load_state_dict(lg_state) 106 | 107 | def predict(self, prod_smi: str, rxn_class: int = None, **kwargs) -> Tuple[List]: 108 | """Make predictions for given product smiles string. 109 | 110 | Parameters 111 | ---------- 112 | prod_smi: str, 113 | Product SMILES string 114 | """ 115 | edits = self.edit_net.predict(prod_smi, rxn_class=rxn_class) 116 | if not isinstance(edits, list): 117 | edits = [edits] 118 | 119 | try: 120 | labels = self.lg_net.predict(prod_smi, core_edits=edits, rxn_class=rxn_class) 121 | return edits, labels 122 | 123 | except: 124 | return edits, [] 125 | -------------------------------------------------------------------------------- /seq_graph_retro/molgraph/__init__.py: -------------------------------------------------------------------------------- 1 | from seq_graph_retro.molgraph.rxn_graphs import RxnGraph, RxnElement 2 | from seq_graph_retro.molgraph.rxn_graphs import BondEditsRxn, MultiElement 3 | -------------------------------------------------------------------------------- /seq_graph_retro/molgraph/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/molgraph/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/molgraph/__pycache__/mol_features.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/molgraph/__pycache__/mol_features.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/molgraph/__pycache__/rxn_graphs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/molgraph/__pycache__/rxn_graphs.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/molgraph/__pycache__/vocab.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/molgraph/__pycache__/vocab.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/molgraph/mol_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | from typing import Set, Any, List, Union 4 | 5 | from seq_graph_retro.utils.chem import get_mol 6 | 7 | idxfunc = lambda a : a.GetAtomMapNum() - 1 8 | bond_idx_fn = lambda a, b, mol: mol.GetBondBetweenAtoms(a.GetIdx(), b.GetIdx()).GetIdx() 9 | 10 | # Symbols for different atoms 11 | ATOM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', \ 12 | 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', \ 13 | 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', \ 14 | 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Ta', 'Tc', 'Ba', 'Bi', 'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir', \ 15 | 'Ce','Gd','Ga','Cs', '*', 'unk'] 16 | 17 | MAX_NB = 10 18 | DEGREES = list(range(MAX_NB)) 19 | HYBRIDIZATION = [Chem.rdchem.HybridizationType.SP, 20 | Chem.rdchem.HybridizationType.SP2, 21 | Chem.rdchem.HybridizationType.SP3, 22 | Chem.rdchem.HybridizationType.SP3D, 23 | Chem.rdchem.HybridizationType.SP3D2] 24 | 25 | FORMAL_CHARGE = [-1, -2, 1, 2, 0] 26 | VALENCE = [0, 1, 2, 3, 4, 5, 6] 27 | NUM_Hs = [0, 1, 3, 4, 5] 28 | 29 | BOND_TYPES = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, \ 30 | Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] 31 | BOND_FLOAT_TO_TYPE = { 32 | 0.0: BOND_TYPES[0], 33 | 1.0: BOND_TYPES[1], 34 | 2.0: BOND_TYPES[2], 35 | 3.0: BOND_TYPES[3], 36 | 1.5: BOND_TYPES[4], 37 | } 38 | 39 | BOND_DELTAS = {-3: 0, -2: 1, -1.5: 2, -1: 3, -0.5: 4, 0: 5, 0.5: 6, 1: 7, 1.5: 8, 2:9, 3:10} 40 | BOND_FLOATS = [0.0, 1.0, 2.0, 3.0, 1.5] 41 | RXN_CLASSES = list(range(10)) 42 | 43 | ATOM_FDIM = len(ATOM_LIST) + len(DEGREES) + len(FORMAL_CHARGE) + len(HYBRIDIZATION) \ 44 | + len(VALENCE) + len(NUM_Hs) + 1 45 | BOND_FDIM = 6 46 | BINARY_FDIM = 5 + BOND_FDIM 47 | INVALID_BOND = -1 48 | PATTERN_DIM = 389 49 | 50 | def sanitize(mol, kekulize: bool = True) -> Chem.Mol: 51 | """Sanitize mol. 52 | Parameters 53 | ---------- 54 | mol: Chem.Mol 55 | Molecule to sanitize 56 | kekulize: bool 57 | Whether to kekulize the molecule 58 | """ 59 | try: 60 | smiles = get_smiles(mol) if kekulize else Chem.MolToSmiles(mol) 61 | mol = get_mol(smiles) if kekulize else Chem.MolFromSmiles(smiles) 62 | except: 63 | mol = None 64 | return mol 65 | 66 | def onek_encoding_unk(x: Any, allowable_set: Union[List, Set]) -> List: 67 | """Converts x to one hot encoding. 68 | 69 | Parameters 70 | ---------- 71 | x: Any, 72 | An element of any type 73 | allowable_set: Union[List, Set] 74 | Allowable element collection 75 | """ 76 | if x not in allowable_set: 77 | x = allowable_set[-1] 78 | return list(map(lambda s: float(x == s), allowable_set)) 79 | 80 | def get_atom_features(atom: Chem.Atom, rxn_class: int = None, use_rxn_class: bool = False) -> np.ndarray: 81 | """Get atom features. 82 | 83 | Parameters 84 | ---------- 85 | atom: Chem.Atom, 86 | Atom object from RDKit 87 | rxn_class: int, None 88 | Reaction class the molecule was part of 89 | use_rxn_class: bool, default False, 90 | Whether to use reaction class as additional input 91 | """ 92 | if atom.GetSymbol() == '*': 93 | symbol = onek_encoding_unk(atom.GetSymbol(), ATOM_LIST) 94 | if use_rxn_class: 95 | padding = [0] * (ATOM_FDIM + len(RXN_CLASSES)- len(symbol)) 96 | else: 97 | padding = [0] * (ATOM_FDIM - len(symbol)) 98 | feature_array = symbol + padding 99 | return feature_array 100 | 101 | if use_rxn_class: 102 | return np.array(onek_encoding_unk(atom.GetSymbol(), ATOM_LIST) + 103 | onek_encoding_unk(atom.GetDegree(), DEGREES) + 104 | onek_encoding_unk(atom.GetFormalCharge(), FORMAL_CHARGE) + 105 | onek_encoding_unk(atom.GetHybridization(), HYBRIDIZATION) + 106 | onek_encoding_unk(atom.GetTotalValence(), VALENCE) + 107 | onek_encoding_unk(atom.GetTotalNumHs(), NUM_Hs) + 108 | [float(atom.GetIsAromatic())] + onek_encoding_unk(rxn_class, RXN_CLASSES)).tolist() 109 | 110 | else: 111 | return np.array(onek_encoding_unk(atom.GetSymbol(), ATOM_LIST) + 112 | onek_encoding_unk(atom.GetDegree(), DEGREES) + 113 | onek_encoding_unk(atom.GetFormalCharge(), FORMAL_CHARGE) + 114 | onek_encoding_unk(atom.GetHybridization(), HYBRIDIZATION) + 115 | onek_encoding_unk(atom.GetTotalValence(), VALENCE) + 116 | onek_encoding_unk(atom.GetTotalNumHs(), NUM_Hs) + 117 | [float(atom.GetIsAromatic())]).tolist() 118 | 119 | def get_binary_features(mol: Chem.Mol) -> np.ndarray: 120 | """ 121 | This function is used to generate descriptions of atom-atom relationships, including 122 | the bond type between the atoms (if any) and whether they belong to the same molecule. 123 | It is used in the global attention mechanism. 124 | 125 | Parameters 126 | ---------- 127 | mol: Chem.Mol, 128 | Molecule for which we want to compute binary features. 129 | """ 130 | comp = {} 131 | amap_idx = {atom.GetAtomMapNum(): atom.GetIdx() for atom in mol.GetAtoms()} 132 | for atom in mol.GetAtoms(): 133 | comp[amap_idx[atom.GetAtomMapNum()]] = 0 134 | 135 | n_comp = 1 136 | n_atoms = mol.GetNumAtoms() 137 | 138 | bond_map = {} 139 | for bond in mol.GetBonds(): 140 | a1 = amap_idx[bond.GetBeginAtom().GetAtomMapNum()] 141 | a2 = amap_idx[bond.GetEndAtom().GetAtomMapNum()] 142 | bond_map[(a1, a2)] = bond_map[(a2, a1)] = bond 143 | 144 | features = [] 145 | for i in range(n_atoms): 146 | for j in range(n_atoms): 147 | f = np.zeros((BINARY_FDIM,)) 148 | if (i, j) in bond_map: 149 | bond = bond_map[(i, j)] 150 | f[1:1 + BOND_FDIM] = get_bond_features(bond) 151 | else: 152 | f[0] = 1.0 153 | f[-4] = 1.0 if comp[i] != comp[j] else 0.0 154 | f[-3] = 1.0 if comp[i] == comp[j] else 0.0 155 | f[-2] = 1.0 if n_comp == 1 else 0.0 156 | f[-1] = 1.0 if n_comp > 1 else 0.0 157 | features.append(f) 158 | return np.vstack(features).reshape((n_atoms, n_atoms, BINARY_FDIM)) 159 | 160 | def get_bond_features(bond: Chem.Bond) -> np.ndarray: 161 | """Get bond features. 162 | 163 | Parameters 164 | ---------- 165 | bond: Chem.Bond, 166 | bond object 167 | """ 168 | bt = bond.GetBondType() 169 | bond_features = [float(bt == bond_type) for bond_type in BOND_TYPES[1:]] 170 | bond_features.extend([float(bond.GetIsConjugated()), float(bond.IsInRing())]) 171 | bond_features = np.array(bond_features, dtype=np.float32) 172 | return bond_features 173 | 174 | def get_atom_graph(mol: Chem.Mol) -> np.ndarray: 175 | """Atom graph is the adjacency list of atoms and its neighbors. 176 | 177 | Parameters 178 | ---------- 179 | mol: Chem.Mol, 180 | Molecule for which we want to compute atom graph. 181 | """ 182 | agraph = np.zeros((mol.GetNumAtoms(), MAX_NB), dtype=np.int32) 183 | for idx, atom in enumerate(mol.GetAtoms()): 184 | nei_indices = [nei.GetIdx() + 1 for nei in atom.GetNeighbors()] 185 | agraph[idx, :len(nei_indices)] = nei_indices 186 | return agraph 187 | 188 | def get_bond_graph(mol: Chem.Mol) -> np.ndarray: 189 | """Bond graph is the adjacency list of bond indices for each atom and its neighbors. 190 | 191 | Parameters 192 | ---------- 193 | mol: Chem.Mol, 194 | Molecule for which we want to compute bond graph 195 | """ 196 | bgraph = np.zeros((mol.GetNumAtoms(), MAX_NB), dtype=np.int32) 197 | for idx, atom in enumerate(mol.GetAtoms()): 198 | bond_indices = [bond_idx_fn(atom, nei, mol) + 1 for nei in atom.GetNeighbors()] 199 | bgraph[idx, :len(bond_indices)] = bond_indices 200 | return bgraph 201 | -------------------------------------------------------------------------------- /seq_graph_retro/molgraph/rxn_graphs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | import networkx as nx 4 | from typing import List, Tuple, Union 5 | 6 | from seq_graph_retro.utils.chem import get_sub_mol 7 | from seq_graph_retro.molgraph.mol_features import BOND_TYPES, BOND_FLOATS 8 | 9 | class RxnGraph: 10 | """ 11 | RxnGraph is an abstract class for storing all elements of a reaction, like 12 | reactants, products and fragments. The edits associated with the reaction 13 | are also captured in edit labels. One can also use h_labels, which keep track 14 | of atoms with hydrogen changes. For reactions with multiple edits, a done 15 | label is also added to account for termination of edits. 16 | """ 17 | 18 | def __init__(self, 19 | prod_mol: Chem.Mol, 20 | frag_mol: Chem.Mol = None, 21 | reac_mol: Chem.Mol = None, 22 | edits_to_apply: List = [], 23 | rxn_class: int = None) -> None: 24 | """ 25 | Parameters 26 | ---------- 27 | prod_mol: Chem.Mol, 28 | Product molecule 29 | frag_mol: Chem.Mol, default None 30 | Fragment molecule(s) 31 | reac_mol: Chem.Mol, default None 32 | Reactant molecule(s) 33 | edits_to_apply: List, default [], 34 | Edits to apply to the product molecule, captured in edit_/h_labels 35 | rxn_class: int, default None, 36 | Reaction class for this reaction. 37 | """ 38 | self.prod_mol = RxnElement(mol=prod_mol, rxn_class=rxn_class) 39 | if frag_mol is not None: 40 | self.frag_mol = MultiElement(mol=frag_mol, rxn_class=rxn_class) 41 | if reac_mol is not None: 42 | self.reac_mol = MultiElement(mol=reac_mol, rxn_class=rxn_class) 43 | self.edits_to_apply = edits_to_apply 44 | self.edit_label, self.h_label, self.done_label = self._get_labels() 45 | self.rxn_class = rxn_class 46 | 47 | def _get_labels(self) -> Tuple[np.ndarray]: 48 | """Returns the different labels associated with the reaction.""" 49 | return None, None, None 50 | 51 | def get_attributes(self, 52 | mol_attrs: List = ['prod_mol', 'frag_mol', 'reac_mol'], 53 | label_attrs: List = ['edit_label', 'h_label']) -> Tuple: 54 | """ 55 | Parameters 56 | ---------- 57 | Returns the different attributes associated with the reaction graph. 58 | 59 | mol_attrs: List, 60 | Molecule objects to return 61 | label_attrs: List, 62 | Label attributes to return. Individual label attrs are coerced into 63 | a single label 64 | """ 65 | mol_tuple = () 66 | label_tuple = () 67 | 68 | for attr in mol_attrs: 69 | if hasattr(self, attr): 70 | mol_tuple += (getattr(self, attr),) 71 | else: 72 | print(f"Does not have attr {attr}") 73 | 74 | for attr in label_attrs: 75 | if hasattr(self, attr): 76 | label_tuple += (getattr(self, attr).flatten(), ) 77 | 78 | if len(label_tuple): 79 | label_tuple = np.concatenate(label_tuple) 80 | new_tuple = mol_tuple + (label_tuple,) 81 | return new_tuple 82 | 83 | return mol_tuple 84 | 85 | class RxnElement: 86 | """ 87 | RxnElement is an abstract class for dealing with single molecule. The graph 88 | and corresponding molecule attributes are built for the molecule. The constructor 89 | accepts only mol objects, sidestepping the use of SMILES string which may always 90 | not be achievable, especially for a unkekulizable molecule. 91 | """ 92 | 93 | def __init__(self, mol: Chem.Mol, rxn_class: int = None) -> None: 94 | """ 95 | Parameters 96 | ---------- 97 | mol: Chem.Mol, 98 | Molecule 99 | rxn_class: int, default None, 100 | Reaction class for this reaction. 101 | """ 102 | self.mol = mol 103 | self.rxn_class = rxn_class 104 | self._build_mol() 105 | self._build_graph() 106 | 107 | def _build_mol(self) -> None: 108 | """Builds the molecule attributes.""" 109 | self.num_atoms = self.mol.GetNumAtoms() 110 | self.num_bonds = self.mol.GetNumBonds() 111 | self.amap_to_idx = {atom.GetAtomMapNum(): atom.GetIdx() 112 | for atom in self.mol.GetAtoms()} 113 | self.idx_to_amap = {value: key for key, value in self.amap_to_idx.items()} 114 | 115 | def _build_graph(self) -> None: 116 | """Builds the graph attributes.""" 117 | self.G_undir = nx.Graph(Chem.rdmolops.GetAdjacencyMatrix(self.mol)) 118 | self.G_dir = nx.DiGraph(Chem.rdmolops.GetAdjacencyMatrix(self.mol)) 119 | 120 | for atom in self.mol.GetAtoms(): 121 | self.G_undir.nodes[atom.GetIdx()]['label'] = atom.GetSymbol() 122 | self.G_dir.nodes[atom.GetIdx()]['label'] = atom.GetSymbol() 123 | 124 | for bond in self.mol.GetBonds(): 125 | a1 = bond.GetBeginAtom().GetIdx() 126 | a2 = bond.GetEndAtom().GetIdx() 127 | btype = BOND_TYPES.index( bond.GetBondType() ) 128 | self.G_undir[a1][a2]['label'] = btype 129 | self.G_dir[a1][a2]['label'] = btype 130 | self.G_dir[a2][a1]['label'] = btype 131 | 132 | self.atom_scope = (0, self.num_atoms) 133 | self.bond_scope = (0, self.num_bonds) 134 | 135 | #CHECK IF THESE TWO ARE NEEDED 136 | def update_atom_scope(self, offset: int) -> Union[List, Tuple]: 137 | """Updates the atom indices by the offset. 138 | 139 | Parameters 140 | ---------- 141 | offset: int, 142 | Offset to apply 143 | """ 144 | if isinstance(self.atom_scope, list): 145 | return [(st + offset, le) for (st, le) in self.atom_scope] 146 | st, le = self.atom_scope 147 | return (st + offset, le) 148 | 149 | def update_bond_scope(self, offset: int) -> Union[List, Tuple]: 150 | """Updates the atom indices by the offset. 151 | 152 | Parameters 153 | ---------- 154 | offset: int, 155 | Offset to apply 156 | """ 157 | if isinstance(self.bond_scope, list): 158 | return [(st + offset, le) for (st, le) in self.bond_scope] 159 | st, le = self.bond_scope 160 | return (st + offset, le) 161 | 162 | 163 | class BondEditsRxn(RxnGraph): 164 | 165 | def _get_labels(self) -> Tuple[np.ndarray]: 166 | """Returns the different labels associated with the reaction.""" 167 | edit_label = np.zeros((self.prod_mol.num_bonds, len(BOND_FLOATS))) 168 | h_label = np.zeros(self.prod_mol.num_atoms) 169 | done_label = np.zeros((1,)) 170 | 171 | if not isinstance(self.edits_to_apply, list): 172 | edits_to_apply = [self.edits_to_apply] 173 | else: 174 | edits_to_apply = self.edits_to_apply 175 | 176 | if len(edits_to_apply) == 0: 177 | done_label[0] = 1.0 178 | return edit_label, h_label, done_label 179 | 180 | else: 181 | for edit in edits_to_apply: 182 | a1, a2, b1, b2 = edit.split(":") 183 | a1, a2 = int(a1), int(a2) 184 | b1, b2 = float(b1), float(b2) 185 | 186 | if a2 == 0: 187 | a_start = self.prod_mol.amap_to_idx[a1] 188 | h_label[a_start] = 1 189 | else: 190 | #delta = b2 - b1 191 | a_start, a_end = self.prod_mol.amap_to_idx[a1], self.prod_mol.amap_to_idx[a2] 192 | 193 | b_idx = self.prod_mol.mol.GetBondBetweenAtoms(a_start, a_end).GetIdx() 194 | edit_label[b_idx][BOND_FLOATS.index(b2)] = 1 195 | 196 | return edit_label, h_label, done_label 197 | 198 | 199 | class MultiElement(RxnElement): 200 | """ 201 | MultiElement is an abstract class for dealing with multiple molecules. The graph 202 | is built with all molecules, but different molecules and their sizes are stored. 203 | The constructor accepts only mol objects, sidestepping the use of SMILES string 204 | which may always not be achievable, especially for an invalid intermediates. 205 | """ 206 | def _build_graph(self) -> None: 207 | """Builds the graph attributes.""" 208 | self.G_undir = nx.Graph(Chem.rdmolops.GetAdjacencyMatrix(self.mol)) 209 | self.G_dir = nx.DiGraph(Chem.rdmolops.GetAdjacencyMatrix(self.mol)) 210 | 211 | for atom in self.mol.GetAtoms(): 212 | self.G_undir.nodes[atom.GetIdx()]['label'] = atom.GetSymbol() 213 | self.G_dir.nodes[atom.GetIdx()]['label'] = atom.GetSymbol() 214 | 215 | for bond in self.mol.GetBonds(): 216 | a1 = bond.GetBeginAtom().GetIdx() 217 | a2 = bond.GetEndAtom().GetIdx() 218 | btype = BOND_TYPES.index( bond.GetBondType() ) 219 | self.G_undir[a1][a2]['label'] = btype 220 | self.G_dir[a1][a2]['label'] = btype 221 | self.G_dir[a2][a1]['label'] = btype 222 | 223 | frag_indices = [c for c in nx.strongly_connected_components(self.G_dir)] 224 | self.mols = [get_sub_mol(self.mol, sub_atoms) for sub_atoms in frag_indices] 225 | 226 | atom_start = 0 227 | bond_start = 0 228 | self.atom_scope = [] 229 | self.bond_scope = [] 230 | 231 | for mol in self.mols: 232 | self.atom_scope.append((atom_start, mol.GetNumAtoms())) 233 | self.bond_scope.append((bond_start, mol.GetNumBonds())) 234 | atom_start += mol.GetNumAtoms() 235 | bond_start += mol.GetNumBonds() 236 | -------------------------------------------------------------------------------- /seq_graph_retro/molgraph/vocab.py: -------------------------------------------------------------------------------- 1 | from seq_graph_retro.molgraph.mol_features import ATOM_LIST 2 | 3 | class Vocab: 4 | """Vocab class to deal with atom vocabularies and other attributes.""" 5 | 6 | def __init__(self, elem_list=ATOM_LIST[:-1]) -> None: 7 | """ 8 | Parameters 9 | ---------- 10 | elem_list: List, default ATOM_LIST 11 | Element list used for setting up the vocab 12 | """ 13 | self.elem_list = elem_list 14 | if isinstance(elem_list, dict): 15 | self.elem_list = list(elem_list.keys()) 16 | self.elem_to_idx = {a: idx for idx, a in enumerate(self.elem_list)} 17 | self.idx_to_elem = {idx: a for idx, a in enumerate(self.elem_list)} 18 | 19 | def __getitem__(self, a_type: str) -> int: 20 | return self.elem_to_idx[a_type] 21 | 22 | def get(self, elem: str, idx: int = None) -> int: 23 | """Returns the index of the element, else a None for missing element. 24 | 25 | Parameters 26 | ---------- 27 | elem: str, 28 | Element to query 29 | idx: int, default None 30 | Index to return if element not in vocab 31 | """ 32 | return self.elem_to_idx.get(elem, idx) 33 | 34 | def get_elem(self, idx: int) -> str: 35 | """Returns the element at given index. 36 | 37 | Parameters 38 | ---------- 39 | idx: int, 40 | Index to return if element not in vocab 41 | """ 42 | return self.idx_to_elem[idx] 43 | 44 | def __len__(self) -> int: 45 | return len(self.elem_list) 46 | 47 | def index(self, elem: str) -> int: 48 | """Returns the index of the element. 49 | 50 | Parameters 51 | ---------- 52 | elem: str, 53 | Element to query 54 | """ 55 | return self.elem_to_idx[elem] 56 | 57 | def size(self) -> int: 58 | """Returns length of Vocab.""" 59 | return len(self.elem_list) 60 | 61 | COMMON_ATOMS = [('B', 0), ('Br', 0), ('C', -1), ('C', 0), ('Cl', 0), ('Cu', 0), ('F', 0), 62 | ('I', 0), ('Mg', 0), ('Mg', 1), ('N', -1), ('N', 0), ('N', 1), ('O', -1), ('O', 0), 63 | ('P', 0), ('P', 1), ('S', -1), ('S', 0), ('S', 1), ('Se', 0), ('Si', 0), ('Sn', 0), 64 | ('Zn', 0), ('Zn', 1)] 65 | 66 | common_atom_vocab = Vocab(COMMON_ATOMS) 67 | -------------------------------------------------------------------------------- /seq_graph_retro/search/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/search/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(v: str) -> bool: 4 | v = v.lower() 5 | if v == "true": 6 | return True 7 | elif v == "false": 8 | return False 9 | else: 10 | raise argparse.ArgumentTypeError(f"Boolean value expected, got '{v}'.") 11 | -------------------------------------------------------------------------------- /seq_graph_retro/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/utils/__pycache__/chem.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/utils/__pycache__/chem.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/utils/__pycache__/edit_mol.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/utils/__pycache__/edit_mol.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/utils/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/utils/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/utils/__pycache__/parse.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/utils/__pycache__/parse.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/utils/__pycache__/torch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsomnath/graphretro/971712b3874c178fe511daacabf3310f9340e061/seq_graph_retro/utils/__pycache__/torch.cpython-37.pyc -------------------------------------------------------------------------------- /seq_graph_retro/utils/chem.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | from typing import Iterable, List 3 | 4 | # Redfined from seq_graph_retro/molgraph/mol_features.py 5 | BOND_TYPES = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, \ 6 | Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] 7 | BOND_FLOAT_TO_TYPE = { 8 | 0.0: BOND_TYPES[0], 9 | 1.0: BOND_TYPES[1], 10 | 2.0: BOND_TYPES[2], 11 | 3.0: BOND_TYPES[3], 12 | 1.5: BOND_TYPES[4], 13 | } 14 | 15 | def get_mol(smiles: str, kekulize: bool = False) -> Chem.Mol: 16 | """SMILES string to Mol. 17 | 18 | Parameters 19 | ---------- 20 | smiles: str, 21 | SMILES string for molecule 22 | kekulize: bool, 23 | Whether to kekulize the molecule 24 | """ 25 | mol = Chem.MolFromSmiles(smiles) 26 | if mol is not None and kekulize: 27 | Chem.Kekulize(mol) 28 | return mol 29 | 30 | def apply_edits_to_mol(mol: Chem.Mol, edits: Iterable[str]) -> Chem.Mol: 31 | """Apply edits to molecular graph. 32 | 33 | Parameters 34 | ---------- 35 | mol: Chem.Mol, 36 | RDKit mol object 37 | edits: Iterable[str], 38 | Iterable of edits to apply. An edit is structured as a1:a2:b1:b2, where 39 | a1, a2 are atom maps of participating atoms and b1, b2 are previous and 40 | new bond orders. When a2 = 0, we update the hydrogen count. 41 | """ 42 | new_mol = Chem.RWMol(mol) 43 | amap = {atom.GetAtomMapNum(): atom.GetIdx() for atom in new_mol.GetAtoms()} 44 | 45 | # Keep track of aromatic nitrogens, might cause explicit hydrogen issues 46 | aromatic_nitrogen_idx = set() 47 | aromatic_carbonyl_adj_to_aromatic_nH = {} 48 | aromatic_carbondeg3_adj_to_aromatic_nH0 = {} 49 | for a in new_mol.GetAtoms(): 50 | if a.GetIsAromatic() and a.GetSymbol() == 'N': 51 | aromatic_nitrogen_idx.add(a.GetIdx()) 52 | for nbr in a.GetNeighbors(): 53 | nbr_is_carbon = (nbr.GetSymbol() == 'C') 54 | nbr_is_aromatic = nbr.GetIsAromatic() 55 | nbr_has_double_bond = any(b.GetBondTypeAsDouble() == 2 for b in nbr.GetBonds()) 56 | nbr_has_3_bonds = (len(nbr.GetBonds()) == 3) 57 | 58 | if (a.GetNumExplicitHs() ==1 and nbr_is_carbon and nbr_is_aromatic 59 | and nbr_has_double_bond): 60 | aromatic_carbonyl_adj_to_aromatic_nH[nbr.GetIdx()] = a.GetIdx() 61 | elif (a.GetNumExplicitHs() == 0 and nbr_is_carbon and nbr_is_aromatic 62 | and nbr_has_3_bonds): 63 | aromatic_carbondeg3_adj_to_aromatic_nH0[nbr.GetIdx()] = a.GetIdx() 64 | else: 65 | a.SetNumExplicitHs(0) 66 | new_mol.UpdatePropertyCache() 67 | 68 | # Apply the edits as predicted 69 | for edit in edits: 70 | x, y, prev_bo, new_bo = edit.split(":") 71 | x, y = int(x), int(y) 72 | new_bo = float(new_bo) 73 | 74 | if y == 0: 75 | continue 76 | 77 | bond = new_mol.GetBondBetweenAtoms(amap[x],amap[y]) 78 | a1 = new_mol.GetAtomWithIdx(amap[x]) 79 | a2 = new_mol.GetAtomWithIdx(amap[y]) 80 | 81 | if bond is not None: 82 | new_mol.RemoveBond(amap[x],amap[y]) 83 | 84 | # Are we losing a bond on an aromatic nitrogen? 85 | if bond.GetBondTypeAsDouble() == 1.0: 86 | if amap[x] in aromatic_nitrogen_idx: 87 | if a1.GetTotalNumHs() == 0: 88 | a1.SetNumExplicitHs(1) 89 | elif a1.GetFormalCharge() == 1: 90 | a1.SetFormalCharge(0) 91 | elif amap[y] in aromatic_nitrogen_idx: 92 | if a2.GetTotalNumHs() == 0: 93 | a2.SetNumExplicitHs(1) 94 | elif a2.GetFormalCharge() == 1: 95 | a2.SetFormalCharge(0) 96 | 97 | # Are we losing a c=O bond on an aromatic ring? If so, remove H from adjacent nH if appropriate 98 | if bond.GetBondTypeAsDouble() == 2.0: 99 | if amap[x] in aromatic_carbonyl_adj_to_aromatic_nH: 100 | new_mol.GetAtomWithIdx(aromatic_carbonyl_adj_to_aromatic_nH[amap[x]]).SetNumExplicitHs(0) 101 | elif amap[y] in aromatic_carbonyl_adj_to_aromatic_nH: 102 | new_mol.GetAtomWithIdx(aromatic_carbonyl_adj_to_aromatic_nH[amap[y]]).SetNumExplicitHs(0) 103 | 104 | if new_bo > 0: 105 | new_mol.AddBond(amap[x],amap[y],BOND_FLOAT_TO_TYPE[new_bo]) 106 | 107 | # Special alkylation case? 108 | if new_bo == 1: 109 | if amap[x] in aromatic_nitrogen_idx: 110 | if a1.GetTotalNumHs() == 1: 111 | a1.SetNumExplicitHs(0) 112 | else: 113 | a1.SetFormalCharge(1) 114 | elif amap[y] in aromatic_nitrogen_idx: 115 | if a2.GetTotalNumHs() == 1: 116 | a2.SetNumExplicitHs(0) 117 | else: 118 | a2.SetFormalCharge(1) 119 | 120 | # Are we getting a c=O bond on an aromatic ring? If so, add H to adjacent nH0 if appropriate 121 | if new_bo == 2: 122 | if amap[x] in aromatic_carbondeg3_adj_to_aromatic_nH0: 123 | new_mol.GetAtomWithIdx(aromatic_carbondeg3_adj_to_aromatic_nH0[amap[x]]).SetNumExplicitHs(1) 124 | elif amap[y] in aromatic_carbondeg3_adj_to_aromatic_nH0: 125 | new_mol.GetAtomWithIdx(aromatic_carbondeg3_adj_to_aromatic_nH0[amap[y]]).SetNumExplicitHs(1) 126 | 127 | pred_mol = new_mol.GetMol() 128 | 129 | # Clear formal charges to make molecules valid 130 | # Note: because S and P (among others) can change valence, be more flexible 131 | for atom in pred_mol.GetAtoms(): 132 | if atom.GetSymbol() == 'N' and atom.GetFormalCharge() == 1: # exclude negatively-charged azide 133 | bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]) 134 | if bond_vals <= 3: 135 | atom.SetFormalCharge(0) 136 | elif atom.GetSymbol() == 'N' and atom.GetFormalCharge() == -1: # handle negatively-charged azide addition 137 | bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]) 138 | if bond_vals == 3 and any([nbr.GetSymbol() == 'N' for nbr in atom.GetNeighbors()]): 139 | atom.SetFormalCharge(0) 140 | elif atom.GetSymbol() == 'N': 141 | bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]) 142 | if bond_vals == 4 and not atom.GetIsAromatic(): # and atom.IsInRingSize(5)): 143 | atom.SetFormalCharge(1) 144 | elif atom.GetSymbol() == 'C' and atom.GetFormalCharge() != 0: 145 | atom.SetFormalCharge(0) 146 | elif atom.GetSymbol() == 'O' and atom.GetFormalCharge() != 0: 147 | bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]) + atom.GetNumExplicitHs() 148 | if bond_vals == 2: 149 | atom.SetFormalCharge(0) 150 | elif atom.GetSymbol() in ['Cl', 'Br', 'I', 'F'] and atom.GetFormalCharge() != 0: 151 | bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]) 152 | if bond_vals == 1: 153 | atom.SetFormalCharge(0) 154 | elif atom.GetSymbol() == 'S' and atom.GetFormalCharge() != 0: 155 | bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]) 156 | if bond_vals in [2, 4, 6]: 157 | atom.SetFormalCharge(0) 158 | elif atom.GetSymbol() == 'P': # quartenary phosphorous should be pos. charge with 0 H 159 | bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()] 160 | if sum(bond_vals) == 4 and len(bond_vals) == 4: 161 | atom.SetFormalCharge(1) 162 | atom.SetNumExplicitHs(0) 163 | elif sum(bond_vals) == 3 and len(bond_vals) == 3: # make sure neutral 164 | atom.SetFormalCharge(0) 165 | elif atom.GetSymbol() == 'B': # quartenary boron should be neg. charge with 0 H 166 | bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()] 167 | if sum(bond_vals) == 4 and len(bond_vals) == 4: 168 | atom.SetFormalCharge(-1) 169 | atom.SetNumExplicitHs(0) 170 | elif atom.GetSymbol() in ['Mg', 'Zn']: 171 | bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()] 172 | if sum(bond_vals) == 1 and len(bond_vals) == 1: 173 | atom.SetFormalCharge(1) 174 | elif atom.GetSymbol() == 'Si': 175 | bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()] 176 | if sum(bond_vals) == len(bond_vals): 177 | atom.SetNumExplicitHs(max(0, 4 - len(bond_vals))) 178 | 179 | return pred_mol 180 | 181 | def get_sub_mol(mol: Chem.Mol, sub_atoms: List[int]) -> Chem.Mol: 182 | """Extract subgraph from molecular graph. 183 | 184 | Parameters 185 | ---------- 186 | mol: Chem.Mol, 187 | RDKit mol object, 188 | sub_atoms: List[int], 189 | List of atom indices in the subgraph. 190 | """ 191 | new_mol = Chem.RWMol() 192 | atom_map = {} 193 | for idx in sub_atoms: 194 | atom = mol.GetAtomWithIdx(idx) 195 | atom_map[idx] = new_mol.AddAtom(atom) 196 | 197 | sub_atoms = set(sub_atoms) 198 | for idx in sub_atoms: 199 | a = mol.GetAtomWithIdx(idx) 200 | for b in a.GetNeighbors(): 201 | if b.GetIdx() not in sub_atoms: continue 202 | bond = mol.GetBondBetweenAtoms(a.GetIdx(), b.GetIdx()) 203 | bt = bond.GetBondType() 204 | if a.GetIdx() < b.GetIdx(): #each bond is enumerated twice 205 | new_mol.AddBond(atom_map[a.GetIdx()], atom_map[b.GetIdx()], bt) 206 | 207 | return new_mol.GetMol() 208 | 209 | def get_sub_mol_stereo(mol: Chem.Mol, sub_atoms: List[int]) -> Chem.Mol: 210 | """Extract subgraph from molecular graph, while preserving stereochemistry. 211 | 212 | Parameters 213 | ---------- 214 | mol: Chem.Mol, 215 | RDKit mol object, 216 | sub_atoms: List[int], 217 | List of atom indices in the subgraph. 218 | """ 219 | # This version retains stereochemistry, as opposed to the other version 220 | new_mol = Chem.RWMol(Chem.Mol(mol)) 221 | atoms_to_remove = [] 222 | for atom in mol.GetAtoms(): 223 | if atom.GetIdx() not in sub_atoms: 224 | atoms_to_remove.append(atom.GetIdx()) 225 | 226 | for atom_idx in sorted(atoms_to_remove, reverse=True): 227 | new_mol.RemoveAtom(atom_idx) 228 | 229 | return new_mol.GetMol() 230 | -------------------------------------------------------------------------------- /seq_graph_retro/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_accuracy_edits(edit_logits, labels): 4 | accuracy = torch.tensor(0.0) 5 | for i in range(len(edit_logits)): 6 | try: 7 | if torch.argmax(edit_logits[i]).item() == labels[i].item(): 8 | accuracy += 1.0 9 | except ValueError: 10 | if torch.argmax(edit_logits[i]).item() == torch.argmax(labels[i]).item(): 11 | accuracy += 1.0 12 | return accuracy / len(labels) 13 | 14 | def get_accuracy_overall(edit_logits, lg_logits, edit_labels, lg_labels, lengths, device='cpu'): 15 | acc_edits = torch.zeros(len(edit_labels), dtype=torch.long).to(device) 16 | acc_lg = torch.zeros(len(lg_labels), dtype=torch.long).to(device) 17 | 18 | _, preds = torch.max(lg_logits, dim=-1) 19 | 20 | for i in range(lg_logits.size(0)): 21 | length = lengths[i] 22 | result = torch.eq(preds[i][:length], lg_labels[i][:length]).float() 23 | if int(result.sum().item()) == length: 24 | acc_lg[i] = 1 25 | 26 | try: 27 | if torch.argmax(edit_logits[i]).item() == edit_labels[i].item(): 28 | acc_edits[i] = 1 29 | except ValueError: 30 | if torch.argmax(edit_logits[i]).item() == torch.argmax(edit_labels[i]).item(): 31 | acc_edits[i] = 1 32 | 33 | acc_overall = (acc_edits & acc_lg).float() 34 | return torch.mean(acc_overall) 35 | 36 | 37 | def get_edit_seq_accuracy(seq_edit_logits, seq_labels, seq_mask): 38 | max_seq_len = seq_mask.shape[0] 39 | batch_size = seq_mask.shape[1] 40 | assert len(seq_edit_logits) == max_seq_len 41 | assert len(seq_labels) == max_seq_len 42 | assert len(seq_edit_logits[0]) == batch_size 43 | lengths = seq_mask.sum(dim=0).flatten() 44 | 45 | check_equals = lambda x, y: torch.argmax(x) == torch.argmax(y) 46 | 47 | acc_matrix = torch.stack([torch.stack([(check_equals(seq_edit_logits[idx][bid], seq_labels[idx][bid])).long() 48 | for idx in range(max_seq_len)]) for bid in range(batch_size)]) 49 | assert acc_matrix.shape == (batch_size, max_seq_len) 50 | acc_matrix = acc_matrix.to(seq_mask.device) * seq_mask.t() 51 | num_correct = acc_matrix.sum(dim=1) 52 | assert len(num_correct) == batch_size 53 | accuracy = (num_correct == lengths).float().mean() 54 | return accuracy 55 | 56 | 57 | def get_seq_accuracy_overall(seq_edit_logits, lg_logits, seq_labels, lg_labels, lg_lengths, seq_mask): 58 | acc_edits = torch.zeros(seq_mask.shape[-1], dtype=torch.long).to(seq_mask.device) 59 | acc_lg = torch.zeros(seq_mask.shape[-1], dtype=torch.long).to(seq_mask.device) 60 | 61 | _, preds = torch.max(lg_logits, dim=-1) 62 | max_seq_len = len(seq_edit_logits) 63 | batch_size = seq_mask.shape[1] 64 | lengths = seq_mask.sum(dim=0).flatten() 65 | 66 | check_equals = lambda x, y: torch.argmax(x) == torch.argmax(y) 67 | 68 | acc_matrix = torch.stack([torch.stack([(check_equals(seq_edit_logits[idx][bid], seq_labels[idx][bid])).long() 69 | for idx in range(max_seq_len)]) for bid in range(batch_size)]) 70 | acc_matrix = acc_matrix.to(seq_mask.device) * seq_mask.t() 71 | num_correct = acc_matrix.sum(dim=1) 72 | 73 | acc_edits = (num_correct == lengths).long() 74 | 75 | for i in range(lg_logits.size(0)): 76 | length = lg_lengths[i] 77 | result = torch.eq(preds[i][:length], lg_labels[i][:length]).float() 78 | if int(result.sum().item()) == length: 79 | acc_lg[i] = 1 80 | 81 | acc_overall = (acc_edits & acc_lg).float() 82 | return torch.mean(acc_overall) 83 | 84 | 85 | def get_accuracy_bin(scores, labels): 86 | preds = torch.ge(scores, 0).float() 87 | acc = torch.eq(preds, labels.float()).float() 88 | return torch.sum(acc) / labels.nelement() 89 | 90 | 91 | def get_accuracy(scores, labels): 92 | _,preds = torch.max(scores, dim=-1) 93 | acc = torch.eq(preds, labels).float() 94 | return torch.sum(acc) / labels.nelement() 95 | 96 | 97 | def get_accuracy_lg(scores, labels, lengths, device='cpu'): 98 | _, preds = torch.max(scores, dim=-1) 99 | results = torch.zeros(scores.size(0), dtype=torch.float).to(device) 100 | 101 | for i in range(scores.size(0)): 102 | length = lengths[i] 103 | result = torch.eq(preds[i][:length], labels[i][:length]).float() 104 | if int(result.sum().item()) == length: 105 | results[i] = 1 106 | return torch.sum(results) / scores.size(0) 107 | -------------------------------------------------------------------------------- /seq_graph_retro/utils/torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim import Optimizer 5 | import numpy as np 6 | from typing import List, Tuple, Any, Optional, Union, Set 7 | 8 | def zip_tensors(tup_list): 9 | arr0, arr1, arr2 = zip(*tup_list) 10 | if type(arr2[0]) is int: 11 | arr0 = torch.stack(arr0, dim=0) 12 | arr1 = torch.tensor(arr1, dtype=torch.long) 13 | arr2 = torch.tensor(arr2, dtype=torch.long) 14 | else: 15 | arr0 = torch.cat(arr0, dim=0) 16 | arr1 = [x for a in arr1 for x in a] 17 | arr1 = torch.tensor(arr1, dtype=torch.long) 18 | arr2 = torch.cat(arr2, dim=0) 19 | return arr0, arr1, arr2 20 | 21 | def create_pad_tensor(alist): 22 | max_len = max([len(a) for a in alist]) 23 | for a in alist: 24 | pad_len = max_len - len(a) 25 | a.extend([0] * pad_len) 26 | return torch.tensor(alist, dtype=torch.long) 27 | 28 | def index_scatter(sub_data, all_data, index): 29 | d0, d1 = all_data.size() 30 | buf = torch.zeros_like(all_data).scatter_(0, index.repeat(d1, 1).t(), sub_data) 31 | mask = torch.ones(d0, device=all_data.device).scatter_(0, index, 0) 32 | return all_data * mask.unsqueeze(-1) + buf 33 | 34 | 35 | def stack_pad_tensor(tensor_list): 36 | max_len = max([t.size(0) for t in tensor_list]) 37 | for i,tensor in enumerate(tensor_list): 38 | pad_len = max_len - tensor.size(0) 39 | tensor_list[i] = F.pad( tensor, (0,0,0,pad_len) ) 40 | return torch.stack(tensor_list, dim=0) 41 | 42 | 43 | def index_select_ND(source, dim, index): 44 | index_size = index.size() 45 | suffix_dim = source.size()[1:] 46 | final_size = index_size + suffix_dim 47 | target = source.index_select(dim, index.view(-1)) 48 | return target.view(final_size) 49 | 50 | 51 | class EncOptimizer(Optimizer): 52 | 53 | def __init__(self, optimizer: Optimizer, enc_opt: Optional[Optimizer]) -> None: 54 | enc_params = [] 55 | if enc_opt is not None: 56 | enc_params = enc_opt.param_groups 57 | super().__init__(optimizer.param_groups + enc_params, {}) 58 | self.optimizer = optimizer 59 | self.enc_opt = enc_opt 60 | 61 | def zero_grad(self) -> None: 62 | self.optimizer.zero_grad() 63 | if self.enc_opt is not None: 64 | self.enc_opt.zero_grad() 65 | 66 | def step(self, closure: Optional[Any] = None) -> None: 67 | self.optimizer.step(closure) 68 | if self.enc_opt is not None: 69 | self.enc_opt.step(closure) 70 | 71 | 72 | def build_mlp(in_dim: int, 73 | h_dim: Union[int, List], 74 | out_dim: int = None, 75 | dropout_p: float = 0.2) -> nn.Sequential: 76 | """Builds an MLP. 77 | Parameters 78 | ---------- 79 | in_dim: int, 80 | Input dimension of the MLP 81 | h_dim: int, 82 | Hidden layer dimension of the MLP 83 | out_dim: int, default None 84 | Output size of the MLP. If None, a Linear layer is returned, with ReLU 85 | dropout_p: float, default 0.2, 86 | Dropout probability 87 | """ 88 | if isinstance(h_dim, int): 89 | h_dim = [h_dim] 90 | 91 | sizes = [in_dim] + h_dim 92 | mlp_size_tuple = list(zip(*(sizes[:-1], sizes[1:]))) 93 | 94 | if isinstance(dropout_p, float): 95 | dropout_p = [dropout_p] * len(mlp_size_tuple) 96 | 97 | layers = [] 98 | 99 | for idx, (prev_size, next_size) in enumerate(mlp_size_tuple): 100 | layers.append(nn.Linear(prev_size, next_size)) 101 | layers.append(nn.ReLU()) 102 | layers.append(nn.Dropout(dropout_p[idx])) 103 | 104 | if out_dim is not None: 105 | layers.append(nn.Linear(sizes[-1], out_dim)) 106 | 107 | return nn.Sequential(*layers) 108 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='seq_graph_retro', 5 | version='1.0', 6 | description='Sequential graph edit model for retrosynthesis.', 7 | packages=find_packages(exclude=[]), 8 | python_requires='>=3.5', 9 | ) 10 | --------------------------------------------------------------------------------