├── .gitignore ├── README.md ├── assets └── textreact.png ├── environment.yml ├── main.py ├── preprocess ├── dedup_corpus.py ├── gen_uspto.py ├── get_templates.py ├── preprocess_retrosynthesis.py ├── raw_retro_year_split.py ├── reagent_Ionic_compound.txt ├── reagent_unknown.txt ├── retro_year_split.py ├── template_extraction │ ├── template_extract_utils.py │ └── template_extractor.py └── uspto_script │ ├── 1.get_condition_from_uspto.py │ ├── 2.0.clean_up_rxn_condition.py │ ├── 2.0.clean_up_rxn_condition.sh │ ├── 2.1.merge_clean_up_rxn_conditon.py │ ├── 3.0.split_condition_and_slect.py │ ├── 4.0.split_train_val_test.py │ ├── 5.0.convert_context_tokens.py │ ├── condition_classfication.ipynb │ ├── extract_nosmiles.py │ ├── gen_grant_corpus.py │ ├── get_aug_condition_data.py │ ├── get_dataset_for_condition.py │ ├── get_dummy_model_results.py │ ├── get_fragment_from_rxn_dataset.py │ ├── merge_comp.py │ ├── uspto_condition.md │ └── utils.py ├── retrieve ├── condition_year.sh ├── convert_format.py ├── retrieve.py ├── retrieve_faiss.py ├── retro.sh └── retro_year.sh ├── scripts ├── train_RCR.sh ├── train_RCR_TS.sh ├── train_RetroSyn_tb.sh ├── train_RetroSyn_tb_TS.sh ├── train_RetroSyn_tf.sh └── train_RetroSyn_tf_TS.sh └── textreact ├── __init__.py ├── configs └── bert_l6.json ├── dataset.py ├── evaluate.py ├── model.py ├── template_decoder.py ├── tokenizer.py ├── utils.py └── vocab ├── vocab_condition.txt └── vocab_smiles.txt /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | scripts_old/* 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TextReact 2 | 3 | This repository contains the code for [TextReact](https://aclanthology.org/2023.emnlp-main.784/), a novel method that directly augments 4 | predictive chemistry with text retrieval. 5 | 6 | ![](assets/textreact.png) 7 | 8 | ``` 9 | @inproceedings{TextReact, 10 | author = {Yujie Qian and 11 | Zhening Li and 12 | Zhengkai Tu and 13 | Connor W. Coley and 14 | Regina Barzilay}, 15 | title = {Predictive Chemistry Augmented with Text Retrieval}, 16 | booktitle = {Proceedings of the 2023 Conference on Empirical Methods in Natural 17 | Language Processing, {EMNLP} 2023, Singapore, December 6-10, 2023}, 18 | pages = {12731--12745}, 19 | publisher = {Association for Computational Linguistics}, 20 | year = {2023}, 21 | url = {https://aclanthology.org/2023.emnlp-main.784} 22 | } 23 | ``` 24 | 25 | ## Requirements 26 | We implement the code with `torch==1.11.0`, `pytorch-lightning==2.0.0`, and `transformers==4.27.3`. 27 | To reproduce our experiments, we recommend creating a conda environment with the same dependencies: 28 | ```bash 29 | conda env create -f environment.yml -n textreact 30 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 31 | ``` 32 | 33 | ## Data 34 | 35 | Run the following commands to download and unzip the preprocessed datasets: 36 | ``` 37 | git clone https://huggingface.co/datasets/yujieq/TextReact data 38 | cd data 39 | unzip '*' 40 | ``` 41 | 42 | ## Training Scripts 43 | 44 | TextReact consists of two modules: SMILES-To-text retriever and 45 | text-augmented predictor. This repository only contains the code for 46 | training the predictor, while the code for the retriever is available in 47 | a separate repository: https://github.com/thomas0809/tevatron. 48 | 49 | The training scripts are located under [`scripts`](scripts): 50 | * [`train_RCR.sh`](scripts/train_RCR.sh) trains a model for reaction condition recommendation (RCR) 51 | on the random split of the USPTO dataset. 52 | * [`train_RetroSyn_tf.sh`](scripts/train_RetroSyn_tf.sh) trains a template-free model for retrosynthesis 53 | on the random split of the USPTO-50K dataset. 54 | * [`train_RetroSyn_tb.sh`](scripts/train_RetroSyn_tb.sh) trains a template-based model for retrosynthesis 55 | on the random split of the USPTO-50K dataset. 56 | In addition, [`train_RCR_TS.sh`](scripts/train_RCR_TS.sh), [`train_RetroSyn_tf_TS.sh`](scripts/train_RetroSyn_tf_TS.sh) 57 | and [`train_RetroSyn_tb_TS.sh`](scripts/train_RetroSyn_tb_TS.sh) train the corresponding models 58 | on the time-based split of the dataset. 59 | 60 | If you're working on a distributed file system, it is recommended to 61 | add to the script a `--cache_path` option specifying a local path to reduce network time. 62 | 63 | To run the script `scripts/train_MODEL.sh`, run the following command at the root of the folder: 64 | ``` 65 | bash scripts/train_MODEL.sh 66 | ``` 67 | 68 | At the end of training, two dictionaries are printed with the top-k test accuracies. 69 | The first one corresponds to retrieving from the full corpus 70 | and the second one corresponds to retrieving from the gold-removed corpus. 71 | 72 | Models and test predictions are stored under the path specified by the `SAVE_PATH` variable in the script. 73 | * `best.ckpt` is the checkpoint with the highest validation accuracy so far, whereas 74 | * `last.ckpt` is the last checkpoint. 75 | * `prediction_test_0.json` contains the test predictions when retrieving from the full corpus. 76 | * `prediction_test_1.json` contains the predictions when retrieving from the gold-removed corpus. 77 | -------------------------------------------------------------------------------- /assets/textreact.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomas0809/textreact/feb62d868a627d293997a56a72f50420377c59b4/assets/textreact.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: textreact 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=2_gnu 8 | - bzip2=1.0.8=h7f98852_4 9 | - ca-certificates=2022.12.7=ha878542_0 10 | - ld_impl_linux-64=2.40=h41732ed_0 11 | - libffi=3.4.2=h7f98852_5 12 | - libgcc-ng=12.2.0=h65d4601_19 13 | - libgomp=12.2.0=h65d4601_19 14 | - libnsl=2.0.0=h7f98852_0 15 | - libsqlite=3.40.0=h753d276_0 16 | - libuuid=2.32.1=h7f98852_1000 17 | - libzlib=1.2.13=h166bdaf_4 18 | - ncurses=6.3=h27087fc_1 19 | - openssl=3.1.0=h0b41bf4_0 20 | - pip=23.0.1=pyhd8ed1ab_0 21 | - python=3.8.16=he550d4f_1_cpython 22 | - readline=8.2=h8228510_1 23 | - setuptools=67.6.0=pyhd8ed1ab_0 24 | - tk=8.6.12=h27826a3_0 25 | - wheel=0.40.0=pyhd8ed1ab_0 26 | - xz=5.2.6=h166bdaf_0 27 | - pip: 28 | - aiohttp==3.8.4 29 | - aiosignal==1.3.1 30 | - anyio==3.6.2 31 | - appdirs==1.4.4 32 | - arrow==1.2.3 33 | - asttokens==2.2.1 34 | - async-timeout==4.0.2 35 | - attrs==22.2.0 36 | - backcall==0.2.0 37 | - beautifulsoup4==4.12.0 38 | - blessed==1.20.0 39 | - certifi==2022.12.7 40 | - charset-normalizer==3.1.0 41 | - click==8.1.3 42 | - croniter==1.3.8 43 | - dateutils==0.6.12 44 | - decorator==5.1.1 45 | - deepdiff==6.3.0 46 | - dnspython==2.3.0 47 | - docker-pycreds==0.4.0 48 | - email-validator==1.3.1 49 | - executing==1.2.0 50 | - fastapi==0.88.0 51 | - filelock==3.10.7 52 | - frozenlist==1.3.3 53 | - fsspec==2023.3.0 54 | - gitdb==4.0.10 55 | - gitpython==3.1.31 56 | - h11==0.14.0 57 | - httpcore==0.16.3 58 | - httptools==0.5.0 59 | - httpx==0.23.3 60 | - huggingface-hub==0.13.3 61 | - idna==3.4 62 | - inquirer==3.1.3 63 | - ipython==8.12.2 64 | - itsdangerous==2.1.2 65 | - jedi==0.18.2 66 | - jinja2==3.1.2 67 | - lightning==2.0.0 68 | - lightning-cloud==0.5.32 69 | - lightning-utilities==0.8.0 70 | - markdown-it-py==2.2.0 71 | - markupsafe==2.1.2 72 | - matplotlib-inline==0.1.6 73 | - mdurl==0.1.2 74 | - multidict==6.0.4 75 | - numpy==1.24.2 76 | - ordered-set==4.1.0 77 | - orjson==3.8.8 78 | - packaging==23.0 79 | - pandas==1.5.3 80 | - parso==0.8.3 81 | - pathtools==0.1.2 82 | - pexpect==4.8.0 83 | - pickleshare==0.7.5 84 | - pillow==9.4.0 85 | - prompt-toolkit==3.0.38 86 | - protobuf==4.22.1 87 | - psutil==5.9.4 88 | - ptyprocess==0.7.0 89 | - pure-eval==0.2.2 90 | - pydantic==1.10.7 91 | - pygments==2.14.0 92 | - pyjwt==2.6.0 93 | - python-dateutil==2.8.2 94 | - python-dotenv==1.0.0 95 | - python-editor==1.0.4 96 | - python-multipart==0.0.6 97 | - pytorch-lightning==2.0.0 98 | - pytz==2023.2 99 | - pyyaml==6.0 100 | - rdkit-pypi==2022.9.5 101 | - readchar==4.0.5 102 | - regex==2023.3.23 103 | - requests==2.28.2 104 | - rfc3986==1.5.0 105 | - rich==13.3.3 106 | - sentry-sdk==1.18.0 107 | - setproctitle==1.3.2 108 | - six==1.16.0 109 | - smmap==5.0.0 110 | - sniffio==1.3.0 111 | - soupsieve==2.4 112 | - stack-data==0.6.2 113 | - starlette==0.22.0 114 | - starsessions==1.3.0 115 | - tokenizers==0.13.2 116 | - tqdm==4.65.0 117 | - traitlets==5.9.0 118 | - transformers==4.27.3 119 | - typing-extensions==4.5.0 120 | - ujson==5.7.0 121 | - urllib3==1.26.15 122 | - uvicorn==0.21.1 123 | - uvloop==0.17.0 124 | - wandb==0.14.0 125 | - watchfiles==0.19.0 126 | - wcwidth==0.2.6 127 | - websocket-client==1.5.1 128 | - websockets==10.4 129 | - yarl==1.8.2 130 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import copy 5 | import random 6 | import argparse 7 | import collections 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.distributed as dist 14 | import pytorch_lightning as pl 15 | from pytorch_lightning import LightningModule, LightningDataModule 16 | from pytorch_lightning.strategies.ddp import DDPStrategy 17 | from transformers import get_scheduler, EncoderDecoderModel, EncoderDecoderConfig, AutoTokenizer, AutoConfig, AutoModel 18 | 19 | from textreact.tokenizer import get_tokenizers 20 | from textreact.model import get_model, get_mlm_head 21 | from textreact.dataset import ReactionConditionDataset, RetrosynthesisDataset, read_corpus, generate_train_label_corpus 22 | from textreact.evaluate import evaluate_reaction_condition, evaluate_retrosynthesis 23 | import textreact.utils as utils 24 | 25 | 26 | def get_args(notebook=False): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--task', type=str, default='condition') 29 | parser.add_argument('--do_train', action='store_true') 30 | parser.add_argument('--do_valid', action='store_true') 31 | parser.add_argument('--do_test', action='store_true') 32 | parser.add_argument('--precision', type=str, default='32') 33 | parser.add_argument('--seed', type=int, default=42) 34 | parser.add_argument('--gpus', type=int, default=1) 35 | parser.add_argument('--print_freq', type=int, default=200) 36 | parser.add_argument('--debug', action='store_true') 37 | # Model 38 | parser.add_argument('--template_based', action='store_true') 39 | parser.add_argument('--unattend_nonbonds', action='store_true') 40 | parser.add_argument('--encoder', type=str, default=None) 41 | parser.add_argument('--decoder', type=str, default=None) 42 | parser.add_argument('--encoder_pretrained', action='store_true') 43 | parser.add_argument('--decoder_pretrained', action='store_true') 44 | parser.add_argument('--share_embedding', action='store_true') 45 | parser.add_argument('--encoder_tokenizer', type=str, default='text') 46 | # Data 47 | parser.add_argument('--data_path', type=str, default=None) 48 | parser.add_argument('--template_path', type=str, default=None) 49 | parser.add_argument('--train_file', type=str, default=None) 50 | parser.add_argument('--valid_file', type=str, default=None) 51 | parser.add_argument('--test_file', type=str, default=None) 52 | parser.add_argument('--vocab_file', type=str, default=None) 53 | parser.add_argument('--corpus_file', type=str, default=None) 54 | parser.add_argument('--train_label_corpus', action='store_true') 55 | parser.add_argument('--cache_path', type=str, default=None) 56 | parser.add_argument('--nn_path', type=str, default=None) 57 | parser.add_argument('--train_nn_file', type=str, default=None) 58 | parser.add_argument('--valid_nn_file', type=str, default=None) 59 | parser.add_argument('--test_nn_file', type=str, default=None) 60 | parser.add_argument('--max_length', type=int, default=128) 61 | parser.add_argument('--max_dec_length', type=int, default=128) 62 | parser.add_argument('--num_workers', type=int, default=8) 63 | parser.add_argument('--shuffle_smiles', action='store_true') 64 | parser.add_argument('--no_smiles', action='store_true') 65 | parser.add_argument('--num_neighbors', type=int, default=-1) 66 | parser.add_argument('--use_gold_neighbor', action='store_true') 67 | parser.add_argument('--max_num_neighbors', type=int, default=10) 68 | parser.add_argument('--random_neighbor_ratio', type=float, default=0.8) 69 | parser.add_argument('--mlm', action='store_true') 70 | parser.add_argument('--mlm_ratio', type=float, default=0.15) 71 | parser.add_argument('--mlm_layer', type=str, default='linear') 72 | parser.add_argument('--mlm_lambda', type=float, default=1) 73 | # Training 74 | parser.add_argument('--epochs', type=int, default=8) 75 | parser.add_argument('--batch_size', type=int, default=256) 76 | parser.add_argument('--lr', type=float, default=1e-4) 77 | parser.add_argument('--weight_decay', type=float, default=0.01) 78 | parser.add_argument('--max_grad_norm', type=float, default=5.) 79 | parser.add_argument('--scheduler', type=str, choices=['cosine', 'constant'], default='cosine') 80 | parser.add_argument('--warmup_ratio', type=float, default=0) 81 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1) 82 | parser.add_argument('--load_ckpt', type=str, default='best.ckpt') 83 | parser.add_argument('--eval_per_epoch', type=int, default=1) 84 | parser.add_argument('--val_metric', type=str, default='val_acc') 85 | parser.add_argument('--save_path', type=str, default='output/') 86 | parser.add_argument('--overwrite', action='store_true') 87 | parser.add_argument('--num_train_example', type=int, default=None) 88 | parser.add_argument('--label_smoothing', type=float, default=0.0) 89 | # Inference 90 | parser.add_argument('--test_batch_size', type=int, default=64) 91 | parser.add_argument('--num_beams', type=int, default=1) 92 | parser.add_argument('--test_each_neighbor', action='store_true') 93 | parser.add_argument('--test_num_neighbors', type=int, default=1) 94 | 95 | args = parser.parse_args([]) if notebook else parser.parse_args() 96 | 97 | return args 98 | 99 | 100 | class ReactionConditionRecommender(LightningModule): 101 | 102 | def __init__(self, args): 103 | super().__init__() 104 | self.args = args 105 | self.enc_tokenizer, self.dec_tokenizer = get_tokenizers(args) 106 | self.model = get_model(args, self.enc_tokenizer, self.dec_tokenizer) 107 | if args.mlm: 108 | self.mlm_head = get_mlm_head(args, self.model) 109 | self.validation_outputs = collections.defaultdict(dict) 110 | self.test_outputs = collections.defaultdict(dict) 111 | 112 | def compute_loss(self, logits, batch, reduction='mean'): 113 | if self.args.template_based: 114 | atom_logits, bond_logits = logits 115 | batch_size, max_len, atom_vocab_size = atom_logits.size() 116 | bond_vocab_size = bond_logits.size()[-1] 117 | atom_template_loss = F.cross_entropy(input=atom_logits.reshape(-1, atom_vocab_size), 118 | target=batch['decoder_atom_template_labels'].reshape(-1), 119 | reduction=reduction) 120 | bond_template_loss = F.cross_entropy(input=bond_logits.reshape(-1, bond_vocab_size), 121 | target=batch['decoder_bond_template_labels'].reshape(-1), 122 | reduction=reduction) 123 | if reduction == 'none': 124 | atom_template_loss = atom_template_loss.view(batch_size, -1).mean(dim=1) 125 | bond_template_loss = bond_template_loss.view(batch_size, -1).mean(dim=1) 126 | loss = atom_template_loss + bond_template_loss 127 | else: 128 | batch_size, max_len, vocab_size = logits.size() 129 | labels = batch['decoder_input_ids'][:, 1:] 130 | loss = F.cross_entropy(input=logits[:, :-1].reshape(-1, vocab_size), target=labels.reshape(-1), 131 | ignore_index=self.dec_tokenizer.pad_token_id, reduction=reduction) 132 | if reduction == 'none': 133 | loss = loss.view(batch_size, -1).mean(dim=1) 134 | return loss 135 | 136 | def compute_acc(self, logits, batch, reduction='mean'): 137 | # This accuracy is equivalent to greedy search accuracy 138 | if self.args.template_based: 139 | atom_logits_batch, bond_logits_batch = logits 140 | atom_probs_batch = F.softmax(atom_logits_batch, dim=-1) 141 | bond_probs_batch = F.softmax(bond_logits_batch, dim=-1) 142 | atom_probs_batch[batch['decoder_atom_template_labels'] == -100] = 0 143 | bond_probs_batch[batch['decoder_bond_template_labels'] == -100] = 0 144 | acc = [] 145 | for atom_probs, bond_probs, bonds, raw_template_labels in zip( 146 | atom_probs_batch, bond_probs_batch, batch['bonds'], batch['decoder_raw_template_labels']): 147 | edit_pred = utils.combined_edit(atom_probs, bond_probs, bonds, 1)[0][0] 148 | acc.append(float(edit_pred in raw_template_labels) / max(len(raw_template_labels), 1)) 149 | acc = torch.tensor(acc) 150 | else: 151 | preds = logits.argmax(dim=-1)[:, :-1] 152 | labels = batch['decoder_input_ids'][:, 1:] 153 | acc = torch.logical_or(preds.eq(labels), labels.eq(self.dec_tokenizer.pad_token_id)).all(dim=-1) 154 | if reduction == 'mean': 155 | acc = acc.mean() 156 | return acc 157 | 158 | def compute_mlm_loss(self, encoder_last_hidden_state, labels): 159 | batch_size, trunc_len = labels.size() 160 | trunc_hidden_state = encoder_last_hidden_state[:, :trunc_len].contiguous() 161 | logits = self.mlm_head(trunc_hidden_state) 162 | return F.cross_entropy(input=logits.view(batch_size * trunc_len, -1), target=labels.view(-1)) 163 | 164 | def training_step(self, batch, batch_idx): 165 | indices, batch_in, batch_out = batch 166 | output = self.model(**batch_in) 167 | loss = self.compute_loss(output.logits, batch_in) 168 | self.log('train_loss', loss) 169 | total_loss = loss 170 | if self.args.mlm: 171 | mlm_loss = self.compute_mlm_loss(output.encoder_last_hidden_state, batch_out['mlm_labels']) 172 | total_loss += mlm_loss * self.args.mlm_lambda 173 | self.log('mlm_loss', mlm_loss) 174 | self.log('total_loss', total_loss) 175 | return total_loss 176 | 177 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 178 | indices, batch_in, batch_out = batch 179 | output = self.model(**batch_in) 180 | if self.args.val_metric == 'val_loss': 181 | scores = self.compute_loss(output.logits, batch_in, reduction='none').tolist() 182 | elif self.args.val_metric == 'val_acc': 183 | scores = self.compute_acc(output.logits, batch_in, reduction='none').tolist() 184 | else: 185 | raise ValueError 186 | for idx, score in zip(indices, scores): 187 | self.validation_outputs[dataloader_idx][idx] = score 188 | return output 189 | 190 | def on_validation_epoch_end(self): 191 | for dataloader_idx in self.validation_outputs: 192 | validation_outputs = self.gather_outputs(self.validation_outputs[dataloader_idx]) 193 | val_score = np.mean([v for v in validation_outputs.values()]) 194 | metric_name = self.args.val_metric if dataloader_idx == 0 else f'{self.args.val_metric}/{dataloader_idx}' 195 | self.log(metric_name, val_score, prog_bar=True, rank_zero_only=True) 196 | self.validation_outputs.clear() 197 | 198 | def test_step(self, batch, batch_idx, dataloader_idx=0): 199 | indices, batch_in, batch_out = batch 200 | num_beams = self.args.num_beams 201 | if self.args.template_based: 202 | atom_logits_batch, bond_logits_batch = self.model(**batch_in).logits 203 | atom_probs_batch = F.softmax(atom_logits_batch, dim=-1) 204 | bond_probs_batch = F.softmax(bond_logits_batch, dim=-1) 205 | atom_probs_batch[batch_in['decoder_atom_template_labels'] == -100] = 0 206 | bond_probs_batch[batch_in['decoder_bond_template_labels'] == -100] = 0 207 | acc = [] 208 | for idx, atom_probs, bond_probs, bonds, raw_template_labels in zip( 209 | indices, atom_probs_batch, bond_probs_batch, batch_in['bonds'], batch_in['decoder_raw_template_labels']): 210 | edit_pred, edit_prob = utils.combined_edit(atom_probs, bond_probs, bonds, top_num=500) 211 | self.test_outputs[dataloader_idx][idx] = { 212 | 'prediction': edit_pred, 213 | 'score': edit_prob, 214 | 'raw_template_labels': raw_template_labels, 215 | 'top1_template_match': edit_pred[0] in raw_template_labels 216 | } 217 | else: 218 | output = self.model.generate( 219 | **batch_in, num_beams=num_beams, num_return_sequences=num_beams, 220 | max_length=self.args.max_dec_length, length_penalty=0, 221 | bos_token_id=self.dec_tokenizer.bos_token_id, eos_token_id=self.dec_tokenizer.eos_token_id, 222 | pad_token_id=self.dec_tokenizer.pad_token_id, 223 | return_dict_in_generate=True, output_scores=True) 224 | predictions = self.dec_tokenizer.batch_decode(output.sequences, skip_special_tokens=True) 225 | if 'sequences_scores' in predictions: 226 | scores = output.sequences_scores.tolist() 227 | else: 228 | scores = [0] * len(predictions) 229 | for i, idx in enumerate(indices): 230 | self.test_outputs[dataloader_idx][idx] = { 231 | 'prediction': predictions[i * num_beams: (i + 1) * num_beams], 232 | 'score': scores[i * num_beams: (i + 1) * num_beams] 233 | } 234 | return 235 | 236 | def on_test_epoch_end(self): 237 | for dataloader_idx in self.test_outputs: 238 | test_outputs = self.gather_outputs(self.test_outputs[dataloader_idx]) 239 | if self.args.test_each_neighbor: 240 | test_outputs = utils.gather_prediction_each_neighbor(test_outputs, self.args.test_num_neighbors) 241 | if self.trainer.is_global_zero: 242 | # Save prediction 243 | with open(os.path.join(self.args.save_path, 244 | f'prediction_{self.test_dataset.name}_{dataloader_idx}.json'), 'w') as f: 245 | json.dump(test_outputs, f) 246 | # Evaluate 247 | if self.args.task == 'condition': 248 | accuracy = evaluate_reaction_condition(test_outputs, self.test_dataset.data_df) 249 | elif self.args.task == 'retro': 250 | accuracy = evaluate_retrosynthesis(test_outputs, self.test_dataset.data_df, self.args.num_beams, 251 | template_based=self.args.template_based, 252 | template_path=self.args.template_path) 253 | else: 254 | accuracy = [] 255 | self.print(self.ckpt_path) 256 | self.print(json.dumps(accuracy)) 257 | self.test_outputs.clear() 258 | 259 | def gather_outputs(self, outputs): 260 | if self.trainer.num_devices > 1: 261 | gathered = [{} for i in range(self.trainer.num_devices)] 262 | dist.all_gather_object(gathered, outputs) 263 | gathered_outputs = {} 264 | for outputs in gathered: 265 | gathered_outputs.update(outputs) 266 | else: 267 | gathered_outputs = outputs 268 | return gathered_outputs 269 | 270 | def configure_optimizers(self): 271 | num_training_steps = self.trainer.num_training_steps 272 | self.print(f'Num training steps: {num_training_steps}') 273 | num_warmup_steps = int(num_training_steps * self.args.warmup_ratio) 274 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) 275 | scheduler = get_scheduler(self.args.scheduler, optimizer, num_warmup_steps, num_training_steps) 276 | return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}} 277 | 278 | 279 | class ReactionConditionDataModule(LightningDataModule): 280 | 281 | DATASET_CLS = { 282 | 'condition': ReactionConditionDataset, 283 | 'retro': RetrosynthesisDataset, 284 | } 285 | 286 | def __init__(self, args, model): 287 | super().__init__() 288 | self.args = args 289 | self.enc_tokenizer = model.enc_tokenizer 290 | self.dec_tokenizer = model.dec_tokenizer 291 | self.train_dataset, self.val_dataset, self.test_dataset = None, None, None 292 | 293 | def prepare_data(self): 294 | args = self.args 295 | dataset_cls = self.DATASET_CLS[args.task] 296 | if args.do_train: 297 | data_file = os.path.join(args.data_path, args.train_file) 298 | self.train_dataset = dataset_cls( 299 | args, data_file, self.enc_tokenizer, self.dec_tokenizer, split='train') 300 | print(f'Train dataset: {len(self.train_dataset)}') 301 | if args.do_train or args.do_valid: 302 | data_file = os.path.join(args.data_path, args.valid_file) 303 | self.val_dataset = dataset_cls( 304 | args, data_file, self.enc_tokenizer, self.dec_tokenizer, split='val') 305 | print(f'Valid dataset: {len(self.val_dataset)}') 306 | if args.do_test: 307 | data_file = os.path.join(args.data_path, args.test_file) 308 | self.test_dataset = dataset_cls( 309 | args, data_file, self.enc_tokenizer, self.dec_tokenizer, split='test') 310 | print(f'Test dataset: {len(self.test_dataset)}') 311 | if args.corpus_file: 312 | if args.train_label_corpus: 313 | assert args.task == 'condition' 314 | corpus = generate_train_label_corpus(os.path.join(args.data_path, args.train_file)) 315 | else: 316 | corpus = read_corpus(args.corpus_file, args.cache_path) 317 | if self.train_dataset is not None: 318 | self.train_dataset.load_corpus(corpus, os.path.join(args.nn_path, args.train_nn_file)) 319 | self.train_dataset.print_example() 320 | if self.val_dataset is not None: 321 | self.val_dataset.load_corpus(corpus, os.path.join(args.nn_path, args.valid_nn_file)) 322 | if self.test_dataset is not None: 323 | self.test_dataset.load_corpus(corpus, os.path.join(args.nn_path, args.test_nn_file)) 324 | 325 | def train_dataloader(self): 326 | return torch.utils.data.DataLoader( 327 | self.train_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers, 328 | collate_fn=self.train_dataset.collator) 329 | 330 | def get_eval_dataloaders(self, dataset): 331 | args = self.args 332 | dataloader = torch.utils.data.DataLoader( 333 | dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=dataset.collator) 334 | if args.corpus_file is None: 335 | return dataloader 336 | dataset_skip_gold = copy.copy(dataset) 337 | dataset_skip_gold.skip_gold_neighbor = True 338 | dataloader_skip_gold = torch.utils.data.DataLoader( 339 | dataset_skip_gold, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=dataset.collator) 340 | return [dataloader, dataloader_skip_gold] 341 | 342 | def val_dataloader(self): 343 | return self.get_eval_dataloaders(self.val_dataset) 344 | 345 | def test_dataloader(self): 346 | return self.get_eval_dataloaders(self.test_dataset) 347 | 348 | 349 | def main(): 350 | args = get_args() 351 | pl.seed_everything(args.seed, workers=True) 352 | 353 | model = ReactionConditionRecommender(args) 354 | 355 | dm = ReactionConditionDataModule(args, model) 356 | dm.prepare_data() 357 | 358 | checkpoint = pl.callbacks.ModelCheckpoint( 359 | monitor=args.val_metric, mode=utils.metric_to_mode[args.val_metric], save_top_k=1, filename='best', 360 | save_last=True, dirpath=args.save_path, auto_insert_metric_name=False) 361 | lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step') 362 | if args.do_train and not args.debug: 363 | project_name = 'TextReact' 364 | if args.task == 'retro': 365 | project_name += '_retro' 366 | logger = pl.loggers.WandbLogger( 367 | project=project_name, save_dir=args.save_path, name=os.path.basename(args.save_path)) 368 | else: 369 | logger = None 370 | 371 | trainer = pl.Trainer( 372 | strategy=DDPStrategy(find_unused_parameters=True), 373 | accelerator='gpu', 374 | devices=args.gpus, 375 | precision=args.precision, 376 | logger=logger, 377 | default_root_dir=args.save_path, 378 | callbacks=[checkpoint, lr_monitor], 379 | max_epochs=args.epochs, 380 | gradient_clip_val=args.max_grad_norm, 381 | accumulate_grad_batches=args.gradient_accumulation_steps, 382 | check_val_every_n_epoch=args.eval_per_epoch, 383 | log_every_n_steps=10, 384 | deterministic=True) 385 | 386 | if args.do_train: 387 | trainer.num_training_steps = math.ceil( 388 | len(dm.train_dataset) / (args.batch_size * args.gpus * args.gradient_accumulation_steps)) * args.epochs 389 | # Load or delete existing checkpoint 390 | if args.overwrite: 391 | utils.clear_path(args.save_path, trainer) 392 | ckpt_path = None 393 | else: 394 | ckpt_path = os.path.join(args.save_path, args.load_ckpt) 395 | ckpt_path = ckpt_path if checkpoint.file_exists(ckpt_path, trainer) else None 396 | # Train 397 | trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path) 398 | best_model_path = checkpoint.best_model_path 399 | else: 400 | best_model_path = os.path.join(args.save_path, args.load_ckpt) 401 | 402 | if args.do_valid or args.do_test: 403 | print('Load model checkpoint:', best_model_path) 404 | model = ReactionConditionRecommender.load_from_checkpoint(best_model_path, strict=False, args=args) 405 | model.ckpt_path = best_model_path 406 | 407 | if args.do_valid: 408 | trainer.validate(model, datamodule=dm) 409 | 410 | if args.do_test: 411 | model.test_dataset = dm.test_dataset 412 | trainer.test(model, datamodule=dm) 413 | 414 | 415 | if __name__ == "__main__": 416 | main() 417 | -------------------------------------------------------------------------------- /preprocess/dedup_corpus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | 5 | # Dedup 6 | 7 | # corpus_df = pd.read_csv('uspto_script/uspto_rxn_corpus.csv') 8 | # text_to_corpus_id = {} 9 | # id_to_corpus_id = {} 10 | # dedup_flags = [False] * len(corpus_df) 11 | # for i, (idx, text) in enumerate(zip(corpus_df['id'], corpus_df['paragraph_text'])): 12 | # if text not in text_to_corpus_id: 13 | # text_to_corpus_id[text] = idx 14 | # dedup_flags[i] = True 15 | # id_to_corpus_id[idx] = text_to_corpus_id[text] 16 | # 17 | # dedup_df = corpus_df[dedup_flags] 18 | # dedup_df.to_csv('uspto_script/uspto_rxn_corpus_dedup.csv', index=False) 19 | # 20 | # with open('id_to_corpus_id.json', 'w') as f: 21 | # json.dump(id_to_corpus_id, f) 22 | 23 | 24 | # Add corpus id 25 | with open('id_to_corpus_id.json') as f: 26 | id_to_corpus_id = json.load(f) 27 | 28 | # data_path = '../data/USPTO_condition/' 29 | # data_path = '../data/USPTO_condition_year/' 30 | # for file in ['USPTO_condition_train.csv', 'USPTO_condition_val.csv', 'USPTO_condition_test.csv']: 31 | 32 | # data_path = '../data/USPTO_50K/matched1/' 33 | data_path = '../data/USPTO_50K_year/' 34 | for file in ['train.csv', 'valid.csv', 'test.csv']: 35 | path = os.path.join(data_path, file) 36 | print(path) 37 | df = pd.read_csv(path) 38 | corpus_id = [id_to_corpus_id.get(idx, idx) for idx in df['id']] 39 | df['corpus_id'] = corpus_id 40 | cols = ['id', 'corpus_id'] 41 | for col in df.columns: 42 | if col not in cols: 43 | cols.append(col) 44 | df = df[cols] 45 | df.to_csv(path, index=False) 46 | -------------------------------------------------------------------------------- /preprocess/gen_uspto.py: -------------------------------------------------------------------------------- 1 | import urllib.request 2 | import zipfile 3 | import os 4 | import re 5 | import sys 6 | import glob 7 | import shutil 8 | import multiprocessing 9 | import pandas as pd 10 | from tqdm import tqdm 11 | from collections import Counter 12 | import json 13 | from json import encoder 14 | encoder.FLOAT_REPR = lambda o: format(o, '.3f') 15 | import numpy as np 16 | 17 | 18 | BASE = '/scratch/yujieq/uspto_grant_red/' 19 | BASE_TXT = '/scratch/yujieq/uspto_grant_fulltext/' 20 | BASE_TXT_ZIP = '/scratch/yujieq/uspto_grant_fulltext_zip/' 21 | 22 | 23 | # Download data 24 | def _download_file(url, output): 25 | if not os.path.exists(output): 26 | urllib.request.urlretrieve(url, output) 27 | 28 | 29 | def download(): 30 | for year in range(2016, 2017): 31 | url = f"https://bulkdata.uspto.gov/data/patent/grant/redbook/{year}/" 32 | f = urllib.request.urlopen(url) 33 | content = f.read().decode('utf-8') 34 | print(url) 35 | zip_files = re.findall(r"href=\"(I*\d\d\d\d\d\d\d\d(.ZIP|.zip|.tar))\"", content) 36 | print(zip_files) 37 | path = os.path.join(BASE, str(year)) 38 | os.makedirs(path, exist_ok=True) 39 | args = [] 40 | for file, ext in zip_files: 41 | output = os.path.join(path, file) 42 | args.append((url + file, output)) 43 | # with multiprocessing.Pool(8) as p: 44 | # p.starmap(_download_file, args) 45 | for url, output in args: 46 | print(url) 47 | _download_file(url, output) 48 | 49 | 50 | def download_fulltext(): 51 | for year in range(2002, 2017): 52 | url = f'https://bulkdata.uspto.gov/data/patent/grant/redbook/fulltext/{year}/' 53 | f = urllib.request.urlopen(url) 54 | content = f.read().decode('utf-8') 55 | print(url) 56 | zip_files = re.findall(r"href=\"(\w*(.ZIP|.zip|.tar))\"", content) 57 | print(zip_files) 58 | path = os.path.join(BASE_TXT, str(year)) 59 | os.makedirs(path, exist_ok=True) 60 | args = [] 61 | for file, ext in zip_files: 62 | output = os.path.join(path, file) 63 | args.append((url + file, output)) 64 | # with multiprocessing.Pool(8) as p: 65 | # p.starmap(_download_file, args) 66 | for url, output in args: 67 | print(url) 68 | _download_file(url, output) 69 | 70 | 71 | # Unzip 72 | def is_zip(file): 73 | return file[-4:] in ['.zip', '.ZIP'] 74 | 75 | 76 | def unzip(): 77 | for year in range(1976, 2017): 78 | path = os.path.join(BASE_TXT_ZIP, str(year)) 79 | outpath = os.path.join(BASE_TXT, str(year)) 80 | for datefile in sorted(os.listdir(path)): 81 | if is_zip(datefile): 82 | print(os.path.join(path, datefile)) 83 | with zipfile.ZipFile(os.path.join(path, datefile), 'r') as zipobj: 84 | zipobj.extractall(outpath) 85 | 86 | 87 | if __name__ == "__main__": 88 | if sys.argv[1] == 'download': 89 | download() 90 | elif sys.argv[1] == 'download_fulltext': 91 | download_fulltext() 92 | elif sys.argv[1] == 'unzip': 93 | unzip() 94 | -------------------------------------------------------------------------------- /preprocess/get_templates.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import logging 3 | from copy import deepcopy 4 | import os 5 | import pandas as pd 6 | import re 7 | from abc import ABC, abstractmethod 8 | from collections import defaultdict 9 | # from models.localretro_model.Extract_from_train_data import get_full_template 10 | # from models.localretro_model.LocalTemplate.template_extractor import extract_from_reaction 11 | # from models.localretro_model.Run_preprocessing import get_edit_site_retro 12 | from typing import Dict, List 13 | from rdkit import Chem 14 | from rdkit.Chem import AllChem 15 | from template_extraction.template_extractor import extract_from_reaction 16 | from template_extraction.template_extract_utils import get_bonds_from_smiles 17 | 18 | import sys 19 | sys.path.append('../') 20 | from textreact.tokenizer import BasicSmilesTokenizer 21 | 22 | log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 23 | logging.basicConfig(format = log_format, level = logging.INFO) 24 | 25 | logger = logging.getLogger() 26 | 27 | 28 | ATOM_REGEX = re.compile(r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p") 29 | 30 | 31 | def get_full_template(template, H_change, Charge_change, Chiral_change): 32 | H_code = ''.join([str(H_change[k+1]) for k in range(len(H_change))]) 33 | Charge_code = ''.join([str(Charge_change[k+1]) for k in range(len(Charge_change))]) 34 | Chiral_code = ''.join([str(Chiral_change[k+1]) for k in range(len(Chiral_change))]) 35 | if Chiral_code == '': 36 | return '_'.join([template, H_code, Charge_code]) 37 | else: 38 | return '_'.join([template, H_code, Charge_code, Chiral_code]) 39 | 40 | 41 | def canonicalize_smiles(smiles): 42 | mol = Chem.MolFromSmiles(smiles) 43 | for a in mol.GetAtoms(): 44 | a.SetAtomMapNum(0) 45 | canon_smile = Chem.MolToSmiles(mol) 46 | canon_perm = eval(mol.GetProp("_smilesAtomOutputOrder")) 47 | # atomidx2stridx = [i 48 | # for i, token in enumerate(BasicSmilesTokenizer().tokenize(canon_smile)) 49 | # if ATOM_REGEX.fullmatch(token) is not None] 50 | # atomidx2canonstridx = [0 for _ in range(len(canon_perm))] 51 | # for canon_idx, orig_idx in enumerate(canon_perm): 52 | # atomidx2canonstridx[orig_idx] = atomidx2stridx[canon_idx] 53 | atomidx2canonidx = [None for _ in range(len(canon_perm))] 54 | for canon_idx, orig_idx in enumerate(canon_perm): 55 | atomidx2canonidx[orig_idx] = canon_idx 56 | return canon_smile, atomidx2canonidx 57 | 58 | 59 | class Processor(ABC): 60 | """Base class for processor""" 61 | 62 | @abstractmethod 63 | def __init__(self, 64 | model_name: str, 65 | model_args, # let's enforce everything to be passed in args 66 | model_config: Dict[str, any], # or config 67 | data_name: str, 68 | raw_data_files: List[str], 69 | processed_data_path: str): 70 | self.model_name = model_name 71 | self.model_args = model_args 72 | self.model_config = model_config 73 | self.data_name = data_name 74 | self.raw_data_files = raw_data_files 75 | self.processed_data_path = processed_data_path 76 | 77 | os.makedirs(self.processed_data_path, exist_ok=True) 78 | 79 | self.check_count = 100 80 | 81 | def check_data_format(self) -> None: 82 | """Check that all files exists and the data format is correct for the first few lines""" 83 | logger.info(f"Checking the first {self.check_count} entries for each file") 84 | for fn in self.raw_data_files: 85 | if not fn: 86 | continue 87 | assert os.path.exists(fn), f"{fn} does not exist!" 88 | 89 | with open(fn, "r") as csv_file: 90 | csv_reader = csv.DictReader(csv_file) 91 | for i, row in enumerate(csv_reader): 92 | if i > self.check_count: # check the first few rows 93 | break 94 | 95 | assert (c in row for c in ["class", "rxn_smiles"]), \ 96 | f"Error processing file {fn} line {i}, ensure columns 'class' and " \ 97 | f"'rxn_smiles' is included!" 98 | 99 | reactants, reagents, products = row["rxn_smiles"].split(">") 100 | Chem.MolFromSmiles(reactants) # simply ensures that SMILES can be parsed 101 | Chem.MolFromSmiles(products) # simply ensures that SMILES can be parsed 102 | 103 | logger.info("Data format check passed") 104 | 105 | @abstractmethod 106 | def preprocess(self) -> None: 107 | """Actual file-based preprocessing""" 108 | pass 109 | 110 | # Adpated from Extract_from_train_data.py 111 | class LocalRetroProcessor(Processor): 112 | """Class for LocalRetro Preprocessing""" 113 | 114 | def __init__(self, 115 | model_name: str, 116 | model_args, 117 | model_config: Dict[str, any], 118 | data_name: str, 119 | raw_data_files: List[str], 120 | processed_data_path: str, 121 | num_cores: int = None): 122 | super().__init__(model_name=model_name, 123 | model_args=model_args, 124 | model_config=model_config, 125 | data_name=data_name, 126 | raw_data_files=raw_data_files, 127 | processed_data_path=processed_data_path) 128 | self.train_file, self.val_file, self.test_file = raw_data_files 129 | self.num_cores = num_cores 130 | self.setting = {'verbose': False, 'use_stereo': True, 'use_symbol': True, 131 | 'max_unmap': 5, 'retro': True, 'remote': True, 'least_atom_num': 2, 132 | "max_edit_n": 8, "min_template_n": 1} 133 | self.RXNHASCLASS = False 134 | 135 | def preprocess(self) -> None: 136 | """Actual file-based preprocessing""" 137 | self.extract_templates() 138 | self.match_templates() 139 | 140 | def extract_templates(self): 141 | """Adapted from Extract_from_train_data.py""" 142 | if all(os.path.exists(os.path.join(self.processed_data_path, fn)) 143 | for fn in ["template_infos.csv", "atom_templates.csv", "bond_templates.csv"]): 144 | logger.info(f"Extracted templates found at {self.processed_data_path}, skipping extraction.") 145 | return 146 | 147 | logger.info(f"Extracting templates from {self.train_file}") 148 | 149 | with open(self.train_file, "r") as csv_file: 150 | csv_reader = csv.DictReader(csv_file) 151 | rxns = [row["rxn_smiles"].strip() for row in csv_reader] 152 | 153 | TemplateEdits = {} 154 | TemplateCs, TemplateHs, TemplateSs = {}, {}, {} 155 | TemplateFreq, templates_A, templates_B = defaultdict(int), defaultdict(int), defaultdict(int) 156 | 157 | for i, rxn in enumerate(rxns): 158 | try: 159 | rxn = {'reactants': rxn.split('>')[0], 'products': rxn.split('>')[-1], '_id': i} 160 | result = extract_from_reaction(rxn, self.setting) 161 | if 'reactants' not in result or 'reaction_smarts' not in result.keys(): 162 | logger.info(f'\ntemplate problem: id: {i}') 163 | continue 164 | reactant = result['reactants'] 165 | template = result['reaction_smarts'] 166 | edits = result['edits'] 167 | H_change = result['H_change'] 168 | Charge_change = result['Charge_change'] 169 | Chiral_change = result["Chiral_change"] if self.setting["use_stereo"] else {} 170 | 171 | template_H = get_full_template(template, H_change, Charge_change, Chiral_change) 172 | if template_H not in TemplateHs.keys(): 173 | TemplateEdits[template_H] = {edit_type: edits[edit_type][2] for edit_type in edits} 174 | TemplateHs[template_H] = H_change 175 | TemplateCs[template_H] = Charge_change 176 | TemplateSs[template_H] = Chiral_change 177 | 178 | TemplateFreq[template_H] += 1 179 | for edit_type, bonds in edits.items(): 180 | bonds = bonds[0] 181 | if len(bonds) > 0: 182 | if edit_type in ['A', 'R']: 183 | templates_A[template_H] += 1 184 | else: 185 | templates_B[template_H] += 1 186 | 187 | except Exception as e: 188 | logger.info(i, e) 189 | 190 | if i % 1000 == 0: 191 | logger.info(f'\r i = {i}, # of template: {len(TemplateFreq)}, ' 192 | f'# of atom template: {len(templates_A)}, ' 193 | f'# of bond template: {len(templates_B)}') 194 | logger.info('\n total # of template: %s' % len(TemplateFreq)) 195 | 196 | derived_templates = {'atom': templates_A, 'bond': templates_B} 197 | 198 | ofn = os.path.join(self.processed_data_path, "template_infos.csv") 199 | TemplateInfos = pd.DataFrame( 200 | {'Template': k, 201 | 'edit_site': TemplateEdits[k], 202 | 'change_H': TemplateHs[k], 203 | 'change_C': TemplateCs[k], 204 | 'change_S': TemplateSs[k], 205 | 'Frequency': TemplateFreq[k]} for k in TemplateHs.keys()) 206 | TemplateInfos.to_csv(ofn) 207 | 208 | for k, local_templates in derived_templates.items(): 209 | ofn = os.path.join(self.processed_data_path, f"{k}_templates.csv") 210 | with open(ofn, "w") as of: 211 | writer = csv.writer(of) 212 | header = ['Template', 'Frequency', 'Class'] 213 | writer.writerow(header) 214 | 215 | sorted_tuples = sorted(local_templates.items(), key=lambda item: item[1]) 216 | for i, (template, template_freq) in enumerate(sorted_tuples): 217 | writer.writerow([template, template_freq, i + 1]) 218 | 219 | def match_templates(self): 220 | """Adapted from Run_preprocessing.py""" 221 | # load_templates() 222 | template_dicts = {} 223 | 224 | for site in ['atom', 'bond']: 225 | fn = os.path.join(self.processed_data_path, f"{site}_templates.csv") 226 | with open(fn, "r") as csv_file: 227 | csv_reader = csv.DictReader(csv_file) 228 | template_dict = {row["Template"].strip(): int(row["Class"]) for row in csv_reader} 229 | logger.info(f'loaded {len(template_dict)} {site} templates') 230 | template_dicts[site] = template_dict 231 | 232 | fn = os.path.join(self.processed_data_path, "template_infos.csv") 233 | with open(fn, "r") as csv_file: 234 | csv_reader = csv.DictReader(csv_file) 235 | template_infos = { 236 | row["Template"]: { 237 | "edit_site": eval(row["edit_site"]), 238 | "frequency": int(row["Frequency"]) 239 | } for row in csv_reader 240 | } 241 | logger.info('loaded total %s templates' % len(template_infos)) 242 | 243 | # labeling_dataset() 244 | dfs = {} 245 | for phase, fn in [("train", self.train_file), 246 | ("val", self.val_file), 247 | ("test", self.test_file)]: 248 | with open(fn, "r") as csv_file: 249 | csv_reader = csv.DictReader(csv_file) 250 | rxns = [row["rxn_smiles"].strip() for row in csv_reader] 251 | reactants, products, reagents = [], [], [] 252 | labels, frequency = [], [] 253 | product_canon_smiles = [] 254 | # product_atomidx2canonstridxs = [] 255 | product_atomidx2canonidxs = [] 256 | product_canon_bondss = [] 257 | success = 0 258 | num_canon_smiles_mismatch = 0 259 | 260 | for i, rxn in enumerate(rxns): 261 | reactant, _, product = rxn.split(">") 262 | reagent = '' 263 | rxn_labels = [] 264 | 265 | # get canonical permutation of atoms 266 | product_canon_smile, product_atomidx2canonidx = canonicalize_smiles(product) 267 | product_canon_bonds = get_bonds_from_smiles(product_canon_smile) 268 | 269 | # parse reaction and store results 270 | try: 271 | rxn = {'reactants': reactant, 'products': product, '_id': i} 272 | result = extract_from_reaction(rxn, self.setting) 273 | 274 | template = result['reaction_smarts'] 275 | reactant = result['reactants'] 276 | product = result['products'] 277 | extracted_product_canon_smile, product_atomidx2canonidx = canonicalize_smiles(product) 278 | num_canon_smiles_mismatch += int(extracted_product_canon_smile != product_canon_smile) 279 | reagent = '.'.join(result['necessary_reagent']) 280 | edits = {edit_type: edit_bond[0] for edit_type, edit_bond in result['edits'].items()} 281 | H_change, Charge_change, Chiral_change = \ 282 | result['H_change'], result['Charge_change'], result['Chiral_change'] 283 | template_H = get_full_template(template, H_change, Charge_change, Chiral_change) 284 | 285 | if template_H not in template_infos.keys(): 286 | reactants.append(reactant) 287 | products.append(product) 288 | reagents.append(reagent) 289 | labels.append(rxn_labels) 290 | frequency.append(0) 291 | product_canon_smiles.append(product_canon_smile) 292 | # product_atomidx2canonstridxs.append(product_atomidx2canonstridx) 293 | product_atomidx2canonidxs.append(product_atomidx2canonidx) 294 | product_canon_bondss.append(product_canon_bonds) 295 | continue 296 | 297 | except Exception as e: 298 | logger.info(i, e) 299 | reactants.append(reactant) 300 | products.append(product) 301 | reagents.append(reagent) 302 | labels.append(rxn_labels) 303 | frequency.append(0) 304 | product_canon_smiles.append(product_canon_smile) 305 | # product_atomidx2canonstridxs.append(product_atomidx2canonstridx) 306 | product_atomidx2canonidxs.append(product_atomidx2canonidx) 307 | product_canon_bondss.append(product_canon_bonds) 308 | continue 309 | 310 | edit_n = 0 311 | for edit_type in edits: 312 | if edit_type == 'C': 313 | edit_n += len(edits[edit_type]) / 2 314 | else: 315 | edit_n += len(edits[edit_type]) 316 | 317 | if edit_n <= self.setting['max_edit_n']: 318 | try: 319 | success += 1 320 | for edit_type, edit in edits.items(): 321 | for e in edit: 322 | if edit_type in ['A', 'R']: 323 | rxn_labels.append( 324 | ('a', e, template_dicts['atom'][template_H])) 325 | else: 326 | rxn_labels.append( 327 | ('b', e, template_dicts['bond'][template_H])) 328 | reactants.append(reactant) 329 | products.append(product) 330 | reagents.append(reagent) 331 | labels.append(rxn_labels) 332 | frequency.append(template_infos[template_H]['frequency']) 333 | product_canon_smiles.append(product_canon_smile) 334 | # product_atomidx2canonstridxs.append(product_atomidx2canonstridx) 335 | product_atomidx2canonidxs.append(product_atomidx2canonidx) 336 | product_canon_bondss.append(product_canon_bonds) 337 | 338 | except Exception as e: 339 | logger.info(i, e) 340 | reactants.append(reactant) 341 | products.append(product) 342 | reagents.append(reagent) 343 | labels.append(rxn_labels) 344 | frequency.append(0) 345 | product_canon_smiles.append(product_canon_smile) 346 | # product_atomidx2canonstridxs.append(product_atomidx2canonstridx) 347 | product_atomidx2canonidxs.append(product_atomidx2canonidx) 348 | product_canon_bondss.append(product_canon_bonds) 349 | continue 350 | 351 | if i % 1000 == 0: 352 | logger.info(f'\r Processing {self.data_name} {phase} data..., ' 353 | f'success {success} data ({i}/{len(rxns)})') 354 | else: 355 | logger.info(f'\nReaction # {i} has too many edits ({edit_n})...may be wrong mapping!') 356 | reactants.append(reactant) 357 | products.append(product) 358 | reagents.append(reagent) 359 | labels.append(rxn_labels) 360 | frequency.append(0) 361 | product_canon_smiles.append(product_canon_smile) 362 | # product_atomidx2canonstridxs.append(product_atomidx2canonstridx) 363 | product_atomidx2canonidxs.append(product_atomidx2canonidx) 364 | product_canon_bondss.append(product_canon_bonds) 365 | 366 | logger.info(f'\nDerived templates cover {success / len(rxns): .3f} of {phase} data reactions') 367 | logger.info(f'\nNumber of canonical smiles mismatches: {num_canon_smiles_mismatch} / {len(rxns)}') 368 | ofn = os.path.join(self.processed_data_path, f"preprocessed_{phase}.csv") 369 | dfs[phase] = pd.DataFrame( 370 | {'Reactants': reactants, 371 | 'Products': products, 372 | 'Reagents': reagents, 373 | 'Labels': labels, 374 | 'Frequency': frequency, 375 | 'ProductCanonSmiles': product_canon_smiles, 376 | # 'ProductAtomIdx2CanonStrIdx': product_atomidx2canonstridxs, 377 | 'ProductAtomIdx2CanonIdx': product_atomidx2canonidxs, 378 | 'ProductCanonBonds': product_canon_bondss}) 379 | dfs[phase].to_csv(ofn) 380 | 381 | # make_simulate_output() 382 | df = dfs["test"] 383 | ofn = os.path.join(self.processed_data_path, "simulate_output.txt") 384 | with open(ofn, 'w') as of: 385 | of.write('Test_id\tReactant\tProduct\t%s\n' % '\t'.join( 386 | [f'Edit {i + 1}\tProba {i + 1}' for i in range(self.setting['max_edit_n'])])) 387 | for i in df.index: 388 | labels = [] 389 | for y in df['Labels'][i]: 390 | if y != 0: 391 | labels.append(y) 392 | if not labels: 393 | labels = [(0, 0)] 394 | string_labels = '\t'.join([f'{l}\t{1.0}' for l in labels]) 395 | of.write('%s\t%s\t%s\t%s\n' % (i, df['Reactants'][i], df['Products'][i], string_labels)) 396 | 397 | # combine_preprocessed_data() 398 | dfs["train"]['Split'] = ['train'] * len(dfs["train"]) 399 | dfs["val"]['Split'] = ['val'] * len(dfs["val"]) 400 | dfs["test"]['Split'] = ['test'] * len(dfs["test"]) 401 | all_valid = dfs["train"]._append(dfs["val"], ignore_index=True) 402 | all_valid = all_valid._append(dfs["test"], ignore_index=True) 403 | all_valid['Mask'] = [int(f >= self.setting['min_template_n']) for f in all_valid['Frequency']] 404 | ofn = os.path.join(self.processed_data_path, "labeled_data.csv") 405 | all_valid.to_csv(ofn, index=None) 406 | logger.info(f'Valid data size: {len(all_valid)}') 407 | 408 | 409 | if __name__ == "__main__": 410 | from argparse import Namespace 411 | 412 | model_name = "localretro" 413 | model_args = Namespace() 414 | model_config = {} 415 | data_name = "USPTO_50K_year" 416 | raw_data_files = ["train.csv", "valid.csv", "test.csv"] 417 | raw_data_files = [os.path.join("../data_USPTO_50K_year_raw/", file_name) for file_name in raw_data_files] 418 | processed_data_path = "../data_template/USPTO_50K_year/" 419 | preprocessor = LocalRetroProcessor(model_name, model_args, model_config, data_name, raw_data_files, processed_data_path) 420 | preprocessor.check_data_format() 421 | preprocessor.preprocess() 422 | -------------------------------------------------------------------------------- /preprocess/preprocess_retrosynthesis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import multiprocessing 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | 9 | import rdkit 10 | from rdkit import Chem 11 | rdkit.RDLogger.DisableLog('rdApp.*') 12 | import rdkit.Chem.rdChemReactions as rdChemReactions 13 | import rdkit.DataStructs as DataStructs 14 | 15 | 16 | BASE = 'data/USPTO_50K' 17 | 18 | 19 | def canonical_rxn_smiles(rxn_smiles): 20 | reactants, reagents, products = rxn_smiles.split(">") 21 | try: 22 | mols_r = Chem.MolFromSmiles(reactants) 23 | mols_p = Chem.MolFromSmiles(products) 24 | [a.ClearProp('molAtomMapNumber') for a in mols_r.GetAtoms()] 25 | [a.ClearProp('molAtomMapNumber') for a in mols_p.GetAtoms()] 26 | cano_smi_r = Chem.MolToSmiles(mols_r, isomericSmiles=True, canonical=True) 27 | cano_smi_p = Chem.MolToSmiles(mols_p, isomericSmiles=True, canonical=True) 28 | return cano_smi_r + '>>' + cano_smi_p, cano_smi_r, cano_smi_p, True 29 | except: 30 | return rxn_smiles, reactants, products, False 31 | 32 | 33 | def reaction_fingerprint(smiles): 34 | rxn = rdChemReactions.ReactionFromSmarts(smiles) 35 | fp = rdChemReactions.CreateDifferenceFingerprintForReaction(rxn) 36 | return fp 37 | 38 | 39 | def reaction_similarity(smiles1=None, smiles2=None, fp1=None, fp2=None): 40 | if fp1 is None and smiles1: 41 | fp1 = reaction_fingerprint(smiles1) 42 | if fp2 is None and smiles2: 43 | fp2 = reaction_fingerprint(smiles2) 44 | assert fp1 is not None 45 | assert fp2 is not None 46 | return DataStructs.TanimotoSimilarity(fp1, fp2) 47 | 48 | 49 | # Canonical SMILES 50 | 51 | # for split in ['train', 'valid', 'test']: 52 | # df = pd.read_csv(os.path.join(BASE, f'{split}.csv')) 53 | # invalid_count = 0 54 | # for i, row in df.iterrows(): 55 | # try: 56 | # reactants, reagents, products = row["rxn_smiles"].split(">") 57 | # mols_r = Chem.MolFromSmiles(reactants) 58 | # mols_p = Chem.MolFromSmiles(products) 59 | # if mols_r is None or mols_p is None: 60 | # invalid_count += 1 61 | # continue 62 | 63 | # [a.ClearProp('molAtomMapNumber') for a in mols_r.GetAtoms()] 64 | # [a.ClearProp('molAtomMapNumber') for a in mols_p.GetAtoms()] 65 | 66 | # cano_smi_r = Chem.MolToSmiles(mols_r, isomericSmiles=True, canonical=True) 67 | # cano_smi_p = Chem.MolToSmiles(mols_p, isomericSmiles=True, canonical=True) 68 | 69 | # df.loc[i, 'reactant_smiles'] = cano_smi_r 70 | # df.loc[i, 'product_smiles'] = cano_smi_p 71 | # except Exception as e: 72 | # logging.info(e) 73 | # logging.info(row["rxn_smiles"].split(">")) 74 | # invalid_count += 1 75 | 76 | # logging.info(f"Invalid count: {invalid_count}") 77 | # df.drop(columns='rxn_smiles', inplace=True) 78 | # df.to_csv(os.path.join(BASE, f'processed/{split}.csv'), index=False) 79 | 80 | 81 | # rxn_df = pd.read_csv('data/USPTO_rxn_condition.csv') 82 | # with multiprocessing.Pool(32) as p: 83 | # results = p.map(canonical_rxn_smiles, rxn_df['rxn_smiles'], chunksize=128) 84 | # canonical_rxn, reactants, products, success = zip(*results) 85 | # 86 | # print(np.mean(success)) 87 | # rxn_df['canonical_rxn'] = canonical_rxn 88 | # rxn_df['reactants'] = reactants 89 | # rxn_df['products'] = products 90 | # rxn_df = rxn_df[['id', 'source', 'year', 'patent_type', 'canonical_rxn', 'reactants', 'products']] 91 | # rxn_df.to_csv('data/USPTO_rxn_smiles.csv', index=False) 92 | 93 | 94 | # Match id 95 | 96 | corpus_df = pd.read_csv('preprocess/uspto_script/uspto_rxn_condition_remapped_and_reassign_condition_role.csv') 97 | rxn_smiles_to_id = {} 98 | for i, row in tqdm(corpus_df.iterrows()): 99 | canonical_rxn = row['canonical_rxn'] 100 | if canonical_rxn not in rxn_smiles_to_id: 101 | rxn_smiles_to_id[canonical_rxn] = [] 102 | rxn_smiles_to_id[canonical_rxn].append(row['id']) 103 | 104 | for split in ['train', 'valid', 'test']: 105 | df = pd.read_csv(f'data/USPTO_50K/processed/{split}.csv') 106 | cnt = 0 107 | match_patent_cnt = 0 108 | nomatch_cnt = 0 109 | matched_ids = [] 110 | f = open('tmp.txt', 'w') 111 | for i, row in tqdm(df.iterrows()): 112 | rxn_smiles = row['reactant_smiles'] + '>>' + row['product_smiles'] 113 | if rxn_smiles in rxn_smiles_to_id: 114 | rxn_id = rxn_smiles_to_id[rxn_smiles][0] 115 | for idx in rxn_smiles_to_id[rxn_smiles]: 116 | if idx.startswith(row['id']): 117 | rxn_id = idx 118 | match_patent_cnt += 1 119 | break 120 | cnt += 1 121 | else: 122 | patent_df = corpus_df.loc[corpus_df['source'] == row['id']] 123 | f.write(row['id'] + '\n') 124 | rxn_id = f'unk_{split}_{i}' 125 | if len(patent_df) == 0: 126 | nomatch_cnt += 1 127 | f.write('No match\n\n') 128 | else: 129 | fp = reaction_fingerprint(rxn_smiles) 130 | # patent_rxn_smiles = [canonical_rxn_smiles(smiles)[0] for smiles in patent_df['rxn_smiles']] 131 | patent_rxn_smiles = patent_df['canonical_rxn'].tolist() 132 | similarities = [reaction_similarity(fp1=fp, smiles2=smiles) for smiles in patent_rxn_smiles] 133 | nearest_idx = np.argmax(similarities) 134 | nearest_row = patent_df.iloc[nearest_idx] 135 | if similarities[nearest_idx] > 0.9: 136 | rxn_id = nearest_row['id'] 137 | cnt += 1 138 | f.write(rxn_smiles + '\n') 139 | f.write(patent_rxn_smiles[nearest_idx] + '\n') 140 | f.write(json.dumps(similarities) + '\n') 141 | f.write(f'{similarities[nearest_idx]}\n') 142 | f.write('\n') 143 | f.flush() 144 | matched_ids.append(rxn_id) 145 | f.close() 146 | df['source'] = df['id'] 147 | df['id'] = matched_ids 148 | os.makedirs('data/USPTO_50K/matched1/', exist_ok=True) 149 | df.to_csv(f'data/USPTO_50K/matched1/{split}.csv', index=False) 150 | print(cnt, match_patent_cnt, nomatch_cnt, len(df)) 151 | -------------------------------------------------------------------------------- /preprocess/raw_retro_year_split.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import json 4 | 5 | with open('/Mounts/rbg-storage1/users/yujieq/textreact/preprocess/uspto_script/patent_info.json') as f: 6 | patent_info = json.load(f) 7 | 8 | train_df = pd.read_csv('data/USPTO_50K/train.csv') 9 | valid_df = pd.read_csv('data/USPTO_50K/valid.csv') 10 | test_df = pd.read_csv('data/USPTO_50K/test.csv') 11 | train_df_proc = pd.read_csv('data/USPTO_50K/matched1/train.csv') 12 | valid_df_proc = pd.read_csv('data/USPTO_50K/matched1/valid.csv') 13 | test_df_proc = pd.read_csv('data/USPTO_50K/matched1/test.csv') 14 | 15 | df = pd.concat([train_df, valid_df, test_df]).reindex() 16 | df_proc = pd.concat([train_df_proc, valid_df_proc, test_df_proc]).reindex() 17 | 18 | train_idx = [] 19 | valid_idx = [] 20 | test_idx = [] 21 | train_bad = 0 22 | valid_bad = 0 23 | test_bad = 0 24 | unk = 0 25 | for i, (proc_rxn_id, source, rxn_id) in enumerate(zip(df_proc['id'], df_proc['source'], df['id'])): 26 | p = proc_rxn_id.split('_')[0] 27 | if p != 'unk': 28 | bad = source != rxn_id 29 | if p in patent_info: 30 | year = patent_info[p]['year'] 31 | else: 32 | year = -1 33 | if year < 2012: 34 | train_idx.append(i) 35 | train_bad += int(bad) 36 | elif year in [2012, 2013]: 37 | valid_idx.append(i) 38 | valid_bad += int(bad) 39 | else: 40 | test_idx.append(i) 41 | test_bad += int(bad) 42 | else: 43 | unk += 1 44 | print(train_bad, valid_bad, test_bad) 45 | print(unk) 46 | 47 | os.makedirs('data_USPTO_50K_year_raw_alt/', exist_ok=True) 48 | train_df = df.iloc[train_idx].reindex() 49 | train_df.to_csv('data_USPTO_50K_year_raw_alt/train.csv', index=False) 50 | valid_df = df.iloc[valid_idx].reindex() 51 | valid_df.to_csv('data_USPTO_50K_year_raw_alt/valid.csv', index=False) 52 | test_df = df.iloc[test_idx].reindex() 53 | test_df.to_csv('data_USPTO_50K_year_raw_alt/test.csv', index=False) 54 | -------------------------------------------------------------------------------- /preprocess/reagent_Ionic_compound.txt: -------------------------------------------------------------------------------- 1 | [Na+].[OH-] 2 | [Li+].[OH-] 3 | [K+].[OH-] 4 | [NH4+].[OH-] 5 | [H-].[Na+] 6 | [H-].[K+] 7 | [H-].[Cs+] 8 | [BH4-].[Na+] 9 | [BH4-].[Li+] 10 | [Al+3].[H-].[H-].[H-].[H-].[Li+] 11 | [Al+3].[Cl-].[Cl-].[Cl-] 12 | O=P([O-])([O-])[O-].[K+].[K+].[K+] 13 | [BH3-]C#N.[Na+] 14 | [I-].[K+] 15 | [I-].[Li+] 16 | [I-].[Na+] 17 | [I-].[Cs+] 18 | [I-].[NH4+] 19 | [Cl-].[NH4+] 20 | [Cl-].[Na+] 21 | [Cl-].[K+] 22 | [Cl-].[Cs+] 23 | [Cl-].[Li+] 24 | [Cs+].[F-] 25 | [K+].[F-] 26 | [Na+].[F-] 27 | [Na+].[O-]Cl 28 | [Na+].[O-][Cl+][O-] 29 | [Na+].[O-][I+3]([O-])([O-])[O-] 30 | O=S([O-])([O-])=S.[Na+].[Na+] 31 | O=S([O-])S(=O)[O-].[Na+].[Na+] 32 | O=S(=O)([O-])[O-].[Mg+2] 33 | O=S(=O)([O-])[O-].[Na+].[Na+] 34 | O=S([O-])[O-].[Na+].[Na+] 35 | O=S([O-])O.[Na+] 36 | [Fe+2].c1cc[cH-]c1.c1cc[cH-]c1 37 | [Ca+2].[Cl-].[Cl-] 38 | [Br-].[K+] 39 | [C-]#N.[Na+] 40 | CC(C)[N-]C(C)C.[Li+] 41 | 42 | 43 | 44 | O=C([O-])[O-].[K+].[K+] 45 | O=C([O-])[O-].[Na+].[Na+] 46 | O=C([O-])[O-].[Cs+].[Cs+] 47 | O=C([O-])[O-].[Ca+2] 48 | O=C([O-])[O-].[Li+] 49 | O=C[O-].[NH4+] 50 | O=C([O-])O.[Na+] 51 | O=C([O-])O.[K+] 52 | CC(=O)[O-].[Na+] 53 | CC(=O)[O-].[K+] 54 | C[O-].[Na+] 55 | CC(C)(C)[O-].[K+] 56 | O=N[O-].[Na+] 57 | CN(C)C(On1nnc2cccnc21)=[N+](C)C.F[P-](F)(F)(F)(F)F 58 | CN(C)[P+](On1nnc2ccccc21)(N(C)C)N(C)C.F[P-](F)(F)(F)(F)F 59 | CCCC[N+](CCCC)(CCCC)CCCC.[F-] 60 | CC(=O)O[BH-](OC(C)=O)OC(C)=O.[Na+] 61 | C[Si](C)(C)[N-][Si](C)(C)C.[Li+] 62 | O=[Cr](=O)([O-])Cl.c1cc[nH+]cc1 63 | CC(C)C[Al+]CC(C)C.[H-] 64 | CN(C)C(On1nnc2ccccc21)=[N+](C)C.F[B-](F)(F)F 65 | CN(C)C(On1nnc2ccccc21)=[N+](C)C.F[P-](F)(F)(F)(F)F 66 | CN(C)C(On1nnc2cccnc21)=[N+](C)C.F[P-](F)(F)(F)(F)F 67 | Br[P+](N1CCCC1)(N1CCCC1)N1CCCC1.F[P-](F)(F)(F)(F)F 68 | F[P-](F)(F)(F)(F)F.c1ccc2c(c1)nnn2O[P+](N1CCCC1)(N1CCCC1)N1CCCC1 69 | C1CC[NH2+]CC1.CC(=O)[O-] 70 | CN(C)C(N(C)C)=[N+]1N=[N+]([O-])c2ncccc21.F[P-](F)(F)(F)(F)F 71 | CN1CC[NH+](C)C1Cl.[Cl-] 72 | COc1nc(OC)nc([N+]2(C)CCOCC2)n1.[Cl-] 73 | C[Si](C)(C)[N-][Si](C)(C)C.[K+] 74 | C[Si](C)(C)[N-][Si](C)(C)C.[Na+] 75 | C[n+]1ccccc1Cl.[I-] 76 | O=S(=O)([O-])C(F)(F)F.O=S(=O)([O-])C(F)(F)F.O=S(=O)([O-])C(F)(F)F.[Yb+3] 77 | O=[Cr](=O)([O-])O[Cr](=O)(=O)[O-].c1cc[nH+]cc1.c1cc[nH+]cc1 78 | Cc1ccc(S(=O)(=O)[O-])cc1.c1cc[nH+]cc1 79 | -------------------------------------------------------------------------------- /preprocess/reagent_unknown.txt: -------------------------------------------------------------------------------- 1 | [Na+],15 2 | O=C([O-])O,9 3 | CC(=O)O[BH-](OC(C)=O)OC(C)=O,8 4 | O=C([O-])[O-],6 5 | [K+],5 6 | [Cl-],5 7 | [BH4-],3 8 | [OH-],3 9 | [I-],3 10 | [K+].[K+],2 11 | [H-],2 12 | CC(=O)[O-],2 13 | O=S(=O)([O-])[O-].[Al+3].[Li+],2 14 | F[P-](F)(F)(F)(F)F,2 15 | [NH4+],1 16 | F[B-](F)(F)F,1 17 | F[B-](F)(F)F.F[B-](F)(F)F,1 18 | [Li+],1 19 | [BH3-]C#N,1 20 | [Na+].[Na+],1 21 | [Br-],1 22 | [Cs+].[Cs+],1 23 | C[O-],1 24 | CCCC[N+](CCCC)(CCCC)CCCC,1 25 | O=P([O-])(O)O,1 26 | -------------------------------------------------------------------------------- /preprocess/retro_year_split.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import json 4 | 5 | with open('preprocess/uspto_script/patent_info.json') as f: 6 | patent_info = json.load(f) 7 | 8 | train_df = pd.read_csv('data/USPTO_50K/matched1/train.csv') 9 | valid_df = pd.read_csv('data/USPTO_50K/matched1/valid.csv') 10 | test_df = pd.read_csv('data/USPTO_50K/matched1/test.csv') 11 | 12 | df = pd.concat([train_df, valid_df, test_df]).reindex() 13 | 14 | train_idx = [] 15 | valid_idx = [] 16 | test_idx = [] 17 | for i, rxn_id in enumerate(df['id']): 18 | p = rxn_id.split('_')[0] 19 | if p in patent_info: 20 | year = patent_info[p]['year'] 21 | else: 22 | year = -1 23 | if year < 2012: 24 | train_idx.append(i) 25 | elif year in [2012, 2013]: 26 | valid_idx.append(i) 27 | else: 28 | test_idx.append(i) 29 | 30 | os.makedirs('data/USPTO_50K_year/', exist_ok=True) 31 | train_df = df.iloc[train_idx].reindex() 32 | train_df.to_csv('data/USPTO_50K_year/train.csv', index=False) 33 | valid_df = df.iloc[valid_idx].reindex() 34 | valid_df.to_csv('data/USPTO_50K_year/valid.csv', index=False) 35 | test_df = df.iloc[test_idx].reindex() 36 | test_df.to_csv('data/USPTO_50K_year/test.csv', index=False) 37 | -------------------------------------------------------------------------------- /preprocess/template_extraction/template_extract_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | from collections import defaultdict 4 | 5 | from rdkit import Chem 6 | from rdkit.Chem import AllChem 7 | from rdkit.Chem.rdchem import ChiralType 8 | 9 | chiral_type_map = {ChiralType.CHI_UNSPECIFIED : 0, ChiralType.CHI_TETRAHEDRAL_CW: 1, ChiralType.CHI_TETRAHEDRAL_CCW: 2} 10 | bond_type_map = {'SINGLE': '-', 'DOUBLE': '=', 'TRIPLE': '#', 'AROMATIC': '@'} 11 | 12 | def get_template_bond(temp_order, bond_smarts): 13 | bond_match = {} 14 | for n, _ in enumerate(temp_order): 15 | bond_match[(temp_order[n], temp_order[n-1])] = bond_smarts[n-1] 16 | bond_match[(temp_order[n-1], temp_order[n])] = bond_smarts[n-1] 17 | return bond_match 18 | 19 | def bond_to_smiles(bond): 20 | a1_label = str(bond.GetBeginAtom().GetAtomicNum()) 21 | a2_label = str(bond.GetEndAtom().GetAtomicNum()) 22 | if bond.GetBeginAtom().HasProp('molAtomMapNumber'): 23 | a1_label += bond.GetBeginAtom().GetProp('molAtomMapNumber') 24 | if bond.GetEndAtom().HasProp('molAtomMapNumber'): 25 | a2_label += bond.GetEndAtom().GetProp('molAtomMapNumber') 26 | atoms = sorted([a1_label, a2_label]) 27 | bond_smarts = bond_type_map[str(bond.GetBondType())] 28 | return '{}{}{}'.format(atoms[0], bond_smarts, atoms[1]) 29 | 30 | def check_bond_break(bond1, bond2): 31 | if bond1 == None and bond2 != None: 32 | return False 33 | elif bond1 != None and bond2 == None: 34 | return True 35 | else: 36 | return False 37 | 38 | def check_bond_formed(bond1, bond2): 39 | if bond1 != None and bond2 == None: 40 | return False 41 | elif bond1 == None and bond2 != None: 42 | return True 43 | else: 44 | return False 45 | 46 | def check_bond_change(pbond, rbond): 47 | if pbond == None or rbond == None: 48 | return False 49 | elif bond_to_smiles(pbond) != bond_to_smiles(rbond): 50 | return True 51 | else: 52 | return False 53 | 54 | def atom_neighbors(atom): 55 | neighbor = [] 56 | for n in atom.GetNeighbors(): 57 | neighbor.append(n.GetAtomMapNum()) 58 | return sorted(neighbor) 59 | 60 | def extend_changed_atoms(changed_atom_tags, reactants, max_map): 61 | for reactant in reactants: 62 | extend_idx = [] 63 | for atom in reactant.GetAtoms(): 64 | if str(atom.GetAtomMapNum()) in changed_atom_tags: 65 | for n in atom.GetNeighbors(): 66 | if n.GetAtomMapNum() == 0: 67 | extend_idx.append(n.GetIdx()) 68 | for idx in extend_idx: 69 | reactant.GetAtomWithIdx(idx).SetAtomMapNum(max_map) 70 | 71 | def check_atom_change(patom, ratom): 72 | return atom_neighbors(patom) != atom_neighbors(ratom) 73 | 74 | def label_retro_edit_site(products, reactants, edit_num): 75 | edit_num = [int(num) for num in edit_num] 76 | pmol = Chem.MolFromSmiles(products) 77 | rmol = Chem.MolFromSmiles(reactants) 78 | patom_map = {atom.GetAtomMapNum():atom.GetIdx() for atom in pmol.GetAtoms()} 79 | ratom_map = {atom.GetAtomMapNum():atom.GetIdx() for atom in rmol.GetAtoms()} 80 | used_atom = set() 81 | grow_atoms = [] 82 | broken_bonds = [] 83 | changed_bonds = [] 84 | 85 | # cut bond 86 | for a in edit_num: 87 | for b in edit_num: 88 | if a >= b: 89 | continue 90 | pbond = pmol.GetBondBetweenAtoms(patom_map[a], patom_map[b]) 91 | rbond = rmol.GetBondBetweenAtoms(ratom_map[a], ratom_map[b]) 92 | if check_bond_break(pbond, rbond): # cut bond 93 | broken_bonds.append((a, b)) 94 | used_atom.update([a, b]) 95 | 96 | # Add LG 97 | for a in edit_num: 98 | if a in used_atom: 99 | continue 100 | patom = pmol.GetAtomWithIdx(patom_map[a]) 101 | ratom = rmol.GetAtomWithIdx(ratom_map[a]) 102 | if check_atom_change(patom, ratom): 103 | used_atom.update([a]) 104 | grow_atoms.append(a) 105 | 106 | # change bond type 107 | for a in edit_num: 108 | for b in edit_num: 109 | if a >= b: 110 | continue 111 | pbond = pmol.GetBondBetweenAtoms(patom_map[a], patom_map[b]) 112 | rbond = rmol.GetBondBetweenAtoms(ratom_map[a], ratom_map[b]) 113 | if check_bond_change(pbond, rbond): 114 | if a not in used_atom and b not in used_atom: 115 | changed_bonds.append((a, b)) 116 | changed_bonds.append((b, a)) 117 | 118 | used_atoms = set(grow_atoms + [atom for bond in broken_bonds+changed_bonds for atom in bond]) 119 | remote_atoms = [atom for atom in edit_num if atom not in used_atoms] 120 | remote_atoms_ = [] 121 | for a in remote_atoms: 122 | atom = rmol.GetAtomWithIdx(ratom_map[a]) 123 | neighbors_map = [n.GetAtomMapNum() for n in atom.GetNeighbors()] 124 | connected_neighbors = [b for b in used_atoms if b in neighbors_map] 125 | if len(connected_neighbors) > 0: 126 | pass 127 | else: 128 | for n in neighbors_map: 129 | remote_atoms_.append(a) 130 | 131 | return grow_atoms, broken_bonds, changed_bonds, remote_atoms_ 132 | 133 | def label_foward_edit_site(reactants, products, edit_num): 134 | edit_num = [int(num) for num in edit_num] 135 | rmol = Chem.MolFromSmiles(reactants) 136 | pmol = Chem.MolFromSmiles(products) 137 | ratom_map = {atom.GetAtomMapNum():atom.GetIdx() for atom in rmol.GetAtoms()} 138 | patom_map = {atom.GetAtomMapNum():atom.GetIdx() for atom in pmol.GetAtoms()} 139 | atom_symbols = {atom.GetAtomMapNum():atom.GetSymbol() for atom in rmol.GetAtoms()} 140 | 141 | formed_bonds = [] 142 | broken_bonds = [] 143 | changed_bonds = [] 144 | acceptors1 = set() 145 | acceptors2 = set() 146 | donors = set() 147 | form_bond = False 148 | break_bond = False 149 | change_bond = False 150 | 151 | # cut bond 152 | for a in edit_num: 153 | for b in edit_num: 154 | if a >= b: 155 | continue 156 | try: 157 | pbond = pmol.GetBondBetweenAtoms(patom_map[a], patom_map[b]) 158 | except: 159 | pbond = None 160 | rbond = rmol.GetBondBetweenAtoms(ratom_map[a], ratom_map[b]) 161 | if check_bond_break(rbond, pbond): 162 | if a in patom_map: 163 | broken_bonds.append((a, b)) 164 | acceptors1.add(a) 165 | if b in patom_map: 166 | broken_bonds.append((b, a)) 167 | acceptors1.add(b) 168 | break_bond = True 169 | 170 | # change bond 171 | for a in edit_num: 172 | for b in edit_num: 173 | if a >= b: 174 | continue 175 | try: 176 | pbond = pmol.GetBondBetweenAtoms(patom_map[a], patom_map[b]) 177 | except: 178 | pbond = None 179 | rbond = rmol.GetBondBetweenAtoms(ratom_map[a], ratom_map[b]) 180 | if check_bond_change(rbond, pbond): 181 | changed_bonds.append((a, b)) 182 | changed_bonds.append((b, a)) 183 | change_bond = True 184 | acceptors2.update([a, b]) 185 | 186 | symmetric = True 187 | # form bond 188 | for a in edit_num: 189 | for b in edit_num: 190 | if a >= b: 191 | continue 192 | try: 193 | pbond = pmol.GetBondBetweenAtoms(patom_map[a], patom_map[b]) 194 | except: 195 | pbond = None 196 | rbond = rmol.GetBondBetweenAtoms(ratom_map[a], ratom_map[b]) 197 | if check_bond_formed(rbond, pbond): # cut bond 198 | form_bond = True 199 | if a not in acceptors1 and b not in acceptors1 and a not in acceptors2 and b not in acceptors2 : 200 | formed_bonds.append((a, b)) 201 | formed_bonds.append((b, a)) 202 | elif a in acceptors1 and b in acceptors1: 203 | symmetric = False 204 | formed_bonds.append((a, b)) 205 | formed_bonds.append((b, a)) 206 | else: 207 | symmetric = False 208 | if a in acceptors1: 209 | formed_bonds.append((b, a)) 210 | elif a in acceptors2 and b not in acceptors1: 211 | formed_bonds.append((b, a)) 212 | if b in acceptors1: 213 | formed_bonds.append((a, b)) 214 | elif b in acceptors2 and a not in acceptors1: 215 | formed_bonds.append((a, b)) 216 | 217 | if not symmetric: 218 | new_changed_bonds = [] 219 | # electron acceptor propagation 220 | acceptors = set([bond[1] for bond in formed_bonds]).union(acceptors1) 221 | for atom in acceptors: 222 | for bond in changed_bonds: 223 | if bond[0] == atom: 224 | new_changed_bonds.append(bond) 225 | donors = set([bond[0] for bond in formed_bonds]) 226 | for atom in donors: 227 | for bond in changed_bonds: 228 | if bond[1] == atom: 229 | new_changed_bonds.append(bond) 230 | changed_bonds = list(set(new_changed_bonds)) 231 | 232 | used_atoms = set([atom for bond in formed_bonds+broken_bonds+changed_bonds for atom in bond]) 233 | remote_atoms = [atom for atom in edit_num if atom not in used_atoms] 234 | remote_bonds = [] 235 | for a in remote_atoms: 236 | atom = rmol.GetAtomWithIdx(ratom_map[a]) 237 | neighbors_map = [n.GetAtomMapNum() for n in atom.GetNeighbors()] 238 | connected_neighbors = [b for b in used_atoms if b in neighbors_map] 239 | if len(connected_neighbors) > 0: 240 | pass 241 | else: 242 | for n in neighbors_map: 243 | remote_bonds.append((a, n)) 244 | return formed_bonds, broken_bonds, changed_bonds, remote_bonds 245 | 246 | def label_CHS_change(smiles1, smiles2, edit_num, replacement_dict, use_stereo): 247 | mol1, mol2 = Chem.MolFromSmiles(smiles1), Chem.MolFromSmiles(smiles2) 248 | atom_map_dict1 = {atom.GetAtomMapNum():atom.GetIdx() for atom in mol1.GetAtoms()} 249 | atom_map_dict2 = {atom.GetAtomMapNum():atom.GetIdx() for atom in mol2.GetAtoms()} 250 | H_dict = defaultdict(dict) 251 | C_dict = defaultdict(dict) 252 | S_dict = defaultdict(dict) 253 | for atom_map in edit_num: 254 | atom_map = int(atom_map) 255 | if atom_map in atom_map_dict2: 256 | atom1, atom2 = mol1.GetAtomWithIdx(atom_map_dict1[atom_map]), mol2.GetAtomWithIdx(atom_map_dict2[atom_map]) 257 | H_dict[atom_map]['smiles1'], C_dict[atom_map]['smiles1'], S_dict[atom_map]['smiles1'] = atom1.GetNumExplicitHs(), int(atom1.GetFormalCharge()), chiral_type_map[atom1.GetChiralTag()] 258 | H_dict[atom_map]['smiles2'], C_dict[atom_map]['smiles2'], S_dict[atom_map]['smiles2'] = atom2.GetNumExplicitHs(), int(atom2.GetFormalCharge()), chiral_type_map[atom2.GetChiralTag()] 259 | 260 | H_change = {replacement_dict[k]:v['smiles2'] - v['smiles1'] for k, v in H_dict.items()} 261 | Charge_change = {replacement_dict[k]:v['smiles2'] - v['smiles1'] for k, v in C_dict.items()} 262 | Chiral_change = {replacement_dict[k]:v['smiles2'] - v['smiles1'] for k, v in S_dict.items()} 263 | for k, v in S_dict.items(): 264 | if v['smiles2'] == v['smiles1'] or not use_stereo: # no chiral change 265 | Chiral_change[replacement_dict[k]] = 0 266 | # elif v['smiles1'] != 0: # opposite the stereo bond 267 | # Chiral_change[replacement_dict[k]] = 3 268 | else: 269 | Chiral_change[replacement_dict[k]] = v['smiles2'] 270 | return atom_map_dict1, H_change, Charge_change, Chiral_change 271 | 272 | def bondmap2idx(bond_maps, idx_dict, temp_dict, sort = False, remote = False): 273 | bond_idxs = [(idx_dict[bond_map[0]], idx_dict[bond_map[1]]) for bond_map in bond_maps] 274 | if remote: 275 | bond_temps = list(set([(temp_dict[bond_map[0]], -1) for bond_map in bond_maps])) 276 | return (bond_idxs, bond_maps, bond_temps) 277 | else: 278 | bond_temps = [(temp_dict[bond_map[0]], temp_dict[bond_map[1]]) for bond_map in bond_maps] 279 | if not sort: 280 | return (bond_idxs, bond_maps, bond_temps) 281 | else: 282 | sort_bond_idxs = [] 283 | sort_bond_maps = [] 284 | sort_bond_temps = [] 285 | for bond1, bond2, bond3 in zip(bond_idxs, bond_maps, bond_temps): 286 | if bond3[0] < bond3[1]: 287 | sort_bond_idxs.append(bond1) 288 | sort_bond_maps.append(bond2) 289 | sort_bond_temps.append(bond3) 290 | else: 291 | sort_bond_idxs.append(tuple(bond1[::-1])) 292 | sort_bond_maps.append(tuple(bond2[::-1])) 293 | sort_bond_temps.append(tuple(bond3[::-1])) 294 | return (sort_bond_idxs, sort_bond_maps, sort_bond_temps) 295 | 296 | def atommap2idx(atom_maps, idx_dict, temp_dict): 297 | atom_idxs = [idx_dict[atom_map] for atom_map in atom_maps] 298 | atom_temps = [temp_dict[atom_map] for atom_map in atom_maps] 299 | return (atom_idxs, atom_maps, atom_temps) 300 | 301 | def match_label(reactants, products, replacement_dict, edit_num, retro = True, remote = True, use_stereo = True): 302 | if retro: 303 | smiles1 = products 304 | smiles2 = reactants 305 | else: 306 | smiles1 = reactants 307 | smiles2 = products 308 | 309 | replacement_dict = {int(k): int(v) for k, v in replacement_dict.items()} 310 | atom_map_dict, H_change, Charge_change, Chiral_change = label_CHS_change(smiles1, smiles2, edit_num, replacement_dict, use_stereo) 311 | 312 | if retro: 313 | ALG_atoms, broken_bonds, changed_bonds, remote_atoms = label_retro_edit_site(smiles1, smiles2, edit_num) 314 | edits = {'A': atommap2idx(ALG_atoms, atom_map_dict, replacement_dict), 315 | 'B': bondmap2idx(broken_bonds, atom_map_dict, replacement_dict, True), 316 | 'C': bondmap2idx(changed_bonds, atom_map_dict, replacement_dict)} 317 | if remote: 318 | edits['R'] = atommap2idx(remote_atoms, atom_map_dict, replacement_dict) 319 | else: 320 | formed_bonds, broken_bonds, changed_bonds, remote_bonds = label_foward_edit_site(smiles1, smiles2, edit_num) 321 | edits = {'A': bondmap2idx(formed_bonds, atom_map_dict, replacement_dict), 322 | 'B': bondmap2idx(broken_bonds, atom_map_dict, replacement_dict), 323 | 'C': bondmap2idx(changed_bonds, atom_map_dict, replacement_dict)} 324 | if remote: 325 | edits['R'] = bondmap2idx(remote_bonds, atom_map_dict, replacement_dict, False, True) 326 | return edits, H_change, Charge_change, Chiral_change 327 | 328 | def get_bonds_from_smiles(smiles): 329 | """ Adapted from get_edit_site_retro from in Run_preprocessing.py """ 330 | mol = Chem.MolFromSmiles(smiles) 331 | B = set() 332 | for atom in mol.GetAtoms(): 333 | others = [] 334 | bonds = atom.GetBonds() 335 | for bond in bonds: 336 | atoms = [bond.GetBeginAtom().GetIdx(), bond.GetEndAtom().GetIdx()] 337 | other = [a for a in atoms if a != atom.GetIdx()][0] 338 | others.append(other) 339 | B.update((atom.GetIdx(), other) for other in sorted(others)) 340 | return B 341 | -------------------------------------------------------------------------------- /preprocess/uspto_script/1.get_condition_from_uspto.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import xmltodict 3 | from collections import OrderedDict, Counter 4 | import pandas as pd 5 | import os 6 | import copy 7 | import json 8 | 9 | from utils import get_writer 10 | 11 | 12 | patent_cnt = Counter() 13 | patent_info = {} 14 | CONDITION_DICT = { 15 | 'id': [], 16 | 'source': [], 17 | 'year': [], 18 | 'patent_type': [], 19 | 'rxn_smiles': [], 20 | 'solvent': [], 21 | 'catalyst': [], 22 | 'reagent': [], 23 | } 24 | CORPUS_DICT = { 25 | 'id': [], 26 | 'year': [], 27 | 'patent_type': [], 28 | 'xml': [], 29 | 'heading_text': [], 30 | 'paragraph_text': [], 31 | } 32 | 33 | 34 | def read_xml2dict(xml_fpath): 35 | with open(xml_fpath, 'r') as f: 36 | data = xmltodict.parse(f.read()) 37 | reaction_and_condition_dict = copy.deepcopy(CONDITION_DICT) 38 | corpus_dict = copy.deepcopy(CORPUS_DICT) 39 | 40 | try: 41 | reaction_data_list = data['reactionList']['reaction'] 42 | except: 43 | return reaction_and_condition_dict, corpus_dict 44 | 45 | 46 | for rxn_data in reaction_data_list: 47 | 48 | try: 49 | patent_id = rxn_data['dl:source']['dl:documentId'] 50 | heading_text = rxn_data['dl:source'].get('dl:headingText', '') 51 | paragraph_text = rxn_data['dl:source'].get('dl:paragraphText', '') 52 | year = os.path.dirname(xml_fpath).split('/')[-1] 53 | patent_type = 'grant' if 'grants' in xml_fpath else 'application' 54 | patent_info[patent_id] = { 55 | 'year': int(year), 56 | 'type': patent_type 57 | } 58 | except: 59 | print(xml_fpath) 60 | print(rxn_data) 61 | continue 62 | 63 | if type(rxn_data) is str: 64 | continue 65 | if rxn_data['spectatorList'] is None: 66 | continue 67 | if rxn_data['dl:reactionSmiles'] is None: 68 | continue 69 | try: 70 | spectator_obj = rxn_data['spectatorList']['spectator'] 71 | except: 72 | continue 73 | 74 | s_list = [] 75 | c_list = [] 76 | r_list = [] 77 | if type(spectator_obj) is list: 78 | pass 79 | elif type(spectator_obj) is OrderedDict: 80 | spectator_obj = [spectator_obj] 81 | else: 82 | print('Warning spectator_obj is not in (list, OrderedDict)!!!') 83 | 84 | for one_spectator in spectator_obj: 85 | if 'identifier' not in one_spectator: 86 | continue 87 | if one_spectator['@role'] == 'solvent': 88 | if type(one_spectator['identifier']) is list: 89 | for identifier in one_spectator['identifier']: 90 | if identifier['@dictRef'] == 'cml:smiles': 91 | s_list.append(identifier['@value']) 92 | elif type(one_spectator['identifier']) is OrderedDict: 93 | identifier = one_spectator['identifier'] 94 | if identifier['@dictRef'] == 'cml:smiles': 95 | s_list.append(identifier['@value']) 96 | elif one_spectator['@role'] == 'catalyst': 97 | if type(one_spectator['identifier']) is list: 98 | for identifier in one_spectator['identifier']: 99 | if identifier['@dictRef'] == 'cml:smiles': 100 | c_list.append(identifier['@value']) 101 | elif type(one_spectator['identifier']) is OrderedDict: 102 | identifier = one_spectator['identifier'] 103 | if identifier['@dictRef'] == 'cml:smiles': 104 | c_list.append(identifier['@value']) 105 | elif one_spectator['@role'] == 'reagent': 106 | if type(one_spectator['identifier']) is list: 107 | for identifier in one_spectator['identifier']: 108 | if identifier['@dictRef'] == 'cml:smiles': 109 | r_list.append(identifier['@value']) 110 | elif type(one_spectator['identifier']) is OrderedDict: 111 | identifier = one_spectator['identifier'] 112 | if identifier['@dictRef'] == 'cml:smiles': 113 | r_list.append(identifier['@value']) 114 | else: 115 | print(one_spectator['@role']) 116 | 117 | rxn_id = patent_id + '_' + str(patent_cnt[patent_id]) 118 | patent_cnt[patent_id] += 1 119 | 120 | s_list = list(set(s_list)) 121 | c_list = list(set(c_list)) 122 | r_list = list(set(r_list)) 123 | reaction_and_condition_dict['solvent'].append('.'.join(s_list)) 124 | reaction_and_condition_dict['catalyst'].append('.'.join(c_list)) 125 | reaction_and_condition_dict['reagent'].append('.'.join(r_list)) 126 | reaction_and_condition_dict['rxn_smiles'].append(rxn_data['dl:reactionSmiles']) 127 | reaction_and_condition_dict['source'].append(patent_id) 128 | reaction_and_condition_dict['id'].append(rxn_id) 129 | reaction_and_condition_dict['year'].append(year) 130 | reaction_and_condition_dict['patent_type'].append(patent_type) 131 | 132 | corpus_dict['id'].append(rxn_id) 133 | corpus_dict['xml'].append(os.path.basename(xml_fpath)) 134 | corpus_dict['heading_text'].append(heading_text) 135 | corpus_dict['paragraph_text'].append(paragraph_text) 136 | corpus_dict['year'].append(year) 137 | corpus_dict['patent_type'].append(patent_type) 138 | 139 | return reaction_and_condition_dict, corpus_dict 140 | 141 | 142 | if __name__ == '__main__': 143 | uspto_org_path = '/Mounts/rbg-storage1/users/yujieq/USPTO/' 144 | 145 | # reaction_and_condition_df = pd.DataFrame() 146 | xml_path_list = [] 147 | for root, _, files in os.walk(uspto_org_path, topdown=False): 148 | if '.ipynb_checkpoints' in root: 149 | continue 150 | for fname in files: 151 | if fname.endswith('.xml'): 152 | xml_path_list.append(os.path.join(root, fname)) 153 | xml_path_list = sorted(xml_path_list) 154 | # fout, writer = get_writer('uspto_rxn_condition.csv', CONDITION_DICT.keys()) 155 | # corpus_fout, corpus_writer = get_writer('uspto_rxn_corpus.csv', CORPUS_DICT.keys()) 156 | cnt = 0 157 | for i, path in tqdm(enumerate(xml_path_list), total=len(xml_path_list)): 158 | reaction_and_condition_dict, corpus_dict = read_xml2dict(path) 159 | reaction_and_condition_df = pd.DataFrame(reaction_and_condition_dict) 160 | # for row in reaction_and_condition_df.itertuples(): 161 | # writer.writerow(list(row)[1:]) 162 | # fout.flush() 163 | corpus_df = pd.DataFrame(corpus_dict) 164 | # for row in corpus_df.itertuples(): 165 | # corpus_writer.writerow(list(row)[1:]) 166 | # corpus_fout.flush() 167 | cnt += len(reaction_and_condition_df) 168 | if i % 100 == 0: 169 | print(f'step {i}: {cnt} data') 170 | # fout.close() 171 | # corpus_fout.close() 172 | 173 | for patent_id in patent_cnt: 174 | patent_info[patent_id]['num_rxn'] = patent_cnt[patent_id] 175 | with open('patent_info.json', 'w') as f: 176 | json.dump(patent_info, f) 177 | 178 | print('Done') 179 | 180 | -------------------------------------------------------------------------------- /preprocess/uspto_script/2.0.clean_up_rxn_condition.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import OrderedDict 3 | from joblib import Parallel, delayed 4 | import pandas as pd 5 | import multiprocessing 6 | import os 7 | import argparse 8 | import torch 9 | from tqdm import tqdm 10 | from utils import canonicalize_smiles, get_writer 11 | from rxnmapper import RXNMapper 12 | 13 | 14 | debug = False 15 | 16 | 17 | def remap_and_reassign_condition_role(org_rxn, org_solvent, org_catalyst, org_reagent): 18 | if org_rxn.split('>') == 1: 19 | return None 20 | if '|' in org_rxn: 21 | rxn, frag = org_rxn.split(' ') 22 | else: 23 | rxn, frag = org_rxn, '' 24 | 25 | org_solvent, org_catalyst, org_reagent = [ 26 | canonicalize_smiles(x) for x in [org_solvent, org_catalyst, org_reagent] 27 | ] 28 | try: 29 | results = rxn_mapper.get_attention_guided_atom_maps([rxn])[0] 30 | except Exception as e: 31 | print('\n'+rxn+'\n') 32 | print(e) 33 | return None 34 | 35 | remapped_rxn = results['mapped_rxn'] 36 | confidence = results['confidence'] 37 | 38 | new_precursors, new_products = remapped_rxn.split('>>') 39 | 40 | pt = re.compile(r':(\d+)]') 41 | new_react_list = [] 42 | new_reag_list = [] 43 | for precursor in new_precursors.split('.'): 44 | if re.findall(pt, precursor): 45 | new_react_list.append(precursor) # 有原子映射-->反应物 46 | else: 47 | new_reag_list.append(precursor) # 无原子映射-->试剂 48 | 49 | new_reactants = '.'.join(new_react_list) 50 | react_maps = sorted(re.findall(pt, new_reactants)) 51 | prod_maps = sorted(re.findall(pt, new_products)) 52 | if react_maps != prod_maps: 53 | return None 54 | new_reagent_list = [] 55 | c_list = org_catalyst.split('.') 56 | s_list = org_solvent.split('.') 57 | r_list = org_reagent.split('.') 58 | for r in new_reag_list: 59 | if (r not in c_list + s_list) and (r not in r_list): 60 | new_reagent_list.append(r) 61 | new_reagent_list += [x for x in r_list if x != ''] 62 | catalyst = org_catalyst 63 | solvent = org_solvent 64 | reagent = '.'.join(new_reagent_list) 65 | can_react = canonicalize_smiles(new_reactants, clear_map=True) 66 | can_prod = canonicalize_smiles(new_products, clear_map=True) 67 | can_rxn = '{}>>{}'.format(can_react, can_prod) 68 | results = OrderedDict() 69 | results['remapped_rxn'] = remapped_rxn # remapped_rxn中包含有反应条件 70 | results['fragment'] = frag 71 | results['confidence'] = confidence 72 | results['canonical_rxn'] = can_rxn # can_rxn中无反应条件,只有原子参与贡献的反应物和产物 73 | results['catalyst'] = catalyst 74 | results['solvent'] = solvent 75 | results['reagent'] = reagent 76 | 77 | return results 78 | 79 | 80 | def run_tasks(task): 81 | idx, rxn, solvent, catalyst, reagent, source = task 82 | if pd.isna(solvent): 83 | solvent = '' 84 | if pd.isna(catalyst): 85 | catalyst = '' 86 | if pd.isna(reagent): 87 | reagent = '' 88 | results = remap_and_reassign_condition_role(rxn, solvent, catalyst, reagent) 89 | 90 | return idx, results, source 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument('--gpu', type=int, default=0) 96 | parser.add_argument('--split_group', type=int, default=4) 97 | parser.add_argument('--group', type=int, default=1) 98 | args = parser.parse_args() 99 | 100 | assert args.group <= args.split_group-1 101 | 102 | print('Debug:', debug) 103 | print('Split group: {}'.format(args.split_group)) 104 | print('Group number: {}'.format(args.group)) 105 | print('GPU index: {}'.format(args.gpu)) 106 | 107 | # device = torch.device('cuda:{}'.format(args.gpu) if args.gpu >= 0 else 'cpu') 108 | rxn_mapper = RXNMapper() 109 | source_data_path = '.' 110 | rxn_condition_fname = 'uspto_rxn_condition.csv' 111 | new_database_fpath = os.path.join( 112 | source_data_path, 'uspto_rxn_condition_remapped_and_reassign_condition_role_group_{}.csv'.format(args.group)) 113 | # n_core = 14 114 | 115 | # pool = multiprocessing.Pool(n_core) 116 | if debug: 117 | database = pd.read_csv(os.path.join(source_data_path, rxn_condition_fname), nrows=10001) 118 | else: 119 | database = pd.read_csv(os.path.join(source_data_path, rxn_condition_fname)) 120 | print('All data number: {}'.format(len(database))) 121 | 122 | group_size = len(database) // args.split_group 123 | 124 | if args.group >= args.split_group-1: 125 | database = database.iloc[args.group * group_size:] 126 | else: 127 | database = database.iloc[args.group * group_size:(args.group+1) * group_size] 128 | 129 | print('Caculate index {} to {}'.format(database.index.min(), database.index.max())) 130 | 131 | # rxn_smiles = database['rxn_smiles'].tolist() 132 | 133 | # tasks = [(idx, rxn, database.iloc[idx].solvent, database.iloc[idx].catalyst, database.iloc[idx].reagent, database.iloc[idx].source) 134 | # for idx, rxn in tqdm(enumerate(rxn_smiles), total=len(rxn_smiles))] 135 | header = [ 136 | 'id', 137 | 'source', 138 | 'org_rxn', 139 | 'fragment', 140 | 'remapped_rxn', 141 | 'confidence', 142 | 'canonical_rxn', 143 | 'catalyst', 144 | 'solvent', 145 | 'reagent', 146 | ] 147 | fout, writer = get_writer(new_database_fpath, header=header) 148 | all_results = [] 149 | for row in tqdm(database.itertuples(), total=len(database)): 150 | task = (row.id, row.rxn_smiles, row.solvent, row.catalyst, row.reagent, row.source) 151 | try: 152 | run_results = run_tasks(task) 153 | idx, results, source = run_results 154 | if results: 155 | results['id'] = row.id 156 | results['source'] = source 157 | results['org_rxn'] = row.rxn_smiles 158 | assert len(results) == len(header) 159 | writer.writerow([results[key] for key in header]) 160 | fout.flush() 161 | except Exception as e: 162 | print(e) 163 | pass 164 | # for results in tqdm(pool.imap_unordered(run_tasks, tasks), total=len(tasks)): 165 | # all_results.append(results) 166 | # all_results = Parallel(n_jobs=n_core, verbose=1)( 167 | # delayed(run_tasks)(task) for task in tqdm(tasks)) 168 | fout.close() 169 | # new_database = pd.read_csv(new_database_fpath) 170 | # reset_header = [ 171 | # 'source', 172 | # 'org_rxn', 173 | # 'fragment', 174 | # 'remapped_rxn', 175 | # 'confidence', 176 | # 'canonical_rxn', 177 | # 'catalyst', 178 | # 'solvent', 179 | # 'reagent', 180 | # ] 181 | # new_database = new_database[reset_header] 182 | # new_database.to_csv(new_database_fpath, index=False) 183 | print('Done!') 184 | -------------------------------------------------------------------------------- /preprocess/uspto_script/2.0.clean_up_rxn_condition.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p log 3 | CUDA_VISIBLE_DEVICES=2 nohup python -u 2.0.clean_up_rxn_condition.py --split_group 4 --group 0 > log/get_rxn_condition_uspto_0.log & 4 | CUDA_VISIBLE_DEVICES=3 nohup python -u 2.0.clean_up_rxn_condition.py --split_group 4 --group 1 > log/get_rxn_condition_uspto_1.log & 5 | CUDA_VISIBLE_DEVICES=6 nohup python -u 2.0.clean_up_rxn_condition.py --split_group 4 --group 2 > log/get_rxn_condition_uspto_2.log & 6 | CUDA_VISIBLE_DEVICES=7 nohup python -u 2.0.clean_up_rxn_condition.py --split_group 4 --group 3 > log/get_rxn_condition_uspto_3.log & 7 | -------------------------------------------------------------------------------- /preprocess/uspto_script/2.1.merge_clean_up_rxn_conditon.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import os 3 | import pandas as pd 4 | from tqdm import tqdm 5 | 6 | from utils import calculate_frequency, get_writer 7 | 8 | def write_freq(fpath, freq_data): 9 | fout, writer = get_writer(fpath, ['smiles', 'freq_cnt']) 10 | for data in freq_data: 11 | writer.writerow(data) 12 | fout.flush() 13 | fout.close() 14 | 15 | if __name__ == '__main__': 16 | debug = False 17 | 18 | source_data_path = '.' 19 | merge_data_fname = 'uspto_rxn_condition_remapped_and_reassign_condition_role.csv' 20 | duplicate_removal_fname = 'uspto_rxn_condition_remapped_and_reassign_condition_role_rm_duplicate.csv' 21 | freq_info_path = os.path.join(source_data_path, 'freq_info') 22 | if not os.path.exists(freq_info_path): 23 | os.makedirs(freq_info_path) 24 | if not os.path.exists(os.path.join(source_data_path, merge_data_fname)): 25 | database = pd.DataFrame() 26 | for group_idx in [0, 1, 2, 3]: 27 | database = database.append(pd.read_csv(os.path.join( 28 | source_data_path, f'uspto_rxn_condition_remapped_and_reassign_condition_role_group_{group_idx}.csv'))) 29 | database.reset_index(inplace=True, drop=True) 30 | database.to_csv(os.path.join(os.path.join( 31 | source_data_path, merge_data_fname)), index=False) 32 | else: 33 | if not debug: 34 | database = pd.read_csv(os.path.join( 35 | source_data_path, merge_data_fname)) 36 | else: 37 | database = pd.read_csv(os.path.join( 38 | source_data_path, merge_data_fname), nrows=10000) 39 | 40 | # 按照 remapped_rxn + canonical_rxn + catalyst + solvent + reagent 比照标准去除重复, source, org_rxn, fragment, confidence 这几列取类别中的第一个 41 | info_row_name = ['remapped_rxn', 'canonical_rxn', 'catalyst', 'solvent', 'reagent'] 42 | database_duplicate_removal = database.drop_duplicates(subset=info_row_name, keep='first') 43 | database_duplicate_removal.reset_index(inplace=True, drop=True) 44 | print() 45 | print('catalyst count:', len(set(database_duplicate_removal['catalyst']))) 46 | catalyst_freq = calculate_frequency(database_duplicate_removal['catalyst'].tolist()) 47 | write_freq(os.path.join(freq_info_path, 'catalyst_freq.csv'), catalyst_freq) 48 | print() 49 | print('solvent count:', len(set(database_duplicate_removal['solvent']))) 50 | solvent_freq = calculate_frequency(database_duplicate_removal['solvent'].tolist()) 51 | write_freq(os.path.join(freq_info_path, 'solvent_freq.csv'), solvent_freq) 52 | print() 53 | print('reagent count:', len(set(database_duplicate_removal['reagent']))) 54 | reagent_freq = calculate_frequency(database_duplicate_removal['reagent'].tolist()) 55 | write_freq(os.path.join(freq_info_path, 'reagent_freq.csv'), reagent_freq) 56 | print() 57 | 58 | print('All dataset count:', len(database_duplicate_removal)) 59 | database_duplicate_removal.to_csv(os.path.join(source_data_path, duplicate_removal_fname), index=False) 60 | 61 | -------------------------------------------------------------------------------- /preprocess/uspto_script/3.0.split_condition_and_slect.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, namedtuple 2 | import os 3 | import pandas as pd 4 | from tqdm import tqdm 5 | from rdkit import Chem 6 | from utils import MolRemover, get_mol_charge, list_of_metal_atoms, mol_charge_class 7 | from rdkit import RDLogger 8 | RDLogger.DisableLog('rdApp.*') 9 | # if any(metal in cat for metal in list_of_metals): 10 | # metal_cat_dict[key] = True 11 | 12 | 13 | if __name__ == '__main__': 14 | debug = False 15 | unknown_check = False 16 | split_token = '分' 17 | print('Debug:', debug) 18 | remove_threshold = 100 19 | source_data_path = '.' 20 | duplicate_removal_fname = 'uspto_rxn_condition_remapped_and_reassign_condition_role_rm_duplicate.csv' 21 | freq_info_path = os.path.join(source_data_path, 'freq_info') 22 | 23 | if debug: 24 | database = pd.read_csv(os.path.join(source_data_path, duplicate_removal_fname), nrows=10000) 25 | else: 26 | database = pd.read_csv(os.path.join(source_data_path, duplicate_removal_fname)) 27 | 28 | # 去除条件频率少于阈值以下的数据 29 | print('Remove data with less than remove_threshold...') 30 | print(f'remove_threshold = {remove_threshold}') 31 | print('data number before remove: {}'.format(len(database))) 32 | condition_roles = ['catalyst', 'solvent', 'reagent'] 33 | remove_index = pd.isna(database.index) 34 | for role in condition_roles: 35 | df = pd.read_csv(os.path.join(freq_info_path, f'{role}_freq.csv')) 36 | remove_compund = df[df['freq_cnt'] < remove_threshold]['smiles'] 37 | remove_index = remove_index | database[role].isin(remove_compund) 38 | database_remove_below_threshold = database.loc[~remove_index].reset_index(drop=True) 39 | print('data number after remove: {}'.format(len(database_remove_below_threshold))) 40 | # database_remove_below_threshold.to_csv 41 | remover = MolRemover(defnFilename='../reagent_Ionic_compound.txt') 42 | reagent2index_dict = defaultdict(list) 43 | for idx, reagent in tqdm(enumerate(database_remove_below_threshold.reagent.tolist()), 44 | total=len(database_remove_below_threshold)): 45 | reagent2index_dict[reagent].append(idx) 46 | 47 | if unknown_check: 48 | # 检查带有离子的试剂,删去非电中性的组合 49 | unknown_combination = defaultdict(int) 50 | reagent2single_reagent_dict = {} 51 | for reagent in tqdm(reagent2index_dict): 52 | if pd.isna(reagent): 53 | reagent_namedtuple = namedtuple('reagent', ['known', 'unknown']) 54 | reagent2single_reagent_dict[reagent] = reagent_namedtuple([], []) 55 | continue 56 | reagent_mol = Chem.MolFromSmiles(reagent) 57 | reagent_mol_after_rm, remove_mol = remover.StripMolWithDeleted(reagent_mol, onlyFrags=False) 58 | reagent_smiles_after_rm = Chem.MolToSmiles(reagent_mol_after_rm) 59 | reagent_smiles_after_rm_list = reagent_smiles_after_rm.split('.') 60 | reagent_mols_after_rm = [Chem.MolFromSmiles(x) for x in reagent_smiles_after_rm_list] 61 | known_ionic = [Chem.MolToSmiles(x) for x in remove_mol] 62 | _unknown = [] 63 | reagent_charge_neutral = [] 64 | for mol in reagent_mols_after_rm: 65 | mol_charge_flag, mol_neutralization = get_mol_charge(mol) 66 | if mol_charge_flag != mol_charge_class[2]: 67 | smi = Chem.MolToSmiles(mol) 68 | if smi != '': 69 | _unknown.append(smi) 70 | else: 71 | smi = Chem.MolToSmiles(mol) 72 | if smi != '': 73 | reagent_charge_neutral.append(smi) 74 | reagent_namedtuple = namedtuple('reagent', ['known', 'unknown']) 75 | _known = reagent_charge_neutral + known_ionic 76 | if _unknown: 77 | unknown_combination['.'.join(_unknown)] += 1 78 | reagent2single_reagent_dict[reagent] = reagent_namedtuple(_known, _unknown) 79 | unknown_combination = list(unknown_combination.items()) 80 | unknown_combination.sort(key=lambda x:x[1], reverse=True) 81 | print('Unknown reagent data count:', sum([x[1] for x in unknown_combination])) 82 | with open('../reagent_unknown.txt', 'w', encoding='utf-8') as f: 83 | for line in unknown_combination: 84 | f.write('{},{}\n'.format(line[0], line[1])) 85 | else: 86 | block_unknown_combination = pd.read_csv('../reagent_unknown.txt', header=None) 87 | block_unknown_combination.columns = ['smiles', 'cnt'] 88 | print('Will block {} reagent combination, a total of {} data will be deleted.'.format( 89 | len(block_unknown_combination), block_unknown_combination['cnt'].sum())) 90 | 91 | reagent2single_reagent_dict = defaultdict(list) 92 | for reagent in tqdm(reagent2index_dict): 93 | if pd.isna(reagent): 94 | reagent2single_reagent_dict[reagent].append('') 95 | continue 96 | reagent_mol = Chem.MolFromSmiles(reagent) 97 | reagent_mol_after_rm, remove_mol = remover.StripMolWithDeleted(reagent_mol, onlyFrags=False) 98 | reagent_smiles_after_rm = Chem.MolToSmiles(reagent_mol_after_rm) 99 | reagent_smiles_after_rm_list = reagent_smiles_after_rm.split('.') 100 | reagent_mols_after_rm = [Chem.MolFromSmiles(x) for x in reagent_smiles_after_rm_list] 101 | known_ionic = [Chem.MolToSmiles(x) for x in remove_mol] 102 | _unknown = [] 103 | reagent_charge_neutral = [] 104 | for mol in reagent_mols_after_rm: 105 | mol_charge_flag, mol_neutralization = get_mol_charge(mol) 106 | if mol_charge_flag != mol_charge_class[2]: 107 | smi = Chem.MolToSmiles(mol) 108 | if smi != '': 109 | _unknown.append(smi) 110 | else: 111 | smi = Chem.MolToSmiles(mol) 112 | if smi != '': 113 | reagent_charge_neutral.append(smi) 114 | if _unknown: 115 | _unknown_smiles = '.'.join(_unknown) 116 | assert _unknown_smiles in block_unknown_combination['smiles'].tolist() 117 | _known = reagent_charge_neutral + known_ionic 118 | reagent2single_reagent_dict[reagent] += _known 119 | unknown_drop_index = [] 120 | for reagent in reagent2single_reagent_dict: 121 | if not reagent2single_reagent_dict[reagent]: 122 | unknown_drop_index.extend(reagent2index_dict[reagent]) 123 | reagent2single_reagent_dict = {k: v for k, v in reagent2single_reagent_dict.items() if v} 124 | database_remove_below_threshold = database_remove_below_threshold.drop(unknown_drop_index) 125 | database_remove_below_threshold = database_remove_below_threshold.reset_index(drop=True) 126 | reagent2index_dict = defaultdict(list) 127 | for idx, reagent in tqdm(enumerate(database_remove_below_threshold.reagent.tolist()), 128 | total=len(database_remove_below_threshold)): 129 | reagent2index_dict[reagent].append(idx) 130 | 131 | # 按照 https://pubs.acs.org/doi/10.1021/acscentsci.8b00357 论文所提到的将catalyst>1, solvent>2, reagent>2 的数据排除 132 | print('Exceeding one catalyst, two solvents, or two reagents --> remove') 133 | remove_index_for_excess = pd.isna(database_remove_below_threshold.index) 134 | # reagent > 2 remove! 135 | for reagent in reagent2index_dict: 136 | if len(reagent2single_reagent_dict[reagent]) > 2: 137 | remove_idx = reagent2index_dict[reagent] 138 | for _idx in remove_idx: 139 | remove_index_for_excess[_idx] = True 140 | for _idx, catalyst in enumerate(database_remove_below_threshold.catalyst.tolist()): 141 | if not pd.isna(catalyst): 142 | if len(catalyst.split('.')) > 1: 143 | remove_index_for_excess[_idx] = True 144 | for _idx, solvent in enumerate(database_remove_below_threshold.solvent.tolist()): 145 | if not pd.isna(solvent): 146 | if len(solvent.split('.')) > 2: 147 | remove_index_for_excess[_idx] = True 148 | database_remove_below_threshold = database_remove_below_threshold.loc[~remove_index_for_excess].reset_index(drop=True) 149 | reagent2index_dict = defaultdict(list) 150 | for idx, reagent in tqdm(enumerate(database_remove_below_threshold.reagent.tolist()), 151 | total=len(database_remove_below_threshold)): 152 | reagent2index_dict[reagent].append(idx) 153 | print('Spliting conditions...') 154 | database_remove_below_threshold['catalyst_split'] = database_remove_below_threshold['catalyst'] 155 | database_remove_below_threshold['solvent_split'] = ['']*len(database_remove_below_threshold) 156 | database_remove_below_threshold['reagent_split'] = ['']*len(database_remove_below_threshold) 157 | for reagent in reagent2index_dict: 158 | write_index = reagent2index_dict[reagent] 159 | write_value = split_token.join(reagent2single_reagent_dict[reagent]) 160 | database_remove_below_threshold.loc[write_index, 'reagent_split'] = write_value 161 | 162 | solvent_split = [] 163 | for x in database_remove_below_threshold['solvent'].tolist(): 164 | if pd.isna(x): 165 | solvent_split.append('') 166 | continue 167 | solvent_split.append(split_token.join(x.split('.'))) 168 | database_remove_below_threshold['solvent_split'] = solvent_split 169 | database_remove_below_threshold.to_csv( 170 | os.path.join( 171 | source_data_path, 172 | 'uspto_rxn_condition_remapped_and_reassign_condition_role_rm_duplicate_rm_excess.csv' 173 | ), 174 | index=False) 175 | print('Remaining data in the end:', len(database_remove_below_threshold)) 176 | print('Unique canonical reaction:', len(set(database_remove_below_threshold['canonical_rxn']))) 177 | print('Unique remapped reaction:', len(set(database_remove_below_threshold['remapped_rxn']))) 178 | print('Unique catalyst:', len(set(database_remove_below_threshold['catalyst']))) 179 | print('Unique solvent:', len(set(database_remove_below_threshold['solvent']))) 180 | print('Unique reagent:', len(set(database_remove_below_threshold['reagent']))) 181 | print('Done!') 182 | -------------------------------------------------------------------------------- /preprocess/uspto_script/4.0.split_train_val_test.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import os 3 | import json 4 | import pandas as pd 5 | import random 6 | from tqdm import tqdm 7 | 8 | 9 | if __name__ == '__main__': 10 | debug = False 11 | print('Debug:', debug) 12 | seed = 123 13 | random.seed(seed) 14 | # 数据集: canonical_rxn --> catalyst, solvent1, solvent2, reagent1, reagent2 15 | split_token = '分' 16 | split_frac = (0.8, 0.1, 0.1) # train:val:test 17 | source_data_path = '.' 18 | database_remove_below_threshold_fname = 'uspto_rxn_condition_remapped_and_reassign_condition_role_rm_duplicate_rm_excess.csv' 19 | final_condition_data_path = os.path.join(source_data_path, 'USPTO_condition_final') 20 | if not os.path.exists(final_condition_data_path): 21 | os.makedirs(final_condition_data_path) 22 | if debug: 23 | database = pd.read_csv(os.path.join(source_data_path, database_remove_below_threshold_fname), nrows=10000) 24 | else: 25 | database = pd.read_csv(os.path.join(source_data_path, database_remove_below_threshold_fname)) 26 | 27 | database = database[['id', 'source', 'canonical_rxn', 'catalyst_split', 'solvent_split', 'reagent_split']] 28 | database['catalyst1'] = database['catalyst_split'] 29 | split_solvent = database['solvent_split'].str.split(split_token, 1, expand=True) 30 | database['solvent1'], database['solvent2'] = split_solvent[0], split_solvent[1] 31 | split_reagent = database['reagent_split'].str.split(split_token, 1, expand=True) 32 | database['reagent1'], database['reagent2'] = split_reagent[0], split_reagent[1] 33 | columns = ['id', 'source', 'canonical_rxn', 'catalyst1', 'solvent1', 'solvent2', 'reagent1', 'reagent2'] 34 | database = database[columns] 35 | database_sample = database.sample(frac=1, random_state=seed) 36 | 37 | # Random split (no rxn overlap) 38 | can_rxn2idx_dict = defaultdict(list) 39 | for idx, row in tqdm(database_sample.iterrows(), total=len(database_sample)): 40 | can_rxn2idx_dict[row.canonical_rxn].append(idx) 41 | train_idx, val_idx, test_idx = [], [], [] 42 | can_rxn2idx_dict_items = list(can_rxn2idx_dict.items()) 43 | random.shuffle(can_rxn2idx_dict_items) 44 | all_data_number = len(database_sample) 45 | for rxn, idx_list in tqdm(can_rxn2idx_dict_items): 46 | if len(idx_list) == 1: 47 | if len(test_idx) < split_frac[2] * all_data_number: 48 | test_idx += idx_list 49 | elif len(val_idx) < split_frac[1] * all_data_number: 50 | val_idx += idx_list 51 | else: 52 | train_idx += idx_list 53 | else: 54 | train_idx += idx_list 55 | 56 | database_sample.loc[train_idx, 'dataset'] = 'train' 57 | database_sample.loc[val_idx, 'dataset'] = 'val' 58 | database_sample.loc[test_idx, 'dataset'] = 'test' 59 | database_sample.to_csv(os.path.join(final_condition_data_path, 'USPTO_condition.csv'), index=False) 60 | 61 | # Time split 62 | with open('patent_info.json') as f: 63 | patent_info = json.load(f) 64 | train_idx, val_idx, test_idx = [], [], [] 65 | for idx, patent_id in enumerate(database_sample['source']): 66 | if patent_info[patent_id]['year'] in [2016]: 67 | test_idx.append(idx) 68 | elif patent_info[patent_id]['year'] in [2015]: 69 | val_idx.append(idx) 70 | else: 71 | train_idx.append(idx) 72 | year_data_path = 'USPTO_condition_year' 73 | os.makedirs(year_data_path, exist_ok=True) 74 | train_df = database_sample.iloc[train_idx] 75 | train_df.to_csv(os.path.join(year_data_path, 'USPTO_condition_train.csv'), index=False) 76 | val_df = database_sample.iloc[val_idx] 77 | val_df.to_csv(os.path.join(year_data_path, 'USPTO_condition_val.csv'), index=False) 78 | test_df = database_sample.iloc[test_idx] 79 | test_df.to_csv(os.path.join(year_data_path, 'USPTO_condition_test.csv'), index=False) 80 | 81 | # Grant time split 82 | with open('patent_info.json') as f: 83 | patent_info = json.load(f) 84 | train_idx, val_idx, test_idx = [], [], [] 85 | for idx, patent_id in enumerate(database_sample['source']): 86 | if patent_info[patent_id]['type'] != 'grant': 87 | continue 88 | if patent_info[patent_id]['year'] in [2016]: 89 | test_idx.append(idx) 90 | elif patent_info[patent_id]['year'] in [2015]: 91 | val_idx.append(idx) 92 | else: 93 | train_idx.append(idx) 94 | year_data_path = 'USPTO_condition_grant_year' 95 | os.makedirs(year_data_path, exist_ok=True) 96 | train_df = database_sample.iloc[train_idx] 97 | train_df.to_csv(os.path.join(year_data_path, 'USPTO_condition_train.csv'), index=False) 98 | val_df = database_sample.iloc[val_idx] 99 | val_df.to_csv(os.path.join(year_data_path, 'USPTO_condition_val.csv'), index=False) 100 | test_df = database_sample.iloc[test_idx] 101 | test_df.to_csv(os.path.join(year_data_path, 'USPTO_condition_test.csv'), index=False) 102 | -------------------------------------------------------------------------------- /preprocess/uspto_script/5.0.convert_context_tokens.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from tqdm import tqdm 4 | import pickle 5 | 6 | BOS, EOS, PAD, MASK = '[BOS]', '[EOS]', '[PAD]', '[MASK]' 7 | UNK, SEP = '[UNK]', '[SEP]' 8 | 9 | 10 | def get_condition2idx_mapping(all_condition_data: pd.DataFrame): 11 | col_unique_data = [BOS, EOS, PAD, MASK] 12 | for col in all_condition_data.columns.tolist(): 13 | one_col_unique = list(set(all_condition_data[col].tolist())) 14 | col_unique_data.extend(one_col_unique) 15 | col_unique_data = list(set(col_unique_data)) 16 | col_unique_data.sort() 17 | idx2data = {i: x for i, x in enumerate(col_unique_data)} 18 | data2idx = {x: i for i, x in enumerate(col_unique_data)} 19 | return idx2data, data2idx 20 | 21 | 22 | def get_condition_vocab(all_condition_data: pd.DataFrame): 23 | col_unique_data = [] 24 | for col in all_condition_data.columns.tolist(): 25 | one_col_unique = list(set(all_condition_data[col].tolist())) 26 | col_unique_data.extend(one_col_unique) 27 | col_unique_data = list(set(col_unique_data)) 28 | col_unique_data.sort() 29 | vocab = [PAD, BOS, EOS, MASK, UNK, SEP] + col_unique_data 30 | return vocab 31 | 32 | 33 | if __name__ == '__main__': 34 | calculate_fps = False 35 | convert_aug_data = False 36 | 37 | debug = False 38 | print('Debug:', debug, 'Convert data:', convert_aug_data) 39 | fp_size = 16384 40 | source_data_path = '.' 41 | final_condition_data_path = os.path.join(source_data_path, 'USPTO_condition_final') 42 | if convert_aug_data: 43 | database_fname = 'USPTO_condition_aug_n5.csv' 44 | else: 45 | database_fname = 'USPTO_condition.csv' 46 | 47 | if debug: 48 | database = pd.read_csv(os.path.join(final_condition_data_path, database_fname), nrows=10000) 49 | final_condition_data_path = os.path.join(source_data_path, 'USPTO_condition_final_debug') 50 | if not os.path.exists(final_condition_data_path): 51 | os.makedirs(final_condition_data_path) 52 | database.to_csv(os.path.join(final_condition_data_path, database_fname), index=False) 53 | else: 54 | database = pd.read_csv(os.path.join(final_condition_data_path, database_fname), keep_default_na=False) 55 | 56 | condition_cols = ['catalyst1', 'solvent1', 'solvent2', 'reagent1', 'reagent2'] 57 | 58 | vocab = get_condition_vocab(database[condition_cols]) 59 | vocab_path = os.path.join(final_condition_data_path, 'vocab_smiles.txt') 60 | with open(vocab_path, 'w') as f: 61 | f.write('\n'.join(vocab)) 62 | 63 | all_idx2data, all_data2idx = get_condition2idx_mapping(database[condition_cols]) 64 | all_idx_mapping_data_fpath = os.path.join( 65 | final_condition_data_path, 66 | '{}_alldata_idx.pkl'.format(database_fname.split('.')[0])) 67 | with open(all_idx_mapping_data_fpath, 'wb') as f: 68 | pickle.dump((all_idx2data, all_data2idx), f) 69 | all_condition_labels = [] 70 | for _, row in tqdm(database[condition_cols].iterrows(), total=len(database)): 71 | row = list(row) 72 | row = ['[BOS]'] + row + ['[EOS]'] 73 | all_condition_labels.append([all_data2idx[x] for x in row]) 74 | 75 | all_condition_labels_fpath = os.path.join( 76 | final_condition_data_path, 77 | '{}_condition_labels.pkl'.format(database_fname.split('.')[0])) 78 | with open(all_condition_labels_fpath, 'wb') as f: 79 | pickle.dump((all_condition_labels), f) 80 | print('Done!') 81 | -------------------------------------------------------------------------------- /preprocess/uspto_script/extract_nosmiles.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | # from show_coincidence import read_json 5 | from rdkit import Chem 6 | from rdkit import RDLogger 7 | 8 | from utils import canonicalize_smiles, read_pickle 9 | RDLogger.DisableLog('rdApp.*') 10 | 11 | 12 | 13 | if __name__ == '__main__': 14 | all_data = {} 15 | for fname in os.listdir('../'): 16 | name, ext = os.path.splitext(fname) 17 | fpath = os.path.join('..', fname) 18 | if ext == '.pickle': 19 | all_data['{}'.format(name.replace('_dict', ''))] = read_pickle(fpath) 20 | all_reaxys_name = [] 21 | all_data_reaxys_name_comp = {} 22 | all_data_rm = {} 23 | for k in all_data: 24 | can_err = 0 25 | all_data_rm[k] = {} 26 | all_data_reaxys_name_comp[k] = [] 27 | idx = 0 28 | for comp_idx in all_data[k]: 29 | comp = all_data[k][comp_idx] 30 | if 'Reaxys' not in comp: 31 | if comp == '': 32 | all_data_rm[k][idx] = comp 33 | idx += 1 34 | continue 35 | comp_can = canonicalize_smiles(comp) 36 | if comp_can == '': 37 | can_err += 1 38 | continue 39 | all_data_rm[k][idx] = comp_can 40 | idx += 1 41 | elif 'Reaxys Name' in comp: 42 | all_data_reaxys_name_comp[k].append(comp) 43 | all_reaxys_name.append(comp) 44 | 45 | print(f'{k} canonicalize fail: {can_err}') 46 | print('{}: {}'.format(k, len(all_data_rm[k]))) 47 | print('{}: Reaxys Name {}'.format(k, len(all_data_reaxys_name_comp[k]))) 48 | 49 | with open('../check_data/condition_compound.pkl', 'wb') as f: 50 | pickle.dump(all_data_rm, f) 51 | with open('../check_data/reaxys_name_comp.pkl', 'wb') as f: 52 | pickle.dump(all_data_reaxys_name_comp, f) 53 | 54 | with open('../check_data/condition_compound.json', 'w', encoding='utf-8') as f: 55 | json.dump(all_data_rm, f) 56 | with open('../check_data/reaxys_name_comp.json', 'w', encoding='utf-8') as f: 57 | json.dump(all_data_reaxys_name_comp, f) 58 | 59 | with open('../check_data/all_reaxys_name.txt', 'w', encoding='utf-8') as f: 60 | f.write('\n'.join(all_reaxys_name)) 61 | 62 | -------------------------------------------------------------------------------- /preprocess/uspto_script/gen_grant_corpus.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | 4 | 5 | corpus_df = pd.read_csv('uspto_rxn_corpus.csv') 6 | 7 | grant_df = corpus_df[corpus_df['patent_type'] == 'grant'] 8 | grant_df.to_csv('USPTO_rxn_grant_corpus.csv', index=False) 9 | -------------------------------------------------------------------------------- /preprocess/uspto_script/get_aug_condition_data.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | import os 3 | import pandas as pd 4 | import sys 5 | sys.path.append('../pretrain_data_script/') 6 | from get_pretrain_dataset_augmentation import get_random_rxn 7 | 8 | 9 | ''' 10 | 用于USPTO-Condition数据集的数据增强,在本脚本中我们只增强训练集''' 11 | def generate_one_rxn_condition_aug(tuple_data): 12 | i, len_all, source, rxn, c1, s1, s2, r1, r2, dataset_flag = tuple_data 13 | if i > 0 and i % 1000 == 0: 14 | print(f"Processing {i}th Reaction / {len_all}") 15 | results = [(source, rxn, c1, s1, s2, r1, r2, dataset_flag)] 16 | for i in range(N-1): 17 | random_rxn = get_random_rxn(rxn) 18 | results.append((source, random_rxn, c1, s1, s2, r1, r2, dataset_flag)) 19 | return results 20 | 21 | if __name__ == '__main__': 22 | debug = False 23 | 24 | N = 5 25 | num_workers = 10 26 | 27 | 28 | 29 | source_data_path = '../../dataset/source_dataset/' 30 | final_condition_data_path = os.path.join(source_data_path, 'USPTO_condition_final') 31 | database_fname = 'USPTO_condition.csv' 32 | 33 | 34 | uspto_condition_dataset = pd.read_csv(os.path.join(final_condition_data_path, database_fname)) 35 | 36 | 37 | uspto_condition_train_df = uspto_condition_dataset.loc[uspto_condition_dataset['dataset']=='train'] 38 | uspto_condition_val_test_df = uspto_condition_dataset.loc[uspto_condition_dataset['dataset']!='train'] 39 | if debug: 40 | uspto_condition_train_df = uspto_condition_train_df.sample(3000) 41 | 42 | p = Pool(num_workers) 43 | all_len = len(uspto_condition_train_df) 44 | 45 | augmentation_rxn_condition_train_data = p.imap( 46 | generate_one_rxn_condition_aug, 47 | ((i, all_len, *row.tolist()) for i, (index, row) in enumerate(uspto_condition_train_df.iterrows())) 48 | ) 49 | 50 | p.close() 51 | p.join() 52 | augmentation_rxn_condition_train_data = list(augmentation_rxn_condition_train_data) 53 | augmentation_rxn_condition_train = [] # 一个标准rxn smiles N-1个random smiles 54 | for one in augmentation_rxn_condition_train_data: 55 | augmentation_rxn_condition_train += one 56 | 57 | aug_train_df = pd.DataFrame(augmentation_rxn_condition_train) 58 | aug_train_df.columns = uspto_condition_train_df.columns.tolist() 59 | 60 | aug_uspto_conditon_dataset = aug_train_df.append(uspto_condition_val_test_df).reset_index(drop=True) 61 | aug_uspto_conditon_dataset.to_csv(os.path.join(final_condition_data_path, 'USPTO_condition_aug_n5.csv'), index=False) -------------------------------------------------------------------------------- /preprocess/uspto_script/get_dataset_for_condition.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 半成品,按照已有的反应条件去smiles数据集里面匹配试剂等反应条件分子 3 | ''' 4 | from collections import defaultdict 5 | import pickle 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from rdkit import RDLogger 9 | RDLogger.DisableLog('rdApp.*') 10 | from utils import canonicalize_smiles, covert_to_series 11 | 12 | NONE_PANDDING = '[None]' 13 | 14 | class AssignmentCondition: 15 | def __init__(self, condition_compound_dict) -> None: 16 | self.padding = 0 17 | self.condition_compound_dict = condition_compound_dict 18 | self.condition_compound_df_dict= {} 19 | for k in condition_compound_dict: 20 | new_data = {} 21 | data = condition_compound_dict[k] 22 | max_col = 0 23 | for idx in data: 24 | smi = data[idx] 25 | if smi == '': 26 | smi = NONE_PANDDING 27 | new_data[smi] = smi.split('.') 28 | if max_col < len(new_data[smi]): 29 | max_col = len(new_data[smi]) 30 | for smi in new_data: 31 | new_data[smi] = new_data[smi] + [self.padding] * (max_col - len(new_data[smi])) 32 | self.condition_compound_df_dict[k] = pd.DataFrame(new_data) 33 | 34 | print('Read solvet {}, reagent {}, catalyst {}'.format( 35 | len(condition_compound_dict['s1']), 36 | len(condition_compound_dict['r1']), 37 | len(condition_compound_dict['c1']))) 38 | 39 | def apply(self, reag_ser): 40 | one_condition_dict = defaultdict(list) 41 | for name in ['c1', 'r1', 's1']: 42 | df = self.condition_compound_df_dict[name] 43 | for smi in df.columns.tolist(): 44 | data_series = df[smi][df[smi]!=self.padding] 45 | if data_series.isin(reag_ser).sum() == len(data_series): 46 | one_condition_dict[name].append(smi) 47 | 48 | return one_condition_dict 49 | 50 | if __name__ == '__main__': 51 | debug = True 52 | 53 | condition_compound_fpath = '../check_data/condition_compound_add_name2smiles.pkl' 54 | if not debug: 55 | print('Reading USPTO-1k-tpl...') 56 | uspto_1k_tpl_train = pd.read_csv('../../dataset/source_dataset/uspto_1k_TPL_train_valid.tsv.gzip', compression='gzip', sep='\t', index_col=0) 57 | uspto_1k_tpl_test = pd.read_csv('../../dataset/source_dataset/uspto_1k_TPL_test.tsv.gzip', compression='gzip', sep='\t', index_col=0) 58 | uspto_1k_tpl = uspto_1k_tpl_train.append(uspto_1k_tpl_test) 59 | print('# data: {}'.format(len(uspto_1k_tpl))) 60 | else: 61 | print('Debug...') 62 | print('Reading debug tsv...') 63 | debug_df = pd.read_csv('../../dataset/source_dataset/debug_df.tsv', sep='\t') 64 | uspto_1k_tpl = debug_df 65 | 66 | 67 | uspto_1k_tpl['reagents_series'] = [covert_to_series(canonicalize_smiles(x), none_pandding=NONE_PANDDING) for x in tqdm(uspto_1k_tpl.reagents.tolist())] 68 | with open(condition_compound_fpath, 'rb') as f: 69 | condition_compound_dict = pickle.load(f) 70 | assignment_cls = AssignmentCondition(condition_compound_dict=condition_compound_dict) 71 | for reag_ser in tqdm(uspto_1k_tpl['reagents_series'].tolist()): 72 | assignment_cls.apply(reag_ser) 73 | pass 74 | 75 | -------------------------------------------------------------------------------- /preprocess/uspto_script/get_dummy_model_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import sys 5 | 6 | from tqdm import tqdm 7 | sys.path = [os.path.abspath(os.path.join(os.path.abspath(__file__), '../../../'))] + sys.path 8 | from models.utils import load_dataset 9 | import torch 10 | 11 | dummy_prediction = [ 12 | [0, 0, 0, 0, 0], # [na, na, na, na, na] 13 | [0, 160, 0, 0, 0], # [na, DCM, na, na, na] 14 | [0, 16, 0, 0, 0], # [na, THF, na, na, na] 15 | [0, 0, 0, 85, 0], # [na, na, na, TEA, na] 16 | [0, 0, 0, 202, 0], # [na, na, na, K2CO3, na] 17 | [0, 160, 0, 85, 0], # [na, DCM, na, TEA, na] 18 | [0, 16, 0, 85, 0], # [na, THF, na, TEA, na] 19 | [0, 160, 0, 202, 0], # [na, DCM, na, K2CO3, na] 20 | [0, 16, 0, 202, 0], # [na, THF, na, K2CO3, na] 21 | [298, 0, 0, 0, 0], # [[Pd], na, na, na, na] 22 | ] 23 | 24 | def _get_accuracy_for_one(one_pred, one_ground_truth, topk_get=[1, 3, 5, 10, 15]): 25 | repeat_number = one_pred.size(0) 26 | hit_mat = one_ground_truth.unsqueeze( 27 | 0).repeat(repeat_number, 1) == one_pred 28 | overall_hit_mat = hit_mat.sum(1) == hit_mat.size(1) 29 | topk_hit_df = pd.DataFrame() 30 | for k in topk_get: 31 | hit_mat_k = hit_mat[:k, :] 32 | overall_hit_mat_k = overall_hit_mat[:k] 33 | topk_hit = [] 34 | for col_idx in range(hit_mat.size(1)): 35 | if hit_mat_k[:, col_idx].sum() != 0: 36 | topk_hit.append(1) 37 | else: 38 | topk_hit.append(0) 39 | if overall_hit_mat_k.sum() != 0: 40 | topk_hit.append(1) 41 | else: 42 | topk_hit.append(0) 43 | topk_hit_df[k] = topk_hit 44 | # topk_hit_df.index = ['c1', 's1', 's2', 'r1', 'r2'] 45 | return topk_hit_df 46 | 47 | def _calculate_batch_topk_hit(batch_preds, batch_ground_truth, topk_get=[1, 3, 5, 10, 15]): 48 | ''' 49 | batch_pred <-- tgt_tokens_list 50 | batch_ground_truth <-- inputs['labels'] 51 | ''' 52 | batch_preds = torch.tensor(batch_preds)[:, :, :].to(torch.device('cpu')) 53 | batch_ground_truth = batch_ground_truth[:, 1:-1] 54 | 55 | one_batch_topk_acc_mat = np.zeros((6, 5)) 56 | # topk_get = [1, 3, 5, 10, 15] 57 | for idx in range(batch_preds.size(0)): 58 | topk_hit_df = _get_accuracy_for_one( 59 | batch_preds[idx], batch_ground_truth[idx], topk_get=topk_get) 60 | one_batch_topk_acc_mat += topk_hit_df.values 61 | return one_batch_topk_acc_mat 62 | 63 | if __name__ == '__main__': 64 | topk_get = [1, 3, 5, 10, 15] 65 | source_data_path = '../../dataset/source_dataset/' 66 | uspto_root = os.path.abspath(os.path.join(source_data_path, 'USPTO_condition_final')) 67 | database_df, condition_label_mapping = load_dataset(dataset_root=uspto_root, database_fname='USPTO_condition.csv', use_temperature=False) 68 | 69 | test_df = database_df.loc[database_df['dataset']=='test'] 70 | topk_acc_mat = np.zeros((6, 5)) 71 | for gt in tqdm(test_df['condition_labels'].tolist()): 72 | one_batch_topk_acc_mat = _calculate_batch_topk_hit(batch_preds=[dummy_prediction], batch_ground_truth=torch.tensor([gt])) 73 | topk_acc_mat += one_batch_topk_acc_mat 74 | topk_acc_mat /= len(test_df['condition_labels'].tolist()) 75 | topk_acc_df = pd.DataFrame(topk_acc_mat) 76 | topk_acc_df.columns = [f'top-{k} accuracy' for k in topk_get] 77 | topk_acc_df.index = ['c1', 's1', 's2', 'r1', 'r2', 'overall'] 78 | print(topk_acc_df) 79 | 80 | ''' 81 | top-1 accuracy top-3 accuracy top-5 accuracy top-10 accuracy top-15 accuracy 82 | c1 0.869600 0.869600 0.869600 0.914682 0.914682 83 | s1 0.010180 0.309805 0.309805 0.309805 0.309805 84 | s2 0.808549 0.808549 0.808549 0.808549 0.808549 85 | r1 0.260859 0.260859 0.377216 0.377216 0.377216 86 | r2 0.746515 0.746515 0.746515 0.746515 0.746515 87 | overall 0.000059 0.043085 0.043085 0.066074 0.066074 88 | ''' -------------------------------------------------------------------------------- /preprocess/uspto_script/get_fragment_from_rxn_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import pandas as pd 5 | from rdkit import Chem 6 | from rdkit.Chem import Recap 7 | from rdkit.Chem import AllChem as Chem 8 | from multiprocessing import Pool 9 | from rdkit.Chem import BRICS 10 | from tqdm import tqdm 11 | import pickle 12 | from collections import defaultdict 13 | from get_pretrain_dataset_with_rxn_center import timeout 14 | 15 | 16 | def read_txt(fpath): 17 | with open(fpath, 'r', encoding='utf-8') as f: 18 | return [x.strip() for x in f.readlines()] 19 | 20 | 21 | # def get_frag_from_rxn_recap(rxn): 22 | # # one_set = set() 23 | # react, prod = rxn.split('>>') 24 | # reacts = react.split('.') 25 | # prods = prod.split('.') 26 | # # react_mol, prod_mol = Chem.MolFromSmiles(react), Chem.MolFromSmiles(prod) 27 | # one_fragmet_dict = set() 28 | # mols = [Chem.MolFromSmiles(smi) for smi in reacts+prods] 29 | # for mol in mols: 30 | # hierarch = Recap.RecapDecompose(mol) 31 | # one_fragmet_dict.update(set(hierarch.GetLeaves().keys())) 32 | # # one_fragmet_dict = set(list(hierarch_react.GetLeaves().keys())+list(hierarch_prod.GetLeaves().keys())) 33 | # return one_fragmet_dict 34 | 35 | 36 | def get_frag_from_rxn_brics(rxn): 37 | # one_set = set() 38 | react, prod = rxn.split('>>') 39 | reacts = react.split('.') 40 | prods = prod.split('.') 41 | # react_mol, prod_mol = Chem.MolFromSmiles(react), Chem.MolFromSmiles(prod) 42 | one_fragmet_dict = defaultdict(int) 43 | mols = [Chem.MolFromSmiles(smi) for smi in reacts+prods] 44 | for mol in mols: 45 | try: 46 | with timeout(): 47 | 48 | frags = list(BRICS.BRICSDecompose(mol)) 49 | for frag in frags: 50 | frag = re.sub('\[([0-9]+)\*\]', '*', frag) 51 | if frag not in reacts + prods: 52 | one_fragmet_dict[frag] += 1 53 | except Exception as e: 54 | print(e) 55 | pass 56 | 57 | # one_fragmet_dict.update() 58 | return one_fragmet_dict 59 | 60 | 61 | 62 | if __name__ == '__main__': 63 | 64 | 65 | 66 | if not os.path.exists('../check_data/frag.pkl'): 67 | pretrain_dataset_path = '../../dataset/pretrain_data/' 68 | fnames = ['mlm_rxn_train.txt', 'mlm_rxn_val.txt'] 69 | pretrain_reactions = [] 70 | 71 | for fname in fnames: 72 | pretrain_reactions += read_txt(os.path.join(pretrain_dataset_path, fname)) 73 | 74 | 75 | fragments = defaultdict(int) 76 | pool = Pool(12) 77 | for one_fragmet_dict in tqdm(pool.imap(get_frag_from_rxn_brics, pretrain_reactions), total=len(pretrain_reactions)): 78 | # for rxn in tqdm(pretrain_reactions): 79 | # one_fragmet_dict = get_frag_from_rxn_recap(rxn) 80 | for frag in one_fragmet_dict: 81 | fragments[frag] += one_fragmet_dict[frag] 82 | pool.close() 83 | 84 | with open('../check_data/frag.pkl', 'wb') as f: 85 | pickle.dump(fragments, f) 86 | with open('../check_data/frag.json', 'w', encoding='utf-8') as f: 87 | json.dump(fragments, f) 88 | 89 | else: 90 | with open('../check_data/frag.pkl', 'rb') as f: 91 | fragments = pickle.load(f) 92 | print('Fragments #:', len(fragments)) 93 | fragments_items = list(fragments.items()) 94 | fragments_items.sort(key=lambda x:x[1]) 95 | 96 | top_number = 10000 97 | write_top_frag = ['{},{}'.format(x[0], x[1]) for x in fragments_items if x[1] > top_number] 98 | print(f'fragment count > {top_number}: {len(write_top_frag)}') 99 | with open(f'../check_data/frag_cnt_nubmer_{top_number}.txt', 'w', encoding='utf-8') as f: 100 | f.write('\n'.join(write_top_frag)) 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /preprocess/uspto_script/merge_comp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | 4 | from utils import canonicalize_smiles, read_pickle 5 | from rdkit import RDLogger 6 | RDLogger.DisableLog('rdApp.*') 7 | 8 | 9 | if __name__ == '__main__': 10 | condition_compound = read_pickle('../check_data/condition_compound.pkl') 11 | 12 | merge_condition_compound = {} 13 | 14 | for key in ['c1', 's1', 'r1']: 15 | merge_condition_compound[key] = condition_compound[key] 16 | idx = len(merge_condition_compound[key]) 17 | old_smiles_list = list(merge_condition_compound[key].values()) 18 | print('{} old: {}'.format(key, len(old_smiles_list))) 19 | with open('../check_data/qurey_reaxys_names_{}_smi.txt'.format(key), 'r', encoding='utf-8') as f: 20 | new_smiles_list = list(set([canonicalize_smiles(x.strip()) for x in f.readlines()])) 21 | for smi in new_smiles_list: 22 | if (smi != '') and (smi not in old_smiles_list): 23 | merge_condition_compound[key][idx] = smi 24 | idx += 1 25 | print('{} now: {}'.format(key, len(merge_condition_compound[key]))) 26 | 27 | with open('../check_data/condition_compound_add_name2smiles.pkl', 'wb') as f: 28 | pickle.dump(merge_condition_compound, f) 29 | with open('../check_data/condition_compound_add_name2smiles.json', 'w', encoding='utf-8') as f: 30 | json.dump(merge_condition_compound, f) 31 | -------------------------------------------------------------------------------- /preprocess/uspto_script/uspto_condition.md: -------------------------------------------------------------------------------- 1 | # USPTO-Condition Curation 2 | ``` 3 | cd Parrot/preprocess_script/uspto_script 4 | mkdir ../../dataset/source_dataset/uspto_org_xml/ 5 | ``` 6 | Download the original USPTO reaction dataset from [here](https://figshare.com/articles/dataset/Chemical_reactions_from_US_patents_1976-Sep2016_/5104873). Put `1976_Sep2016_USPTOgrants_cml.7z` and `2001_Sep2016_USPTOapplications_cml.7z` under `../../dataset/source_dataset/uspto_org_xml/`.
7 | 8 | 9 | The running environment of rxnmapper is required when generating USPTO-Condition. This environment is not compatible with parrot_env. Please re-create a virtual environment of rxnmapper. See [rxnmapper github repository](https://github.com/rxn4chemistry/rxnmapper) for details.
10 | 11 | 12 | Then: 13 | ``` 14 | cd ../../dataset/source_dataset/uspto_org_xml/ 15 | 7z x 1976_Sep2016_USPTOgrants_cml.7z 16 | 7z x 2001_Sep2016_USPTOapplications_cml.7z 17 | cd Parrot/preprocess_script/uspto_script 18 | python 1.get_condition_from_uspto.py 19 | sh 2.0.clean_up_rxn_condition.sh 20 | python 2.1.merge_clean_up_rxn_conditon.py 21 | python 3.0.split_condition_and_slect.py 22 | python 4.0.split_train_val_test.py 23 | python 5.0.convert_context_tokens.py 24 | ``` 25 | Done! 26 | -------------------------------------------------------------------------------- /preprocess/uspto_script/utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, namedtuple 2 | import csv 3 | import json 4 | import os 5 | import pickle 6 | from rdkit import Chem 7 | import pandas as pd 8 | from rdkit.Chem.SaltRemover import SaltRemover, InputFormat 9 | 10 | # atapted from https://github.com/Coughy1991/Reaction_condition_recommendation/blob/64f151e302abcb87e0a14088e08e11b4cea8d5ab/scripts/prepare_data_cont_2_rgt_2_slv_1_cat_temp_deploy.py#L658-L690 11 | list_of_metal_atoms = ['Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 12 | 'Cd', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 13 | # 'Cn', 14 | 'Ln', 15 | 'Ce', 16 | 'Pr', 17 | 'Nd', 18 | 'Pm', 19 | 'Sm', 20 | 'Eu', 21 | 'Gd', 22 | 'Tb', 23 | 'Dy', 24 | 'Ho', 25 | 'Er', 26 | 'Tm', 27 | 'Yb', 28 | 'Lu', 29 | 'Ac', 30 | 'Th', 31 | 'Pa', 32 | 'U', 33 | 'Np', 34 | 'Am', 35 | 'Cm', 36 | 'Bk', 37 | 'Cf', 38 | 'Es', 39 | 'Fm', 40 | 'Md', 41 | 'No', 42 | 'Lr', 43 | ] 44 | 45 | mol_charge_class = [ 46 | 'Postive', 47 | 'Negative', 48 | 'Neutral' 49 | ] 50 | 51 | class MolRemover(SaltRemover): 52 | def __init__(self, defnFilename=None, defnData=None, defnFormat=InputFormat.SMARTS): 53 | super().__init__(defnFilename, defnData, defnFormat) 54 | 55 | def _StripMol(self, mol, dontRemoveEverything=False, onlyFrags=False): 56 | 57 | def _applyPattern(m, salt, notEverything, onlyFrags=onlyFrags): 58 | nAts = m.GetNumAtoms() 59 | if not nAts: 60 | return m 61 | res = m 62 | 63 | t = Chem.DeleteSubstructs(res, salt, onlyFrags) 64 | if not t or (notEverything and t.GetNumAtoms() == 0): 65 | return res 66 | res = t 67 | while res.GetNumAtoms() and nAts > res.GetNumAtoms(): 68 | nAts = res.GetNumAtoms() 69 | t = Chem.DeleteSubstructs(res, salt, True) 70 | if notEverything and t.GetNumAtoms() == 0: 71 | break 72 | res = t 73 | return res 74 | 75 | StrippedMol = namedtuple('StrippedMol', ['mol', 'deleted']) 76 | deleted = [] 77 | if dontRemoveEverything and len(Chem.GetMolFrags(mol)) <= 1: 78 | return StrippedMol(mol, deleted) 79 | modified = False 80 | natoms = mol.GetNumAtoms() 81 | for salt in self.salts: 82 | mol = _applyPattern(mol, salt, dontRemoveEverything, onlyFrags) 83 | if natoms != mol.GetNumAtoms(): 84 | natoms = mol.GetNumAtoms() 85 | modified = True 86 | deleted.append(salt) 87 | if dontRemoveEverything and len(Chem.GetMolFrags(mol)) <= 1: 88 | break 89 | if modified and mol.GetNumAtoms() > 0: 90 | Chem.SanitizeMol(mol) 91 | return StrippedMol(mol, deleted) 92 | 93 | def StripMolWithDeleted(self, mol, dontRemoveEverything=False, onlyFrags=False): 94 | return self._StripMol(mol, dontRemoveEverything, onlyFrags=onlyFrags) 95 | 96 | def pickle2json(fname): 97 | name, ext = os.path.splitext(fname) 98 | with open(fname, 'rb') as f: 99 | data = pickle.load(f) 100 | 101 | with open(name + '.json', 'w', encoding='utf-8') as f: 102 | json.dump(data, f) 103 | 104 | 105 | def read_json(fpath): 106 | with open(fpath, 'r', encoding='utf-8') as f: 107 | data = json.load(f) 108 | return data 109 | 110 | 111 | def read_pickle(fname): 112 | with open(fname, 'rb') as f: 113 | return pickle.load(f) 114 | 115 | 116 | def canonicalize_smiles(smi, clear_map=False): 117 | if pd.isna(smi): 118 | return '' 119 | mol = Chem.MolFromSmiles(smi) 120 | if mol is not None: 121 | if clear_map: 122 | [atom.ClearProp('molAtomMapNumber') for atom in mol.GetAtoms()] 123 | return Chem.MolToSmiles(mol) 124 | else: 125 | return '' 126 | 127 | 128 | def covert_to_series(smi, none_pandding): 129 | if pd.isna(smi): 130 | return pd.Series([none_pandding]) 131 | elif smi == '': 132 | return pd.Series([none_pandding]) 133 | else: 134 | return pd.Series(smi.split('.')) 135 | 136 | 137 | def get_writer(output_name, header): 138 | # output_name = os.path.join(cmd_args.save_dir, fname) 139 | fout = open(output_name, 'w') 140 | writer = csv.writer(fout) 141 | writer.writerow(header) 142 | return fout, writer 143 | 144 | 145 | def calculate_frequency(input_list, report=True): 146 | output_dict = defaultdict(int) 147 | for x in input_list: 148 | output_dict[x] += 1 149 | 150 | output_items = list(output_dict.items()) 151 | output_items.sort(key=lambda x: x[1], reverse=True) 152 | 153 | if report: 154 | report_threshold = [10000, 5000, 1000, 500, 100, 50, 1] 155 | for t in report_threshold: 156 | t_list = [x for x in output_items if x[1] > t] 157 | print('Frequency >={} : {}'.format(t, len(t_list))) 158 | 159 | return output_items 160 | 161 | 162 | def get_mol_charge(mol): 163 | mol_neutralization = None 164 | positive = [] 165 | negative = [] 166 | for atom in mol.GetAtoms(): 167 | charge = atom.GetFormalCharge() 168 | if charge > 0: 169 | positive.append(charge) 170 | elif charge < 0: 171 | negative.append(charge) 172 | if len(positive) == 0 and len(negative) == 0: 173 | mol_charge_flag = mol_charge_class[2] 174 | mol_neutralization = False 175 | elif len(positive) != 0 and len(negative) == 0: 176 | mol_charge_flag = mol_charge_class[0] 177 | mol_neutralization = False 178 | elif len(positive) == 0 and len(negative) != 0: 179 | mol_charge_flag = mol_charge_class[1] 180 | mol_neutralization = False 181 | elif len(positive) != 0 and len(negative) != 0: 182 | mol_charge = sum(positive) + sum(negative) 183 | if mol_charge > 0: 184 | mol_charge_flag = mol_charge_class[0] 185 | elif mol_charge < 0: 186 | mol_charge_flag = mol_charge_class[1] 187 | else: 188 | mol_charge_flag = mol_charge_class[2] 189 | mol_neutralization = True 190 | return mol_charge_flag, mol_neutralization 191 | 192 | 193 | 194 | 195 | 196 | if __name__ == '__main__': 197 | # smi = 'CC(C)C[Al+]CC(C)C.O=C(O)CC(O)(CC(=O)O)C(=O)O.[H-]' 198 | # smi = 'CCN(CC)CC.CN(C)C(On1nnc2ccccc21)=[N+](C)C.F[B-](F)(F)F' 199 | smi = 'O.[Al+3].[H-].[H-].[H-].[H-].[Li+].[Na+].[OH-]' 200 | mol = Chem.MolFromSmiles(smi) 201 | # print(get_mol_charge(mol)) 202 | # from rdkit.Chem.SaltRemover import SaltRemover 203 | remover = MolRemover(defnFilename='../reagent_Ionic_compound.txt') 204 | remover.StripMolWithDeleted( 205 | mol, dontRemoveEverything=False, onlyFrags=False) 206 | -------------------------------------------------------------------------------- /retrieve/condition_year.sh: -------------------------------------------------------------------------------- 1 | python retrieve_faiss.py \ 2 | --data_path ../data/USPTO_condition_year \ 3 | --train_file USPTO_condition_train.csv \ 4 | --valid_file USPTO_condition_val.csv \ 5 | --test_file USPTO_condition_test.csv \ 6 | --field canonical_rxn \ 7 | --output_path output/USPTO_condition_year 8 | -------------------------------------------------------------------------------- /retrieve/convert_format.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | input_path = '/data/scratch/zhengkai/neural-retrieval/retrieved/USPTO_condition_MIT_smiles/test.jsonl' 4 | output_path = 'contrastive/test.json' 5 | 6 | output = [] 7 | with open(input_path) as f: 8 | for line in f: 9 | data = json.loads(line) 10 | output.append({ 11 | 'id': data['query_id'], 12 | 'nn': [p['docid'] for p in data['negative_passages']] 13 | }) 14 | 15 | with open(output_path, 'w') as f: 16 | json.dump(output, f) 17 | -------------------------------------------------------------------------------- /retrieve/retrieve.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import json 4 | import rdkit 5 | import rdkit.Chem as Chem 6 | import rdkit.Chem.rdChemReactions as rdChemReactions 7 | import rdkit.DataStructs as DataStructs 8 | from tqdm import tqdm 9 | import multiprocessing 10 | 11 | 12 | def reaction_fingerprint(smiles): 13 | rxn = rdChemReactions.ReactionFromSmarts(smiles) 14 | fp = rdChemReactions.CreateDifferenceFingerprintForReaction(rxn) 15 | return fp 16 | 17 | 18 | def reaction_similarity(smiles1=None, smiles2=None, fp1=None, fp2=None): 19 | if fp1 is None and smiles1: 20 | fp1 = reaction_fingerprint(smiles1) 21 | if fp2 is None and smiles2: 22 | fp2 = reaction_fingerprint(smiles2) 23 | assert fp1 is not None 24 | assert fp2 is not None 25 | return DataStructs.TanimotoSimilarity(fp1, fp2) 26 | 27 | 28 | def reaction_similarity_fp_smiles(fp, smiles): 29 | return reaction_similarity(fp1=fp, smiles2=smiles) 30 | 31 | 32 | def compute_reaction_similarities(test_smiles, train_smiles_list, num_workers=64): 33 | test_fp = reaction_fingerprint(test_smiles) 34 | with multiprocessing.Pool(num_workers) as p: 35 | similarities = p.starmap( 36 | reaction_similarity_fp_smiles, 37 | [(test_fp, smiles) for smiles in train_smiles_list], 38 | chunksize=128 39 | ) 40 | return similarities 41 | 42 | 43 | if __name__ == '__main__': 44 | 45 | train_df = pd.read_csv('data/USPTO_condition_train.csv') 46 | val_df = pd.read_csv('data/USPTO_condition_val.csv') 47 | test_df = pd.read_csv('data/USPTO_condition_test.csv') 48 | 49 | train_smiles_list = train_df['canonical_rxn'] 50 | 51 | results = {} 52 | # with open('test_nn.json') as f: 53 | # results = json.load(f) 54 | 55 | for i, test_row in tqdm(test_df.iterrows()): 56 | if str(i) in results: 57 | continue 58 | test_smiles = test_row['canonical_rxn'] 59 | similarities = compute_reaction_similarities(test_smiles, train_smiles_list, num_workers=64) 60 | ranks = np.argsort(similarities)[::-1][:100].tolist() 61 | results[i] = { 62 | 'rank': ranks, 63 | 'similarity': [similarities[j] for j in ranks] 64 | } 65 | if i + 1 == 100: 66 | break 67 | 68 | with open('test_nn.json', 'w') as f: 69 | json.dump(results, f) 70 | -------------------------------------------------------------------------------- /retrieve/retrieve_faiss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | import numpy as np 5 | import json 6 | import rdkit 7 | import rdkit.Chem as Chem 8 | rdkit.RDLogger.DisableLog('rdApp.*') 9 | import rdkit.Chem.AllChem as AllChem 10 | import rdkit.Chem.rdChemReactions as rdChemReactions 11 | import rdkit.DataStructs as DataStructs 12 | from tqdm import tqdm 13 | import multiprocessing 14 | import faiss 15 | import time 16 | 17 | 18 | def reaction_fingerprint(smiles): 19 | rxn = rdChemReactions.ReactionFromSmarts(smiles) 20 | fp = rdChemReactions.CreateDifferenceFingerprintForReaction(rxn) 21 | return fp 22 | 23 | 24 | def reaction_fingerprint_array(smiles): 25 | fp = reaction_fingerprint(smiles) 26 | array = np.array([x for x in fp]) 27 | return array 28 | 29 | 30 | def compute_reaction_fingerprints(smiles_list, num_workers=64): 31 | with multiprocessing.Pool(num_workers) as p: 32 | fps = p.map(reaction_fingerprint_array, smiles_list, chunksize=128) 33 | return np.array(fps) 34 | 35 | 36 | def morgan_fingerprint(smiles): 37 | try: 38 | mol = Chem.MolFromSmiles(smiles) 39 | fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024) 40 | array = np.zeros((0,), dtype=np.int8) 41 | DataStructs.ConvertToNumpyArray(fp, array) 42 | except: 43 | return morgan_fingerprint('C') 44 | return array 45 | 46 | 47 | def compute_molecule_fingerprints(smiles_list, num_workers=64): 48 | with multiprocessing.Pool(num_workers) as p: 49 | fps = p.map(morgan_fingerprint, smiles_list, chunksize=64) 50 | return np.array(fps) 51 | 52 | 53 | def compare_condition(row1, row2): 54 | for field in ['catalyst1', 'solvent1', 'solvent2', 'reagent1', 'reagent2']: 55 | if type(row1[field]) is not str and type(row2[field]) is not str: 56 | continue 57 | if row1[field] != row2[field]: 58 | return False 59 | return True 60 | 61 | 62 | def index_and_search(train_fps, query_fps): 63 | print('Faiss build index') 64 | d = train_fps.shape[1] 65 | index = faiss.IndexFlatL2(d) 66 | index.add(train_fps) 67 | 68 | print('Faiss nearest neighbor search') 69 | start = time.time() 70 | k = 20 71 | distance, rank = index.search(query_fps, k) 72 | end = time.time() 73 | print(f"{end - start:.2f} s") 74 | return rank 75 | 76 | 77 | if __name__ == '__main__': 78 | 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('--data_path', type=str, default=None, required=True) 81 | parser.add_argument('--train_file', type=str, default=None, required=True) 82 | parser.add_argument('--valid_file', type=str, default=None, required=True) 83 | parser.add_argument('--test_file', type=str, default=None, required=True) 84 | parser.add_argument('--field', type=str, default='canonical_rxn') 85 | parser.add_argument('--before', type=int, default=-1) 86 | parser.add_argument('--output_path', type=str, default=None, required=True) 87 | args = parser.parse_args() 88 | 89 | train_df = pd.read_csv(os.path.join(args.data_path, args.train_file), keep_default_na=False) 90 | val_df = pd.read_csv(os.path.join(args.data_path, args.valid_file), keep_default_na=False) 91 | test_df = pd.read_csv(os.path.join(args.data_path, args.test_file), keep_default_na=False) 92 | 93 | if args.field == 'canonical_rxn': 94 | print('Reaction fingerprint') 95 | fingerprint_fn = compute_reaction_fingerprints 96 | else: 97 | print('Molecule fingerprint') 98 | fingerprint_fn = compute_molecule_fingerprints 99 | 100 | train_fp_file = os.path.join(args.output_path, 'train_fp.pkl') 101 | if not os.path.exists(train_fp_file): 102 | if args.before != -1: 103 | train_df = train_df[train_df['year'] < args.before].reset_index(drop=True) 104 | train_fps = fingerprint_fn(train_df[args.field]) 105 | os.makedirs(args.output_path, exist_ok=True) 106 | with open(train_fp_file, 'wb') as f: 107 | np.save(f, train_fps) 108 | else: 109 | with open(train_fp_file, 'rb') as f: 110 | train_fps = np.load(f) 111 | 112 | train_id = train_df['id'] 113 | 114 | query_fps, query_id = train_fps, train_df['id'] 115 | rank = index_and_search(train_fps, query_fps) 116 | result = [{'id': query_id[i], 'nn': [train_id[n] for n in nn]} for i, nn in enumerate(rank)] 117 | with open(os.path.join(args.output_path, 'train.json'), 'w') as f: 118 | json.dump(result, f) 119 | 120 | query_fps, query_id = fingerprint_fn(val_df[args.field]), val_df['id'] 121 | rank = index_and_search(train_fps, query_fps) 122 | result = [{'id': query_id[i], 'nn': [train_id[n] for n in nn]} for i, nn in enumerate(rank)] 123 | with open(os.path.join(args.output_path, 'val.json'), 'w') as f: 124 | json.dump(result, f) 125 | 126 | query_fps, query_id = fingerprint_fn(test_df[args.field]), test_df['id'] 127 | rank = index_and_search(train_fps, query_fps) 128 | result = [{'id': query_id[i], 'nn': [train_id[n] for n in nn]} for i, nn in enumerate(rank)] 129 | with open(os.path.join(args.output_path, 'test.json'), 'w') as f: 130 | json.dump(result, f) 131 | 132 | if args.field == 'canonical_rxn': 133 | cnt = {x: 0 for x in [1, 3, 5, 10, 15]} 134 | for i, nn in enumerate(rank): 135 | test_row = test_df.iloc[i] 136 | train_rows = [train_df.iloc[n] for n in nn] 137 | hit_map = [compare_condition(test_row, train_row) for train_row in train_rows] 138 | for x in cnt: 139 | cnt[x] += np.any(hit_map[:x]) 140 | 141 | print(cnt, len(test_df)) 142 | for x in cnt: 143 | print(f"Top-{x}: {cnt[x] / len(test_df):.4f}", end=' ') 144 | print() 145 | -------------------------------------------------------------------------------- /retrieve/retro.sh: -------------------------------------------------------------------------------- 1 | python retrieve_faiss.py \ 2 | --data_path ../data/USPTO_50K/matched1 \ 3 | --train_file train.csv \ 4 | --valid_file valid.csv \ 5 | --test_file test.csv \ 6 | --field product_smiles \ 7 | --output_path output/USPTO_50K 8 | 9 | #python retrieve_faiss.py \ 10 | # --data_path ../data/USPTO_50K/matched1 \ 11 | # --train_file ../../USPTO_rxn_smiles.csv \ 12 | # --valid_file valid.csv \ 13 | # --test_file test.csv \ 14 | # --field product_smiles \ 15 | # --output_path output/USPTO_50K/full 16 | -------------------------------------------------------------------------------- /retrieve/retro_year.sh: -------------------------------------------------------------------------------- 1 | #python retrieve_faiss.py \ 2 | # --data_path ../data/USPTO_50K_year \ 3 | # --train_file ../USPTO_rxn_smiles.csv \ 4 | # --valid_file valid.csv \ 5 | # --test_file test.csv \ 6 | # --field product_smiles \ 7 | # --output_path output/USPTO_50K_year/full 8 | 9 | python retrieve_faiss.py \ 10 | --data_path ../data/USPTO_50K_year \ 11 | --train_file ../USPTO_rxn_smiles.csv \ 12 | --before 2012 \ 13 | --valid_file valid.csv \ 14 | --test_file test.csv \ 15 | --field product_smiles \ 16 | --output_path output/USPTO_50K_year/corpus_before_2012 17 | 18 | -------------------------------------------------------------------------------- /scripts/train_RCR.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NUM_GPUS_PER_NODE=4 4 | BATCH_SIZE=128 5 | ACCUM_STEP=1 6 | 7 | SAVE_PATH=output/RCR_textreact 8 | NN_PATH=data/Tevatron_output/RCR/ 9 | 10 | mkdir -p ${SAVE_PATH} 11 | 12 | NCCL_P2P_DISABLE=1 python main.py \ 13 | --task condition \ 14 | --encoder allenai/scibert_scivocab_uncased \ 15 | --decoder textreact/configs/bert_l6.json \ 16 | --encoder_pretrained \ 17 | --data_path data/RCR/ \ 18 | --train_file train.csv \ 19 | --valid_file val.csv \ 20 | --test_file test.csv \ 21 | --vocab_file textreact/vocab/vocab_condition.txt \ 22 | --corpus_file data/USPTO_rxn_corpus.csv \ 23 | --nn_path ${NN_PATH} \ 24 | --train_nn_file train_rank.json \ 25 | --valid_nn_file val_rank.json \ 26 | --test_nn_file test_rank.json \ 27 | --num_neighbors 3 \ 28 | --use_gold_neighbor \ 29 | --save_path ${SAVE_PATH} \ 30 | --max_length 512 \ 31 | --shuffle_smiles \ 32 | --mlm --mlm_ratio 0.15 --mlm_layer mlp --mlm_lambda 0.1 \ 33 | --lr 1e-4 \ 34 | --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \ 35 | --gradient_accumulation_steps ${ACCUM_STEP} \ 36 | --epochs 20 \ 37 | --warmup 0.02 \ 38 | --do_train --do_valid --do_test \ 39 | --num_beams 15 \ 40 | --precision 16-mixed \ 41 | --gpus ${NUM_GPUS_PER_NODE} 42 | -------------------------------------------------------------------------------- /scripts/train_RCR_TS.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NUM_GPUS_PER_NODE=4 4 | BATCH_SIZE=128 5 | ACCUM_STEP=1 6 | 7 | SAVE_PATH=output/RCR_TS_textreact 8 | NN_PATH=data/Tevatron_output/RCR_TS/ 9 | 10 | mkdir -p ${SAVE_PATH} 11 | 12 | NCCL_P2P_DISABLE=1 python main.py \ 13 | --task condition \ 14 | --encoder allenai/scibert_scivocab_uncased \ 15 | --decoder textreact/configs/bert_l6.json \ 16 | --encoder_pretrained \ 17 | --data_path data/RCR_TS/ \ 18 | --train_file train.csv \ 19 | --valid_file val.csv \ 20 | --test_file test.csv \ 21 | --vocab_file textreact/vocab/vocab_condition.txt \ 22 | --corpus_file data/USPTO_rxn_corpus.csv \ 23 | --nn_path ${NN_PATH} \ 24 | --train_nn_file train_rank.json \ 25 | --valid_nn_file val_rank_full.json \ 26 | --test_nn_file test_rank_full.json \ 27 | --num_neighbors 3 \ 28 | --use_gold_neighbor \ 29 | --save_path ${SAVE_PATH} \ 30 | --max_length 512 \ 31 | --shuffle_smiles \ 32 | --mlm --mlm_ratio 0.15 --mlm_layer mlp --mlm_lambda 0.1 \ 33 | --lr 1e-4 \ 34 | --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \ 35 | --gradient_accumulation_steps ${ACCUM_STEP} \ 36 | --epochs 20 \ 37 | --warmup 0.02 \ 38 | --do_train --do_valid --do_test \ 39 | --num_beams 15 \ 40 | --precision 16-mixed \ 41 | --gpus ${NUM_GPUS_PER_NODE} 42 | -------------------------------------------------------------------------------- /scripts/train_RetroSyn_tb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NUM_GPUS_PER_NODE=4 4 | BATCH_SIZE=128 5 | ACCUM_STEP=1 6 | 7 | SAVE_PATH=output/RetroSyn_tb_textreact 8 | NN_PATH=data/Tevatron_output/RetroSyn/ 9 | 10 | mkdir -p ${SAVE_PATH} 11 | 12 | NCCL_P2P_DISABLE=1 python main.py \ 13 | --task retro \ 14 | --template_based \ 15 | --shuffle_smiles \ 16 | --encoder allenai/scibert_scivocab_uncased \ 17 | --encoder_pretrained \ 18 | --encoder_tokenizer smiles_text \ 19 | --vocab_file textreact/vocab/vocab_smiles.txt \ 20 | --data_path data/RetroSyn/ \ 21 | --template_path data/RetroSyn/template_based \ 22 | --train_file train.csv \ 23 | --valid_file valid.csv \ 24 | --test_file test.csv \ 25 | --corpus_file data/USPTO_rxn_corpus.csv \ 26 | --nn_path ${NN_PATH} \ 27 | --train_nn_file train_rank.json \ 28 | --valid_nn_file valid_rank.json \ 29 | --test_nn_file test_rank.json \ 30 | --num_neighbors 3 \ 31 | --use_gold_neighbor \ 32 | --random_neighbor_ratio 0.2 \ 33 | --save_path ${SAVE_PATH} \ 34 | --load_ckpt best.ckpt \ 35 | --max_length 512 \ 36 | --max_dec_length 160 \ 37 | --mlm --mlm_ratio 0.15 --mlm_layer mlp --mlm_lambda 0.1 \ 38 | --lr 1e-4 \ 39 | --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \ 40 | --gradient_accumulation_steps ${ACCUM_STEP} \ 41 | --test_batch_size 32 \ 42 | --epochs 200 \ 43 | --eval_per_epoch 10 \ 44 | --warmup 0.02 \ 45 | --do_train --do_valid --do_test \ 46 | --num_beams 20 \ 47 | --precision 16 \ 48 | --gpus ${NUM_GPUS_PER_NODE} 49 | -------------------------------------------------------------------------------- /scripts/train_RetroSyn_tb_TS.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NUM_GPUS_PER_NODE=4 4 | BATCH_SIZE=128 5 | ACCUM_STEP=1 6 | 7 | SAVE_PATH=output/RetroSyn_tb_TS_textreact 8 | NN_PATH=data/Tevatron_output/RetroSyn_TS/ 9 | 10 | mkdir -p ${SAVE_PATH} 11 | 12 | NCCL_P2P_DISABLE=1 python main.py \ 13 | --task retro \ 14 | --template_based \ 15 | --shuffle_smiles \ 16 | --encoder allenai/scibert_scivocab_uncased \ 17 | --encoder_pretrained \ 18 | --encoder_tokenizer smiles_text \ 19 | --vocab_file textreact/vocab/vocab_smiles.txt \ 20 | --data_path data/RetroSyn_TS/ \ 21 | --template_path data/RetroSyn_TS/template_based \ 22 | --train_file train.csv \ 23 | --valid_file valid.csv \ 24 | --test_file test.csv \ 25 | --corpus_file data/USPTO_rxn_corpus.csv \ 26 | --nn_path ${NN_PATH} \ 27 | --train_nn_file train_rank.json \ 28 | --valid_nn_file valid_rank_full.json \ 29 | --test_nn_file test_rank_full.json \ 30 | --num_neighbors 3 \ 31 | --use_gold_neighbor \ 32 | --random_neighbor_ratio 0.2 \ 33 | --save_path ${SAVE_PATH} \ 34 | --load_ckpt best.ckpt \ 35 | --max_length 512 \ 36 | --max_dec_length 160 \ 37 | --mlm --mlm_ratio 0.15 --mlm_layer mlp --mlm_lambda 0.1 \ 38 | --lr 1e-4 \ 39 | --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \ 40 | --gradient_accumulation_steps ${ACCUM_STEP} \ 41 | --test_batch_size 32 \ 42 | --epochs 200 \ 43 | --eval_per_epoch 10 \ 44 | --warmup 0.02 \ 45 | --do_train --do_valid --do_test \ 46 | --num_beams 20 \ 47 | --precision 16 \ 48 | --gpus ${NUM_GPUS_PER_NODE} 49 | -------------------------------------------------------------------------------- /scripts/train_RetroSyn_tf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NUM_GPUS_PER_NODE=4 4 | BATCH_SIZE=128 5 | ACCUM_STEP=1 6 | 7 | SAVE_PATH=output/RetroSyn_tf_textreact 8 | NN_PATH=data/Tevatron_output/RetroSyn/ 9 | 10 | mkdir -p ${SAVE_PATH} 11 | 12 | NCCL_P2P_DISABLE=1 python main.py \ 13 | --task retro \ 14 | --encoder allenai/scibert_scivocab_uncased \ 15 | --decoder textreact/configs/bert_l6.json \ 16 | --encoder_pretrained \ 17 | --vocab_file textreact/vocab/vocab_smiles.txt \ 18 | --data_path data/RetroSyn/ \ 19 | --train_file train.csv \ 20 | --valid_file valid.csv \ 21 | --test_file test.csv \ 22 | --corpus_file data/USPTO_rxn_corpus.csv \ 23 | --nn_path ${NN_PATH} \ 24 | --train_nn_file train_rank.json \ 25 | --valid_nn_file valid_rank.json \ 26 | --test_nn_file test_rank.json \ 27 | --num_neighbors 3 \ 28 | --use_gold_neighbor \ 29 | --random_neighbor_ratio 0.2 \ 30 | --save_path ${SAVE_PATH} \ 31 | --load_ckpt best.ckpt \ 32 | --max_length 512 \ 33 | --max_dec_length 160 \ 34 | --mlm --mlm_ratio 0.15 --mlm_layer mlp --mlm_lambda 0.1 \ 35 | --lr 1e-4 \ 36 | --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \ 37 | --gradient_accumulation_steps ${ACCUM_STEP} \ 38 | --test_batch_size 32 \ 39 | --epochs 200 \ 40 | --eval_per_epoch 25 \ 41 | --warmup 0.02 \ 42 | --do_train --do_valid --do_test \ 43 | --num_beams 20 \ 44 | --precision 16 \ 45 | --gpus ${NUM_GPUS_PER_NODE} 46 | -------------------------------------------------------------------------------- /scripts/train_RetroSyn_tf_TS.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NUM_GPUS_PER_NODE=4 4 | BATCH_SIZE=128 5 | ACCUM_STEP=1 6 | 7 | SAVE_PATH=output/RetroSyn_tf_TS_textreact 8 | NN_PATH=data/Tevatron_output/RetroSyn_TS/ 9 | 10 | mkdir -p ${SAVE_PATH} 11 | 12 | NCCL_P2P_DISABLE=1 python main.py \ 13 | --task retro \ 14 | --encoder allenai/scibert_scivocab_uncased \ 15 | --decoder textreact/configs/bert_l6.json \ 16 | --encoder_pretrained \ 17 | --vocab_file textreact/vocab/vocab_smiles.txt \ 18 | --data_path data/RetroSyn_TS/ \ 19 | --train_file train.csv \ 20 | --valid_file valid.csv \ 21 | --test_file test.csv \ 22 | --corpus_file data/USPTO_rxn_corpus.csv \ 23 | --nn_path data/Tevatron_output/RetroSyn_TS/ \ 24 | --nn_path ${NN_PATH} \ 25 | --train_nn_file train_rank.json \ 26 | --valid_nn_file valid_rank_full.json \ 27 | --test_nn_file test_rank_full.json \ 28 | --num_neighbors 3 \ 29 | --use_gold_neighbor \ 30 | --random_neighbor_ratio 0.2 \ 31 | --save_path ${SAVE_PATH} \ 32 | --load_ckpt best.ckpt \ 33 | --max_length 512 \ 34 | --max_dec_length 160 \ 35 | --mlm --mlm_ratio 0.15 --mlm_layer mlp --mlm_lambda 0.1 \ 36 | --lr 1e-4 \ 37 | --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \ 38 | --gradient_accumulation_steps ${ACCUM_STEP} \ 39 | --test_batch_size 32 \ 40 | --epochs 200 \ 41 | --eval_per_epoch 25 \ 42 | --warmup 0.02 \ 43 | --do_train --do_valid --do_test \ 44 | --num_beams 20 \ 45 | --precision 16 \ 46 | --gpus ${NUM_GPUS_PER_NODE} 47 | -------------------------------------------------------------------------------- /textreact/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomas0809/textreact/feb62d868a627d293997a56a72f50420377c59b4/textreact/__init__.py -------------------------------------------------------------------------------- /textreact/configs/bert_l6.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "layer_norm_eps": 1e-05, 9 | "max_position_embeddings": 512, 10 | "model_type": "roberta", 11 | "num_attention_heads": 12, 12 | "num_hidden_layers": 6, 13 | "type_vocab_size": 1, 14 | "vocab_size": 600, 15 | "bos_token_id": 12, 16 | "eos_token_id": 13, 17 | "pad_token_id": 0 18 | } -------------------------------------------------------------------------------- /textreact/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import repeat 3 | import numpy as np 4 | import pandas as pd 5 | import multiprocessing 6 | 7 | import rdkit 8 | from rdkit import Chem 9 | rdkit.RDLogger.DisableLog('rdApp.*') 10 | 11 | from .dataset import CONDITION_COLS 12 | from .template_decoder import get_pred_smiles_from_templates 13 | 14 | 15 | def evaluate_reaction_condition(prediction, data_df): 16 | cnt = {x: 0 for x in [1, 3, 5, 10, 15]} 17 | for i, output in prediction.items(): 18 | label = data_df.loc[i, CONDITION_COLS].tolist() 19 | hit_map = [pred == label for pred in output['prediction']] 20 | for x in cnt: 21 | cnt[x] += np.any(hit_map[:x]) 22 | num_example = len(data_df) 23 | accuracy = {x: cnt[x] / num_example for x in cnt} 24 | return accuracy 25 | 26 | 27 | def canonical_smiles(smiles): 28 | try: 29 | canon_smiles = Chem.CanonSmiles(smiles) 30 | except: 31 | canon_smiles = smiles 32 | return canon_smiles 33 | 34 | 35 | def _compare_pred_and_gold(pred, gold): 36 | pred = [canonical_smiles(smiles) for smiles in pred] 37 | for i, smiles in enumerate(pred): 38 | if smiles == gold: 39 | return i 40 | return 100000 41 | 42 | 43 | def evaluate_retrosynthesis(prediction, data_df, top_k, template_based=False, template_path=None, num_workers=16): 44 | num_example = len(data_df) 45 | with multiprocessing.Pool(num_workers) as p: 46 | gold_list = p.map(canonical_smiles, data_df['reactant_smiles']) 47 | if template_based: 48 | pred_prob_list = [[(*prediction, score) 49 | for prediction, score in zip(prediction[i]['prediction'], prediction[i]['score'])] 50 | for i in range(num_example)] 51 | atom_templates = pd.read_csv(os.path.join(template_path, 'atom_templates.csv')) 52 | bond_templates = pd.read_csv(os.path.join(template_path, 'bond_templates.csv')) 53 | template_infos = pd.read_csv(os.path.join(template_path, 'template_infos.csv')) 54 | atom_templates = {atom_templates['Class'][i]: atom_templates['Template'][i] for i in atom_templates.index} 55 | bond_templates = {bond_templates['Class'][i]: bond_templates['Template'][i] for i in bond_templates.index} 56 | template_infos = {template_infos['Template'][i]: { 57 | 'edit_site': eval(template_infos['edit_site'][i]), 58 | 'change_H': eval(template_infos['change_H'][i]), 59 | 'change_C': eval(template_infos['change_C'][i]), 60 | 'change_S': eval(template_infos['change_S'][i]) 61 | } for i in template_infos.index} 62 | pred_list = p.starmap(get_pred_smiles_from_templates, 63 | zip(pred_prob_list, data_df['product_smiles'], 64 | repeat(atom_templates), repeat(bond_templates), repeat(template_infos), repeat(top_k))) 65 | else: 66 | pred_list = [prediction[i]['prediction'] for i in range(num_example)] 67 | indices = p.starmap(_compare_pred_and_gold, [(p, g) for p, g in zip(pred_list, gold_list)]) 68 | accuracy = {} 69 | for x in [1, 2, 3, 5, 10, 20]: 70 | accuracy[x] = sum([idx < x for idx in indices]) / num_example 71 | return accuracy 72 | -------------------------------------------------------------------------------- /textreact/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pad_sequence 4 | from transformers import get_scheduler, EncoderDecoderModel, EncoderDecoderConfig, AutoTokenizer, AutoConfig, AutoModel 5 | from transformers.models.bert.modeling_bert import BertLMPredictionHead 6 | from transformers.utils import ModelOutput 7 | from . import utils 8 | 9 | 10 | def get_model(args, enc_tokenizer=None, dec_tokenizer=None): 11 | if args.template_based: 12 | assert args.decoder is None and not args.decoder_pretrained 13 | if args.encoder_pretrained: 14 | encoder = AutoModel.from_pretrained(pretrained_model_name_or_path=args.encoder) 15 | else: 16 | encoder_config = AutoConfig.from_pretrained(args.encoder) 17 | encoder = AutoModel.from_config(encoder_config) 18 | template_head = TemplatePredictionHead(encoder.config.hidden_size, len(dec_tokenizer[0]), len(dec_tokenizer[1])) 19 | model = TemplateBasedModel(encoder, template_head) 20 | else: 21 | if args.encoder_pretrained and args.decoder_pretrained: 22 | model = EncoderDecoderModel.from_encoder_decoder_pretrained( 23 | encoder_pretrained_model_name_or_path=args.encoder, decoder_pretrained_model_name_or_path=args.decoder) 24 | else: 25 | encoder_config = AutoConfig.from_pretrained(args.encoder) 26 | decoder_config = AutoConfig.from_pretrained(args.decoder) 27 | config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config) 28 | model = EncoderDecoderModel(config=config) 29 | if args.encoder_pretrained: 30 | encoder = AutoModel.from_pretrained(args.encoder) 31 | model.encoder = encoder 32 | encoder = model.encoder 33 | if args.max_length > encoder.config.max_position_embeddings: 34 | utils.expand_position_embeddings(encoder, args.max_length) 35 | if args.encoder_tokenizer == 'smiles_text': 36 | utils.expand_word_embeddings(encoder, len(enc_tokenizer)) 37 | return model 38 | 39 | 40 | def get_mlm_head(args, model): 41 | if args.mlm_layer == 'linear': 42 | mlm_head = nn.Linear(model.encoder.config.hidden_size, model.encoder.config.vocab_size) 43 | elif args.mlm_layer == 'mlp': 44 | mlm_head = BertLMPredictionHead(model.encoder.config) 45 | else: 46 | raise NotImplementedError 47 | return mlm_head 48 | 49 | 50 | class TemplateBasedModel(nn.Module): 51 | def __init__(self, encoder, template_head): 52 | super().__init__() 53 | self.encoder = encoder 54 | self.template_head = template_head 55 | 56 | def forward(self, **inputs): 57 | encoder_output = self.encoder(**{k: v for k, v in inputs.items() 58 | if not k.startswith('decoder_') and k not in ['atom_indices', 'bonds']}) 59 | atom_hidden_states = [] 60 | for hidden_states, atom_indices in zip(encoder_output.last_hidden_state, inputs['atom_indices']): 61 | atom_hidden_states.append(hidden_states[atom_indices]) 62 | atom_hidden_states = pad_sequence(atom_hidden_states, batch_first=True) 63 | return ModelOutput(logits=self.template_head(atom_hidden_states), encoder_last_hidden_state=encoder_output.last_hidden_state) 64 | 65 | 66 | class TemplatePredictionHead(nn.Module): 67 | def __init__(self, input_size, num_atom_templates, num_bond_templates): 68 | super().__init__() 69 | self.atom_template_head = nn.Linear(input_size, num_atom_templates + 1) 70 | self.bond_template_head = BondTemplatePredictor(input_size, num_bond_templates) 71 | 72 | def forward(self, input_states): 73 | """ 74 | Input: [B x] L x d_in 75 | Output: [B x] L x n_a, [B x] L x L x n_b 76 | """ 77 | return self.atom_template_head(input_states), self.bond_template_head(input_states) 78 | 79 | 80 | class BondTemplatePredictor(nn.Module): 81 | def __init__(self, input_size, num_bond_templates): 82 | super().__init__() 83 | self.linear = nn.Linear(2 * input_size, num_bond_templates + 1) 84 | 85 | def forward(self, input_states): 86 | concat_pair_shape = input_states.shape[:-1] + input_states.shape[-2:] 87 | concat_pairs = torch.cat((input_states.unsqueeze(-2).expand(concat_pair_shape), 88 | input_states.unsqueeze(-3).expand(concat_pair_shape)), 89 | dim=-1) 90 | return self.linear(concat_pairs) 91 | -------------------------------------------------------------------------------- /textreact/template_decoder.py: -------------------------------------------------------------------------------- 1 | """ Adapted from Decode_predictions.py and template_decoder.py """ 2 | 3 | import os, re, copy 4 | import pandas as pd 5 | from collections import defaultdict 6 | 7 | import rdkit 8 | from rdkit import Chem, RDLogger 9 | from rdkit.Chem import rdChemReactions 10 | from rdkit.Chem.rdchem import ChiralType 11 | from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers 12 | 13 | RDLogger.DisableLog('rdApp.*') 14 | 15 | chiral_type_map = {ChiralType.CHI_UNSPECIFIED : -1, ChiralType.CHI_TETRAHEDRAL_CW: 1, ChiralType.CHI_TETRAHEDRAL_CCW: 2} 16 | chiral_type_map_inv = {v:k for k, v in chiral_type_map.items()} 17 | 18 | a, b = 'a', 'b' 19 | 20 | def get_pred_smiles_from_templates(template_preds, product, atom_templates, bond_templates, template_infos, top_k): 21 | smiles_preds = [] 22 | for prediction in template_preds: 23 | mol, pred_site, template, template_info, score = read_prediction(product, prediction, atom_templates, bond_templates, template_infos) 24 | local_template = '>>'.join(['(%s)' % smarts for smarts in template.split('_')[0].split('>>')]) 25 | decoded_smiles = decode_localtemplate(mol, pred_site, local_template, template_info) 26 | try: 27 | decoded_smiles = decode_localtemplate(mol, pred_site, local_template, template_info) 28 | if decoded_smiles == None or decoded_smiles in smiles_preds: 29 | continue 30 | except Exception as e: 31 | continue 32 | smiles_preds.append(decoded_smiles) 33 | 34 | if len (smiles_preds) >= top_k: 35 | break 36 | 37 | return smiles_preds 38 | 39 | def get_isomers(smi): 40 | mol = Chem.MolFromSmiles(smi) 41 | isomers = tuple(EnumerateStereoisomers(mol)) 42 | isomers_smi = [Chem.MolToSmiles(x, isomericSmiles=True) for x in isomers] 43 | return isomers_smi 44 | 45 | def get_MaxFrag(smiles): 46 | return max(smiles.split('.'), key=len) 47 | 48 | def isomer_match(preds, reac): 49 | reac_isomers = get_isomers(reac) 50 | for k, pred in enumerate(preds): 51 | try: 52 | pred_isomers = get_isomers(pred) 53 | if set(pred_isomers).issubset(set(reac_isomers)) or set(reac_isomers).issubset(set(pred_isomers)): 54 | return k+1 55 | except Exception as e: 56 | pass 57 | return -1 58 | 59 | def get_idx_map(mol): 60 | for atom in mol.GetAtoms(): 61 | atom.SetAtomMapNum(atom.GetIdx()) 62 | smiles = Chem.MolToSmiles(mol) 63 | num_map = {} 64 | for i, s in enumerate(smiles.split('.')): 65 | m = Chem.MolFromSmiles(s) 66 | for atom in m.GetAtoms(): 67 | num_map[atom.GetAtomMapNum()] = atom.GetIdx() 68 | return num_map 69 | 70 | def get_possible_map(pred_site, change_info): 71 | possible_maps = [] 72 | if type(pred_site) == type(0): 73 | for edit_type, edits in change_info['edit_site'].items(): 74 | if edit_type not in ['A', 'R']: 75 | continue 76 | for edit in edits: 77 | possible_maps.append({edit: pred_site}) 78 | else: 79 | for edit_type, edits in change_info['edit_site'].items(): 80 | if edit_type not in ['B', 'C']: 81 | continue 82 | for edit in edits: 83 | possible_maps.append({e:p for e, p in zip(edit, pred_site)}) 84 | return possible_maps 85 | 86 | def check_idx_match(mols, possible_maps): 87 | matched_maps = [] 88 | found_map = {} 89 | for mol in mols: 90 | for atom in mol.GetAtoms(): 91 | if atom.HasProp('old_mapno') and atom.HasProp('react_atom_idx'): 92 | found_map[int(atom.GetProp('old_mapno'))] = int(atom.GetProp('react_atom_idx')) 93 | for possible_map in possible_maps: 94 | if possible_map.items() <= found_map.items(): 95 | matched_maps.append(found_map) 96 | return matched_maps 97 | 98 | def fix_aromatic(mol): 99 | for atom in mol.GetAtoms(): 100 | if not atom.IsInRing() and atom.GetIsAromatic(): 101 | atom.SetIsAromatic(False) 102 | 103 | for bond in mol.GetBonds(): 104 | if not bond.IsInRing(): 105 | bond.SetIsAromatic(False) 106 | if str(bond.GetBondType()) == 'AROMATIC': 107 | bond.SetBondType(Chem.rdchem.BondType.SINGLE) 108 | 109 | def validate_mols(mols): 110 | for mol in mols: 111 | if Chem.MolFromSmiles(Chem.MolToSmiles(mol)) == None: 112 | return False 113 | return True 114 | 115 | def fix_reactant_atoms(product, reactants, matched_map, change_info): 116 | H_change, C_change, S_change = change_info['change_H'], change_info['change_C'], change_info['change_S'] 117 | fixed_mols = [] 118 | for mol in reactants: 119 | for atom in mol.GetAtoms(): 120 | if atom.HasProp('old_mapno'): 121 | mapno = int(atom.GetProp('old_mapno')) 122 | if mapno not in matched_map: 123 | return None 124 | product_atom = product.GetAtomWithIdx(matched_map[mapno]) 125 | H_before = product_atom.GetNumExplicitHs() + product_atom.GetNumImplicitHs() 126 | C_before = product_atom.GetFormalCharge() 127 | S_before = chiral_type_map[product_atom.GetChiralTag()] 128 | H_after = H_before + H_change[mapno] 129 | C_after = C_before + C_change[mapno] 130 | S_after = S_change[mapno] 131 | if H_after < 0: 132 | return None 133 | atom.SetNumExplicitHs(H_after) 134 | atom.SetFormalCharge(C_after) 135 | if S_after != 0: 136 | atom.SetChiralTag(chiral_type_map_inv[S_after]) 137 | fix_aromatic(mol) 138 | fixed_mols.append(mol) 139 | if validate_mols(fixed_mols): 140 | return tuple(fixed_mols) 141 | else: 142 | return None 143 | 144 | def demap(mols, stereo = True): 145 | if type(mols) == type((0, 0)): 146 | ss = [] 147 | for mol in mols: 148 | [atom.SetAtomMapNum(0) for atom in mol.GetAtoms()] 149 | mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol, stereo)) 150 | if mol == None: 151 | return None 152 | ss.append(Chem.MolToSmiles(mol)) 153 | return '.'.join(sorted(ss)) 154 | else: 155 | [atom.SetAtomMapNum(0) for atom in mols.GetAtoms()] 156 | return '.'.join(sorted(Chem.MolToSmiles(mols, stereo).split('.'))) 157 | 158 | def read_prediction(smiles, prediction, atom_templates, bond_templates, template_infos): 159 | mol = Chem.MolFromSmiles(smiles) 160 | if len(prediction) == 1: 161 | return mol, None, None, None, 0 162 | else: 163 | edit_type, pred_site, pred_template_class, prediction_score = prediction # (edit_type, pred_site, pred_template_class) 164 | idx_map = get_idx_map(mol) 165 | if edit_type == 'a': 166 | template = atom_templates[pred_template_class] 167 | try: 168 | if len(template.split('>>')[0].split('.')) > 1: pred_site = idx_map[pred_site] 169 | except Exception as e: 170 | raise Exception(f'{smiles}\n{prediction}\n{template}\n{idx_map}\n{pred_site}') from e 171 | else: 172 | template = bond_templates[pred_template_class] 173 | if len(template.split('>>')[0].split('.')) > 1: pred_site= (idx_map[pred_site[0]], idx_map[pred_site[1]]) 174 | [atom.SetAtomMapNum(atom.GetIdx()) for atom in mol.GetAtoms()] 175 | if pred_site == None: 176 | return mol, pred_site, short_template, {}, 0 177 | return mol, pred_site, template, template_infos[template], prediction_score 178 | 179 | def decode_localtemplate(product, pred_site, template, template_info): 180 | if pred_site == None: 181 | return None 182 | possible_maps = get_possible_map(pred_site, template_info) 183 | reaction = rdChemReactions.ReactionFromSmarts(template) 184 | reactants = reaction.RunReactants([product]) 185 | decodes = [] 186 | for output in reactants: 187 | if output == None: 188 | continue 189 | matched_maps = check_idx_match(output, possible_maps) 190 | for matched_map in matched_maps: 191 | decoded = fix_reactant_atoms(product, output, matched_map, template_info) 192 | if decoded == None: 193 | continue 194 | else: 195 | return demap(decoded) 196 | return None 197 | 198 | -------------------------------------------------------------------------------- /textreact/tokenizer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import re 4 | from typing import List 5 | import pandas as pd 6 | from transformers import PreTrainedTokenizer, AutoTokenizer, BertTokenizer 7 | 8 | 9 | def load_vocab(vocab_file): 10 | """Loads a vocabulary file into a dictionary.""" 11 | vocab = collections.OrderedDict() 12 | with open(vocab_file, "r", encoding="utf-8") as reader: 13 | tokens = reader.readlines() 14 | for index, token in enumerate(tokens): 15 | token = token.rstrip("\n") 16 | vocab[token] = index 17 | return vocab 18 | 19 | 20 | class ReactionConditionTokenizer(PreTrainedTokenizer): 21 | 22 | def __init__(self, vocab_file): 23 | super().__init__( 24 | pad_token='[PAD]', 25 | bos_token='[BOS]', 26 | eos_token='[EOS]', 27 | mask_token='[MASK]', 28 | unk_token='[UNK]', 29 | sep_token='[SEP]' 30 | ) 31 | self.vocab = load_vocab(vocab_file) 32 | self.ids_to_tokens = {ids: tok for tok, ids in self.vocab.items()} 33 | 34 | def __len__(self): 35 | return len(self.vocab) 36 | 37 | def __call__(self, conditions, **kwargs): 38 | tokens = self.convert_tokens_to_ids(conditions) 39 | return self.prepare_for_model(tokens, **kwargs) 40 | 41 | def _decode(self, token_ids, skip_special_tokens=False, **kwargs): 42 | tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) 43 | return tokens 44 | 45 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 46 | token_ids = token_ids_0 47 | return [self.bos_token_id] + token_ids + [self.eos_token_id] 48 | 49 | def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): 50 | token_ids = token_ids_0 51 | return [0] * (len(token_ids) + 2) 52 | 53 | def _convert_token_to_id(self, token): 54 | """Converts a token (str) in an id using the vocab.""" 55 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 56 | 57 | def _convert_id_to_token(self, index): 58 | """Converts an index (integer) in a token (str) using the vocab.""" 59 | return self.ids_to_tokens.get(index, self.unk_token) 60 | 61 | 62 | SMI_REGEX_PATTERN = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#" \ 63 | r"|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])" 64 | 65 | 66 | class SmilesTokenizer(BertTokenizer): 67 | """ 68 | Creates the SmilesTokenizer class. The tokenizer heavily inherits from the BertTokenizer 69 | implementation found in Huggingface's transformers library. It runs a WordPiece tokenization 70 | algorithm over SMILES strings using the tokenisation SMILES regex developed by Schwaller et. al. 71 | Please see https://github.com/huggingface/transformers and https://github.com/rxn4chemistry/rxnfp for more details. 72 | 73 | This class requires huggingface's transformers and tokenizers libraries to be installed. 74 | """ 75 | 76 | def __init__(self, vocab_file: str = '', **kwargs): 77 | """Constructs a SmilesTokenizer. 78 | Parameters 79 | ---------- 80 | vocab_file: str 81 | Path to a SMILES character per line vocabulary file. 82 | Default vocab file is found in deepchem/feat/tests/data/vocab_smiles.txt 83 | """ 84 | 85 | super().__init__(vocab_file, bos_token='[CLS]', eos_token='[SEP]', **kwargs) 86 | 87 | if not os.path.isfile(vocab_file): 88 | raise ValueError("Can't find a vocab file at path '{}'.".format(vocab_file)) 89 | self.vocab = load_vocab(vocab_file) 90 | self.highest_unused_index = max( 91 | [i for i, v in enumerate(self.vocab.keys()) if v.startswith("[unused")]) 92 | self.ids_to_tokens = collections.OrderedDict( 93 | [(ids, tok) for tok, ids in self.vocab.items()]) 94 | self.basic_tokenizer = BasicSmilesTokenizer() 95 | 96 | @property 97 | def vocab_size(self): 98 | return len(self.vocab) 99 | 100 | @property 101 | def vocab_list(self): 102 | return list(self.vocab.keys()) 103 | 104 | def _tokenize(self, text: str): 105 | """ 106 | Tokenize a string into a list of tokens. 107 | Parameters 108 | ---------- 109 | text: str 110 | Input string sequence to be tokenized. 111 | """ 112 | split_tokens = [token for token in self.basic_tokenizer.tokenize(text)] 113 | return split_tokens 114 | 115 | def _convert_token_to_id(self, token): 116 | """ 117 | Converts a token (str/unicode) in an id using the vocab. 118 | Parameters 119 | ---------- 120 | token: str 121 | String token from a larger sequence to be converted to a numerical id. 122 | """ 123 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 124 | 125 | def _convert_id_to_token(self, index): 126 | """ 127 | Converts an index (integer) in a token (string/unicode) using the vocab. 128 | Parameters 129 | ---------- 130 | index: int 131 | Integer index to be converted back to a string-based token as part of a larger sequence. 132 | """ 133 | return self.ids_to_tokens.get(index, self.unk_token) 134 | 135 | def convert_tokens_to_string(self, tokens: List[str]): 136 | """ Converts a sequence of tokens (string) in a single string. 137 | Parameters 138 | ---------- 139 | tokens: List[str] 140 | List of tokens for a given string sequence. 141 | Returns 142 | ------- 143 | out_string: str 144 | Single string from combined tokens. 145 | """ 146 | out_string: str = "".join(tokens).replace(" ##", "").strip() 147 | return out_string 148 | 149 | def add_special_tokens_ids_single_sequence(self, token_ids: List[int]): 150 | """ 151 | Adds special tokens to the a sequence for sequence classification tasks. 152 | A BERT sequence has the following format: [CLS] X [SEP] 153 | Parameters 154 | ---------- 155 | token_ids: list[int] 156 | list of tokenized input ids. Can be obtained using the encode or encode_plus methods. 157 | """ 158 | return [self.cls_token_id] + token_ids + [self.sep_token_id] 159 | 160 | def add_special_tokens_single_sequence(self, tokens: List[str]): 161 | """ 162 | Adds special tokens to the a sequence for sequence classification tasks. 163 | A BERT sequence has the following format: [CLS] X [SEP] 164 | Parameters 165 | ---------- 166 | tokens: List[str] 167 | List of tokens for a given string sequence. 168 | """ 169 | return [self.cls_token] + tokens + [self.sep_token] 170 | 171 | def add_special_tokens_ids_sequence_pair(self, token_ids_0: List[int], 172 | token_ids_1: List[int]) -> List[int]: 173 | """ 174 | Adds special tokens to a sequence pair for sequence classification tasks. 175 | A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP] 176 | Parameters 177 | ---------- 178 | token_ids_0: List[int] 179 | List of ids for the first string sequence in the sequence pair (A). 180 | token_ids_1: List[int] 181 | List of tokens for the second string sequence in the sequence pair (B). 182 | """ 183 | sep = [self.sep_token_id] 184 | cls = [self.cls_token_id] 185 | return cls + token_ids_0 + sep + token_ids_1 + sep 186 | 187 | def add_padding_tokens(self, 188 | token_ids: List[int], 189 | length: int, 190 | right: bool = True) -> List[int]: 191 | """ 192 | Adds padding tokens to return a sequence of length max_length. 193 | By default padding tokens are added to the right of the sequence. 194 | Parameters 195 | ---------- 196 | token_ids: list[int] 197 | list of tokenized input ids. Can be obtained using the encode or encode_plus methods. 198 | length: int 199 | right: bool (True by default) 200 | Returns 201 | ---------- 202 | token_ids : 203 | list of tokenized input ids. Can be obtained using the encode or encode_plus methods. 204 | padding: int 205 | Integer to be added as padding token 206 | """ 207 | padding = [self.pad_token_id] * (length - len(token_ids)) 208 | 209 | if right: 210 | return token_ids + padding 211 | else: 212 | return padding + token_ids 213 | 214 | 215 | class BasicSmilesTokenizer(object): 216 | """ 217 | Run basic SMILES tokenization using a regex pattern developed by Schwaller et. al. This tokenizer is to be used 218 | when a tokenizer that does not require the transformers library by HuggingFace is required. 219 | """ 220 | 221 | def __init__(self, regex_pattern: str = SMI_REGEX_PATTERN): 222 | """ Constructs a BasicSMILESTokenizer. """ 223 | self.regex_pattern = regex_pattern 224 | self.regex = re.compile(self.regex_pattern) 225 | 226 | def tokenize(self, text): 227 | """ Basic Tokenization of a SMILES. """ 228 | tokens = [token for token in self.regex.findall(text)] 229 | return tokens 230 | 231 | 232 | class SmilesTextTokenizer(PreTrainedTokenizer): 233 | 234 | def __init__(self, text_tokenizer, smiles_tokenizer=None): 235 | super().__init__( 236 | pad_token=text_tokenizer.pad_token, 237 | mask_token=text_tokenizer.mask_token) 238 | if smiles_tokenizer is None: 239 | self.separate = False 240 | self.smiles_tokenizer = text_tokenizer 241 | else: 242 | self.separate = True 243 | self.smiles_tokenizer = smiles_tokenizer 244 | self.text_tokenizer = text_tokenizer 245 | 246 | @property 247 | def smiles_offset(self): 248 | return len(self.text_tokenizer) if self.separate is not None else 0 249 | 250 | def __len__(self): 251 | return len(self.text_tokenizer) + self.smiles_offset 252 | 253 | def __call__(self, text, text_pair, **kwargs): 254 | result = self.smiles_tokenizer(text, **kwargs) 255 | if self.separate: 256 | result['input_ids'] = [v + self.smiles_offset for v in result['input_ids']] 257 | if isinstance(text_pair, str): 258 | result_pair = self.text_tokenizer(text_pair, **kwargs) 259 | for key in result: 260 | result[key] = result[key] + result_pair[key][1:] # skip the CLS token 261 | elif isinstance(text_pair, list): 262 | for t in text_pair: 263 | result_pair = self.text_tokenizer(t, **kwargs) 264 | for key in result: 265 | result[key] = result[key] + result_pair[key][1:] # skip the CLS token 266 | return result 267 | 268 | def _convert_id_to_token(self, index): 269 | if index < len(self.text_tokenizer): 270 | return self.text_tokenizer.convert_ids_to_tokens(index) 271 | else: 272 | return self.smiles_tokenizer.convert_ids_to_tokens(index - len(self.text_tokenizer)) 273 | 274 | def _convert_token_to_id(self, token): 275 | return self.text_tokenizer.convert_tokens_to_ids(token) 276 | 277 | 278 | def get_tokenizers(args): 279 | # Encoder 280 | if args.encoder_tokenizer == 'smiles': 281 | enc_tokenizer = SmilesTokenizer(args.vocab_file) 282 | elif args.encoder_tokenizer == 'text': 283 | text_tokenizer = AutoTokenizer.from_pretrained(args.encoder, use_fast=False) 284 | enc_tokenizer = SmilesTextTokenizer(text_tokenizer) 285 | elif args.encoder_tokenizer == 'smiles_text': 286 | smiles_tokenizer = SmilesTokenizer(args.vocab_file) 287 | text_tokenizer = AutoTokenizer.from_pretrained(args.encoder, use_fast=False) 288 | enc_tokenizer = SmilesTextTokenizer(text_tokenizer, smiles_tokenizer) 289 | else: 290 | raise ValueError 291 | if args.template_based: 292 | assert args.encoder_tokenizer.startswith('smiles') 293 | atom_templates = pd.read_csv(os.path.join(args.template_path, 'atom_templates.csv'))['Template'] 294 | bond_templates = pd.read_csv(os.path.join(args.template_path, 'bond_templates.csv'))['Template'] 295 | dec_tokenizer = atom_templates, bond_templates 296 | else: 297 | assert args.template_path is None 298 | # Decoder 299 | if args.task == 'condition': 300 | dec_tokenizer = ReactionConditionTokenizer(args.vocab_file) 301 | elif args.task == 'retro': 302 | dec_tokenizer = SmilesTokenizer(args.vocab_file) 303 | else: 304 | raise ValueError 305 | return enc_tokenizer, dec_tokenizer 306 | -------------------------------------------------------------------------------- /textreact/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import logging 8 | 9 | from rdkit import Chem 10 | 11 | 12 | metric_to_mode = { 13 | 'val_loss': 'min', 14 | 'val_acc': 'max' 15 | } 16 | 17 | 18 | def expand_position_embeddings(encoder, max_length): 19 | if encoder.config.model_type in ['bert', 'longformer', 'roberta']: 20 | if max_length <= encoder.config.max_position_embeddings: 21 | return 22 | embeddings = encoder.embeddings 23 | config = encoder.config 24 | old_emb = embeddings.position_embeddings.weight.data.clone() 25 | config.max_position_embeddings = max_length 26 | embeddings.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 27 | embeddings.position_embeddings.weight.data[:old_emb.size(0)] = old_emb 28 | embeddings.register_buffer( 29 | "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 30 | embeddings.register_buffer( 31 | "token_type_ids", torch.zeros((1, config.max_position_embeddings), dtype=torch.long), persistent=False) 32 | else: 33 | raise NotImplementedError 34 | 35 | 36 | def expand_word_embeddings(encoder, vocab_size): 37 | if vocab_size <= encoder.config.vocab_size: 38 | return 39 | embeddings = encoder.embeddings 40 | config = encoder.config 41 | old_emb = embeddings.word_embeddings.weight.data.clone() 42 | config.vocab_size = vocab_size 43 | embeddings.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 44 | embeddings.word_embeddings.weight.data[:old_emb.size(0)] = old_emb 45 | 46 | 47 | def clear_path(path, trainer): 48 | for file in os.listdir(path): 49 | if file.endswith('.ckpt'): 50 | filepath = os.path.join(path, file) 51 | logging.info(f' Remove checkpoint {filepath}') 52 | trainer.strategy.remove_checkpoint(filepath) 53 | 54 | 55 | def gather_prediction_each_neighbor(prediction, num_neighbors): 56 | results = {} 57 | for i, pred in sorted(prediction.items()): 58 | idx = i // num_neighbors 59 | if i % num_neighbors == 0: 60 | results[idx] = pred 61 | else: 62 | for key in results[idx]: 63 | results[idx][key] += pred[key] 64 | return results 65 | 66 | 67 | """ Adapted from localretro_model/get_edit.py """ 68 | 69 | def get_id_template(a, class_n, num_atoms, edit_type, python_numbers=False): 70 | edit_idx = a // class_n 71 | template = a % class_n 72 | if edit_type == 'b': 73 | edit_idx = (edit_idx // num_atoms, edit_idx % num_atoms) 74 | if python_numbers: 75 | edit_idx = edit_idx.item() if edit_type == 'a' else (edit_idx[0].item(), edit_idx[1].item()) 76 | template = template.item() 77 | return edit_idx, template 78 | 79 | def output2edit(out, top_num, edit_type, bonds=None): 80 | num_atoms, class_n = out.shape[-2:] 81 | readout = out.cpu().detach().numpy() 82 | readout = readout.reshape(-1) 83 | output_rank = np.flip(np.argsort(readout)) 84 | filtered_output_rank = [] 85 | for r in output_rank: 86 | edit_idx, template = get_id_template(r, class_n, num_atoms, edit_type) 87 | if (bonds is None or edit_idx in bonds) and template != 0: 88 | filtered_output_rank.append(r) 89 | if len(filtered_output_rank) == top_num: 90 | break 91 | selected_edit = [get_id_template(a, class_n, num_atoms, edit_type, python_numbers=True) for a in filtered_output_rank] 92 | selected_proba = [readout[a].item() for a in filtered_output_rank] 93 | 94 | return selected_edit, selected_proba 95 | 96 | def combined_edit(atom_out, bond_out, bonds, top_num=None): 97 | edit_id_a, edit_proba_a = output2edit(atom_out, top_num, edit_type='a') 98 | edit_id_b, edit_proba_b = output2edit(bond_out, top_num, edit_type='b', bonds=bonds) 99 | edit_id_c = edit_id_a + edit_id_b 100 | edit_type_c = ['a'] * len(edit_proba_a) + ['b'] * len(edit_proba_b) 101 | edit_proba_c = edit_proba_a + edit_proba_b 102 | edit_rank_c = np.flip(np.argsort(edit_proba_c)) 103 | if top_num is not None: 104 | edit_rank_c = edit_rank_c[:top_num] 105 | edit_preds_c = [(edit_type_c[r], *edit_id_c[r]) for r in edit_rank_c] 106 | edit_proba_c = [edit_proba_c[r] for r in edit_rank_c] 107 | 108 | return edit_preds_c, edit_proba_c 109 | -------------------------------------------------------------------------------- /textreact/vocab/vocab_condition.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [BOS] 3 | [EOS] 4 | [MASK] 5 | [UNK] 6 | [SEP] 7 | 8 | B 9 | B1C2CCCC1CCC2 10 | Br 11 | BrB(Br)Br 12 | BrBr 13 | Br[Cu]Br 14 | Br[P+](N1CCCC1)(N1CCCC1)N1CCCC1.F[P-](F)(F)(F)(F)F 15 | C 16 | C(=NC1CCCCC1)=NC1CCCCC1 17 | C1=CCCCC1 18 | C1CCC(P(C2CCCCC2)C2CCCCC2)CC1 19 | C1CCC2=NCCCN2CC1 20 | C1CCCCC1 21 | C1CCNC1 22 | C1CCNCC1 23 | C1CCOC1 24 | C1CC[NH2+]CC1.CC(=O)[O-] 25 | C1CN2CCN1CC2 26 | C1COCCN1 27 | C1COCCO1 28 | C1COCCOCCOCCOCCO1 29 | C1COCCOCCOCCOCCOCCO1 30 | CC#N 31 | CC(=O)CC(C)C 32 | CC(=O)Cl 33 | CC(=O)N(C)C 34 | CC(=O)O 35 | CC(=O)OC(C)=O 36 | CC(=O)OC(C)C 37 | CC(=O)OI1(OC(C)=O)(OC(C)=O)OC(=O)c2ccccc21 38 | CC(=O)O[BH-](OC(C)=O)OC(C)=O.[Na+] 39 | CC(=O)O[Pd]OC(C)=O 40 | CC(=O)[O-].[K+] 41 | CC(=O)[O-].[Na+] 42 | CC(C)(C#N)N=NC(C)(C)C#N 43 | CC(C)(C)O 44 | CC(C)(C)OC(=O)N=NC(=O)OC(C)(C)C 45 | CC(C)(C)ON=O 46 | CC(C)(C)O[K] 47 | CC(C)(C)O[Na] 48 | CC(C)(C)[O-].[K+] 49 | CC(C)(C)[P]([Pd][P](C(C)(C)C)(C(C)(C)C)C(C)(C)C)(C(C)(C)C)C(C)(C)C 50 | CC(C)=O 51 | CC(C)CCO 52 | CC(C)CCON=O 53 | CC(C)CO 54 | CC(C)COC(=O)Cl 55 | CC(C)C[Al+]CC(C)C.[H-] 56 | CC(C)C[AlH]CC(C)C 57 | CC(C)N=C=NC(C)C 58 | CC(C)NC(C)C 59 | CC(C)O 60 | CC(C)OC(=O)/N=N/C(=O)OC(C)C 61 | CC(C)OC(=O)N=NC(=O)OC(C)C 62 | CC(C)OC(C)C 63 | CC(C)O[Ti](OC(C)C)(OC(C)C)OC(C)C 64 | CC(C)[Mg]Cl 65 | CC(C)[N-]C(C)C.[Li+] 66 | CC(C)c1cc(C(C)C)c(-c2ccccc2P(C2CCCCC2)C2CCCCC2)c(C(C)C)c1 67 | CC(Cl)Cl 68 | CC(Cl)OC(=O)Cl 69 | CC1(C)C2CCC1(CS(=O)(=O)O)C(=O)C2 70 | CC1(C)CCCC(C)(C)N1 71 | CC1(C)c2cccc(P(c3ccccc3)c3ccccc3)c2Oc2c(P(c3ccccc3)c3ccccc3)cccc21 72 | CC1CCCO1 73 | CC=C(C)C 74 | CCC#N 75 | CCC(C)(C)O 76 | CCC(C)=O 77 | CCC(C)O 78 | CCCCC 79 | CCCCCC 80 | CCCCCCC 81 | CCCCCO 82 | CCCCN(CCCC)CCCC 83 | CCCCO 84 | CCCCP(=CC#N)(CCCC)CCCC 85 | CCCCP(CCCC)CCCC 86 | CCCC[N+](CCCC)(CCCC)CCCC.[F-] 87 | CCCC[SnH](CCCC)CCCC 88 | CCCC[Sn](=O)CCCC 89 | CCCO 90 | CCCP1(=O)OP(=O)(CCC)OP(=O)(CCC)O1 91 | CCN(C(C)C)C(C)C 92 | CCN(CC)CC 93 | CCN=C=NCCCN(C)C 94 | CCNCC 95 | CCO 96 | CCOC(=O)/N=N/C(=O)OCC 97 | CCOC(=O)Cl 98 | CCOC(=O)N=NC(=O)OCC 99 | CCOC(C)=O 100 | CCOCC 101 | CCOCCO 102 | CCOP(=O)(C#N)OCC 103 | CC[Mg]Br 104 | CC[N+](CC)(CC)S(=O)(=O)N=C([O-])OC 105 | CC[SiH](CC)CC 106 | CN 107 | CN(C)C(=N)N(C)C 108 | CN(C)C(N(C)C)=[N+]1N=[N+]([O-])c2ncccc21.F[P-](F)(F)(F)(F)F 109 | CN(C)C(On1nnc2ccccc21)=[N+](C)C.F[B-](F)(F)F 110 | CN(C)C(On1nnc2ccccc21)=[N+](C)C.F[P-](F)(F)(F)(F)F 111 | CN(C)C(On1nnc2cccnc21)=[N+](C)C.F[P-](F)(F)(F)(F)F 112 | CN(C)C=O 113 | CN(C)CCN(C)C 114 | CN(C)P(=O)(N(C)C)N(C)C 115 | CN(C)[P+](On1nnc2ccccc21)(N(C)C)N(C)C.F[P-](F)(F)(F)(F)F 116 | CN(C)c1ccccc1 117 | CN(C)c1ccccn1 118 | CN(C)c1ccncc1 119 | CN1CCCC1=O 120 | CN1CCCN(C)C1=O 121 | CN1CCOCC1 122 | CN1CC[NH+](C)C1Cl.[Cl-] 123 | CNCCNC 124 | CN[C@@H]1CCCC[C@H]1NC 125 | CO 126 | COC(C)(C)C 127 | COCCO 128 | COCCOC 129 | COCCOCCOC 130 | CO[Na] 131 | COc1cccc(OC)c1-c1ccccc1P(C1CCCCC1)C1CCCCC1 132 | COc1ccccc1 133 | COc1nc(OC)nc([N+]2(C)CCOCC2)n1.[Cl-] 134 | CS(=O)(=O)Cl 135 | CS(=O)(=O)O 136 | CS(C)=O 137 | CSC 138 | C[Al](C)C 139 | C[N+](=O)[O-] 140 | C[N+]1([O-])CCOCC1 141 | C[O-].[Na+] 142 | C[Si](C)(C)Br 143 | C[Si](C)(C)C=[N+]=[N-] 144 | C[Si](C)(C)Cl 145 | C[Si](C)(C)I 146 | C[Si](C)(C)[N-][Si](C)(C)C.[K+] 147 | C[Si](C)(C)[N-][Si](C)(C)C.[Li+] 148 | C[Si](C)(C)[N-][Si](C)(C)C.[Na+] 149 | C[n+]1ccccc1Cl.[I-] 150 | Cc1cc(C)c(N2CCN(c3c(C)cc(C)cc3C)C2=[Ru](Cl)(Cl)(=Cc2ccccc2)[P](C2CCCCC2)(C2CCCCC2)C2CCCCC2)c(C)c1 151 | Cc1cc(C)cc(C)c1 152 | Cc1ccc(C)cc1 153 | Cc1ccc(S(=O)(=O)Cl)cc1 154 | Cc1ccc(S(=O)(=O)O)cc1 155 | Cc1ccc(S(=O)(=O)[O-])cc1.c1cc[nH+]cc1 156 | Cc1cccc(C)n1 157 | Cc1ccccc1 158 | Cc1ccccc1C 159 | Cc1ccccc1P(c1ccccc1C)c1ccccc1C 160 | Cc1ccccc1S(=O)(=O)O 161 | Cc1ccccc1[P](c1ccccc1C)(c1ccccc1C)[Pd](Cl)(Cl)[P](c1ccccc1C)(c1ccccc1C)c1ccccc1C 162 | Cl 163 | ClB(Cl)Cl 164 | ClC(Cl)(Cl)Cl 165 | ClC(Cl)Cl 166 | ClCCCl 167 | ClCCl 168 | ClP(Cl)(Cl)(Cl)Cl 169 | ClP(Cl)Cl 170 | Cl[Cu] 171 | Cl[Cu]Cl 172 | Cl[Fe](Cl)Cl 173 | Cl[Hg]Cl 174 | Cl[Ni]Cl 175 | Cl[Pd](Cl)([P](c1ccccc1)(c1ccccc1)c1ccccc1)[P](c1ccccc1)(c1ccccc1)c1ccccc1 176 | Cl[Pd]Cl 177 | Cl[Sn](Cl)(Cl)Cl 178 | Cl[Sn]Cl 179 | Cl[Ti](Cl)(Cl)Cl 180 | Clc1ccccc1 181 | Clc1ccccc1Cl 182 | FB(F)F 183 | I 184 | II 185 | I[Cu]I 186 | N 187 | N#CC1=C(C#N)C(=O)C(Cl)=C(Cl)C1=O 188 | N#N 189 | NCCN 190 | NN 191 | N[C@@H]1CCCC[C@H]1N 192 | O 193 | O=C(CO)O[CH][C@@H](O)CO 194 | O=C(Cl)C(=O)Cl 195 | O=C(N=NC(=O)N1CCCCC1)N1CCCCC1 196 | O=C(O)C(=O)O 197 | O=C(O)C(F)(F)F 198 | O=C(O)CC(O)(CC(=O)O)C(=O)O 199 | O=C(O)O 200 | O=C(O)[C@@H]1CCCN1 201 | O=C(OC(=O)C(F)(F)F)C(F)(F)F 202 | O=C(OO)c1cccc(Cl)c1 203 | O=C(OOC(=O)c1ccccc1)c1ccccc1 204 | O=C([O-])O 205 | O=C([O-])[O-].[Ca+2] 206 | O=C([O-])[O-].[Cs+].[Cs+] 207 | O=C([O-])[O-].[K+].[K+] 208 | O=C([O-])[O-].[Li+] 209 | O=C([O-])[O-].[Na+].[Na+] 210 | O=C(c1ncc[nH]1)c1ncc[nH]1 211 | O=C(n1ccnc1)n1ccnc1 212 | O=C1CCC(=O)N1Br 213 | O=C1OCCN1P(=O)(Cl)N1CCOC1=O 214 | O=C=O 215 | O=CO 216 | O=C[O-].[NH4+] 217 | O=N[O-].[Na+] 218 | O=O 219 | O=P(Cl)(Cl)Cl 220 | O=P(O)(O)O 221 | O=P([O-])([O-])[O-] 222 | O=P([O-])([O-])[O-].[K+].[K+].[K+] 223 | O=P12OP3(=O)OP(=O)(O1)OP(=O)(O2)O3 224 | O=S(=O)(O)C(F)(F)F 225 | O=S(=O)(O)O 226 | O=S(=O)([O-])C(F)(F)F.O=S(=O)([O-])C(F)(F)F.O=S(=O)([O-])C(F)(F)F.[Yb+3] 227 | O=S(=O)([O-])[O-].[Mg+2] 228 | O=S(=O)([O-])[O-].[Na+].[Na+] 229 | O=S(Cl)Cl 230 | O=S([O-])([O-])=S.[Na+].[Na+] 231 | O=S([O-])S(=O)[O-].[Na+].[Na+] 232 | O=S([O-])[O-].[Na+].[Na+] 233 | O=S1(=O)CCCC1 234 | O=[Ag-] 235 | O=[Ag] 236 | O=[Cr](=O)([O-])Cl.c1cc[nH+]cc1 237 | O=[Cr](=O)([O-])O[Cr](=O)(=O)[O-].c1cc[nH+]cc1.c1cc[nH+]cc1 238 | O=[Cr](=O)=O 239 | O=[Cu-] 240 | O=[Cu] 241 | O=[Mn]=O 242 | O=[N+]([O-])c1ccccc1 243 | O=[Os](=O)(=O)=O 244 | O=[Pt] 245 | O=[Pt]=O 246 | OC(=O)[O-].[K+] 247 | OC(=O)[O-].[Na+] 248 | OCC(F)(F)F 249 | OCC(O)CO 250 | OCCO 251 | OCCOCCO 252 | OO 253 | OS(=O)[O-].[Na+] 254 | On1nnc2ccccc21 255 | On1nnc2cccnc21 256 | S=C=S 257 | [Al+3].[Cl-].[Cl-].[Cl-] 258 | [Al+3].[H-].[H-].[H-].[H-].[Li+] 259 | [Al] 260 | [BH3-]C#N.[Na+] 261 | [BH4-].[Li+] 262 | [BH4-].[Na+] 263 | [Br-].[K+] 264 | [C-]#N.[Na+] 265 | [Ca+2].[Cl-].[Cl-] 266 | [Cl-] 267 | [Cl-].[Li+] 268 | [Cl-].[NH4+] 269 | [Cl-].[Na+] 270 | [Cs+].[F-] 271 | [Cu] 272 | [Cu]Br 273 | [Cu]I 274 | [F-].[K+] 275 | [Fe] 276 | [H-].[Na+] 277 | [H][H] 278 | [I-].[K+] 279 | [I-].[Li+] 280 | [I-].[Na+] 281 | [K+] 282 | [K+].[OH-] 283 | [K] 284 | [Li+].[OH-] 285 | [Li] 286 | [Li]C(C)(C)C 287 | [Li]C(C)CC 288 | [Li]CCCC 289 | [Li]O 290 | [Mg] 291 | [NH4+] 292 | [NH4+].[OH-] 293 | [Na+] 294 | [Na+].[O-]Cl 295 | [Na+].[O-][I+3]([O-])([O-])[O-] 296 | [Na+].[OH-] 297 | [Na] 298 | [Ni] 299 | [O-][Cl+3]([O-])([O-])O 300 | [OH-] 301 | [Pd] 302 | [Pt] 303 | [Rh] 304 | [Ru] 305 | [Zn] 306 | c1c[nH]cn1 307 | c1ccc(Oc2ccccc2)cc1 308 | c1ccc(P(c2ccccc2)c2ccc3ccccc3c2-c2c(P(c3ccccc3)c3ccccc3)ccc3ccccc23)cc1 309 | c1ccc(P(c2ccccc2)c2ccccc2)cc1 310 | c1ccc([PH](c2ccccc2)(c2ccccc2)[Pd-4]([PH](c2ccccc2)(c2ccccc2)c2ccccc2)([PH](c2ccccc2)(c2ccccc2)c2ccccc2)[PH](c2ccccc2)(c2ccccc2)c2ccccc2)cc1 311 | c1ccc([P](c2ccccc2)(c2ccccc2)[Pd]([P](c2ccccc2)(c2ccccc2)c2ccccc2)([P](c2ccccc2)(c2ccccc2)c2ccccc2)[P](c2ccccc2)(c2ccccc2)c2ccccc2)cc1 312 | c1ccc2ncccc2c1 313 | c1ccccc1 314 | c1ccncc1 315 | c1cnc2c(c1)ccc1cccnc12 -------------------------------------------------------------------------------- /textreact/vocab/vocab_smiles.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [unused1] 3 | [unused2] 4 | [unused3] 5 | [unused4] 6 | [unused5] 7 | [unused6] 8 | [unused7] 9 | [unused8] 10 | [unused9] 11 | [unused10] 12 | [UNK] 13 | [CLS] 14 | [SEP] 15 | [MASK] 16 | c 17 | C 18 | ( 19 | ) 20 | O 21 | 1 22 | 2 23 | = 24 | N 25 | . 26 | n 27 | 3 28 | F 29 | Cl 30 | >> 31 | ~ 32 | - 33 | 4 34 | [C@H] 35 | S 36 | [C@@H] 37 | [O-] 38 | Br 39 | # 40 | / 41 | [nH] 42 | [N+] 43 | s 44 | 5 45 | o 46 | P 47 | [Na+] 48 | [Si] 49 | I 50 | [Na] 51 | [Pd] 52 | [K+] 53 | [K] 54 | [P] 55 | B 56 | [C@] 57 | [C@@] 58 | [Cl-] 59 | 6 60 | [OH-] 61 | \ 62 | [N-] 63 | [Li] 64 | [H] 65 | [2H] 66 | [NH4+] 67 | [c-] 68 | [P-] 69 | [Cs+] 70 | [Li+] 71 | [Cs] 72 | [NaH] 73 | [H-] 74 | [O+] 75 | [BH4-] 76 | [Cu] 77 | 7 78 | [Mg] 79 | [Fe+2] 80 | [n+] 81 | [Sn] 82 | [BH-] 83 | [Pd+2] 84 | [CH] 85 | [I-] 86 | [Br-] 87 | [C-] 88 | [Zn] 89 | [B-] 90 | [F-] 91 | [Al] 92 | [P+] 93 | [BH3-] 94 | [Fe] 95 | [C] 96 | [AlH4] 97 | [Ni] 98 | [SiH] 99 | 8 100 | [Cu+2] 101 | [Mn] 102 | [AlH] 103 | [nH+] 104 | [AlH4-] 105 | [O-2] 106 | [Cr] 107 | [Mg+2] 108 | [NH3+] 109 | [S@] 110 | [Pt] 111 | [Al+3] 112 | [S@@] 113 | [S-] 114 | [Ti] 115 | [Zn+2] 116 | [PH] 117 | [NH2+] 118 | [Ru] 119 | [Ag+] 120 | [S+] 121 | [I+3] 122 | [NH+] 123 | [Ca+2] 124 | [Ag] 125 | 9 126 | [Os] 127 | [Se] 128 | [SiH2] 129 | [Ca] 130 | [Ti+4] 131 | [Ac] 132 | [Cu+] 133 | [S] 134 | [Rh] 135 | [Cl+3] 136 | [cH-] 137 | [Zn+] 138 | [O] 139 | [Cl+] 140 | [SH] 141 | [H+] 142 | [Pd+] 143 | [se] 144 | [PH+] 145 | [I] 146 | [Pt+2] 147 | [C+] 148 | [Mg+] 149 | [Hg] 150 | [W] 151 | [SnH] 152 | [SiH3] 153 | [Fe+3] 154 | [NH] 155 | [Mo] 156 | [CH2+] 157 | %10 158 | [CH2-] 159 | [CH2] 160 | [n-] 161 | [Ce+4] 162 | [NH-] 163 | [Co] 164 | [I+] 165 | [PH2] 166 | [Pt+4] 167 | [Ce] 168 | [B] 169 | [Sn+2] 170 | [Ba+2] 171 | %11 172 | [Fe-3] 173 | [18F] 174 | [SH-] 175 | [Pb+2] 176 | [Os-2] 177 | [Zr+4] 178 | [N] 179 | [Ir] 180 | [Bi] 181 | [Ni+2] 182 | [P@] 183 | [Co+2] 184 | [s+] 185 | [As] 186 | [P+3] 187 | [Hg+2] 188 | [Yb+3] 189 | [CH-] 190 | [Zr+2] 191 | [Mn+2] 192 | [CH+] 193 | [In] 194 | [KH] 195 | [Ce+3] 196 | [Zr] 197 | [AlH2-] 198 | [OH2+] 199 | [Ti+3] 200 | [Rh+2] 201 | [Sb] 202 | [S-2] 203 | %12 204 | [P@@] 205 | [Si@H] 206 | [Mn+4] 207 | p 208 | [Ba] 209 | [NH2-] 210 | [Ge] 211 | [Pb+4] 212 | [Cr+3] 213 | [Au] 214 | [LiH] 215 | [Sc+3] 216 | [o+] 217 | [Rh-3] 218 | %13 219 | [Br] 220 | [Sb-] 221 | [S@+] 222 | [I+2] 223 | [Ar] 224 | [V] 225 | [Cu-] 226 | [Al-] 227 | [Te] 228 | [13c] 229 | [13C] 230 | [Cl] 231 | [PH4+] 232 | [SiH4] 233 | [te] 234 | [CH3-] 235 | [S@@+] 236 | [Rh+3] 237 | [SH+] 238 | [Bi+3] 239 | [Br+2] 240 | [La] 241 | [La+3] 242 | [Pt-2] 243 | [N@@] 244 | [PH3+] 245 | [N@] 246 | [Si+4] 247 | [Sr+2] 248 | [Al+] 249 | [Pb] 250 | [SeH] 251 | [Si-] 252 | [V+5] 253 | [Y+3] 254 | [Re] 255 | [Ru+] 256 | [Sm] 257 | * 258 | [3H] 259 | [NH2] 260 | [Ag-] 261 | [13CH3] 262 | [OH+] 263 | [Ru+3] 264 | [OH] 265 | [Gd+3] 266 | [13CH2] 267 | [In+3] 268 | [Si@@] 269 | [Si@] 270 | [Ti+2] 271 | [Sn+] 272 | [Cl+2] 273 | [AlH-] 274 | [Pd-2] 275 | [SnH3] 276 | [B+3] 277 | [Cu-2] 278 | [Nd+3] 279 | [Pb+3] 280 | [13cH] 281 | [Fe-4] 282 | [Ga] 283 | [Sn+4] 284 | [Hg+] 285 | [11CH3] 286 | [Hf] 287 | [Pr] 288 | [Y] 289 | [S+2] 290 | [Cd] 291 | [Cr+6] 292 | [Zr+3] 293 | [Rh+] 294 | [CH3] 295 | [N-3] 296 | [Hf+2] 297 | [Th] 298 | [Sb+3] 299 | %14 300 | [Cr+2] 301 | [Ru+2] 302 | [Hf+4] 303 | [14C] 304 | [Ta] 305 | [Tl+] 306 | [B+] 307 | [Os+4] 308 | [PdH2] 309 | [Pd-] 310 | [Cd+2] 311 | [Co+3] 312 | [S+4] 313 | [Nb+5] 314 | [123I] 315 | [c+] 316 | [Rb+] 317 | [V+2] 318 | [CH3+] 319 | [Ag+2] 320 | [cH+] 321 | [Mn+3] 322 | [Se-] 323 | [As-] 324 | [Eu+3] 325 | [SH2] 326 | [Sm+3] 327 | [IH+] 328 | %15 329 | [OH3+] 330 | [PH3] 331 | [IH2+] 332 | [SH2+] 333 | [Ir+3] 334 | [AlH3] 335 | [Sc] 336 | [Yb] 337 | [15NH2] 338 | [Lu] 339 | [sH+] 340 | [Gd] 341 | [18F-] 342 | [SH3+] 343 | [SnH4] 344 | [TeH] 345 | [Si@@H] 346 | [Ga+3] 347 | [CaH2] 348 | [Tl] 349 | [Ta+5] 350 | [GeH] 351 | [Br+] 352 | [Sr] 353 | [Tl+3] 354 | [Sm+2] 355 | [PH5] 356 | %16 357 | [N@@+] 358 | [Au+3] 359 | [C-4] 360 | [Nd] 361 | [Ti+] 362 | [IH] 363 | [N@+] 364 | [125I] 365 | [Eu] 366 | [Sn+3] 367 | [Nb] 368 | [Er+3] 369 | [123I-] 370 | [14c] 371 | %17 372 | [SnH2] 373 | [YH] 374 | [Sb+5] 375 | [Pr+3] 376 | [Ir+] 377 | [N+3] 378 | [AlH2] 379 | [19F] 380 | %18 381 | [Tb] 382 | [14CH] 383 | [Mo+4] 384 | [Si+] 385 | [BH] 386 | [Be] 387 | [Rb] 388 | [pH] 389 | %19 390 | %20 391 | [Xe] 392 | [Ir-] 393 | [Be+2] 394 | [C+4] 395 | [RuH2] 396 | [15NH] 397 | [U+2] 398 | [Au-] 399 | %21 400 | %22 401 | [Au+] 402 | [15n] 403 | [Al+2] 404 | [Tb+3] 405 | [15N] 406 | [V+3] 407 | [W+6] 408 | [14CH3] 409 | [Cr+4] 410 | [ClH+] 411 | b 412 | [Ti+6] 413 | [Nd+] 414 | [Zr+] 415 | [PH2+] 416 | [Fm] 417 | [N@H+] 418 | [RuH] 419 | [Dy+3] 420 | %23 421 | [Hf+3] 422 | [W+4] 423 | [11C] 424 | [13CH] 425 | [Er] 426 | [124I] 427 | [LaH] 428 | [F] 429 | [siH] 430 | [Ga+] 431 | [Cm] 432 | [GeH3] 433 | [IH-] 434 | [U+6] 435 | [SeH+] 436 | [32P] 437 | [SeH-] 438 | [Pt-] 439 | [Ir+2] 440 | [se+] 441 | [U] 442 | [F+] 443 | [BH2] 444 | [As+] 445 | [Cf] 446 | [ClH2+] 447 | [Ni+] 448 | [TeH3] 449 | [SbH2] 450 | [Ag+3] 451 | %24 452 | [18O] 453 | [PH4] 454 | [Os+2] 455 | [Na-] 456 | [Sb+2] 457 | [V+4] 458 | [Ho+3] 459 | [68Ga] 460 | [PH-] 461 | [Bi+2] 462 | [Ce+2] 463 | [Pd+3] 464 | [99Tc] 465 | [13C@@H] 466 | [Fe+6] 467 | [c] 468 | [GeH2] 469 | [10B] 470 | [Cu+3] 471 | [Mo+2] 472 | [Cr+] 473 | [Pd+4] 474 | [Dy] 475 | [AsH] 476 | [Ba+] 477 | [SeH2] 478 | [In+] 479 | [TeH2] 480 | [BrH+] 481 | [14cH] 482 | [W+] 483 | [13C@H] 484 | [AsH2] 485 | [In+2] 486 | [N+2] 487 | [N@@H+] 488 | [SbH] 489 | [60Co] 490 | [AsH4+] 491 | [AsH3] 492 | [18OH] 493 | [Ru-2] 494 | [Na-2] 495 | [CuH2] 496 | [31P] 497 | [Ti+5] 498 | [35S] 499 | [P@@H] 500 | [ArH] 501 | [Co+] 502 | [Zr-2] 503 | [BH2-] 504 | [131I] 505 | [SH5] 506 | [VH] 507 | [B+2] 508 | [Yb+2] 509 | [14C@H] 510 | [211At] 511 | [NH3+2] 512 | [IrH] 513 | [IrH2] 514 | [Rh-] 515 | [Cr-] 516 | [Sb+] 517 | [Ni+3] 518 | [TaH3] 519 | [Tl+2] 520 | [64Cu] 521 | [Tc] 522 | [Cd+] 523 | [1H] 524 | [15nH] 525 | [AlH2+] 526 | [FH+2] 527 | [BiH3] 528 | [Ru-] 529 | [Mo+6] 530 | [AsH+] 531 | [BaH2] 532 | [BaH] 533 | [Fe+4] 534 | [229Th] 535 | [Th+4] 536 | [As+3] 537 | [NH+3] 538 | [P@H] 539 | [Li-] 540 | [7NaH] 541 | [Bi+] 542 | [PtH+2] 543 | [p-] 544 | [Re+5] 545 | [NiH] 546 | [Ni-] 547 | [Xe+] 548 | [Ca+] 549 | [11c] 550 | [Rh+4] 551 | [AcH] 552 | [HeH] 553 | [Sc+2] 554 | [Mn+] 555 | [UH] 556 | [14CH2] 557 | [SiH4+] 558 | [18OH2] 559 | [Ac-] 560 | [Re+4] 561 | [118Sn] 562 | [153Sm] 563 | [P+2] 564 | [9CH] 565 | [9CH3] 566 | [Y-] 567 | [NiH2] 568 | [Si+2] 569 | [Mn+6] 570 | [ZrH2] 571 | [C-2] 572 | [Bi+5] 573 | [24NaH] 574 | [Fr] 575 | [15CH] 576 | [Se+] 577 | [At] 578 | [P-3] 579 | [124I-] 580 | [CuH2-] 581 | [Nb+4] 582 | [Nb+3] 583 | [MgH] 584 | [Ir+4] 585 | [67Ga+3] 586 | [67Ga] 587 | [13N] 588 | [15OH2] 589 | [2NH] 590 | [Ho] 591 | [Cn] --------------------------------------------------------------------------------