├── .gitignore
├── README.md
├── assets
└── textreact.png
├── environment.yml
├── main.py
├── preprocess
├── dedup_corpus.py
├── gen_uspto.py
├── get_templates.py
├── preprocess_retrosynthesis.py
├── raw_retro_year_split.py
├── reagent_Ionic_compound.txt
├── reagent_unknown.txt
├── retro_year_split.py
├── template_extraction
│ ├── template_extract_utils.py
│ └── template_extractor.py
└── uspto_script
│ ├── 1.get_condition_from_uspto.py
│ ├── 2.0.clean_up_rxn_condition.py
│ ├── 2.0.clean_up_rxn_condition.sh
│ ├── 2.1.merge_clean_up_rxn_conditon.py
│ ├── 3.0.split_condition_and_slect.py
│ ├── 4.0.split_train_val_test.py
│ ├── 5.0.convert_context_tokens.py
│ ├── condition_classfication.ipynb
│ ├── extract_nosmiles.py
│ ├── gen_grant_corpus.py
│ ├── get_aug_condition_data.py
│ ├── get_dataset_for_condition.py
│ ├── get_dummy_model_results.py
│ ├── get_fragment_from_rxn_dataset.py
│ ├── merge_comp.py
│ ├── uspto_condition.md
│ └── utils.py
├── retrieve
├── condition_year.sh
├── convert_format.py
├── retrieve.py
├── retrieve_faiss.py
├── retro.sh
└── retro_year.sh
├── scripts
├── train_RCR.sh
├── train_RCR_TS.sh
├── train_RetroSyn_tb.sh
├── train_RetroSyn_tb_TS.sh
├── train_RetroSyn_tf.sh
└── train_RetroSyn_tf_TS.sh
└── textreact
├── __init__.py
├── configs
└── bert_l6.json
├── dataset.py
├── evaluate.py
├── model.py
├── template_decoder.py
├── tokenizer.py
├── utils.py
└── vocab
├── vocab_condition.txt
└── vocab_smiles.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.DS_Store
2 | scripts_old/*
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TextReact
2 |
3 | This repository contains the code for [TextReact](https://aclanthology.org/2023.emnlp-main.784/), a novel method that directly augments
4 | predictive chemistry with text retrieval.
5 |
6 | 
7 |
8 | ```
9 | @inproceedings{TextReact,
10 | author = {Yujie Qian and
11 | Zhening Li and
12 | Zhengkai Tu and
13 | Connor W. Coley and
14 | Regina Barzilay},
15 | title = {Predictive Chemistry Augmented with Text Retrieval},
16 | booktitle = {Proceedings of the 2023 Conference on Empirical Methods in Natural
17 | Language Processing, {EMNLP} 2023, Singapore, December 6-10, 2023},
18 | pages = {12731--12745},
19 | publisher = {Association for Computational Linguistics},
20 | year = {2023},
21 | url = {https://aclanthology.org/2023.emnlp-main.784}
22 | }
23 | ```
24 |
25 | ## Requirements
26 | We implement the code with `torch==1.11.0`, `pytorch-lightning==2.0.0`, and `transformers==4.27.3`.
27 | To reproduce our experiments, we recommend creating a conda environment with the same dependencies:
28 | ```bash
29 | conda env create -f environment.yml -n textreact
30 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
31 | ```
32 |
33 | ## Data
34 |
35 | Run the following commands to download and unzip the preprocessed datasets:
36 | ```
37 | git clone https://huggingface.co/datasets/yujieq/TextReact data
38 | cd data
39 | unzip '*'
40 | ```
41 |
42 | ## Training Scripts
43 |
44 | TextReact consists of two modules: SMILES-To-text retriever and
45 | text-augmented predictor. This repository only contains the code for
46 | training the predictor, while the code for the retriever is available in
47 | a separate repository: https://github.com/thomas0809/tevatron.
48 |
49 | The training scripts are located under [`scripts`](scripts):
50 | * [`train_RCR.sh`](scripts/train_RCR.sh) trains a model for reaction condition recommendation (RCR)
51 | on the random split of the USPTO dataset.
52 | * [`train_RetroSyn_tf.sh`](scripts/train_RetroSyn_tf.sh) trains a template-free model for retrosynthesis
53 | on the random split of the USPTO-50K dataset.
54 | * [`train_RetroSyn_tb.sh`](scripts/train_RetroSyn_tb.sh) trains a template-based model for retrosynthesis
55 | on the random split of the USPTO-50K dataset.
56 | In addition, [`train_RCR_TS.sh`](scripts/train_RCR_TS.sh), [`train_RetroSyn_tf_TS.sh`](scripts/train_RetroSyn_tf_TS.sh)
57 | and [`train_RetroSyn_tb_TS.sh`](scripts/train_RetroSyn_tb_TS.sh) train the corresponding models
58 | on the time-based split of the dataset.
59 |
60 | If you're working on a distributed file system, it is recommended to
61 | add to the script a `--cache_path` option specifying a local path to reduce network time.
62 |
63 | To run the script `scripts/train_MODEL.sh`, run the following command at the root of the folder:
64 | ```
65 | bash scripts/train_MODEL.sh
66 | ```
67 |
68 | At the end of training, two dictionaries are printed with the top-k test accuracies.
69 | The first one corresponds to retrieving from the full corpus
70 | and the second one corresponds to retrieving from the gold-removed corpus.
71 |
72 | Models and test predictions are stored under the path specified by the `SAVE_PATH` variable in the script.
73 | * `best.ckpt` is the checkpoint with the highest validation accuracy so far, whereas
74 | * `last.ckpt` is the last checkpoint.
75 | * `prediction_test_0.json` contains the test predictions when retrieving from the full corpus.
76 | * `prediction_test_1.json` contains the predictions when retrieving from the gold-removed corpus.
77 |
--------------------------------------------------------------------------------
/assets/textreact.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thomas0809/textreact/feb62d868a627d293997a56a72f50420377c59b4/assets/textreact.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: textreact
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=conda_forge
7 | - _openmp_mutex=4.5=2_gnu
8 | - bzip2=1.0.8=h7f98852_4
9 | - ca-certificates=2022.12.7=ha878542_0
10 | - ld_impl_linux-64=2.40=h41732ed_0
11 | - libffi=3.4.2=h7f98852_5
12 | - libgcc-ng=12.2.0=h65d4601_19
13 | - libgomp=12.2.0=h65d4601_19
14 | - libnsl=2.0.0=h7f98852_0
15 | - libsqlite=3.40.0=h753d276_0
16 | - libuuid=2.32.1=h7f98852_1000
17 | - libzlib=1.2.13=h166bdaf_4
18 | - ncurses=6.3=h27087fc_1
19 | - openssl=3.1.0=h0b41bf4_0
20 | - pip=23.0.1=pyhd8ed1ab_0
21 | - python=3.8.16=he550d4f_1_cpython
22 | - readline=8.2=h8228510_1
23 | - setuptools=67.6.0=pyhd8ed1ab_0
24 | - tk=8.6.12=h27826a3_0
25 | - wheel=0.40.0=pyhd8ed1ab_0
26 | - xz=5.2.6=h166bdaf_0
27 | - pip:
28 | - aiohttp==3.8.4
29 | - aiosignal==1.3.1
30 | - anyio==3.6.2
31 | - appdirs==1.4.4
32 | - arrow==1.2.3
33 | - asttokens==2.2.1
34 | - async-timeout==4.0.2
35 | - attrs==22.2.0
36 | - backcall==0.2.0
37 | - beautifulsoup4==4.12.0
38 | - blessed==1.20.0
39 | - certifi==2022.12.7
40 | - charset-normalizer==3.1.0
41 | - click==8.1.3
42 | - croniter==1.3.8
43 | - dateutils==0.6.12
44 | - decorator==5.1.1
45 | - deepdiff==6.3.0
46 | - dnspython==2.3.0
47 | - docker-pycreds==0.4.0
48 | - email-validator==1.3.1
49 | - executing==1.2.0
50 | - fastapi==0.88.0
51 | - filelock==3.10.7
52 | - frozenlist==1.3.3
53 | - fsspec==2023.3.0
54 | - gitdb==4.0.10
55 | - gitpython==3.1.31
56 | - h11==0.14.0
57 | - httpcore==0.16.3
58 | - httptools==0.5.0
59 | - httpx==0.23.3
60 | - huggingface-hub==0.13.3
61 | - idna==3.4
62 | - inquirer==3.1.3
63 | - ipython==8.12.2
64 | - itsdangerous==2.1.2
65 | - jedi==0.18.2
66 | - jinja2==3.1.2
67 | - lightning==2.0.0
68 | - lightning-cloud==0.5.32
69 | - lightning-utilities==0.8.0
70 | - markdown-it-py==2.2.0
71 | - markupsafe==2.1.2
72 | - matplotlib-inline==0.1.6
73 | - mdurl==0.1.2
74 | - multidict==6.0.4
75 | - numpy==1.24.2
76 | - ordered-set==4.1.0
77 | - orjson==3.8.8
78 | - packaging==23.0
79 | - pandas==1.5.3
80 | - parso==0.8.3
81 | - pathtools==0.1.2
82 | - pexpect==4.8.0
83 | - pickleshare==0.7.5
84 | - pillow==9.4.0
85 | - prompt-toolkit==3.0.38
86 | - protobuf==4.22.1
87 | - psutil==5.9.4
88 | - ptyprocess==0.7.0
89 | - pure-eval==0.2.2
90 | - pydantic==1.10.7
91 | - pygments==2.14.0
92 | - pyjwt==2.6.0
93 | - python-dateutil==2.8.2
94 | - python-dotenv==1.0.0
95 | - python-editor==1.0.4
96 | - python-multipart==0.0.6
97 | - pytorch-lightning==2.0.0
98 | - pytz==2023.2
99 | - pyyaml==6.0
100 | - rdkit-pypi==2022.9.5
101 | - readchar==4.0.5
102 | - regex==2023.3.23
103 | - requests==2.28.2
104 | - rfc3986==1.5.0
105 | - rich==13.3.3
106 | - sentry-sdk==1.18.0
107 | - setproctitle==1.3.2
108 | - six==1.16.0
109 | - smmap==5.0.0
110 | - sniffio==1.3.0
111 | - soupsieve==2.4
112 | - stack-data==0.6.2
113 | - starlette==0.22.0
114 | - starsessions==1.3.0
115 | - tokenizers==0.13.2
116 | - tqdm==4.65.0
117 | - traitlets==5.9.0
118 | - transformers==4.27.3
119 | - typing-extensions==4.5.0
120 | - ujson==5.7.0
121 | - urllib3==1.26.15
122 | - uvicorn==0.21.1
123 | - uvloop==0.17.0
124 | - wandb==0.14.0
125 | - watchfiles==0.19.0
126 | - wcwidth==0.2.6
127 | - websocket-client==1.5.1
128 | - websockets==10.4
129 | - yarl==1.8.2
130 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import json
4 | import copy
5 | import random
6 | import argparse
7 | import collections
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torch.distributed as dist
14 | import pytorch_lightning as pl
15 | from pytorch_lightning import LightningModule, LightningDataModule
16 | from pytorch_lightning.strategies.ddp import DDPStrategy
17 | from transformers import get_scheduler, EncoderDecoderModel, EncoderDecoderConfig, AutoTokenizer, AutoConfig, AutoModel
18 |
19 | from textreact.tokenizer import get_tokenizers
20 | from textreact.model import get_model, get_mlm_head
21 | from textreact.dataset import ReactionConditionDataset, RetrosynthesisDataset, read_corpus, generate_train_label_corpus
22 | from textreact.evaluate import evaluate_reaction_condition, evaluate_retrosynthesis
23 | import textreact.utils as utils
24 |
25 |
26 | def get_args(notebook=False):
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--task', type=str, default='condition')
29 | parser.add_argument('--do_train', action='store_true')
30 | parser.add_argument('--do_valid', action='store_true')
31 | parser.add_argument('--do_test', action='store_true')
32 | parser.add_argument('--precision', type=str, default='32')
33 | parser.add_argument('--seed', type=int, default=42)
34 | parser.add_argument('--gpus', type=int, default=1)
35 | parser.add_argument('--print_freq', type=int, default=200)
36 | parser.add_argument('--debug', action='store_true')
37 | # Model
38 | parser.add_argument('--template_based', action='store_true')
39 | parser.add_argument('--unattend_nonbonds', action='store_true')
40 | parser.add_argument('--encoder', type=str, default=None)
41 | parser.add_argument('--decoder', type=str, default=None)
42 | parser.add_argument('--encoder_pretrained', action='store_true')
43 | parser.add_argument('--decoder_pretrained', action='store_true')
44 | parser.add_argument('--share_embedding', action='store_true')
45 | parser.add_argument('--encoder_tokenizer', type=str, default='text')
46 | # Data
47 | parser.add_argument('--data_path', type=str, default=None)
48 | parser.add_argument('--template_path', type=str, default=None)
49 | parser.add_argument('--train_file', type=str, default=None)
50 | parser.add_argument('--valid_file', type=str, default=None)
51 | parser.add_argument('--test_file', type=str, default=None)
52 | parser.add_argument('--vocab_file', type=str, default=None)
53 | parser.add_argument('--corpus_file', type=str, default=None)
54 | parser.add_argument('--train_label_corpus', action='store_true')
55 | parser.add_argument('--cache_path', type=str, default=None)
56 | parser.add_argument('--nn_path', type=str, default=None)
57 | parser.add_argument('--train_nn_file', type=str, default=None)
58 | parser.add_argument('--valid_nn_file', type=str, default=None)
59 | parser.add_argument('--test_nn_file', type=str, default=None)
60 | parser.add_argument('--max_length', type=int, default=128)
61 | parser.add_argument('--max_dec_length', type=int, default=128)
62 | parser.add_argument('--num_workers', type=int, default=8)
63 | parser.add_argument('--shuffle_smiles', action='store_true')
64 | parser.add_argument('--no_smiles', action='store_true')
65 | parser.add_argument('--num_neighbors', type=int, default=-1)
66 | parser.add_argument('--use_gold_neighbor', action='store_true')
67 | parser.add_argument('--max_num_neighbors', type=int, default=10)
68 | parser.add_argument('--random_neighbor_ratio', type=float, default=0.8)
69 | parser.add_argument('--mlm', action='store_true')
70 | parser.add_argument('--mlm_ratio', type=float, default=0.15)
71 | parser.add_argument('--mlm_layer', type=str, default='linear')
72 | parser.add_argument('--mlm_lambda', type=float, default=1)
73 | # Training
74 | parser.add_argument('--epochs', type=int, default=8)
75 | parser.add_argument('--batch_size', type=int, default=256)
76 | parser.add_argument('--lr', type=float, default=1e-4)
77 | parser.add_argument('--weight_decay', type=float, default=0.01)
78 | parser.add_argument('--max_grad_norm', type=float, default=5.)
79 | parser.add_argument('--scheduler', type=str, choices=['cosine', 'constant'], default='cosine')
80 | parser.add_argument('--warmup_ratio', type=float, default=0)
81 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
82 | parser.add_argument('--load_ckpt', type=str, default='best.ckpt')
83 | parser.add_argument('--eval_per_epoch', type=int, default=1)
84 | parser.add_argument('--val_metric', type=str, default='val_acc')
85 | parser.add_argument('--save_path', type=str, default='output/')
86 | parser.add_argument('--overwrite', action='store_true')
87 | parser.add_argument('--num_train_example', type=int, default=None)
88 | parser.add_argument('--label_smoothing', type=float, default=0.0)
89 | # Inference
90 | parser.add_argument('--test_batch_size', type=int, default=64)
91 | parser.add_argument('--num_beams', type=int, default=1)
92 | parser.add_argument('--test_each_neighbor', action='store_true')
93 | parser.add_argument('--test_num_neighbors', type=int, default=1)
94 |
95 | args = parser.parse_args([]) if notebook else parser.parse_args()
96 |
97 | return args
98 |
99 |
100 | class ReactionConditionRecommender(LightningModule):
101 |
102 | def __init__(self, args):
103 | super().__init__()
104 | self.args = args
105 | self.enc_tokenizer, self.dec_tokenizer = get_tokenizers(args)
106 | self.model = get_model(args, self.enc_tokenizer, self.dec_tokenizer)
107 | if args.mlm:
108 | self.mlm_head = get_mlm_head(args, self.model)
109 | self.validation_outputs = collections.defaultdict(dict)
110 | self.test_outputs = collections.defaultdict(dict)
111 |
112 | def compute_loss(self, logits, batch, reduction='mean'):
113 | if self.args.template_based:
114 | atom_logits, bond_logits = logits
115 | batch_size, max_len, atom_vocab_size = atom_logits.size()
116 | bond_vocab_size = bond_logits.size()[-1]
117 | atom_template_loss = F.cross_entropy(input=atom_logits.reshape(-1, atom_vocab_size),
118 | target=batch['decoder_atom_template_labels'].reshape(-1),
119 | reduction=reduction)
120 | bond_template_loss = F.cross_entropy(input=bond_logits.reshape(-1, bond_vocab_size),
121 | target=batch['decoder_bond_template_labels'].reshape(-1),
122 | reduction=reduction)
123 | if reduction == 'none':
124 | atom_template_loss = atom_template_loss.view(batch_size, -1).mean(dim=1)
125 | bond_template_loss = bond_template_loss.view(batch_size, -1).mean(dim=1)
126 | loss = atom_template_loss + bond_template_loss
127 | else:
128 | batch_size, max_len, vocab_size = logits.size()
129 | labels = batch['decoder_input_ids'][:, 1:]
130 | loss = F.cross_entropy(input=logits[:, :-1].reshape(-1, vocab_size), target=labels.reshape(-1),
131 | ignore_index=self.dec_tokenizer.pad_token_id, reduction=reduction)
132 | if reduction == 'none':
133 | loss = loss.view(batch_size, -1).mean(dim=1)
134 | return loss
135 |
136 | def compute_acc(self, logits, batch, reduction='mean'):
137 | # This accuracy is equivalent to greedy search accuracy
138 | if self.args.template_based:
139 | atom_logits_batch, bond_logits_batch = logits
140 | atom_probs_batch = F.softmax(atom_logits_batch, dim=-1)
141 | bond_probs_batch = F.softmax(bond_logits_batch, dim=-1)
142 | atom_probs_batch[batch['decoder_atom_template_labels'] == -100] = 0
143 | bond_probs_batch[batch['decoder_bond_template_labels'] == -100] = 0
144 | acc = []
145 | for atom_probs, bond_probs, bonds, raw_template_labels in zip(
146 | atom_probs_batch, bond_probs_batch, batch['bonds'], batch['decoder_raw_template_labels']):
147 | edit_pred = utils.combined_edit(atom_probs, bond_probs, bonds, 1)[0][0]
148 | acc.append(float(edit_pred in raw_template_labels) / max(len(raw_template_labels), 1))
149 | acc = torch.tensor(acc)
150 | else:
151 | preds = logits.argmax(dim=-1)[:, :-1]
152 | labels = batch['decoder_input_ids'][:, 1:]
153 | acc = torch.logical_or(preds.eq(labels), labels.eq(self.dec_tokenizer.pad_token_id)).all(dim=-1)
154 | if reduction == 'mean':
155 | acc = acc.mean()
156 | return acc
157 |
158 | def compute_mlm_loss(self, encoder_last_hidden_state, labels):
159 | batch_size, trunc_len = labels.size()
160 | trunc_hidden_state = encoder_last_hidden_state[:, :trunc_len].contiguous()
161 | logits = self.mlm_head(trunc_hidden_state)
162 | return F.cross_entropy(input=logits.view(batch_size * trunc_len, -1), target=labels.view(-1))
163 |
164 | def training_step(self, batch, batch_idx):
165 | indices, batch_in, batch_out = batch
166 | output = self.model(**batch_in)
167 | loss = self.compute_loss(output.logits, batch_in)
168 | self.log('train_loss', loss)
169 | total_loss = loss
170 | if self.args.mlm:
171 | mlm_loss = self.compute_mlm_loss(output.encoder_last_hidden_state, batch_out['mlm_labels'])
172 | total_loss += mlm_loss * self.args.mlm_lambda
173 | self.log('mlm_loss', mlm_loss)
174 | self.log('total_loss', total_loss)
175 | return total_loss
176 |
177 | def validation_step(self, batch, batch_idx, dataloader_idx=0):
178 | indices, batch_in, batch_out = batch
179 | output = self.model(**batch_in)
180 | if self.args.val_metric == 'val_loss':
181 | scores = self.compute_loss(output.logits, batch_in, reduction='none').tolist()
182 | elif self.args.val_metric == 'val_acc':
183 | scores = self.compute_acc(output.logits, batch_in, reduction='none').tolist()
184 | else:
185 | raise ValueError
186 | for idx, score in zip(indices, scores):
187 | self.validation_outputs[dataloader_idx][idx] = score
188 | return output
189 |
190 | def on_validation_epoch_end(self):
191 | for dataloader_idx in self.validation_outputs:
192 | validation_outputs = self.gather_outputs(self.validation_outputs[dataloader_idx])
193 | val_score = np.mean([v for v in validation_outputs.values()])
194 | metric_name = self.args.val_metric if dataloader_idx == 0 else f'{self.args.val_metric}/{dataloader_idx}'
195 | self.log(metric_name, val_score, prog_bar=True, rank_zero_only=True)
196 | self.validation_outputs.clear()
197 |
198 | def test_step(self, batch, batch_idx, dataloader_idx=0):
199 | indices, batch_in, batch_out = batch
200 | num_beams = self.args.num_beams
201 | if self.args.template_based:
202 | atom_logits_batch, bond_logits_batch = self.model(**batch_in).logits
203 | atom_probs_batch = F.softmax(atom_logits_batch, dim=-1)
204 | bond_probs_batch = F.softmax(bond_logits_batch, dim=-1)
205 | atom_probs_batch[batch_in['decoder_atom_template_labels'] == -100] = 0
206 | bond_probs_batch[batch_in['decoder_bond_template_labels'] == -100] = 0
207 | acc = []
208 | for idx, atom_probs, bond_probs, bonds, raw_template_labels in zip(
209 | indices, atom_probs_batch, bond_probs_batch, batch_in['bonds'], batch_in['decoder_raw_template_labels']):
210 | edit_pred, edit_prob = utils.combined_edit(atom_probs, bond_probs, bonds, top_num=500)
211 | self.test_outputs[dataloader_idx][idx] = {
212 | 'prediction': edit_pred,
213 | 'score': edit_prob,
214 | 'raw_template_labels': raw_template_labels,
215 | 'top1_template_match': edit_pred[0] in raw_template_labels
216 | }
217 | else:
218 | output = self.model.generate(
219 | **batch_in, num_beams=num_beams, num_return_sequences=num_beams,
220 | max_length=self.args.max_dec_length, length_penalty=0,
221 | bos_token_id=self.dec_tokenizer.bos_token_id, eos_token_id=self.dec_tokenizer.eos_token_id,
222 | pad_token_id=self.dec_tokenizer.pad_token_id,
223 | return_dict_in_generate=True, output_scores=True)
224 | predictions = self.dec_tokenizer.batch_decode(output.sequences, skip_special_tokens=True)
225 | if 'sequences_scores' in predictions:
226 | scores = output.sequences_scores.tolist()
227 | else:
228 | scores = [0] * len(predictions)
229 | for i, idx in enumerate(indices):
230 | self.test_outputs[dataloader_idx][idx] = {
231 | 'prediction': predictions[i * num_beams: (i + 1) * num_beams],
232 | 'score': scores[i * num_beams: (i + 1) * num_beams]
233 | }
234 | return
235 |
236 | def on_test_epoch_end(self):
237 | for dataloader_idx in self.test_outputs:
238 | test_outputs = self.gather_outputs(self.test_outputs[dataloader_idx])
239 | if self.args.test_each_neighbor:
240 | test_outputs = utils.gather_prediction_each_neighbor(test_outputs, self.args.test_num_neighbors)
241 | if self.trainer.is_global_zero:
242 | # Save prediction
243 | with open(os.path.join(self.args.save_path,
244 | f'prediction_{self.test_dataset.name}_{dataloader_idx}.json'), 'w') as f:
245 | json.dump(test_outputs, f)
246 | # Evaluate
247 | if self.args.task == 'condition':
248 | accuracy = evaluate_reaction_condition(test_outputs, self.test_dataset.data_df)
249 | elif self.args.task == 'retro':
250 | accuracy = evaluate_retrosynthesis(test_outputs, self.test_dataset.data_df, self.args.num_beams,
251 | template_based=self.args.template_based,
252 | template_path=self.args.template_path)
253 | else:
254 | accuracy = []
255 | self.print(self.ckpt_path)
256 | self.print(json.dumps(accuracy))
257 | self.test_outputs.clear()
258 |
259 | def gather_outputs(self, outputs):
260 | if self.trainer.num_devices > 1:
261 | gathered = [{} for i in range(self.trainer.num_devices)]
262 | dist.all_gather_object(gathered, outputs)
263 | gathered_outputs = {}
264 | for outputs in gathered:
265 | gathered_outputs.update(outputs)
266 | else:
267 | gathered_outputs = outputs
268 | return gathered_outputs
269 |
270 | def configure_optimizers(self):
271 | num_training_steps = self.trainer.num_training_steps
272 | self.print(f'Num training steps: {num_training_steps}')
273 | num_warmup_steps = int(num_training_steps * self.args.warmup_ratio)
274 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
275 | scheduler = get_scheduler(self.args.scheduler, optimizer, num_warmup_steps, num_training_steps)
276 | return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
277 |
278 |
279 | class ReactionConditionDataModule(LightningDataModule):
280 |
281 | DATASET_CLS = {
282 | 'condition': ReactionConditionDataset,
283 | 'retro': RetrosynthesisDataset,
284 | }
285 |
286 | def __init__(self, args, model):
287 | super().__init__()
288 | self.args = args
289 | self.enc_tokenizer = model.enc_tokenizer
290 | self.dec_tokenizer = model.dec_tokenizer
291 | self.train_dataset, self.val_dataset, self.test_dataset = None, None, None
292 |
293 | def prepare_data(self):
294 | args = self.args
295 | dataset_cls = self.DATASET_CLS[args.task]
296 | if args.do_train:
297 | data_file = os.path.join(args.data_path, args.train_file)
298 | self.train_dataset = dataset_cls(
299 | args, data_file, self.enc_tokenizer, self.dec_tokenizer, split='train')
300 | print(f'Train dataset: {len(self.train_dataset)}')
301 | if args.do_train or args.do_valid:
302 | data_file = os.path.join(args.data_path, args.valid_file)
303 | self.val_dataset = dataset_cls(
304 | args, data_file, self.enc_tokenizer, self.dec_tokenizer, split='val')
305 | print(f'Valid dataset: {len(self.val_dataset)}')
306 | if args.do_test:
307 | data_file = os.path.join(args.data_path, args.test_file)
308 | self.test_dataset = dataset_cls(
309 | args, data_file, self.enc_tokenizer, self.dec_tokenizer, split='test')
310 | print(f'Test dataset: {len(self.test_dataset)}')
311 | if args.corpus_file:
312 | if args.train_label_corpus:
313 | assert args.task == 'condition'
314 | corpus = generate_train_label_corpus(os.path.join(args.data_path, args.train_file))
315 | else:
316 | corpus = read_corpus(args.corpus_file, args.cache_path)
317 | if self.train_dataset is not None:
318 | self.train_dataset.load_corpus(corpus, os.path.join(args.nn_path, args.train_nn_file))
319 | self.train_dataset.print_example()
320 | if self.val_dataset is not None:
321 | self.val_dataset.load_corpus(corpus, os.path.join(args.nn_path, args.valid_nn_file))
322 | if self.test_dataset is not None:
323 | self.test_dataset.load_corpus(corpus, os.path.join(args.nn_path, args.test_nn_file))
324 |
325 | def train_dataloader(self):
326 | return torch.utils.data.DataLoader(
327 | self.train_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
328 | collate_fn=self.train_dataset.collator)
329 |
330 | def get_eval_dataloaders(self, dataset):
331 | args = self.args
332 | dataloader = torch.utils.data.DataLoader(
333 | dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=dataset.collator)
334 | if args.corpus_file is None:
335 | return dataloader
336 | dataset_skip_gold = copy.copy(dataset)
337 | dataset_skip_gold.skip_gold_neighbor = True
338 | dataloader_skip_gold = torch.utils.data.DataLoader(
339 | dataset_skip_gold, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=dataset.collator)
340 | return [dataloader, dataloader_skip_gold]
341 |
342 | def val_dataloader(self):
343 | return self.get_eval_dataloaders(self.val_dataset)
344 |
345 | def test_dataloader(self):
346 | return self.get_eval_dataloaders(self.test_dataset)
347 |
348 |
349 | def main():
350 | args = get_args()
351 | pl.seed_everything(args.seed, workers=True)
352 |
353 | model = ReactionConditionRecommender(args)
354 |
355 | dm = ReactionConditionDataModule(args, model)
356 | dm.prepare_data()
357 |
358 | checkpoint = pl.callbacks.ModelCheckpoint(
359 | monitor=args.val_metric, mode=utils.metric_to_mode[args.val_metric], save_top_k=1, filename='best',
360 | save_last=True, dirpath=args.save_path, auto_insert_metric_name=False)
361 | lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
362 | if args.do_train and not args.debug:
363 | project_name = 'TextReact'
364 | if args.task == 'retro':
365 | project_name += '_retro'
366 | logger = pl.loggers.WandbLogger(
367 | project=project_name, save_dir=args.save_path, name=os.path.basename(args.save_path))
368 | else:
369 | logger = None
370 |
371 | trainer = pl.Trainer(
372 | strategy=DDPStrategy(find_unused_parameters=True),
373 | accelerator='gpu',
374 | devices=args.gpus,
375 | precision=args.precision,
376 | logger=logger,
377 | default_root_dir=args.save_path,
378 | callbacks=[checkpoint, lr_monitor],
379 | max_epochs=args.epochs,
380 | gradient_clip_val=args.max_grad_norm,
381 | accumulate_grad_batches=args.gradient_accumulation_steps,
382 | check_val_every_n_epoch=args.eval_per_epoch,
383 | log_every_n_steps=10,
384 | deterministic=True)
385 |
386 | if args.do_train:
387 | trainer.num_training_steps = math.ceil(
388 | len(dm.train_dataset) / (args.batch_size * args.gpus * args.gradient_accumulation_steps)) * args.epochs
389 | # Load or delete existing checkpoint
390 | if args.overwrite:
391 | utils.clear_path(args.save_path, trainer)
392 | ckpt_path = None
393 | else:
394 | ckpt_path = os.path.join(args.save_path, args.load_ckpt)
395 | ckpt_path = ckpt_path if checkpoint.file_exists(ckpt_path, trainer) else None
396 | # Train
397 | trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)
398 | best_model_path = checkpoint.best_model_path
399 | else:
400 | best_model_path = os.path.join(args.save_path, args.load_ckpt)
401 |
402 | if args.do_valid or args.do_test:
403 | print('Load model checkpoint:', best_model_path)
404 | model = ReactionConditionRecommender.load_from_checkpoint(best_model_path, strict=False, args=args)
405 | model.ckpt_path = best_model_path
406 |
407 | if args.do_valid:
408 | trainer.validate(model, datamodule=dm)
409 |
410 | if args.do_test:
411 | model.test_dataset = dm.test_dataset
412 | trainer.test(model, datamodule=dm)
413 |
414 |
415 | if __name__ == "__main__":
416 | main()
417 |
--------------------------------------------------------------------------------
/preprocess/dedup_corpus.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pandas as pd
4 |
5 | # Dedup
6 |
7 | # corpus_df = pd.read_csv('uspto_script/uspto_rxn_corpus.csv')
8 | # text_to_corpus_id = {}
9 | # id_to_corpus_id = {}
10 | # dedup_flags = [False] * len(corpus_df)
11 | # for i, (idx, text) in enumerate(zip(corpus_df['id'], corpus_df['paragraph_text'])):
12 | # if text not in text_to_corpus_id:
13 | # text_to_corpus_id[text] = idx
14 | # dedup_flags[i] = True
15 | # id_to_corpus_id[idx] = text_to_corpus_id[text]
16 | #
17 | # dedup_df = corpus_df[dedup_flags]
18 | # dedup_df.to_csv('uspto_script/uspto_rxn_corpus_dedup.csv', index=False)
19 | #
20 | # with open('id_to_corpus_id.json', 'w') as f:
21 | # json.dump(id_to_corpus_id, f)
22 |
23 |
24 | # Add corpus id
25 | with open('id_to_corpus_id.json') as f:
26 | id_to_corpus_id = json.load(f)
27 |
28 | # data_path = '../data/USPTO_condition/'
29 | # data_path = '../data/USPTO_condition_year/'
30 | # for file in ['USPTO_condition_train.csv', 'USPTO_condition_val.csv', 'USPTO_condition_test.csv']:
31 |
32 | # data_path = '../data/USPTO_50K/matched1/'
33 | data_path = '../data/USPTO_50K_year/'
34 | for file in ['train.csv', 'valid.csv', 'test.csv']:
35 | path = os.path.join(data_path, file)
36 | print(path)
37 | df = pd.read_csv(path)
38 | corpus_id = [id_to_corpus_id.get(idx, idx) for idx in df['id']]
39 | df['corpus_id'] = corpus_id
40 | cols = ['id', 'corpus_id']
41 | for col in df.columns:
42 | if col not in cols:
43 | cols.append(col)
44 | df = df[cols]
45 | df.to_csv(path, index=False)
46 |
--------------------------------------------------------------------------------
/preprocess/gen_uspto.py:
--------------------------------------------------------------------------------
1 | import urllib.request
2 | import zipfile
3 | import os
4 | import re
5 | import sys
6 | import glob
7 | import shutil
8 | import multiprocessing
9 | import pandas as pd
10 | from tqdm import tqdm
11 | from collections import Counter
12 | import json
13 | from json import encoder
14 | encoder.FLOAT_REPR = lambda o: format(o, '.3f')
15 | import numpy as np
16 |
17 |
18 | BASE = '/scratch/yujieq/uspto_grant_red/'
19 | BASE_TXT = '/scratch/yujieq/uspto_grant_fulltext/'
20 | BASE_TXT_ZIP = '/scratch/yujieq/uspto_grant_fulltext_zip/'
21 |
22 |
23 | # Download data
24 | def _download_file(url, output):
25 | if not os.path.exists(output):
26 | urllib.request.urlretrieve(url, output)
27 |
28 |
29 | def download():
30 | for year in range(2016, 2017):
31 | url = f"https://bulkdata.uspto.gov/data/patent/grant/redbook/{year}/"
32 | f = urllib.request.urlopen(url)
33 | content = f.read().decode('utf-8')
34 | print(url)
35 | zip_files = re.findall(r"href=\"(I*\d\d\d\d\d\d\d\d(.ZIP|.zip|.tar))\"", content)
36 | print(zip_files)
37 | path = os.path.join(BASE, str(year))
38 | os.makedirs(path, exist_ok=True)
39 | args = []
40 | for file, ext in zip_files:
41 | output = os.path.join(path, file)
42 | args.append((url + file, output))
43 | # with multiprocessing.Pool(8) as p:
44 | # p.starmap(_download_file, args)
45 | for url, output in args:
46 | print(url)
47 | _download_file(url, output)
48 |
49 |
50 | def download_fulltext():
51 | for year in range(2002, 2017):
52 | url = f'https://bulkdata.uspto.gov/data/patent/grant/redbook/fulltext/{year}/'
53 | f = urllib.request.urlopen(url)
54 | content = f.read().decode('utf-8')
55 | print(url)
56 | zip_files = re.findall(r"href=\"(\w*(.ZIP|.zip|.tar))\"", content)
57 | print(zip_files)
58 | path = os.path.join(BASE_TXT, str(year))
59 | os.makedirs(path, exist_ok=True)
60 | args = []
61 | for file, ext in zip_files:
62 | output = os.path.join(path, file)
63 | args.append((url + file, output))
64 | # with multiprocessing.Pool(8) as p:
65 | # p.starmap(_download_file, args)
66 | for url, output in args:
67 | print(url)
68 | _download_file(url, output)
69 |
70 |
71 | # Unzip
72 | def is_zip(file):
73 | return file[-4:] in ['.zip', '.ZIP']
74 |
75 |
76 | def unzip():
77 | for year in range(1976, 2017):
78 | path = os.path.join(BASE_TXT_ZIP, str(year))
79 | outpath = os.path.join(BASE_TXT, str(year))
80 | for datefile in sorted(os.listdir(path)):
81 | if is_zip(datefile):
82 | print(os.path.join(path, datefile))
83 | with zipfile.ZipFile(os.path.join(path, datefile), 'r') as zipobj:
84 | zipobj.extractall(outpath)
85 |
86 |
87 | if __name__ == "__main__":
88 | if sys.argv[1] == 'download':
89 | download()
90 | elif sys.argv[1] == 'download_fulltext':
91 | download_fulltext()
92 | elif sys.argv[1] == 'unzip':
93 | unzip()
94 |
--------------------------------------------------------------------------------
/preprocess/get_templates.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import logging
3 | from copy import deepcopy
4 | import os
5 | import pandas as pd
6 | import re
7 | from abc import ABC, abstractmethod
8 | from collections import defaultdict
9 | # from models.localretro_model.Extract_from_train_data import get_full_template
10 | # from models.localretro_model.LocalTemplate.template_extractor import extract_from_reaction
11 | # from models.localretro_model.Run_preprocessing import get_edit_site_retro
12 | from typing import Dict, List
13 | from rdkit import Chem
14 | from rdkit.Chem import AllChem
15 | from template_extraction.template_extractor import extract_from_reaction
16 | from template_extraction.template_extract_utils import get_bonds_from_smiles
17 |
18 | import sys
19 | sys.path.append('../')
20 | from textreact.tokenizer import BasicSmilesTokenizer
21 |
22 | log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
23 | logging.basicConfig(format = log_format, level = logging.INFO)
24 |
25 | logger = logging.getLogger()
26 |
27 |
28 | ATOM_REGEX = re.compile(r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p")
29 |
30 |
31 | def get_full_template(template, H_change, Charge_change, Chiral_change):
32 | H_code = ''.join([str(H_change[k+1]) for k in range(len(H_change))])
33 | Charge_code = ''.join([str(Charge_change[k+1]) for k in range(len(Charge_change))])
34 | Chiral_code = ''.join([str(Chiral_change[k+1]) for k in range(len(Chiral_change))])
35 | if Chiral_code == '':
36 | return '_'.join([template, H_code, Charge_code])
37 | else:
38 | return '_'.join([template, H_code, Charge_code, Chiral_code])
39 |
40 |
41 | def canonicalize_smiles(smiles):
42 | mol = Chem.MolFromSmiles(smiles)
43 | for a in mol.GetAtoms():
44 | a.SetAtomMapNum(0)
45 | canon_smile = Chem.MolToSmiles(mol)
46 | canon_perm = eval(mol.GetProp("_smilesAtomOutputOrder"))
47 | # atomidx2stridx = [i
48 | # for i, token in enumerate(BasicSmilesTokenizer().tokenize(canon_smile))
49 | # if ATOM_REGEX.fullmatch(token) is not None]
50 | # atomidx2canonstridx = [0 for _ in range(len(canon_perm))]
51 | # for canon_idx, orig_idx in enumerate(canon_perm):
52 | # atomidx2canonstridx[orig_idx] = atomidx2stridx[canon_idx]
53 | atomidx2canonidx = [None for _ in range(len(canon_perm))]
54 | for canon_idx, orig_idx in enumerate(canon_perm):
55 | atomidx2canonidx[orig_idx] = canon_idx
56 | return canon_smile, atomidx2canonidx
57 |
58 |
59 | class Processor(ABC):
60 | """Base class for processor"""
61 |
62 | @abstractmethod
63 | def __init__(self,
64 | model_name: str,
65 | model_args, # let's enforce everything to be passed in args
66 | model_config: Dict[str, any], # or config
67 | data_name: str,
68 | raw_data_files: List[str],
69 | processed_data_path: str):
70 | self.model_name = model_name
71 | self.model_args = model_args
72 | self.model_config = model_config
73 | self.data_name = data_name
74 | self.raw_data_files = raw_data_files
75 | self.processed_data_path = processed_data_path
76 |
77 | os.makedirs(self.processed_data_path, exist_ok=True)
78 |
79 | self.check_count = 100
80 |
81 | def check_data_format(self) -> None:
82 | """Check that all files exists and the data format is correct for the first few lines"""
83 | logger.info(f"Checking the first {self.check_count} entries for each file")
84 | for fn in self.raw_data_files:
85 | if not fn:
86 | continue
87 | assert os.path.exists(fn), f"{fn} does not exist!"
88 |
89 | with open(fn, "r") as csv_file:
90 | csv_reader = csv.DictReader(csv_file)
91 | for i, row in enumerate(csv_reader):
92 | if i > self.check_count: # check the first few rows
93 | break
94 |
95 | assert (c in row for c in ["class", "rxn_smiles"]), \
96 | f"Error processing file {fn} line {i}, ensure columns 'class' and " \
97 | f"'rxn_smiles' is included!"
98 |
99 | reactants, reagents, products = row["rxn_smiles"].split(">")
100 | Chem.MolFromSmiles(reactants) # simply ensures that SMILES can be parsed
101 | Chem.MolFromSmiles(products) # simply ensures that SMILES can be parsed
102 |
103 | logger.info("Data format check passed")
104 |
105 | @abstractmethod
106 | def preprocess(self) -> None:
107 | """Actual file-based preprocessing"""
108 | pass
109 |
110 | # Adpated from Extract_from_train_data.py
111 | class LocalRetroProcessor(Processor):
112 | """Class for LocalRetro Preprocessing"""
113 |
114 | def __init__(self,
115 | model_name: str,
116 | model_args,
117 | model_config: Dict[str, any],
118 | data_name: str,
119 | raw_data_files: List[str],
120 | processed_data_path: str,
121 | num_cores: int = None):
122 | super().__init__(model_name=model_name,
123 | model_args=model_args,
124 | model_config=model_config,
125 | data_name=data_name,
126 | raw_data_files=raw_data_files,
127 | processed_data_path=processed_data_path)
128 | self.train_file, self.val_file, self.test_file = raw_data_files
129 | self.num_cores = num_cores
130 | self.setting = {'verbose': False, 'use_stereo': True, 'use_symbol': True,
131 | 'max_unmap': 5, 'retro': True, 'remote': True, 'least_atom_num': 2,
132 | "max_edit_n": 8, "min_template_n": 1}
133 | self.RXNHASCLASS = False
134 |
135 | def preprocess(self) -> None:
136 | """Actual file-based preprocessing"""
137 | self.extract_templates()
138 | self.match_templates()
139 |
140 | def extract_templates(self):
141 | """Adapted from Extract_from_train_data.py"""
142 | if all(os.path.exists(os.path.join(self.processed_data_path, fn))
143 | for fn in ["template_infos.csv", "atom_templates.csv", "bond_templates.csv"]):
144 | logger.info(f"Extracted templates found at {self.processed_data_path}, skipping extraction.")
145 | return
146 |
147 | logger.info(f"Extracting templates from {self.train_file}")
148 |
149 | with open(self.train_file, "r") as csv_file:
150 | csv_reader = csv.DictReader(csv_file)
151 | rxns = [row["rxn_smiles"].strip() for row in csv_reader]
152 |
153 | TemplateEdits = {}
154 | TemplateCs, TemplateHs, TemplateSs = {}, {}, {}
155 | TemplateFreq, templates_A, templates_B = defaultdict(int), defaultdict(int), defaultdict(int)
156 |
157 | for i, rxn in enumerate(rxns):
158 | try:
159 | rxn = {'reactants': rxn.split('>')[0], 'products': rxn.split('>')[-1], '_id': i}
160 | result = extract_from_reaction(rxn, self.setting)
161 | if 'reactants' not in result or 'reaction_smarts' not in result.keys():
162 | logger.info(f'\ntemplate problem: id: {i}')
163 | continue
164 | reactant = result['reactants']
165 | template = result['reaction_smarts']
166 | edits = result['edits']
167 | H_change = result['H_change']
168 | Charge_change = result['Charge_change']
169 | Chiral_change = result["Chiral_change"] if self.setting["use_stereo"] else {}
170 |
171 | template_H = get_full_template(template, H_change, Charge_change, Chiral_change)
172 | if template_H not in TemplateHs.keys():
173 | TemplateEdits[template_H] = {edit_type: edits[edit_type][2] for edit_type in edits}
174 | TemplateHs[template_H] = H_change
175 | TemplateCs[template_H] = Charge_change
176 | TemplateSs[template_H] = Chiral_change
177 |
178 | TemplateFreq[template_H] += 1
179 | for edit_type, bonds in edits.items():
180 | bonds = bonds[0]
181 | if len(bonds) > 0:
182 | if edit_type in ['A', 'R']:
183 | templates_A[template_H] += 1
184 | else:
185 | templates_B[template_H] += 1
186 |
187 | except Exception as e:
188 | logger.info(i, e)
189 |
190 | if i % 1000 == 0:
191 | logger.info(f'\r i = {i}, # of template: {len(TemplateFreq)}, '
192 | f'# of atom template: {len(templates_A)}, '
193 | f'# of bond template: {len(templates_B)}')
194 | logger.info('\n total # of template: %s' % len(TemplateFreq))
195 |
196 | derived_templates = {'atom': templates_A, 'bond': templates_B}
197 |
198 | ofn = os.path.join(self.processed_data_path, "template_infos.csv")
199 | TemplateInfos = pd.DataFrame(
200 | {'Template': k,
201 | 'edit_site': TemplateEdits[k],
202 | 'change_H': TemplateHs[k],
203 | 'change_C': TemplateCs[k],
204 | 'change_S': TemplateSs[k],
205 | 'Frequency': TemplateFreq[k]} for k in TemplateHs.keys())
206 | TemplateInfos.to_csv(ofn)
207 |
208 | for k, local_templates in derived_templates.items():
209 | ofn = os.path.join(self.processed_data_path, f"{k}_templates.csv")
210 | with open(ofn, "w") as of:
211 | writer = csv.writer(of)
212 | header = ['Template', 'Frequency', 'Class']
213 | writer.writerow(header)
214 |
215 | sorted_tuples = sorted(local_templates.items(), key=lambda item: item[1])
216 | for i, (template, template_freq) in enumerate(sorted_tuples):
217 | writer.writerow([template, template_freq, i + 1])
218 |
219 | def match_templates(self):
220 | """Adapted from Run_preprocessing.py"""
221 | # load_templates()
222 | template_dicts = {}
223 |
224 | for site in ['atom', 'bond']:
225 | fn = os.path.join(self.processed_data_path, f"{site}_templates.csv")
226 | with open(fn, "r") as csv_file:
227 | csv_reader = csv.DictReader(csv_file)
228 | template_dict = {row["Template"].strip(): int(row["Class"]) for row in csv_reader}
229 | logger.info(f'loaded {len(template_dict)} {site} templates')
230 | template_dicts[site] = template_dict
231 |
232 | fn = os.path.join(self.processed_data_path, "template_infos.csv")
233 | with open(fn, "r") as csv_file:
234 | csv_reader = csv.DictReader(csv_file)
235 | template_infos = {
236 | row["Template"]: {
237 | "edit_site": eval(row["edit_site"]),
238 | "frequency": int(row["Frequency"])
239 | } for row in csv_reader
240 | }
241 | logger.info('loaded total %s templates' % len(template_infos))
242 |
243 | # labeling_dataset()
244 | dfs = {}
245 | for phase, fn in [("train", self.train_file),
246 | ("val", self.val_file),
247 | ("test", self.test_file)]:
248 | with open(fn, "r") as csv_file:
249 | csv_reader = csv.DictReader(csv_file)
250 | rxns = [row["rxn_smiles"].strip() for row in csv_reader]
251 | reactants, products, reagents = [], [], []
252 | labels, frequency = [], []
253 | product_canon_smiles = []
254 | # product_atomidx2canonstridxs = []
255 | product_atomidx2canonidxs = []
256 | product_canon_bondss = []
257 | success = 0
258 | num_canon_smiles_mismatch = 0
259 |
260 | for i, rxn in enumerate(rxns):
261 | reactant, _, product = rxn.split(">")
262 | reagent = ''
263 | rxn_labels = []
264 |
265 | # get canonical permutation of atoms
266 | product_canon_smile, product_atomidx2canonidx = canonicalize_smiles(product)
267 | product_canon_bonds = get_bonds_from_smiles(product_canon_smile)
268 |
269 | # parse reaction and store results
270 | try:
271 | rxn = {'reactants': reactant, 'products': product, '_id': i}
272 | result = extract_from_reaction(rxn, self.setting)
273 |
274 | template = result['reaction_smarts']
275 | reactant = result['reactants']
276 | product = result['products']
277 | extracted_product_canon_smile, product_atomidx2canonidx = canonicalize_smiles(product)
278 | num_canon_smiles_mismatch += int(extracted_product_canon_smile != product_canon_smile)
279 | reagent = '.'.join(result['necessary_reagent'])
280 | edits = {edit_type: edit_bond[0] for edit_type, edit_bond in result['edits'].items()}
281 | H_change, Charge_change, Chiral_change = \
282 | result['H_change'], result['Charge_change'], result['Chiral_change']
283 | template_H = get_full_template(template, H_change, Charge_change, Chiral_change)
284 |
285 | if template_H not in template_infos.keys():
286 | reactants.append(reactant)
287 | products.append(product)
288 | reagents.append(reagent)
289 | labels.append(rxn_labels)
290 | frequency.append(0)
291 | product_canon_smiles.append(product_canon_smile)
292 | # product_atomidx2canonstridxs.append(product_atomidx2canonstridx)
293 | product_atomidx2canonidxs.append(product_atomidx2canonidx)
294 | product_canon_bondss.append(product_canon_bonds)
295 | continue
296 |
297 | except Exception as e:
298 | logger.info(i, e)
299 | reactants.append(reactant)
300 | products.append(product)
301 | reagents.append(reagent)
302 | labels.append(rxn_labels)
303 | frequency.append(0)
304 | product_canon_smiles.append(product_canon_smile)
305 | # product_atomidx2canonstridxs.append(product_atomidx2canonstridx)
306 | product_atomidx2canonidxs.append(product_atomidx2canonidx)
307 | product_canon_bondss.append(product_canon_bonds)
308 | continue
309 |
310 | edit_n = 0
311 | for edit_type in edits:
312 | if edit_type == 'C':
313 | edit_n += len(edits[edit_type]) / 2
314 | else:
315 | edit_n += len(edits[edit_type])
316 |
317 | if edit_n <= self.setting['max_edit_n']:
318 | try:
319 | success += 1
320 | for edit_type, edit in edits.items():
321 | for e in edit:
322 | if edit_type in ['A', 'R']:
323 | rxn_labels.append(
324 | ('a', e, template_dicts['atom'][template_H]))
325 | else:
326 | rxn_labels.append(
327 | ('b', e, template_dicts['bond'][template_H]))
328 | reactants.append(reactant)
329 | products.append(product)
330 | reagents.append(reagent)
331 | labels.append(rxn_labels)
332 | frequency.append(template_infos[template_H]['frequency'])
333 | product_canon_smiles.append(product_canon_smile)
334 | # product_atomidx2canonstridxs.append(product_atomidx2canonstridx)
335 | product_atomidx2canonidxs.append(product_atomidx2canonidx)
336 | product_canon_bondss.append(product_canon_bonds)
337 |
338 | except Exception as e:
339 | logger.info(i, e)
340 | reactants.append(reactant)
341 | products.append(product)
342 | reagents.append(reagent)
343 | labels.append(rxn_labels)
344 | frequency.append(0)
345 | product_canon_smiles.append(product_canon_smile)
346 | # product_atomidx2canonstridxs.append(product_atomidx2canonstridx)
347 | product_atomidx2canonidxs.append(product_atomidx2canonidx)
348 | product_canon_bondss.append(product_canon_bonds)
349 | continue
350 |
351 | if i % 1000 == 0:
352 | logger.info(f'\r Processing {self.data_name} {phase} data..., '
353 | f'success {success} data ({i}/{len(rxns)})')
354 | else:
355 | logger.info(f'\nReaction # {i} has too many edits ({edit_n})...may be wrong mapping!')
356 | reactants.append(reactant)
357 | products.append(product)
358 | reagents.append(reagent)
359 | labels.append(rxn_labels)
360 | frequency.append(0)
361 | product_canon_smiles.append(product_canon_smile)
362 | # product_atomidx2canonstridxs.append(product_atomidx2canonstridx)
363 | product_atomidx2canonidxs.append(product_atomidx2canonidx)
364 | product_canon_bondss.append(product_canon_bonds)
365 |
366 | logger.info(f'\nDerived templates cover {success / len(rxns): .3f} of {phase} data reactions')
367 | logger.info(f'\nNumber of canonical smiles mismatches: {num_canon_smiles_mismatch} / {len(rxns)}')
368 | ofn = os.path.join(self.processed_data_path, f"preprocessed_{phase}.csv")
369 | dfs[phase] = pd.DataFrame(
370 | {'Reactants': reactants,
371 | 'Products': products,
372 | 'Reagents': reagents,
373 | 'Labels': labels,
374 | 'Frequency': frequency,
375 | 'ProductCanonSmiles': product_canon_smiles,
376 | # 'ProductAtomIdx2CanonStrIdx': product_atomidx2canonstridxs,
377 | 'ProductAtomIdx2CanonIdx': product_atomidx2canonidxs,
378 | 'ProductCanonBonds': product_canon_bondss})
379 | dfs[phase].to_csv(ofn)
380 |
381 | # make_simulate_output()
382 | df = dfs["test"]
383 | ofn = os.path.join(self.processed_data_path, "simulate_output.txt")
384 | with open(ofn, 'w') as of:
385 | of.write('Test_id\tReactant\tProduct\t%s\n' % '\t'.join(
386 | [f'Edit {i + 1}\tProba {i + 1}' for i in range(self.setting['max_edit_n'])]))
387 | for i in df.index:
388 | labels = []
389 | for y in df['Labels'][i]:
390 | if y != 0:
391 | labels.append(y)
392 | if not labels:
393 | labels = [(0, 0)]
394 | string_labels = '\t'.join([f'{l}\t{1.0}' for l in labels])
395 | of.write('%s\t%s\t%s\t%s\n' % (i, df['Reactants'][i], df['Products'][i], string_labels))
396 |
397 | # combine_preprocessed_data()
398 | dfs["train"]['Split'] = ['train'] * len(dfs["train"])
399 | dfs["val"]['Split'] = ['val'] * len(dfs["val"])
400 | dfs["test"]['Split'] = ['test'] * len(dfs["test"])
401 | all_valid = dfs["train"]._append(dfs["val"], ignore_index=True)
402 | all_valid = all_valid._append(dfs["test"], ignore_index=True)
403 | all_valid['Mask'] = [int(f >= self.setting['min_template_n']) for f in all_valid['Frequency']]
404 | ofn = os.path.join(self.processed_data_path, "labeled_data.csv")
405 | all_valid.to_csv(ofn, index=None)
406 | logger.info(f'Valid data size: {len(all_valid)}')
407 |
408 |
409 | if __name__ == "__main__":
410 | from argparse import Namespace
411 |
412 | model_name = "localretro"
413 | model_args = Namespace()
414 | model_config = {}
415 | data_name = "USPTO_50K_year"
416 | raw_data_files = ["train.csv", "valid.csv", "test.csv"]
417 | raw_data_files = [os.path.join("../data_USPTO_50K_year_raw/", file_name) for file_name in raw_data_files]
418 | processed_data_path = "../data_template/USPTO_50K_year/"
419 | preprocessor = LocalRetroProcessor(model_name, model_args, model_config, data_name, raw_data_files, processed_data_path)
420 | preprocessor.check_data_format()
421 | preprocessor.preprocess()
422 |
--------------------------------------------------------------------------------
/preprocess/preprocess_retrosynthesis.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 | import multiprocessing
5 | import numpy as np
6 | import pandas as pd
7 | from tqdm import tqdm
8 |
9 | import rdkit
10 | from rdkit import Chem
11 | rdkit.RDLogger.DisableLog('rdApp.*')
12 | import rdkit.Chem.rdChemReactions as rdChemReactions
13 | import rdkit.DataStructs as DataStructs
14 |
15 |
16 | BASE = 'data/USPTO_50K'
17 |
18 |
19 | def canonical_rxn_smiles(rxn_smiles):
20 | reactants, reagents, products = rxn_smiles.split(">")
21 | try:
22 | mols_r = Chem.MolFromSmiles(reactants)
23 | mols_p = Chem.MolFromSmiles(products)
24 | [a.ClearProp('molAtomMapNumber') for a in mols_r.GetAtoms()]
25 | [a.ClearProp('molAtomMapNumber') for a in mols_p.GetAtoms()]
26 | cano_smi_r = Chem.MolToSmiles(mols_r, isomericSmiles=True, canonical=True)
27 | cano_smi_p = Chem.MolToSmiles(mols_p, isomericSmiles=True, canonical=True)
28 | return cano_smi_r + '>>' + cano_smi_p, cano_smi_r, cano_smi_p, True
29 | except:
30 | return rxn_smiles, reactants, products, False
31 |
32 |
33 | def reaction_fingerprint(smiles):
34 | rxn = rdChemReactions.ReactionFromSmarts(smiles)
35 | fp = rdChemReactions.CreateDifferenceFingerprintForReaction(rxn)
36 | return fp
37 |
38 |
39 | def reaction_similarity(smiles1=None, smiles2=None, fp1=None, fp2=None):
40 | if fp1 is None and smiles1:
41 | fp1 = reaction_fingerprint(smiles1)
42 | if fp2 is None and smiles2:
43 | fp2 = reaction_fingerprint(smiles2)
44 | assert fp1 is not None
45 | assert fp2 is not None
46 | return DataStructs.TanimotoSimilarity(fp1, fp2)
47 |
48 |
49 | # Canonical SMILES
50 |
51 | # for split in ['train', 'valid', 'test']:
52 | # df = pd.read_csv(os.path.join(BASE, f'{split}.csv'))
53 | # invalid_count = 0
54 | # for i, row in df.iterrows():
55 | # try:
56 | # reactants, reagents, products = row["rxn_smiles"].split(">")
57 | # mols_r = Chem.MolFromSmiles(reactants)
58 | # mols_p = Chem.MolFromSmiles(products)
59 | # if mols_r is None or mols_p is None:
60 | # invalid_count += 1
61 | # continue
62 |
63 | # [a.ClearProp('molAtomMapNumber') for a in mols_r.GetAtoms()]
64 | # [a.ClearProp('molAtomMapNumber') for a in mols_p.GetAtoms()]
65 |
66 | # cano_smi_r = Chem.MolToSmiles(mols_r, isomericSmiles=True, canonical=True)
67 | # cano_smi_p = Chem.MolToSmiles(mols_p, isomericSmiles=True, canonical=True)
68 |
69 | # df.loc[i, 'reactant_smiles'] = cano_smi_r
70 | # df.loc[i, 'product_smiles'] = cano_smi_p
71 | # except Exception as e:
72 | # logging.info(e)
73 | # logging.info(row["rxn_smiles"].split(">"))
74 | # invalid_count += 1
75 |
76 | # logging.info(f"Invalid count: {invalid_count}")
77 | # df.drop(columns='rxn_smiles', inplace=True)
78 | # df.to_csv(os.path.join(BASE, f'processed/{split}.csv'), index=False)
79 |
80 |
81 | # rxn_df = pd.read_csv('data/USPTO_rxn_condition.csv')
82 | # with multiprocessing.Pool(32) as p:
83 | # results = p.map(canonical_rxn_smiles, rxn_df['rxn_smiles'], chunksize=128)
84 | # canonical_rxn, reactants, products, success = zip(*results)
85 | #
86 | # print(np.mean(success))
87 | # rxn_df['canonical_rxn'] = canonical_rxn
88 | # rxn_df['reactants'] = reactants
89 | # rxn_df['products'] = products
90 | # rxn_df = rxn_df[['id', 'source', 'year', 'patent_type', 'canonical_rxn', 'reactants', 'products']]
91 | # rxn_df.to_csv('data/USPTO_rxn_smiles.csv', index=False)
92 |
93 |
94 | # Match id
95 |
96 | corpus_df = pd.read_csv('preprocess/uspto_script/uspto_rxn_condition_remapped_and_reassign_condition_role.csv')
97 | rxn_smiles_to_id = {}
98 | for i, row in tqdm(corpus_df.iterrows()):
99 | canonical_rxn = row['canonical_rxn']
100 | if canonical_rxn not in rxn_smiles_to_id:
101 | rxn_smiles_to_id[canonical_rxn] = []
102 | rxn_smiles_to_id[canonical_rxn].append(row['id'])
103 |
104 | for split in ['train', 'valid', 'test']:
105 | df = pd.read_csv(f'data/USPTO_50K/processed/{split}.csv')
106 | cnt = 0
107 | match_patent_cnt = 0
108 | nomatch_cnt = 0
109 | matched_ids = []
110 | f = open('tmp.txt', 'w')
111 | for i, row in tqdm(df.iterrows()):
112 | rxn_smiles = row['reactant_smiles'] + '>>' + row['product_smiles']
113 | if rxn_smiles in rxn_smiles_to_id:
114 | rxn_id = rxn_smiles_to_id[rxn_smiles][0]
115 | for idx in rxn_smiles_to_id[rxn_smiles]:
116 | if idx.startswith(row['id']):
117 | rxn_id = idx
118 | match_patent_cnt += 1
119 | break
120 | cnt += 1
121 | else:
122 | patent_df = corpus_df.loc[corpus_df['source'] == row['id']]
123 | f.write(row['id'] + '\n')
124 | rxn_id = f'unk_{split}_{i}'
125 | if len(patent_df) == 0:
126 | nomatch_cnt += 1
127 | f.write('No match\n\n')
128 | else:
129 | fp = reaction_fingerprint(rxn_smiles)
130 | # patent_rxn_smiles = [canonical_rxn_smiles(smiles)[0] for smiles in patent_df['rxn_smiles']]
131 | patent_rxn_smiles = patent_df['canonical_rxn'].tolist()
132 | similarities = [reaction_similarity(fp1=fp, smiles2=smiles) for smiles in patent_rxn_smiles]
133 | nearest_idx = np.argmax(similarities)
134 | nearest_row = patent_df.iloc[nearest_idx]
135 | if similarities[nearest_idx] > 0.9:
136 | rxn_id = nearest_row['id']
137 | cnt += 1
138 | f.write(rxn_smiles + '\n')
139 | f.write(patent_rxn_smiles[nearest_idx] + '\n')
140 | f.write(json.dumps(similarities) + '\n')
141 | f.write(f'{similarities[nearest_idx]}\n')
142 | f.write('\n')
143 | f.flush()
144 | matched_ids.append(rxn_id)
145 | f.close()
146 | df['source'] = df['id']
147 | df['id'] = matched_ids
148 | os.makedirs('data/USPTO_50K/matched1/', exist_ok=True)
149 | df.to_csv(f'data/USPTO_50K/matched1/{split}.csv', index=False)
150 | print(cnt, match_patent_cnt, nomatch_cnt, len(df))
151 |
--------------------------------------------------------------------------------
/preprocess/raw_retro_year_split.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import os
3 | import json
4 |
5 | with open('/Mounts/rbg-storage1/users/yujieq/textreact/preprocess/uspto_script/patent_info.json') as f:
6 | patent_info = json.load(f)
7 |
8 | train_df = pd.read_csv('data/USPTO_50K/train.csv')
9 | valid_df = pd.read_csv('data/USPTO_50K/valid.csv')
10 | test_df = pd.read_csv('data/USPTO_50K/test.csv')
11 | train_df_proc = pd.read_csv('data/USPTO_50K/matched1/train.csv')
12 | valid_df_proc = pd.read_csv('data/USPTO_50K/matched1/valid.csv')
13 | test_df_proc = pd.read_csv('data/USPTO_50K/matched1/test.csv')
14 |
15 | df = pd.concat([train_df, valid_df, test_df]).reindex()
16 | df_proc = pd.concat([train_df_proc, valid_df_proc, test_df_proc]).reindex()
17 |
18 | train_idx = []
19 | valid_idx = []
20 | test_idx = []
21 | train_bad = 0
22 | valid_bad = 0
23 | test_bad = 0
24 | unk = 0
25 | for i, (proc_rxn_id, source, rxn_id) in enumerate(zip(df_proc['id'], df_proc['source'], df['id'])):
26 | p = proc_rxn_id.split('_')[0]
27 | if p != 'unk':
28 | bad = source != rxn_id
29 | if p in patent_info:
30 | year = patent_info[p]['year']
31 | else:
32 | year = -1
33 | if year < 2012:
34 | train_idx.append(i)
35 | train_bad += int(bad)
36 | elif year in [2012, 2013]:
37 | valid_idx.append(i)
38 | valid_bad += int(bad)
39 | else:
40 | test_idx.append(i)
41 | test_bad += int(bad)
42 | else:
43 | unk += 1
44 | print(train_bad, valid_bad, test_bad)
45 | print(unk)
46 |
47 | os.makedirs('data_USPTO_50K_year_raw_alt/', exist_ok=True)
48 | train_df = df.iloc[train_idx].reindex()
49 | train_df.to_csv('data_USPTO_50K_year_raw_alt/train.csv', index=False)
50 | valid_df = df.iloc[valid_idx].reindex()
51 | valid_df.to_csv('data_USPTO_50K_year_raw_alt/valid.csv', index=False)
52 | test_df = df.iloc[test_idx].reindex()
53 | test_df.to_csv('data_USPTO_50K_year_raw_alt/test.csv', index=False)
54 |
--------------------------------------------------------------------------------
/preprocess/reagent_Ionic_compound.txt:
--------------------------------------------------------------------------------
1 | [Na+].[OH-]
2 | [Li+].[OH-]
3 | [K+].[OH-]
4 | [NH4+].[OH-]
5 | [H-].[Na+]
6 | [H-].[K+]
7 | [H-].[Cs+]
8 | [BH4-].[Na+]
9 | [BH4-].[Li+]
10 | [Al+3].[H-].[H-].[H-].[H-].[Li+]
11 | [Al+3].[Cl-].[Cl-].[Cl-]
12 | O=P([O-])([O-])[O-].[K+].[K+].[K+]
13 | [BH3-]C#N.[Na+]
14 | [I-].[K+]
15 | [I-].[Li+]
16 | [I-].[Na+]
17 | [I-].[Cs+]
18 | [I-].[NH4+]
19 | [Cl-].[NH4+]
20 | [Cl-].[Na+]
21 | [Cl-].[K+]
22 | [Cl-].[Cs+]
23 | [Cl-].[Li+]
24 | [Cs+].[F-]
25 | [K+].[F-]
26 | [Na+].[F-]
27 | [Na+].[O-]Cl
28 | [Na+].[O-][Cl+][O-]
29 | [Na+].[O-][I+3]([O-])([O-])[O-]
30 | O=S([O-])([O-])=S.[Na+].[Na+]
31 | O=S([O-])S(=O)[O-].[Na+].[Na+]
32 | O=S(=O)([O-])[O-].[Mg+2]
33 | O=S(=O)([O-])[O-].[Na+].[Na+]
34 | O=S([O-])[O-].[Na+].[Na+]
35 | O=S([O-])O.[Na+]
36 | [Fe+2].c1cc[cH-]c1.c1cc[cH-]c1
37 | [Ca+2].[Cl-].[Cl-]
38 | [Br-].[K+]
39 | [C-]#N.[Na+]
40 | CC(C)[N-]C(C)C.[Li+]
41 |
42 |
43 |
44 | O=C([O-])[O-].[K+].[K+]
45 | O=C([O-])[O-].[Na+].[Na+]
46 | O=C([O-])[O-].[Cs+].[Cs+]
47 | O=C([O-])[O-].[Ca+2]
48 | O=C([O-])[O-].[Li+]
49 | O=C[O-].[NH4+]
50 | O=C([O-])O.[Na+]
51 | O=C([O-])O.[K+]
52 | CC(=O)[O-].[Na+]
53 | CC(=O)[O-].[K+]
54 | C[O-].[Na+]
55 | CC(C)(C)[O-].[K+]
56 | O=N[O-].[Na+]
57 | CN(C)C(On1nnc2cccnc21)=[N+](C)C.F[P-](F)(F)(F)(F)F
58 | CN(C)[P+](On1nnc2ccccc21)(N(C)C)N(C)C.F[P-](F)(F)(F)(F)F
59 | CCCC[N+](CCCC)(CCCC)CCCC.[F-]
60 | CC(=O)O[BH-](OC(C)=O)OC(C)=O.[Na+]
61 | C[Si](C)(C)[N-][Si](C)(C)C.[Li+]
62 | O=[Cr](=O)([O-])Cl.c1cc[nH+]cc1
63 | CC(C)C[Al+]CC(C)C.[H-]
64 | CN(C)C(On1nnc2ccccc21)=[N+](C)C.F[B-](F)(F)F
65 | CN(C)C(On1nnc2ccccc21)=[N+](C)C.F[P-](F)(F)(F)(F)F
66 | CN(C)C(On1nnc2cccnc21)=[N+](C)C.F[P-](F)(F)(F)(F)F
67 | Br[P+](N1CCCC1)(N1CCCC1)N1CCCC1.F[P-](F)(F)(F)(F)F
68 | F[P-](F)(F)(F)(F)F.c1ccc2c(c1)nnn2O[P+](N1CCCC1)(N1CCCC1)N1CCCC1
69 | C1CC[NH2+]CC1.CC(=O)[O-]
70 | CN(C)C(N(C)C)=[N+]1N=[N+]([O-])c2ncccc21.F[P-](F)(F)(F)(F)F
71 | CN1CC[NH+](C)C1Cl.[Cl-]
72 | COc1nc(OC)nc([N+]2(C)CCOCC2)n1.[Cl-]
73 | C[Si](C)(C)[N-][Si](C)(C)C.[K+]
74 | C[Si](C)(C)[N-][Si](C)(C)C.[Na+]
75 | C[n+]1ccccc1Cl.[I-]
76 | O=S(=O)([O-])C(F)(F)F.O=S(=O)([O-])C(F)(F)F.O=S(=O)([O-])C(F)(F)F.[Yb+3]
77 | O=[Cr](=O)([O-])O[Cr](=O)(=O)[O-].c1cc[nH+]cc1.c1cc[nH+]cc1
78 | Cc1ccc(S(=O)(=O)[O-])cc1.c1cc[nH+]cc1
79 |
--------------------------------------------------------------------------------
/preprocess/reagent_unknown.txt:
--------------------------------------------------------------------------------
1 | [Na+],15
2 | O=C([O-])O,9
3 | CC(=O)O[BH-](OC(C)=O)OC(C)=O,8
4 | O=C([O-])[O-],6
5 | [K+],5
6 | [Cl-],5
7 | [BH4-],3
8 | [OH-],3
9 | [I-],3
10 | [K+].[K+],2
11 | [H-],2
12 | CC(=O)[O-],2
13 | O=S(=O)([O-])[O-].[Al+3].[Li+],2
14 | F[P-](F)(F)(F)(F)F,2
15 | [NH4+],1
16 | F[B-](F)(F)F,1
17 | F[B-](F)(F)F.F[B-](F)(F)F,1
18 | [Li+],1
19 | [BH3-]C#N,1
20 | [Na+].[Na+],1
21 | [Br-],1
22 | [Cs+].[Cs+],1
23 | C[O-],1
24 | CCCC[N+](CCCC)(CCCC)CCCC,1
25 | O=P([O-])(O)O,1
26 |
--------------------------------------------------------------------------------
/preprocess/retro_year_split.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import os
3 | import json
4 |
5 | with open('preprocess/uspto_script/patent_info.json') as f:
6 | patent_info = json.load(f)
7 |
8 | train_df = pd.read_csv('data/USPTO_50K/matched1/train.csv')
9 | valid_df = pd.read_csv('data/USPTO_50K/matched1/valid.csv')
10 | test_df = pd.read_csv('data/USPTO_50K/matched1/test.csv')
11 |
12 | df = pd.concat([train_df, valid_df, test_df]).reindex()
13 |
14 | train_idx = []
15 | valid_idx = []
16 | test_idx = []
17 | for i, rxn_id in enumerate(df['id']):
18 | p = rxn_id.split('_')[0]
19 | if p in patent_info:
20 | year = patent_info[p]['year']
21 | else:
22 | year = -1
23 | if year < 2012:
24 | train_idx.append(i)
25 | elif year in [2012, 2013]:
26 | valid_idx.append(i)
27 | else:
28 | test_idx.append(i)
29 |
30 | os.makedirs('data/USPTO_50K_year/', exist_ok=True)
31 | train_df = df.iloc[train_idx].reindex()
32 | train_df.to_csv('data/USPTO_50K_year/train.csv', index=False)
33 | valid_df = df.iloc[valid_idx].reindex()
34 | valid_df.to_csv('data/USPTO_50K_year/valid.csv', index=False)
35 | test_df = df.iloc[test_idx].reindex()
36 | test_df.to_csv('data/USPTO_50K_year/test.csv', index=False)
37 |
--------------------------------------------------------------------------------
/preprocess/template_extraction/template_extract_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import copy
3 | from collections import defaultdict
4 |
5 | from rdkit import Chem
6 | from rdkit.Chem import AllChem
7 | from rdkit.Chem.rdchem import ChiralType
8 |
9 | chiral_type_map = {ChiralType.CHI_UNSPECIFIED : 0, ChiralType.CHI_TETRAHEDRAL_CW: 1, ChiralType.CHI_TETRAHEDRAL_CCW: 2}
10 | bond_type_map = {'SINGLE': '-', 'DOUBLE': '=', 'TRIPLE': '#', 'AROMATIC': '@'}
11 |
12 | def get_template_bond(temp_order, bond_smarts):
13 | bond_match = {}
14 | for n, _ in enumerate(temp_order):
15 | bond_match[(temp_order[n], temp_order[n-1])] = bond_smarts[n-1]
16 | bond_match[(temp_order[n-1], temp_order[n])] = bond_smarts[n-1]
17 | return bond_match
18 |
19 | def bond_to_smiles(bond):
20 | a1_label = str(bond.GetBeginAtom().GetAtomicNum())
21 | a2_label = str(bond.GetEndAtom().GetAtomicNum())
22 | if bond.GetBeginAtom().HasProp('molAtomMapNumber'):
23 | a1_label += bond.GetBeginAtom().GetProp('molAtomMapNumber')
24 | if bond.GetEndAtom().HasProp('molAtomMapNumber'):
25 | a2_label += bond.GetEndAtom().GetProp('molAtomMapNumber')
26 | atoms = sorted([a1_label, a2_label])
27 | bond_smarts = bond_type_map[str(bond.GetBondType())]
28 | return '{}{}{}'.format(atoms[0], bond_smarts, atoms[1])
29 |
30 | def check_bond_break(bond1, bond2):
31 | if bond1 == None and bond2 != None:
32 | return False
33 | elif bond1 != None and bond2 == None:
34 | return True
35 | else:
36 | return False
37 |
38 | def check_bond_formed(bond1, bond2):
39 | if bond1 != None and bond2 == None:
40 | return False
41 | elif bond1 == None and bond2 != None:
42 | return True
43 | else:
44 | return False
45 |
46 | def check_bond_change(pbond, rbond):
47 | if pbond == None or rbond == None:
48 | return False
49 | elif bond_to_smiles(pbond) != bond_to_smiles(rbond):
50 | return True
51 | else:
52 | return False
53 |
54 | def atom_neighbors(atom):
55 | neighbor = []
56 | for n in atom.GetNeighbors():
57 | neighbor.append(n.GetAtomMapNum())
58 | return sorted(neighbor)
59 |
60 | def extend_changed_atoms(changed_atom_tags, reactants, max_map):
61 | for reactant in reactants:
62 | extend_idx = []
63 | for atom in reactant.GetAtoms():
64 | if str(atom.GetAtomMapNum()) in changed_atom_tags:
65 | for n in atom.GetNeighbors():
66 | if n.GetAtomMapNum() == 0:
67 | extend_idx.append(n.GetIdx())
68 | for idx in extend_idx:
69 | reactant.GetAtomWithIdx(idx).SetAtomMapNum(max_map)
70 |
71 | def check_atom_change(patom, ratom):
72 | return atom_neighbors(patom) != atom_neighbors(ratom)
73 |
74 | def label_retro_edit_site(products, reactants, edit_num):
75 | edit_num = [int(num) for num in edit_num]
76 | pmol = Chem.MolFromSmiles(products)
77 | rmol = Chem.MolFromSmiles(reactants)
78 | patom_map = {atom.GetAtomMapNum():atom.GetIdx() for atom in pmol.GetAtoms()}
79 | ratom_map = {atom.GetAtomMapNum():atom.GetIdx() for atom in rmol.GetAtoms()}
80 | used_atom = set()
81 | grow_atoms = []
82 | broken_bonds = []
83 | changed_bonds = []
84 |
85 | # cut bond
86 | for a in edit_num:
87 | for b in edit_num:
88 | if a >= b:
89 | continue
90 | pbond = pmol.GetBondBetweenAtoms(patom_map[a], patom_map[b])
91 | rbond = rmol.GetBondBetweenAtoms(ratom_map[a], ratom_map[b])
92 | if check_bond_break(pbond, rbond): # cut bond
93 | broken_bonds.append((a, b))
94 | used_atom.update([a, b])
95 |
96 | # Add LG
97 | for a in edit_num:
98 | if a in used_atom:
99 | continue
100 | patom = pmol.GetAtomWithIdx(patom_map[a])
101 | ratom = rmol.GetAtomWithIdx(ratom_map[a])
102 | if check_atom_change(patom, ratom):
103 | used_atom.update([a])
104 | grow_atoms.append(a)
105 |
106 | # change bond type
107 | for a in edit_num:
108 | for b in edit_num:
109 | if a >= b:
110 | continue
111 | pbond = pmol.GetBondBetweenAtoms(patom_map[a], patom_map[b])
112 | rbond = rmol.GetBondBetweenAtoms(ratom_map[a], ratom_map[b])
113 | if check_bond_change(pbond, rbond):
114 | if a not in used_atom and b not in used_atom:
115 | changed_bonds.append((a, b))
116 | changed_bonds.append((b, a))
117 |
118 | used_atoms = set(grow_atoms + [atom for bond in broken_bonds+changed_bonds for atom in bond])
119 | remote_atoms = [atom for atom in edit_num if atom not in used_atoms]
120 | remote_atoms_ = []
121 | for a in remote_atoms:
122 | atom = rmol.GetAtomWithIdx(ratom_map[a])
123 | neighbors_map = [n.GetAtomMapNum() for n in atom.GetNeighbors()]
124 | connected_neighbors = [b for b in used_atoms if b in neighbors_map]
125 | if len(connected_neighbors) > 0:
126 | pass
127 | else:
128 | for n in neighbors_map:
129 | remote_atoms_.append(a)
130 |
131 | return grow_atoms, broken_bonds, changed_bonds, remote_atoms_
132 |
133 | def label_foward_edit_site(reactants, products, edit_num):
134 | edit_num = [int(num) for num in edit_num]
135 | rmol = Chem.MolFromSmiles(reactants)
136 | pmol = Chem.MolFromSmiles(products)
137 | ratom_map = {atom.GetAtomMapNum():atom.GetIdx() for atom in rmol.GetAtoms()}
138 | patom_map = {atom.GetAtomMapNum():atom.GetIdx() for atom in pmol.GetAtoms()}
139 | atom_symbols = {atom.GetAtomMapNum():atom.GetSymbol() for atom in rmol.GetAtoms()}
140 |
141 | formed_bonds = []
142 | broken_bonds = []
143 | changed_bonds = []
144 | acceptors1 = set()
145 | acceptors2 = set()
146 | donors = set()
147 | form_bond = False
148 | break_bond = False
149 | change_bond = False
150 |
151 | # cut bond
152 | for a in edit_num:
153 | for b in edit_num:
154 | if a >= b:
155 | continue
156 | try:
157 | pbond = pmol.GetBondBetweenAtoms(patom_map[a], patom_map[b])
158 | except:
159 | pbond = None
160 | rbond = rmol.GetBondBetweenAtoms(ratom_map[a], ratom_map[b])
161 | if check_bond_break(rbond, pbond):
162 | if a in patom_map:
163 | broken_bonds.append((a, b))
164 | acceptors1.add(a)
165 | if b in patom_map:
166 | broken_bonds.append((b, a))
167 | acceptors1.add(b)
168 | break_bond = True
169 |
170 | # change bond
171 | for a in edit_num:
172 | for b in edit_num:
173 | if a >= b:
174 | continue
175 | try:
176 | pbond = pmol.GetBondBetweenAtoms(patom_map[a], patom_map[b])
177 | except:
178 | pbond = None
179 | rbond = rmol.GetBondBetweenAtoms(ratom_map[a], ratom_map[b])
180 | if check_bond_change(rbond, pbond):
181 | changed_bonds.append((a, b))
182 | changed_bonds.append((b, a))
183 | change_bond = True
184 | acceptors2.update([a, b])
185 |
186 | symmetric = True
187 | # form bond
188 | for a in edit_num:
189 | for b in edit_num:
190 | if a >= b:
191 | continue
192 | try:
193 | pbond = pmol.GetBondBetweenAtoms(patom_map[a], patom_map[b])
194 | except:
195 | pbond = None
196 | rbond = rmol.GetBondBetweenAtoms(ratom_map[a], ratom_map[b])
197 | if check_bond_formed(rbond, pbond): # cut bond
198 | form_bond = True
199 | if a not in acceptors1 and b not in acceptors1 and a not in acceptors2 and b not in acceptors2 :
200 | formed_bonds.append((a, b))
201 | formed_bonds.append((b, a))
202 | elif a in acceptors1 and b in acceptors1:
203 | symmetric = False
204 | formed_bonds.append((a, b))
205 | formed_bonds.append((b, a))
206 | else:
207 | symmetric = False
208 | if a in acceptors1:
209 | formed_bonds.append((b, a))
210 | elif a in acceptors2 and b not in acceptors1:
211 | formed_bonds.append((b, a))
212 | if b in acceptors1:
213 | formed_bonds.append((a, b))
214 | elif b in acceptors2 and a not in acceptors1:
215 | formed_bonds.append((a, b))
216 |
217 | if not symmetric:
218 | new_changed_bonds = []
219 | # electron acceptor propagation
220 | acceptors = set([bond[1] for bond in formed_bonds]).union(acceptors1)
221 | for atom in acceptors:
222 | for bond in changed_bonds:
223 | if bond[0] == atom:
224 | new_changed_bonds.append(bond)
225 | donors = set([bond[0] for bond in formed_bonds])
226 | for atom in donors:
227 | for bond in changed_bonds:
228 | if bond[1] == atom:
229 | new_changed_bonds.append(bond)
230 | changed_bonds = list(set(new_changed_bonds))
231 |
232 | used_atoms = set([atom for bond in formed_bonds+broken_bonds+changed_bonds for atom in bond])
233 | remote_atoms = [atom for atom in edit_num if atom not in used_atoms]
234 | remote_bonds = []
235 | for a in remote_atoms:
236 | atom = rmol.GetAtomWithIdx(ratom_map[a])
237 | neighbors_map = [n.GetAtomMapNum() for n in atom.GetNeighbors()]
238 | connected_neighbors = [b for b in used_atoms if b in neighbors_map]
239 | if len(connected_neighbors) > 0:
240 | pass
241 | else:
242 | for n in neighbors_map:
243 | remote_bonds.append((a, n))
244 | return formed_bonds, broken_bonds, changed_bonds, remote_bonds
245 |
246 | def label_CHS_change(smiles1, smiles2, edit_num, replacement_dict, use_stereo):
247 | mol1, mol2 = Chem.MolFromSmiles(smiles1), Chem.MolFromSmiles(smiles2)
248 | atom_map_dict1 = {atom.GetAtomMapNum():atom.GetIdx() for atom in mol1.GetAtoms()}
249 | atom_map_dict2 = {atom.GetAtomMapNum():atom.GetIdx() for atom in mol2.GetAtoms()}
250 | H_dict = defaultdict(dict)
251 | C_dict = defaultdict(dict)
252 | S_dict = defaultdict(dict)
253 | for atom_map in edit_num:
254 | atom_map = int(atom_map)
255 | if atom_map in atom_map_dict2:
256 | atom1, atom2 = mol1.GetAtomWithIdx(atom_map_dict1[atom_map]), mol2.GetAtomWithIdx(atom_map_dict2[atom_map])
257 | H_dict[atom_map]['smiles1'], C_dict[atom_map]['smiles1'], S_dict[atom_map]['smiles1'] = atom1.GetNumExplicitHs(), int(atom1.GetFormalCharge()), chiral_type_map[atom1.GetChiralTag()]
258 | H_dict[atom_map]['smiles2'], C_dict[atom_map]['smiles2'], S_dict[atom_map]['smiles2'] = atom2.GetNumExplicitHs(), int(atom2.GetFormalCharge()), chiral_type_map[atom2.GetChiralTag()]
259 |
260 | H_change = {replacement_dict[k]:v['smiles2'] - v['smiles1'] for k, v in H_dict.items()}
261 | Charge_change = {replacement_dict[k]:v['smiles2'] - v['smiles1'] for k, v in C_dict.items()}
262 | Chiral_change = {replacement_dict[k]:v['smiles2'] - v['smiles1'] for k, v in S_dict.items()}
263 | for k, v in S_dict.items():
264 | if v['smiles2'] == v['smiles1'] or not use_stereo: # no chiral change
265 | Chiral_change[replacement_dict[k]] = 0
266 | # elif v['smiles1'] != 0: # opposite the stereo bond
267 | # Chiral_change[replacement_dict[k]] = 3
268 | else:
269 | Chiral_change[replacement_dict[k]] = v['smiles2']
270 | return atom_map_dict1, H_change, Charge_change, Chiral_change
271 |
272 | def bondmap2idx(bond_maps, idx_dict, temp_dict, sort = False, remote = False):
273 | bond_idxs = [(idx_dict[bond_map[0]], idx_dict[bond_map[1]]) for bond_map in bond_maps]
274 | if remote:
275 | bond_temps = list(set([(temp_dict[bond_map[0]], -1) for bond_map in bond_maps]))
276 | return (bond_idxs, bond_maps, bond_temps)
277 | else:
278 | bond_temps = [(temp_dict[bond_map[0]], temp_dict[bond_map[1]]) for bond_map in bond_maps]
279 | if not sort:
280 | return (bond_idxs, bond_maps, bond_temps)
281 | else:
282 | sort_bond_idxs = []
283 | sort_bond_maps = []
284 | sort_bond_temps = []
285 | for bond1, bond2, bond3 in zip(bond_idxs, bond_maps, bond_temps):
286 | if bond3[0] < bond3[1]:
287 | sort_bond_idxs.append(bond1)
288 | sort_bond_maps.append(bond2)
289 | sort_bond_temps.append(bond3)
290 | else:
291 | sort_bond_idxs.append(tuple(bond1[::-1]))
292 | sort_bond_maps.append(tuple(bond2[::-1]))
293 | sort_bond_temps.append(tuple(bond3[::-1]))
294 | return (sort_bond_idxs, sort_bond_maps, sort_bond_temps)
295 |
296 | def atommap2idx(atom_maps, idx_dict, temp_dict):
297 | atom_idxs = [idx_dict[atom_map] for atom_map in atom_maps]
298 | atom_temps = [temp_dict[atom_map] for atom_map in atom_maps]
299 | return (atom_idxs, atom_maps, atom_temps)
300 |
301 | def match_label(reactants, products, replacement_dict, edit_num, retro = True, remote = True, use_stereo = True):
302 | if retro:
303 | smiles1 = products
304 | smiles2 = reactants
305 | else:
306 | smiles1 = reactants
307 | smiles2 = products
308 |
309 | replacement_dict = {int(k): int(v) for k, v in replacement_dict.items()}
310 | atom_map_dict, H_change, Charge_change, Chiral_change = label_CHS_change(smiles1, smiles2, edit_num, replacement_dict, use_stereo)
311 |
312 | if retro:
313 | ALG_atoms, broken_bonds, changed_bonds, remote_atoms = label_retro_edit_site(smiles1, smiles2, edit_num)
314 | edits = {'A': atommap2idx(ALG_atoms, atom_map_dict, replacement_dict),
315 | 'B': bondmap2idx(broken_bonds, atom_map_dict, replacement_dict, True),
316 | 'C': bondmap2idx(changed_bonds, atom_map_dict, replacement_dict)}
317 | if remote:
318 | edits['R'] = atommap2idx(remote_atoms, atom_map_dict, replacement_dict)
319 | else:
320 | formed_bonds, broken_bonds, changed_bonds, remote_bonds = label_foward_edit_site(smiles1, smiles2, edit_num)
321 | edits = {'A': bondmap2idx(formed_bonds, atom_map_dict, replacement_dict),
322 | 'B': bondmap2idx(broken_bonds, atom_map_dict, replacement_dict),
323 | 'C': bondmap2idx(changed_bonds, atom_map_dict, replacement_dict)}
324 | if remote:
325 | edits['R'] = bondmap2idx(remote_bonds, atom_map_dict, replacement_dict, False, True)
326 | return edits, H_change, Charge_change, Chiral_change
327 |
328 | def get_bonds_from_smiles(smiles):
329 | """ Adapted from get_edit_site_retro from in Run_preprocessing.py """
330 | mol = Chem.MolFromSmiles(smiles)
331 | B = set()
332 | for atom in mol.GetAtoms():
333 | others = []
334 | bonds = atom.GetBonds()
335 | for bond in bonds:
336 | atoms = [bond.GetBeginAtom().GetIdx(), bond.GetEndAtom().GetIdx()]
337 | other = [a for a in atoms if a != atom.GetIdx()][0]
338 | others.append(other)
339 | B.update((atom.GetIdx(), other) for other in sorted(others))
340 | return B
341 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/1.get_condition_from_uspto.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import xmltodict
3 | from collections import OrderedDict, Counter
4 | import pandas as pd
5 | import os
6 | import copy
7 | import json
8 |
9 | from utils import get_writer
10 |
11 |
12 | patent_cnt = Counter()
13 | patent_info = {}
14 | CONDITION_DICT = {
15 | 'id': [],
16 | 'source': [],
17 | 'year': [],
18 | 'patent_type': [],
19 | 'rxn_smiles': [],
20 | 'solvent': [],
21 | 'catalyst': [],
22 | 'reagent': [],
23 | }
24 | CORPUS_DICT = {
25 | 'id': [],
26 | 'year': [],
27 | 'patent_type': [],
28 | 'xml': [],
29 | 'heading_text': [],
30 | 'paragraph_text': [],
31 | }
32 |
33 |
34 | def read_xml2dict(xml_fpath):
35 | with open(xml_fpath, 'r') as f:
36 | data = xmltodict.parse(f.read())
37 | reaction_and_condition_dict = copy.deepcopy(CONDITION_DICT)
38 | corpus_dict = copy.deepcopy(CORPUS_DICT)
39 |
40 | try:
41 | reaction_data_list = data['reactionList']['reaction']
42 | except:
43 | return reaction_and_condition_dict, corpus_dict
44 |
45 |
46 | for rxn_data in reaction_data_list:
47 |
48 | try:
49 | patent_id = rxn_data['dl:source']['dl:documentId']
50 | heading_text = rxn_data['dl:source'].get('dl:headingText', '')
51 | paragraph_text = rxn_data['dl:source'].get('dl:paragraphText', '')
52 | year = os.path.dirname(xml_fpath).split('/')[-1]
53 | patent_type = 'grant' if 'grants' in xml_fpath else 'application'
54 | patent_info[patent_id] = {
55 | 'year': int(year),
56 | 'type': patent_type
57 | }
58 | except:
59 | print(xml_fpath)
60 | print(rxn_data)
61 | continue
62 |
63 | if type(rxn_data) is str:
64 | continue
65 | if rxn_data['spectatorList'] is None:
66 | continue
67 | if rxn_data['dl:reactionSmiles'] is None:
68 | continue
69 | try:
70 | spectator_obj = rxn_data['spectatorList']['spectator']
71 | except:
72 | continue
73 |
74 | s_list = []
75 | c_list = []
76 | r_list = []
77 | if type(spectator_obj) is list:
78 | pass
79 | elif type(spectator_obj) is OrderedDict:
80 | spectator_obj = [spectator_obj]
81 | else:
82 | print('Warning spectator_obj is not in (list, OrderedDict)!!!')
83 |
84 | for one_spectator in spectator_obj:
85 | if 'identifier' not in one_spectator:
86 | continue
87 | if one_spectator['@role'] == 'solvent':
88 | if type(one_spectator['identifier']) is list:
89 | for identifier in one_spectator['identifier']:
90 | if identifier['@dictRef'] == 'cml:smiles':
91 | s_list.append(identifier['@value'])
92 | elif type(one_spectator['identifier']) is OrderedDict:
93 | identifier = one_spectator['identifier']
94 | if identifier['@dictRef'] == 'cml:smiles':
95 | s_list.append(identifier['@value'])
96 | elif one_spectator['@role'] == 'catalyst':
97 | if type(one_spectator['identifier']) is list:
98 | for identifier in one_spectator['identifier']:
99 | if identifier['@dictRef'] == 'cml:smiles':
100 | c_list.append(identifier['@value'])
101 | elif type(one_spectator['identifier']) is OrderedDict:
102 | identifier = one_spectator['identifier']
103 | if identifier['@dictRef'] == 'cml:smiles':
104 | c_list.append(identifier['@value'])
105 | elif one_spectator['@role'] == 'reagent':
106 | if type(one_spectator['identifier']) is list:
107 | for identifier in one_spectator['identifier']:
108 | if identifier['@dictRef'] == 'cml:smiles':
109 | r_list.append(identifier['@value'])
110 | elif type(one_spectator['identifier']) is OrderedDict:
111 | identifier = one_spectator['identifier']
112 | if identifier['@dictRef'] == 'cml:smiles':
113 | r_list.append(identifier['@value'])
114 | else:
115 | print(one_spectator['@role'])
116 |
117 | rxn_id = patent_id + '_' + str(patent_cnt[patent_id])
118 | patent_cnt[patent_id] += 1
119 |
120 | s_list = list(set(s_list))
121 | c_list = list(set(c_list))
122 | r_list = list(set(r_list))
123 | reaction_and_condition_dict['solvent'].append('.'.join(s_list))
124 | reaction_and_condition_dict['catalyst'].append('.'.join(c_list))
125 | reaction_and_condition_dict['reagent'].append('.'.join(r_list))
126 | reaction_and_condition_dict['rxn_smiles'].append(rxn_data['dl:reactionSmiles'])
127 | reaction_and_condition_dict['source'].append(patent_id)
128 | reaction_and_condition_dict['id'].append(rxn_id)
129 | reaction_and_condition_dict['year'].append(year)
130 | reaction_and_condition_dict['patent_type'].append(patent_type)
131 |
132 | corpus_dict['id'].append(rxn_id)
133 | corpus_dict['xml'].append(os.path.basename(xml_fpath))
134 | corpus_dict['heading_text'].append(heading_text)
135 | corpus_dict['paragraph_text'].append(paragraph_text)
136 | corpus_dict['year'].append(year)
137 | corpus_dict['patent_type'].append(patent_type)
138 |
139 | return reaction_and_condition_dict, corpus_dict
140 |
141 |
142 | if __name__ == '__main__':
143 | uspto_org_path = '/Mounts/rbg-storage1/users/yujieq/USPTO/'
144 |
145 | # reaction_and_condition_df = pd.DataFrame()
146 | xml_path_list = []
147 | for root, _, files in os.walk(uspto_org_path, topdown=False):
148 | if '.ipynb_checkpoints' in root:
149 | continue
150 | for fname in files:
151 | if fname.endswith('.xml'):
152 | xml_path_list.append(os.path.join(root, fname))
153 | xml_path_list = sorted(xml_path_list)
154 | # fout, writer = get_writer('uspto_rxn_condition.csv', CONDITION_DICT.keys())
155 | # corpus_fout, corpus_writer = get_writer('uspto_rxn_corpus.csv', CORPUS_DICT.keys())
156 | cnt = 0
157 | for i, path in tqdm(enumerate(xml_path_list), total=len(xml_path_list)):
158 | reaction_and_condition_dict, corpus_dict = read_xml2dict(path)
159 | reaction_and_condition_df = pd.DataFrame(reaction_and_condition_dict)
160 | # for row in reaction_and_condition_df.itertuples():
161 | # writer.writerow(list(row)[1:])
162 | # fout.flush()
163 | corpus_df = pd.DataFrame(corpus_dict)
164 | # for row in corpus_df.itertuples():
165 | # corpus_writer.writerow(list(row)[1:])
166 | # corpus_fout.flush()
167 | cnt += len(reaction_and_condition_df)
168 | if i % 100 == 0:
169 | print(f'step {i}: {cnt} data')
170 | # fout.close()
171 | # corpus_fout.close()
172 |
173 | for patent_id in patent_cnt:
174 | patent_info[patent_id]['num_rxn'] = patent_cnt[patent_id]
175 | with open('patent_info.json', 'w') as f:
176 | json.dump(patent_info, f)
177 |
178 | print('Done')
179 |
180 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/2.0.clean_up_rxn_condition.py:
--------------------------------------------------------------------------------
1 | import re
2 | from collections import OrderedDict
3 | from joblib import Parallel, delayed
4 | import pandas as pd
5 | import multiprocessing
6 | import os
7 | import argparse
8 | import torch
9 | from tqdm import tqdm
10 | from utils import canonicalize_smiles, get_writer
11 | from rxnmapper import RXNMapper
12 |
13 |
14 | debug = False
15 |
16 |
17 | def remap_and_reassign_condition_role(org_rxn, org_solvent, org_catalyst, org_reagent):
18 | if org_rxn.split('>') == 1:
19 | return None
20 | if '|' in org_rxn:
21 | rxn, frag = org_rxn.split(' ')
22 | else:
23 | rxn, frag = org_rxn, ''
24 |
25 | org_solvent, org_catalyst, org_reagent = [
26 | canonicalize_smiles(x) for x in [org_solvent, org_catalyst, org_reagent]
27 | ]
28 | try:
29 | results = rxn_mapper.get_attention_guided_atom_maps([rxn])[0]
30 | except Exception as e:
31 | print('\n'+rxn+'\n')
32 | print(e)
33 | return None
34 |
35 | remapped_rxn = results['mapped_rxn']
36 | confidence = results['confidence']
37 |
38 | new_precursors, new_products = remapped_rxn.split('>>')
39 |
40 | pt = re.compile(r':(\d+)]')
41 | new_react_list = []
42 | new_reag_list = []
43 | for precursor in new_precursors.split('.'):
44 | if re.findall(pt, precursor):
45 | new_react_list.append(precursor) # 有原子映射-->反应物
46 | else:
47 | new_reag_list.append(precursor) # 无原子映射-->试剂
48 |
49 | new_reactants = '.'.join(new_react_list)
50 | react_maps = sorted(re.findall(pt, new_reactants))
51 | prod_maps = sorted(re.findall(pt, new_products))
52 | if react_maps != prod_maps:
53 | return None
54 | new_reagent_list = []
55 | c_list = org_catalyst.split('.')
56 | s_list = org_solvent.split('.')
57 | r_list = org_reagent.split('.')
58 | for r in new_reag_list:
59 | if (r not in c_list + s_list) and (r not in r_list):
60 | new_reagent_list.append(r)
61 | new_reagent_list += [x for x in r_list if x != '']
62 | catalyst = org_catalyst
63 | solvent = org_solvent
64 | reagent = '.'.join(new_reagent_list)
65 | can_react = canonicalize_smiles(new_reactants, clear_map=True)
66 | can_prod = canonicalize_smiles(new_products, clear_map=True)
67 | can_rxn = '{}>>{}'.format(can_react, can_prod)
68 | results = OrderedDict()
69 | results['remapped_rxn'] = remapped_rxn # remapped_rxn中包含有反应条件
70 | results['fragment'] = frag
71 | results['confidence'] = confidence
72 | results['canonical_rxn'] = can_rxn # can_rxn中无反应条件,只有原子参与贡献的反应物和产物
73 | results['catalyst'] = catalyst
74 | results['solvent'] = solvent
75 | results['reagent'] = reagent
76 |
77 | return results
78 |
79 |
80 | def run_tasks(task):
81 | idx, rxn, solvent, catalyst, reagent, source = task
82 | if pd.isna(solvent):
83 | solvent = ''
84 | if pd.isna(catalyst):
85 | catalyst = ''
86 | if pd.isna(reagent):
87 | reagent = ''
88 | results = remap_and_reassign_condition_role(rxn, solvent, catalyst, reagent)
89 |
90 | return idx, results, source
91 |
92 |
93 | if __name__ == '__main__':
94 | parser = argparse.ArgumentParser()
95 | parser.add_argument('--gpu', type=int, default=0)
96 | parser.add_argument('--split_group', type=int, default=4)
97 | parser.add_argument('--group', type=int, default=1)
98 | args = parser.parse_args()
99 |
100 | assert args.group <= args.split_group-1
101 |
102 | print('Debug:', debug)
103 | print('Split group: {}'.format(args.split_group))
104 | print('Group number: {}'.format(args.group))
105 | print('GPU index: {}'.format(args.gpu))
106 |
107 | # device = torch.device('cuda:{}'.format(args.gpu) if args.gpu >= 0 else 'cpu')
108 | rxn_mapper = RXNMapper()
109 | source_data_path = '.'
110 | rxn_condition_fname = 'uspto_rxn_condition.csv'
111 | new_database_fpath = os.path.join(
112 | source_data_path, 'uspto_rxn_condition_remapped_and_reassign_condition_role_group_{}.csv'.format(args.group))
113 | # n_core = 14
114 |
115 | # pool = multiprocessing.Pool(n_core)
116 | if debug:
117 | database = pd.read_csv(os.path.join(source_data_path, rxn_condition_fname), nrows=10001)
118 | else:
119 | database = pd.read_csv(os.path.join(source_data_path, rxn_condition_fname))
120 | print('All data number: {}'.format(len(database)))
121 |
122 | group_size = len(database) // args.split_group
123 |
124 | if args.group >= args.split_group-1:
125 | database = database.iloc[args.group * group_size:]
126 | else:
127 | database = database.iloc[args.group * group_size:(args.group+1) * group_size]
128 |
129 | print('Caculate index {} to {}'.format(database.index.min(), database.index.max()))
130 |
131 | # rxn_smiles = database['rxn_smiles'].tolist()
132 |
133 | # tasks = [(idx, rxn, database.iloc[idx].solvent, database.iloc[idx].catalyst, database.iloc[idx].reagent, database.iloc[idx].source)
134 | # for idx, rxn in tqdm(enumerate(rxn_smiles), total=len(rxn_smiles))]
135 | header = [
136 | 'id',
137 | 'source',
138 | 'org_rxn',
139 | 'fragment',
140 | 'remapped_rxn',
141 | 'confidence',
142 | 'canonical_rxn',
143 | 'catalyst',
144 | 'solvent',
145 | 'reagent',
146 | ]
147 | fout, writer = get_writer(new_database_fpath, header=header)
148 | all_results = []
149 | for row in tqdm(database.itertuples(), total=len(database)):
150 | task = (row.id, row.rxn_smiles, row.solvent, row.catalyst, row.reagent, row.source)
151 | try:
152 | run_results = run_tasks(task)
153 | idx, results, source = run_results
154 | if results:
155 | results['id'] = row.id
156 | results['source'] = source
157 | results['org_rxn'] = row.rxn_smiles
158 | assert len(results) == len(header)
159 | writer.writerow([results[key] for key in header])
160 | fout.flush()
161 | except Exception as e:
162 | print(e)
163 | pass
164 | # for results in tqdm(pool.imap_unordered(run_tasks, tasks), total=len(tasks)):
165 | # all_results.append(results)
166 | # all_results = Parallel(n_jobs=n_core, verbose=1)(
167 | # delayed(run_tasks)(task) for task in tqdm(tasks))
168 | fout.close()
169 | # new_database = pd.read_csv(new_database_fpath)
170 | # reset_header = [
171 | # 'source',
172 | # 'org_rxn',
173 | # 'fragment',
174 | # 'remapped_rxn',
175 | # 'confidence',
176 | # 'canonical_rxn',
177 | # 'catalyst',
178 | # 'solvent',
179 | # 'reagent',
180 | # ]
181 | # new_database = new_database[reset_header]
182 | # new_database.to_csv(new_database_fpath, index=False)
183 | print('Done!')
184 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/2.0.clean_up_rxn_condition.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p log
3 | CUDA_VISIBLE_DEVICES=2 nohup python -u 2.0.clean_up_rxn_condition.py --split_group 4 --group 0 > log/get_rxn_condition_uspto_0.log &
4 | CUDA_VISIBLE_DEVICES=3 nohup python -u 2.0.clean_up_rxn_condition.py --split_group 4 --group 1 > log/get_rxn_condition_uspto_1.log &
5 | CUDA_VISIBLE_DEVICES=6 nohup python -u 2.0.clean_up_rxn_condition.py --split_group 4 --group 2 > log/get_rxn_condition_uspto_2.log &
6 | CUDA_VISIBLE_DEVICES=7 nohup python -u 2.0.clean_up_rxn_condition.py --split_group 4 --group 3 > log/get_rxn_condition_uspto_3.log &
7 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/2.1.merge_clean_up_rxn_conditon.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import os
3 | import pandas as pd
4 | from tqdm import tqdm
5 |
6 | from utils import calculate_frequency, get_writer
7 |
8 | def write_freq(fpath, freq_data):
9 | fout, writer = get_writer(fpath, ['smiles', 'freq_cnt'])
10 | for data in freq_data:
11 | writer.writerow(data)
12 | fout.flush()
13 | fout.close()
14 |
15 | if __name__ == '__main__':
16 | debug = False
17 |
18 | source_data_path = '.'
19 | merge_data_fname = 'uspto_rxn_condition_remapped_and_reassign_condition_role.csv'
20 | duplicate_removal_fname = 'uspto_rxn_condition_remapped_and_reassign_condition_role_rm_duplicate.csv'
21 | freq_info_path = os.path.join(source_data_path, 'freq_info')
22 | if not os.path.exists(freq_info_path):
23 | os.makedirs(freq_info_path)
24 | if not os.path.exists(os.path.join(source_data_path, merge_data_fname)):
25 | database = pd.DataFrame()
26 | for group_idx in [0, 1, 2, 3]:
27 | database = database.append(pd.read_csv(os.path.join(
28 | source_data_path, f'uspto_rxn_condition_remapped_and_reassign_condition_role_group_{group_idx}.csv')))
29 | database.reset_index(inplace=True, drop=True)
30 | database.to_csv(os.path.join(os.path.join(
31 | source_data_path, merge_data_fname)), index=False)
32 | else:
33 | if not debug:
34 | database = pd.read_csv(os.path.join(
35 | source_data_path, merge_data_fname))
36 | else:
37 | database = pd.read_csv(os.path.join(
38 | source_data_path, merge_data_fname), nrows=10000)
39 |
40 | # 按照 remapped_rxn + canonical_rxn + catalyst + solvent + reagent 比照标准去除重复, source, org_rxn, fragment, confidence 这几列取类别中的第一个
41 | info_row_name = ['remapped_rxn', 'canonical_rxn', 'catalyst', 'solvent', 'reagent']
42 | database_duplicate_removal = database.drop_duplicates(subset=info_row_name, keep='first')
43 | database_duplicate_removal.reset_index(inplace=True, drop=True)
44 | print()
45 | print('catalyst count:', len(set(database_duplicate_removal['catalyst'])))
46 | catalyst_freq = calculate_frequency(database_duplicate_removal['catalyst'].tolist())
47 | write_freq(os.path.join(freq_info_path, 'catalyst_freq.csv'), catalyst_freq)
48 | print()
49 | print('solvent count:', len(set(database_duplicate_removal['solvent'])))
50 | solvent_freq = calculate_frequency(database_duplicate_removal['solvent'].tolist())
51 | write_freq(os.path.join(freq_info_path, 'solvent_freq.csv'), solvent_freq)
52 | print()
53 | print('reagent count:', len(set(database_duplicate_removal['reagent'])))
54 | reagent_freq = calculate_frequency(database_duplicate_removal['reagent'].tolist())
55 | write_freq(os.path.join(freq_info_path, 'reagent_freq.csv'), reagent_freq)
56 | print()
57 |
58 | print('All dataset count:', len(database_duplicate_removal))
59 | database_duplicate_removal.to_csv(os.path.join(source_data_path, duplicate_removal_fname), index=False)
60 |
61 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/3.0.split_condition_and_slect.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, namedtuple
2 | import os
3 | import pandas as pd
4 | from tqdm import tqdm
5 | from rdkit import Chem
6 | from utils import MolRemover, get_mol_charge, list_of_metal_atoms, mol_charge_class
7 | from rdkit import RDLogger
8 | RDLogger.DisableLog('rdApp.*')
9 | # if any(metal in cat for metal in list_of_metals):
10 | # metal_cat_dict[key] = True
11 |
12 |
13 | if __name__ == '__main__':
14 | debug = False
15 | unknown_check = False
16 | split_token = '分'
17 | print('Debug:', debug)
18 | remove_threshold = 100
19 | source_data_path = '.'
20 | duplicate_removal_fname = 'uspto_rxn_condition_remapped_and_reassign_condition_role_rm_duplicate.csv'
21 | freq_info_path = os.path.join(source_data_path, 'freq_info')
22 |
23 | if debug:
24 | database = pd.read_csv(os.path.join(source_data_path, duplicate_removal_fname), nrows=10000)
25 | else:
26 | database = pd.read_csv(os.path.join(source_data_path, duplicate_removal_fname))
27 |
28 | # 去除条件频率少于阈值以下的数据
29 | print('Remove data with less than remove_threshold...')
30 | print(f'remove_threshold = {remove_threshold}')
31 | print('data number before remove: {}'.format(len(database)))
32 | condition_roles = ['catalyst', 'solvent', 'reagent']
33 | remove_index = pd.isna(database.index)
34 | for role in condition_roles:
35 | df = pd.read_csv(os.path.join(freq_info_path, f'{role}_freq.csv'))
36 | remove_compund = df[df['freq_cnt'] < remove_threshold]['smiles']
37 | remove_index = remove_index | database[role].isin(remove_compund)
38 | database_remove_below_threshold = database.loc[~remove_index].reset_index(drop=True)
39 | print('data number after remove: {}'.format(len(database_remove_below_threshold)))
40 | # database_remove_below_threshold.to_csv
41 | remover = MolRemover(defnFilename='../reagent_Ionic_compound.txt')
42 | reagent2index_dict = defaultdict(list)
43 | for idx, reagent in tqdm(enumerate(database_remove_below_threshold.reagent.tolist()),
44 | total=len(database_remove_below_threshold)):
45 | reagent2index_dict[reagent].append(idx)
46 |
47 | if unknown_check:
48 | # 检查带有离子的试剂,删去非电中性的组合
49 | unknown_combination = defaultdict(int)
50 | reagent2single_reagent_dict = {}
51 | for reagent in tqdm(reagent2index_dict):
52 | if pd.isna(reagent):
53 | reagent_namedtuple = namedtuple('reagent', ['known', 'unknown'])
54 | reagent2single_reagent_dict[reagent] = reagent_namedtuple([], [])
55 | continue
56 | reagent_mol = Chem.MolFromSmiles(reagent)
57 | reagent_mol_after_rm, remove_mol = remover.StripMolWithDeleted(reagent_mol, onlyFrags=False)
58 | reagent_smiles_after_rm = Chem.MolToSmiles(reagent_mol_after_rm)
59 | reagent_smiles_after_rm_list = reagent_smiles_after_rm.split('.')
60 | reagent_mols_after_rm = [Chem.MolFromSmiles(x) for x in reagent_smiles_after_rm_list]
61 | known_ionic = [Chem.MolToSmiles(x) for x in remove_mol]
62 | _unknown = []
63 | reagent_charge_neutral = []
64 | for mol in reagent_mols_after_rm:
65 | mol_charge_flag, mol_neutralization = get_mol_charge(mol)
66 | if mol_charge_flag != mol_charge_class[2]:
67 | smi = Chem.MolToSmiles(mol)
68 | if smi != '':
69 | _unknown.append(smi)
70 | else:
71 | smi = Chem.MolToSmiles(mol)
72 | if smi != '':
73 | reagent_charge_neutral.append(smi)
74 | reagent_namedtuple = namedtuple('reagent', ['known', 'unknown'])
75 | _known = reagent_charge_neutral + known_ionic
76 | if _unknown:
77 | unknown_combination['.'.join(_unknown)] += 1
78 | reagent2single_reagent_dict[reagent] = reagent_namedtuple(_known, _unknown)
79 | unknown_combination = list(unknown_combination.items())
80 | unknown_combination.sort(key=lambda x:x[1], reverse=True)
81 | print('Unknown reagent data count:', sum([x[1] for x in unknown_combination]))
82 | with open('../reagent_unknown.txt', 'w', encoding='utf-8') as f:
83 | for line in unknown_combination:
84 | f.write('{},{}\n'.format(line[0], line[1]))
85 | else:
86 | block_unknown_combination = pd.read_csv('../reagent_unknown.txt', header=None)
87 | block_unknown_combination.columns = ['smiles', 'cnt']
88 | print('Will block {} reagent combination, a total of {} data will be deleted.'.format(
89 | len(block_unknown_combination), block_unknown_combination['cnt'].sum()))
90 |
91 | reagent2single_reagent_dict = defaultdict(list)
92 | for reagent in tqdm(reagent2index_dict):
93 | if pd.isna(reagent):
94 | reagent2single_reagent_dict[reagent].append('')
95 | continue
96 | reagent_mol = Chem.MolFromSmiles(reagent)
97 | reagent_mol_after_rm, remove_mol = remover.StripMolWithDeleted(reagent_mol, onlyFrags=False)
98 | reagent_smiles_after_rm = Chem.MolToSmiles(reagent_mol_after_rm)
99 | reagent_smiles_after_rm_list = reagent_smiles_after_rm.split('.')
100 | reagent_mols_after_rm = [Chem.MolFromSmiles(x) for x in reagent_smiles_after_rm_list]
101 | known_ionic = [Chem.MolToSmiles(x) for x in remove_mol]
102 | _unknown = []
103 | reagent_charge_neutral = []
104 | for mol in reagent_mols_after_rm:
105 | mol_charge_flag, mol_neutralization = get_mol_charge(mol)
106 | if mol_charge_flag != mol_charge_class[2]:
107 | smi = Chem.MolToSmiles(mol)
108 | if smi != '':
109 | _unknown.append(smi)
110 | else:
111 | smi = Chem.MolToSmiles(mol)
112 | if smi != '':
113 | reagent_charge_neutral.append(smi)
114 | if _unknown:
115 | _unknown_smiles = '.'.join(_unknown)
116 | assert _unknown_smiles in block_unknown_combination['smiles'].tolist()
117 | _known = reagent_charge_neutral + known_ionic
118 | reagent2single_reagent_dict[reagent] += _known
119 | unknown_drop_index = []
120 | for reagent in reagent2single_reagent_dict:
121 | if not reagent2single_reagent_dict[reagent]:
122 | unknown_drop_index.extend(reagent2index_dict[reagent])
123 | reagent2single_reagent_dict = {k: v for k, v in reagent2single_reagent_dict.items() if v}
124 | database_remove_below_threshold = database_remove_below_threshold.drop(unknown_drop_index)
125 | database_remove_below_threshold = database_remove_below_threshold.reset_index(drop=True)
126 | reagent2index_dict = defaultdict(list)
127 | for idx, reagent in tqdm(enumerate(database_remove_below_threshold.reagent.tolist()),
128 | total=len(database_remove_below_threshold)):
129 | reagent2index_dict[reagent].append(idx)
130 |
131 | # 按照 https://pubs.acs.org/doi/10.1021/acscentsci.8b00357 论文所提到的将catalyst>1, solvent>2, reagent>2 的数据排除
132 | print('Exceeding one catalyst, two solvents, or two reagents --> remove')
133 | remove_index_for_excess = pd.isna(database_remove_below_threshold.index)
134 | # reagent > 2 remove!
135 | for reagent in reagent2index_dict:
136 | if len(reagent2single_reagent_dict[reagent]) > 2:
137 | remove_idx = reagent2index_dict[reagent]
138 | for _idx in remove_idx:
139 | remove_index_for_excess[_idx] = True
140 | for _idx, catalyst in enumerate(database_remove_below_threshold.catalyst.tolist()):
141 | if not pd.isna(catalyst):
142 | if len(catalyst.split('.')) > 1:
143 | remove_index_for_excess[_idx] = True
144 | for _idx, solvent in enumerate(database_remove_below_threshold.solvent.tolist()):
145 | if not pd.isna(solvent):
146 | if len(solvent.split('.')) > 2:
147 | remove_index_for_excess[_idx] = True
148 | database_remove_below_threshold = database_remove_below_threshold.loc[~remove_index_for_excess].reset_index(drop=True)
149 | reagent2index_dict = defaultdict(list)
150 | for idx, reagent in tqdm(enumerate(database_remove_below_threshold.reagent.tolist()),
151 | total=len(database_remove_below_threshold)):
152 | reagent2index_dict[reagent].append(idx)
153 | print('Spliting conditions...')
154 | database_remove_below_threshold['catalyst_split'] = database_remove_below_threshold['catalyst']
155 | database_remove_below_threshold['solvent_split'] = ['']*len(database_remove_below_threshold)
156 | database_remove_below_threshold['reagent_split'] = ['']*len(database_remove_below_threshold)
157 | for reagent in reagent2index_dict:
158 | write_index = reagent2index_dict[reagent]
159 | write_value = split_token.join(reagent2single_reagent_dict[reagent])
160 | database_remove_below_threshold.loc[write_index, 'reagent_split'] = write_value
161 |
162 | solvent_split = []
163 | for x in database_remove_below_threshold['solvent'].tolist():
164 | if pd.isna(x):
165 | solvent_split.append('')
166 | continue
167 | solvent_split.append(split_token.join(x.split('.')))
168 | database_remove_below_threshold['solvent_split'] = solvent_split
169 | database_remove_below_threshold.to_csv(
170 | os.path.join(
171 | source_data_path,
172 | 'uspto_rxn_condition_remapped_and_reassign_condition_role_rm_duplicate_rm_excess.csv'
173 | ),
174 | index=False)
175 | print('Remaining data in the end:', len(database_remove_below_threshold))
176 | print('Unique canonical reaction:', len(set(database_remove_below_threshold['canonical_rxn'])))
177 | print('Unique remapped reaction:', len(set(database_remove_below_threshold['remapped_rxn'])))
178 | print('Unique catalyst:', len(set(database_remove_below_threshold['catalyst'])))
179 | print('Unique solvent:', len(set(database_remove_below_threshold['solvent'])))
180 | print('Unique reagent:', len(set(database_remove_below_threshold['reagent'])))
181 | print('Done!')
182 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/4.0.split_train_val_test.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import os
3 | import json
4 | import pandas as pd
5 | import random
6 | from tqdm import tqdm
7 |
8 |
9 | if __name__ == '__main__':
10 | debug = False
11 | print('Debug:', debug)
12 | seed = 123
13 | random.seed(seed)
14 | # 数据集: canonical_rxn --> catalyst, solvent1, solvent2, reagent1, reagent2
15 | split_token = '分'
16 | split_frac = (0.8, 0.1, 0.1) # train:val:test
17 | source_data_path = '.'
18 | database_remove_below_threshold_fname = 'uspto_rxn_condition_remapped_and_reassign_condition_role_rm_duplicate_rm_excess.csv'
19 | final_condition_data_path = os.path.join(source_data_path, 'USPTO_condition_final')
20 | if not os.path.exists(final_condition_data_path):
21 | os.makedirs(final_condition_data_path)
22 | if debug:
23 | database = pd.read_csv(os.path.join(source_data_path, database_remove_below_threshold_fname), nrows=10000)
24 | else:
25 | database = pd.read_csv(os.path.join(source_data_path, database_remove_below_threshold_fname))
26 |
27 | database = database[['id', 'source', 'canonical_rxn', 'catalyst_split', 'solvent_split', 'reagent_split']]
28 | database['catalyst1'] = database['catalyst_split']
29 | split_solvent = database['solvent_split'].str.split(split_token, 1, expand=True)
30 | database['solvent1'], database['solvent2'] = split_solvent[0], split_solvent[1]
31 | split_reagent = database['reagent_split'].str.split(split_token, 1, expand=True)
32 | database['reagent1'], database['reagent2'] = split_reagent[0], split_reagent[1]
33 | columns = ['id', 'source', 'canonical_rxn', 'catalyst1', 'solvent1', 'solvent2', 'reagent1', 'reagent2']
34 | database = database[columns]
35 | database_sample = database.sample(frac=1, random_state=seed)
36 |
37 | # Random split (no rxn overlap)
38 | can_rxn2idx_dict = defaultdict(list)
39 | for idx, row in tqdm(database_sample.iterrows(), total=len(database_sample)):
40 | can_rxn2idx_dict[row.canonical_rxn].append(idx)
41 | train_idx, val_idx, test_idx = [], [], []
42 | can_rxn2idx_dict_items = list(can_rxn2idx_dict.items())
43 | random.shuffle(can_rxn2idx_dict_items)
44 | all_data_number = len(database_sample)
45 | for rxn, idx_list in tqdm(can_rxn2idx_dict_items):
46 | if len(idx_list) == 1:
47 | if len(test_idx) < split_frac[2] * all_data_number:
48 | test_idx += idx_list
49 | elif len(val_idx) < split_frac[1] * all_data_number:
50 | val_idx += idx_list
51 | else:
52 | train_idx += idx_list
53 | else:
54 | train_idx += idx_list
55 |
56 | database_sample.loc[train_idx, 'dataset'] = 'train'
57 | database_sample.loc[val_idx, 'dataset'] = 'val'
58 | database_sample.loc[test_idx, 'dataset'] = 'test'
59 | database_sample.to_csv(os.path.join(final_condition_data_path, 'USPTO_condition.csv'), index=False)
60 |
61 | # Time split
62 | with open('patent_info.json') as f:
63 | patent_info = json.load(f)
64 | train_idx, val_idx, test_idx = [], [], []
65 | for idx, patent_id in enumerate(database_sample['source']):
66 | if patent_info[patent_id]['year'] in [2016]:
67 | test_idx.append(idx)
68 | elif patent_info[patent_id]['year'] in [2015]:
69 | val_idx.append(idx)
70 | else:
71 | train_idx.append(idx)
72 | year_data_path = 'USPTO_condition_year'
73 | os.makedirs(year_data_path, exist_ok=True)
74 | train_df = database_sample.iloc[train_idx]
75 | train_df.to_csv(os.path.join(year_data_path, 'USPTO_condition_train.csv'), index=False)
76 | val_df = database_sample.iloc[val_idx]
77 | val_df.to_csv(os.path.join(year_data_path, 'USPTO_condition_val.csv'), index=False)
78 | test_df = database_sample.iloc[test_idx]
79 | test_df.to_csv(os.path.join(year_data_path, 'USPTO_condition_test.csv'), index=False)
80 |
81 | # Grant time split
82 | with open('patent_info.json') as f:
83 | patent_info = json.load(f)
84 | train_idx, val_idx, test_idx = [], [], []
85 | for idx, patent_id in enumerate(database_sample['source']):
86 | if patent_info[patent_id]['type'] != 'grant':
87 | continue
88 | if patent_info[patent_id]['year'] in [2016]:
89 | test_idx.append(idx)
90 | elif patent_info[patent_id]['year'] in [2015]:
91 | val_idx.append(idx)
92 | else:
93 | train_idx.append(idx)
94 | year_data_path = 'USPTO_condition_grant_year'
95 | os.makedirs(year_data_path, exist_ok=True)
96 | train_df = database_sample.iloc[train_idx]
97 | train_df.to_csv(os.path.join(year_data_path, 'USPTO_condition_train.csv'), index=False)
98 | val_df = database_sample.iloc[val_idx]
99 | val_df.to_csv(os.path.join(year_data_path, 'USPTO_condition_val.csv'), index=False)
100 | test_df = database_sample.iloc[test_idx]
101 | test_df.to_csv(os.path.join(year_data_path, 'USPTO_condition_test.csv'), index=False)
102 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/5.0.convert_context_tokens.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | from tqdm import tqdm
4 | import pickle
5 |
6 | BOS, EOS, PAD, MASK = '[BOS]', '[EOS]', '[PAD]', '[MASK]'
7 | UNK, SEP = '[UNK]', '[SEP]'
8 |
9 |
10 | def get_condition2idx_mapping(all_condition_data: pd.DataFrame):
11 | col_unique_data = [BOS, EOS, PAD, MASK]
12 | for col in all_condition_data.columns.tolist():
13 | one_col_unique = list(set(all_condition_data[col].tolist()))
14 | col_unique_data.extend(one_col_unique)
15 | col_unique_data = list(set(col_unique_data))
16 | col_unique_data.sort()
17 | idx2data = {i: x for i, x in enumerate(col_unique_data)}
18 | data2idx = {x: i for i, x in enumerate(col_unique_data)}
19 | return idx2data, data2idx
20 |
21 |
22 | def get_condition_vocab(all_condition_data: pd.DataFrame):
23 | col_unique_data = []
24 | for col in all_condition_data.columns.tolist():
25 | one_col_unique = list(set(all_condition_data[col].tolist()))
26 | col_unique_data.extend(one_col_unique)
27 | col_unique_data = list(set(col_unique_data))
28 | col_unique_data.sort()
29 | vocab = [PAD, BOS, EOS, MASK, UNK, SEP] + col_unique_data
30 | return vocab
31 |
32 |
33 | if __name__ == '__main__':
34 | calculate_fps = False
35 | convert_aug_data = False
36 |
37 | debug = False
38 | print('Debug:', debug, 'Convert data:', convert_aug_data)
39 | fp_size = 16384
40 | source_data_path = '.'
41 | final_condition_data_path = os.path.join(source_data_path, 'USPTO_condition_final')
42 | if convert_aug_data:
43 | database_fname = 'USPTO_condition_aug_n5.csv'
44 | else:
45 | database_fname = 'USPTO_condition.csv'
46 |
47 | if debug:
48 | database = pd.read_csv(os.path.join(final_condition_data_path, database_fname), nrows=10000)
49 | final_condition_data_path = os.path.join(source_data_path, 'USPTO_condition_final_debug')
50 | if not os.path.exists(final_condition_data_path):
51 | os.makedirs(final_condition_data_path)
52 | database.to_csv(os.path.join(final_condition_data_path, database_fname), index=False)
53 | else:
54 | database = pd.read_csv(os.path.join(final_condition_data_path, database_fname), keep_default_na=False)
55 |
56 | condition_cols = ['catalyst1', 'solvent1', 'solvent2', 'reagent1', 'reagent2']
57 |
58 | vocab = get_condition_vocab(database[condition_cols])
59 | vocab_path = os.path.join(final_condition_data_path, 'vocab_smiles.txt')
60 | with open(vocab_path, 'w') as f:
61 | f.write('\n'.join(vocab))
62 |
63 | all_idx2data, all_data2idx = get_condition2idx_mapping(database[condition_cols])
64 | all_idx_mapping_data_fpath = os.path.join(
65 | final_condition_data_path,
66 | '{}_alldata_idx.pkl'.format(database_fname.split('.')[0]))
67 | with open(all_idx_mapping_data_fpath, 'wb') as f:
68 | pickle.dump((all_idx2data, all_data2idx), f)
69 | all_condition_labels = []
70 | for _, row in tqdm(database[condition_cols].iterrows(), total=len(database)):
71 | row = list(row)
72 | row = ['[BOS]'] + row + ['[EOS]']
73 | all_condition_labels.append([all_data2idx[x] for x in row])
74 |
75 | all_condition_labels_fpath = os.path.join(
76 | final_condition_data_path,
77 | '{}_condition_labels.pkl'.format(database_fname.split('.')[0]))
78 | with open(all_condition_labels_fpath, 'wb') as f:
79 | pickle.dump((all_condition_labels), f)
80 | print('Done!')
81 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/extract_nosmiles.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pickle
4 | # from show_coincidence import read_json
5 | from rdkit import Chem
6 | from rdkit import RDLogger
7 |
8 | from utils import canonicalize_smiles, read_pickle
9 | RDLogger.DisableLog('rdApp.*')
10 |
11 |
12 |
13 | if __name__ == '__main__':
14 | all_data = {}
15 | for fname in os.listdir('../'):
16 | name, ext = os.path.splitext(fname)
17 | fpath = os.path.join('..', fname)
18 | if ext == '.pickle':
19 | all_data['{}'.format(name.replace('_dict', ''))] = read_pickle(fpath)
20 | all_reaxys_name = []
21 | all_data_reaxys_name_comp = {}
22 | all_data_rm = {}
23 | for k in all_data:
24 | can_err = 0
25 | all_data_rm[k] = {}
26 | all_data_reaxys_name_comp[k] = []
27 | idx = 0
28 | for comp_idx in all_data[k]:
29 | comp = all_data[k][comp_idx]
30 | if 'Reaxys' not in comp:
31 | if comp == '':
32 | all_data_rm[k][idx] = comp
33 | idx += 1
34 | continue
35 | comp_can = canonicalize_smiles(comp)
36 | if comp_can == '':
37 | can_err += 1
38 | continue
39 | all_data_rm[k][idx] = comp_can
40 | idx += 1
41 | elif 'Reaxys Name' in comp:
42 | all_data_reaxys_name_comp[k].append(comp)
43 | all_reaxys_name.append(comp)
44 |
45 | print(f'{k} canonicalize fail: {can_err}')
46 | print('{}: {}'.format(k, len(all_data_rm[k])))
47 | print('{}: Reaxys Name {}'.format(k, len(all_data_reaxys_name_comp[k])))
48 |
49 | with open('../check_data/condition_compound.pkl', 'wb') as f:
50 | pickle.dump(all_data_rm, f)
51 | with open('../check_data/reaxys_name_comp.pkl', 'wb') as f:
52 | pickle.dump(all_data_reaxys_name_comp, f)
53 |
54 | with open('../check_data/condition_compound.json', 'w', encoding='utf-8') as f:
55 | json.dump(all_data_rm, f)
56 | with open('../check_data/reaxys_name_comp.json', 'w', encoding='utf-8') as f:
57 | json.dump(all_data_reaxys_name_comp, f)
58 |
59 | with open('../check_data/all_reaxys_name.txt', 'w', encoding='utf-8') as f:
60 | f.write('\n'.join(all_reaxys_name))
61 |
62 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/gen_grant_corpus.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 |
4 |
5 | corpus_df = pd.read_csv('uspto_rxn_corpus.csv')
6 |
7 | grant_df = corpus_df[corpus_df['patent_type'] == 'grant']
8 | grant_df.to_csv('USPTO_rxn_grant_corpus.csv', index=False)
9 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/get_aug_condition_data.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 | import os
3 | import pandas as pd
4 | import sys
5 | sys.path.append('../pretrain_data_script/')
6 | from get_pretrain_dataset_augmentation import get_random_rxn
7 |
8 |
9 | '''
10 | 用于USPTO-Condition数据集的数据增强,在本脚本中我们只增强训练集'''
11 | def generate_one_rxn_condition_aug(tuple_data):
12 | i, len_all, source, rxn, c1, s1, s2, r1, r2, dataset_flag = tuple_data
13 | if i > 0 and i % 1000 == 0:
14 | print(f"Processing {i}th Reaction / {len_all}")
15 | results = [(source, rxn, c1, s1, s2, r1, r2, dataset_flag)]
16 | for i in range(N-1):
17 | random_rxn = get_random_rxn(rxn)
18 | results.append((source, random_rxn, c1, s1, s2, r1, r2, dataset_flag))
19 | return results
20 |
21 | if __name__ == '__main__':
22 | debug = False
23 |
24 | N = 5
25 | num_workers = 10
26 |
27 |
28 |
29 | source_data_path = '../../dataset/source_dataset/'
30 | final_condition_data_path = os.path.join(source_data_path, 'USPTO_condition_final')
31 | database_fname = 'USPTO_condition.csv'
32 |
33 |
34 | uspto_condition_dataset = pd.read_csv(os.path.join(final_condition_data_path, database_fname))
35 |
36 |
37 | uspto_condition_train_df = uspto_condition_dataset.loc[uspto_condition_dataset['dataset']=='train']
38 | uspto_condition_val_test_df = uspto_condition_dataset.loc[uspto_condition_dataset['dataset']!='train']
39 | if debug:
40 | uspto_condition_train_df = uspto_condition_train_df.sample(3000)
41 |
42 | p = Pool(num_workers)
43 | all_len = len(uspto_condition_train_df)
44 |
45 | augmentation_rxn_condition_train_data = p.imap(
46 | generate_one_rxn_condition_aug,
47 | ((i, all_len, *row.tolist()) for i, (index, row) in enumerate(uspto_condition_train_df.iterrows()))
48 | )
49 |
50 | p.close()
51 | p.join()
52 | augmentation_rxn_condition_train_data = list(augmentation_rxn_condition_train_data)
53 | augmentation_rxn_condition_train = [] # 一个标准rxn smiles N-1个random smiles
54 | for one in augmentation_rxn_condition_train_data:
55 | augmentation_rxn_condition_train += one
56 |
57 | aug_train_df = pd.DataFrame(augmentation_rxn_condition_train)
58 | aug_train_df.columns = uspto_condition_train_df.columns.tolist()
59 |
60 | aug_uspto_conditon_dataset = aug_train_df.append(uspto_condition_val_test_df).reset_index(drop=True)
61 | aug_uspto_conditon_dataset.to_csv(os.path.join(final_condition_data_path, 'USPTO_condition_aug_n5.csv'), index=False)
--------------------------------------------------------------------------------
/preprocess/uspto_script/get_dataset_for_condition.py:
--------------------------------------------------------------------------------
1 | '''
2 | 半成品,按照已有的反应条件去smiles数据集里面匹配试剂等反应条件分子
3 | '''
4 | from collections import defaultdict
5 | import pickle
6 | import pandas as pd
7 | from tqdm import tqdm
8 | from rdkit import RDLogger
9 | RDLogger.DisableLog('rdApp.*')
10 | from utils import canonicalize_smiles, covert_to_series
11 |
12 | NONE_PANDDING = '[None]'
13 |
14 | class AssignmentCondition:
15 | def __init__(self, condition_compound_dict) -> None:
16 | self.padding = 0
17 | self.condition_compound_dict = condition_compound_dict
18 | self.condition_compound_df_dict= {}
19 | for k in condition_compound_dict:
20 | new_data = {}
21 | data = condition_compound_dict[k]
22 | max_col = 0
23 | for idx in data:
24 | smi = data[idx]
25 | if smi == '':
26 | smi = NONE_PANDDING
27 | new_data[smi] = smi.split('.')
28 | if max_col < len(new_data[smi]):
29 | max_col = len(new_data[smi])
30 | for smi in new_data:
31 | new_data[smi] = new_data[smi] + [self.padding] * (max_col - len(new_data[smi]))
32 | self.condition_compound_df_dict[k] = pd.DataFrame(new_data)
33 |
34 | print('Read solvet {}, reagent {}, catalyst {}'.format(
35 | len(condition_compound_dict['s1']),
36 | len(condition_compound_dict['r1']),
37 | len(condition_compound_dict['c1'])))
38 |
39 | def apply(self, reag_ser):
40 | one_condition_dict = defaultdict(list)
41 | for name in ['c1', 'r1', 's1']:
42 | df = self.condition_compound_df_dict[name]
43 | for smi in df.columns.tolist():
44 | data_series = df[smi][df[smi]!=self.padding]
45 | if data_series.isin(reag_ser).sum() == len(data_series):
46 | one_condition_dict[name].append(smi)
47 |
48 | return one_condition_dict
49 |
50 | if __name__ == '__main__':
51 | debug = True
52 |
53 | condition_compound_fpath = '../check_data/condition_compound_add_name2smiles.pkl'
54 | if not debug:
55 | print('Reading USPTO-1k-tpl...')
56 | uspto_1k_tpl_train = pd.read_csv('../../dataset/source_dataset/uspto_1k_TPL_train_valid.tsv.gzip', compression='gzip', sep='\t', index_col=0)
57 | uspto_1k_tpl_test = pd.read_csv('../../dataset/source_dataset/uspto_1k_TPL_test.tsv.gzip', compression='gzip', sep='\t', index_col=0)
58 | uspto_1k_tpl = uspto_1k_tpl_train.append(uspto_1k_tpl_test)
59 | print('# data: {}'.format(len(uspto_1k_tpl)))
60 | else:
61 | print('Debug...')
62 | print('Reading debug tsv...')
63 | debug_df = pd.read_csv('../../dataset/source_dataset/debug_df.tsv', sep='\t')
64 | uspto_1k_tpl = debug_df
65 |
66 |
67 | uspto_1k_tpl['reagents_series'] = [covert_to_series(canonicalize_smiles(x), none_pandding=NONE_PANDDING) for x in tqdm(uspto_1k_tpl.reagents.tolist())]
68 | with open(condition_compound_fpath, 'rb') as f:
69 | condition_compound_dict = pickle.load(f)
70 | assignment_cls = AssignmentCondition(condition_compound_dict=condition_compound_dict)
71 | for reag_ser in tqdm(uspto_1k_tpl['reagents_series'].tolist()):
72 | assignment_cls.apply(reag_ser)
73 | pass
74 |
75 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/get_dummy_model_results.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 | import sys
5 |
6 | from tqdm import tqdm
7 | sys.path = [os.path.abspath(os.path.join(os.path.abspath(__file__), '../../../'))] + sys.path
8 | from models.utils import load_dataset
9 | import torch
10 |
11 | dummy_prediction = [
12 | [0, 0, 0, 0, 0], # [na, na, na, na, na]
13 | [0, 160, 0, 0, 0], # [na, DCM, na, na, na]
14 | [0, 16, 0, 0, 0], # [na, THF, na, na, na]
15 | [0, 0, 0, 85, 0], # [na, na, na, TEA, na]
16 | [0, 0, 0, 202, 0], # [na, na, na, K2CO3, na]
17 | [0, 160, 0, 85, 0], # [na, DCM, na, TEA, na]
18 | [0, 16, 0, 85, 0], # [na, THF, na, TEA, na]
19 | [0, 160, 0, 202, 0], # [na, DCM, na, K2CO3, na]
20 | [0, 16, 0, 202, 0], # [na, THF, na, K2CO3, na]
21 | [298, 0, 0, 0, 0], # [[Pd], na, na, na, na]
22 | ]
23 |
24 | def _get_accuracy_for_one(one_pred, one_ground_truth, topk_get=[1, 3, 5, 10, 15]):
25 | repeat_number = one_pred.size(0)
26 | hit_mat = one_ground_truth.unsqueeze(
27 | 0).repeat(repeat_number, 1) == one_pred
28 | overall_hit_mat = hit_mat.sum(1) == hit_mat.size(1)
29 | topk_hit_df = pd.DataFrame()
30 | for k in topk_get:
31 | hit_mat_k = hit_mat[:k, :]
32 | overall_hit_mat_k = overall_hit_mat[:k]
33 | topk_hit = []
34 | for col_idx in range(hit_mat.size(1)):
35 | if hit_mat_k[:, col_idx].sum() != 0:
36 | topk_hit.append(1)
37 | else:
38 | topk_hit.append(0)
39 | if overall_hit_mat_k.sum() != 0:
40 | topk_hit.append(1)
41 | else:
42 | topk_hit.append(0)
43 | topk_hit_df[k] = topk_hit
44 | # topk_hit_df.index = ['c1', 's1', 's2', 'r1', 'r2']
45 | return topk_hit_df
46 |
47 | def _calculate_batch_topk_hit(batch_preds, batch_ground_truth, topk_get=[1, 3, 5, 10, 15]):
48 | '''
49 | batch_pred <-- tgt_tokens_list
50 | batch_ground_truth <-- inputs['labels']
51 | '''
52 | batch_preds = torch.tensor(batch_preds)[:, :, :].to(torch.device('cpu'))
53 | batch_ground_truth = batch_ground_truth[:, 1:-1]
54 |
55 | one_batch_topk_acc_mat = np.zeros((6, 5))
56 | # topk_get = [1, 3, 5, 10, 15]
57 | for idx in range(batch_preds.size(0)):
58 | topk_hit_df = _get_accuracy_for_one(
59 | batch_preds[idx], batch_ground_truth[idx], topk_get=topk_get)
60 | one_batch_topk_acc_mat += topk_hit_df.values
61 | return one_batch_topk_acc_mat
62 |
63 | if __name__ == '__main__':
64 | topk_get = [1, 3, 5, 10, 15]
65 | source_data_path = '../../dataset/source_dataset/'
66 | uspto_root = os.path.abspath(os.path.join(source_data_path, 'USPTO_condition_final'))
67 | database_df, condition_label_mapping = load_dataset(dataset_root=uspto_root, database_fname='USPTO_condition.csv', use_temperature=False)
68 |
69 | test_df = database_df.loc[database_df['dataset']=='test']
70 | topk_acc_mat = np.zeros((6, 5))
71 | for gt in tqdm(test_df['condition_labels'].tolist()):
72 | one_batch_topk_acc_mat = _calculate_batch_topk_hit(batch_preds=[dummy_prediction], batch_ground_truth=torch.tensor([gt]))
73 | topk_acc_mat += one_batch_topk_acc_mat
74 | topk_acc_mat /= len(test_df['condition_labels'].tolist())
75 | topk_acc_df = pd.DataFrame(topk_acc_mat)
76 | topk_acc_df.columns = [f'top-{k} accuracy' for k in topk_get]
77 | topk_acc_df.index = ['c1', 's1', 's2', 'r1', 'r2', 'overall']
78 | print(topk_acc_df)
79 |
80 | '''
81 | top-1 accuracy top-3 accuracy top-5 accuracy top-10 accuracy top-15 accuracy
82 | c1 0.869600 0.869600 0.869600 0.914682 0.914682
83 | s1 0.010180 0.309805 0.309805 0.309805 0.309805
84 | s2 0.808549 0.808549 0.808549 0.808549 0.808549
85 | r1 0.260859 0.260859 0.377216 0.377216 0.377216
86 | r2 0.746515 0.746515 0.746515 0.746515 0.746515
87 | overall 0.000059 0.043085 0.043085 0.066074 0.066074
88 | '''
--------------------------------------------------------------------------------
/preprocess/uspto_script/get_fragment_from_rxn_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 | import pandas as pd
5 | from rdkit import Chem
6 | from rdkit.Chem import Recap
7 | from rdkit.Chem import AllChem as Chem
8 | from multiprocessing import Pool
9 | from rdkit.Chem import BRICS
10 | from tqdm import tqdm
11 | import pickle
12 | from collections import defaultdict
13 | from get_pretrain_dataset_with_rxn_center import timeout
14 |
15 |
16 | def read_txt(fpath):
17 | with open(fpath, 'r', encoding='utf-8') as f:
18 | return [x.strip() for x in f.readlines()]
19 |
20 |
21 | # def get_frag_from_rxn_recap(rxn):
22 | # # one_set = set()
23 | # react, prod = rxn.split('>>')
24 | # reacts = react.split('.')
25 | # prods = prod.split('.')
26 | # # react_mol, prod_mol = Chem.MolFromSmiles(react), Chem.MolFromSmiles(prod)
27 | # one_fragmet_dict = set()
28 | # mols = [Chem.MolFromSmiles(smi) for smi in reacts+prods]
29 | # for mol in mols:
30 | # hierarch = Recap.RecapDecompose(mol)
31 | # one_fragmet_dict.update(set(hierarch.GetLeaves().keys()))
32 | # # one_fragmet_dict = set(list(hierarch_react.GetLeaves().keys())+list(hierarch_prod.GetLeaves().keys()))
33 | # return one_fragmet_dict
34 |
35 |
36 | def get_frag_from_rxn_brics(rxn):
37 | # one_set = set()
38 | react, prod = rxn.split('>>')
39 | reacts = react.split('.')
40 | prods = prod.split('.')
41 | # react_mol, prod_mol = Chem.MolFromSmiles(react), Chem.MolFromSmiles(prod)
42 | one_fragmet_dict = defaultdict(int)
43 | mols = [Chem.MolFromSmiles(smi) for smi in reacts+prods]
44 | for mol in mols:
45 | try:
46 | with timeout():
47 |
48 | frags = list(BRICS.BRICSDecompose(mol))
49 | for frag in frags:
50 | frag = re.sub('\[([0-9]+)\*\]', '*', frag)
51 | if frag not in reacts + prods:
52 | one_fragmet_dict[frag] += 1
53 | except Exception as e:
54 | print(e)
55 | pass
56 |
57 | # one_fragmet_dict.update()
58 | return one_fragmet_dict
59 |
60 |
61 |
62 | if __name__ == '__main__':
63 |
64 |
65 |
66 | if not os.path.exists('../check_data/frag.pkl'):
67 | pretrain_dataset_path = '../../dataset/pretrain_data/'
68 | fnames = ['mlm_rxn_train.txt', 'mlm_rxn_val.txt']
69 | pretrain_reactions = []
70 |
71 | for fname in fnames:
72 | pretrain_reactions += read_txt(os.path.join(pretrain_dataset_path, fname))
73 |
74 |
75 | fragments = defaultdict(int)
76 | pool = Pool(12)
77 | for one_fragmet_dict in tqdm(pool.imap(get_frag_from_rxn_brics, pretrain_reactions), total=len(pretrain_reactions)):
78 | # for rxn in tqdm(pretrain_reactions):
79 | # one_fragmet_dict = get_frag_from_rxn_recap(rxn)
80 | for frag in one_fragmet_dict:
81 | fragments[frag] += one_fragmet_dict[frag]
82 | pool.close()
83 |
84 | with open('../check_data/frag.pkl', 'wb') as f:
85 | pickle.dump(fragments, f)
86 | with open('../check_data/frag.json', 'w', encoding='utf-8') as f:
87 | json.dump(fragments, f)
88 |
89 | else:
90 | with open('../check_data/frag.pkl', 'rb') as f:
91 | fragments = pickle.load(f)
92 | print('Fragments #:', len(fragments))
93 | fragments_items = list(fragments.items())
94 | fragments_items.sort(key=lambda x:x[1])
95 |
96 | top_number = 10000
97 | write_top_frag = ['{},{}'.format(x[0], x[1]) for x in fragments_items if x[1] > top_number]
98 | print(f'fragment count > {top_number}: {len(write_top_frag)}')
99 | with open(f'../check_data/frag_cnt_nubmer_{top_number}.txt', 'w', encoding='utf-8') as f:
100 | f.write('\n'.join(write_top_frag))
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/merge_comp.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 |
4 | from utils import canonicalize_smiles, read_pickle
5 | from rdkit import RDLogger
6 | RDLogger.DisableLog('rdApp.*')
7 |
8 |
9 | if __name__ == '__main__':
10 | condition_compound = read_pickle('../check_data/condition_compound.pkl')
11 |
12 | merge_condition_compound = {}
13 |
14 | for key in ['c1', 's1', 'r1']:
15 | merge_condition_compound[key] = condition_compound[key]
16 | idx = len(merge_condition_compound[key])
17 | old_smiles_list = list(merge_condition_compound[key].values())
18 | print('{} old: {}'.format(key, len(old_smiles_list)))
19 | with open('../check_data/qurey_reaxys_names_{}_smi.txt'.format(key), 'r', encoding='utf-8') as f:
20 | new_smiles_list = list(set([canonicalize_smiles(x.strip()) for x in f.readlines()]))
21 | for smi in new_smiles_list:
22 | if (smi != '') and (smi not in old_smiles_list):
23 | merge_condition_compound[key][idx] = smi
24 | idx += 1
25 | print('{} now: {}'.format(key, len(merge_condition_compound[key])))
26 |
27 | with open('../check_data/condition_compound_add_name2smiles.pkl', 'wb') as f:
28 | pickle.dump(merge_condition_compound, f)
29 | with open('../check_data/condition_compound_add_name2smiles.json', 'w', encoding='utf-8') as f:
30 | json.dump(merge_condition_compound, f)
31 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/uspto_condition.md:
--------------------------------------------------------------------------------
1 | # USPTO-Condition Curation
2 | ```
3 | cd Parrot/preprocess_script/uspto_script
4 | mkdir ../../dataset/source_dataset/uspto_org_xml/
5 | ```
6 | Download the original USPTO reaction dataset from [here](https://figshare.com/articles/dataset/Chemical_reactions_from_US_patents_1976-Sep2016_/5104873). Put `1976_Sep2016_USPTOgrants_cml.7z` and `2001_Sep2016_USPTOapplications_cml.7z` under `../../dataset/source_dataset/uspto_org_xml/`.
7 |
8 |
9 | The running environment of rxnmapper is required when generating USPTO-Condition. This environment is not compatible with parrot_env. Please re-create a virtual environment of rxnmapper. See [rxnmapper github repository](https://github.com/rxn4chemistry/rxnmapper) for details.
10 |
11 |
12 | Then:
13 | ```
14 | cd ../../dataset/source_dataset/uspto_org_xml/
15 | 7z x 1976_Sep2016_USPTOgrants_cml.7z
16 | 7z x 2001_Sep2016_USPTOapplications_cml.7z
17 | cd Parrot/preprocess_script/uspto_script
18 | python 1.get_condition_from_uspto.py
19 | sh 2.0.clean_up_rxn_condition.sh
20 | python 2.1.merge_clean_up_rxn_conditon.py
21 | python 3.0.split_condition_and_slect.py
22 | python 4.0.split_train_val_test.py
23 | python 5.0.convert_context_tokens.py
24 | ```
25 | Done!
26 |
--------------------------------------------------------------------------------
/preprocess/uspto_script/utils.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, namedtuple
2 | import csv
3 | import json
4 | import os
5 | import pickle
6 | from rdkit import Chem
7 | import pandas as pd
8 | from rdkit.Chem.SaltRemover import SaltRemover, InputFormat
9 |
10 | # atapted from https://github.com/Coughy1991/Reaction_condition_recommendation/blob/64f151e302abcb87e0a14088e08e11b4cea8d5ab/scripts/prepare_data_cont_2_rgt_2_slv_1_cat_temp_deploy.py#L658-L690
11 | list_of_metal_atoms = ['Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag',
12 | 'Cd', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg',
13 | # 'Cn',
14 | 'Ln',
15 | 'Ce',
16 | 'Pr',
17 | 'Nd',
18 | 'Pm',
19 | 'Sm',
20 | 'Eu',
21 | 'Gd',
22 | 'Tb',
23 | 'Dy',
24 | 'Ho',
25 | 'Er',
26 | 'Tm',
27 | 'Yb',
28 | 'Lu',
29 | 'Ac',
30 | 'Th',
31 | 'Pa',
32 | 'U',
33 | 'Np',
34 | 'Am',
35 | 'Cm',
36 | 'Bk',
37 | 'Cf',
38 | 'Es',
39 | 'Fm',
40 | 'Md',
41 | 'No',
42 | 'Lr',
43 | ]
44 |
45 | mol_charge_class = [
46 | 'Postive',
47 | 'Negative',
48 | 'Neutral'
49 | ]
50 |
51 | class MolRemover(SaltRemover):
52 | def __init__(self, defnFilename=None, defnData=None, defnFormat=InputFormat.SMARTS):
53 | super().__init__(defnFilename, defnData, defnFormat)
54 |
55 | def _StripMol(self, mol, dontRemoveEverything=False, onlyFrags=False):
56 |
57 | def _applyPattern(m, salt, notEverything, onlyFrags=onlyFrags):
58 | nAts = m.GetNumAtoms()
59 | if not nAts:
60 | return m
61 | res = m
62 |
63 | t = Chem.DeleteSubstructs(res, salt, onlyFrags)
64 | if not t or (notEverything and t.GetNumAtoms() == 0):
65 | return res
66 | res = t
67 | while res.GetNumAtoms() and nAts > res.GetNumAtoms():
68 | nAts = res.GetNumAtoms()
69 | t = Chem.DeleteSubstructs(res, salt, True)
70 | if notEverything and t.GetNumAtoms() == 0:
71 | break
72 | res = t
73 | return res
74 |
75 | StrippedMol = namedtuple('StrippedMol', ['mol', 'deleted'])
76 | deleted = []
77 | if dontRemoveEverything and len(Chem.GetMolFrags(mol)) <= 1:
78 | return StrippedMol(mol, deleted)
79 | modified = False
80 | natoms = mol.GetNumAtoms()
81 | for salt in self.salts:
82 | mol = _applyPattern(mol, salt, dontRemoveEverything, onlyFrags)
83 | if natoms != mol.GetNumAtoms():
84 | natoms = mol.GetNumAtoms()
85 | modified = True
86 | deleted.append(salt)
87 | if dontRemoveEverything and len(Chem.GetMolFrags(mol)) <= 1:
88 | break
89 | if modified and mol.GetNumAtoms() > 0:
90 | Chem.SanitizeMol(mol)
91 | return StrippedMol(mol, deleted)
92 |
93 | def StripMolWithDeleted(self, mol, dontRemoveEverything=False, onlyFrags=False):
94 | return self._StripMol(mol, dontRemoveEverything, onlyFrags=onlyFrags)
95 |
96 | def pickle2json(fname):
97 | name, ext = os.path.splitext(fname)
98 | with open(fname, 'rb') as f:
99 | data = pickle.load(f)
100 |
101 | with open(name + '.json', 'w', encoding='utf-8') as f:
102 | json.dump(data, f)
103 |
104 |
105 | def read_json(fpath):
106 | with open(fpath, 'r', encoding='utf-8') as f:
107 | data = json.load(f)
108 | return data
109 |
110 |
111 | def read_pickle(fname):
112 | with open(fname, 'rb') as f:
113 | return pickle.load(f)
114 |
115 |
116 | def canonicalize_smiles(smi, clear_map=False):
117 | if pd.isna(smi):
118 | return ''
119 | mol = Chem.MolFromSmiles(smi)
120 | if mol is not None:
121 | if clear_map:
122 | [atom.ClearProp('molAtomMapNumber') for atom in mol.GetAtoms()]
123 | return Chem.MolToSmiles(mol)
124 | else:
125 | return ''
126 |
127 |
128 | def covert_to_series(smi, none_pandding):
129 | if pd.isna(smi):
130 | return pd.Series([none_pandding])
131 | elif smi == '':
132 | return pd.Series([none_pandding])
133 | else:
134 | return pd.Series(smi.split('.'))
135 |
136 |
137 | def get_writer(output_name, header):
138 | # output_name = os.path.join(cmd_args.save_dir, fname)
139 | fout = open(output_name, 'w')
140 | writer = csv.writer(fout)
141 | writer.writerow(header)
142 | return fout, writer
143 |
144 |
145 | def calculate_frequency(input_list, report=True):
146 | output_dict = defaultdict(int)
147 | for x in input_list:
148 | output_dict[x] += 1
149 |
150 | output_items = list(output_dict.items())
151 | output_items.sort(key=lambda x: x[1], reverse=True)
152 |
153 | if report:
154 | report_threshold = [10000, 5000, 1000, 500, 100, 50, 1]
155 | for t in report_threshold:
156 | t_list = [x for x in output_items if x[1] > t]
157 | print('Frequency >={} : {}'.format(t, len(t_list)))
158 |
159 | return output_items
160 |
161 |
162 | def get_mol_charge(mol):
163 | mol_neutralization = None
164 | positive = []
165 | negative = []
166 | for atom in mol.GetAtoms():
167 | charge = atom.GetFormalCharge()
168 | if charge > 0:
169 | positive.append(charge)
170 | elif charge < 0:
171 | negative.append(charge)
172 | if len(positive) == 0 and len(negative) == 0:
173 | mol_charge_flag = mol_charge_class[2]
174 | mol_neutralization = False
175 | elif len(positive) != 0 and len(negative) == 0:
176 | mol_charge_flag = mol_charge_class[0]
177 | mol_neutralization = False
178 | elif len(positive) == 0 and len(negative) != 0:
179 | mol_charge_flag = mol_charge_class[1]
180 | mol_neutralization = False
181 | elif len(positive) != 0 and len(negative) != 0:
182 | mol_charge = sum(positive) + sum(negative)
183 | if mol_charge > 0:
184 | mol_charge_flag = mol_charge_class[0]
185 | elif mol_charge < 0:
186 | mol_charge_flag = mol_charge_class[1]
187 | else:
188 | mol_charge_flag = mol_charge_class[2]
189 | mol_neutralization = True
190 | return mol_charge_flag, mol_neutralization
191 |
192 |
193 |
194 |
195 |
196 | if __name__ == '__main__':
197 | # smi = 'CC(C)C[Al+]CC(C)C.O=C(O)CC(O)(CC(=O)O)C(=O)O.[H-]'
198 | # smi = 'CCN(CC)CC.CN(C)C(On1nnc2ccccc21)=[N+](C)C.F[B-](F)(F)F'
199 | smi = 'O.[Al+3].[H-].[H-].[H-].[H-].[Li+].[Na+].[OH-]'
200 | mol = Chem.MolFromSmiles(smi)
201 | # print(get_mol_charge(mol))
202 | # from rdkit.Chem.SaltRemover import SaltRemover
203 | remover = MolRemover(defnFilename='../reagent_Ionic_compound.txt')
204 | remover.StripMolWithDeleted(
205 | mol, dontRemoveEverything=False, onlyFrags=False)
206 |
--------------------------------------------------------------------------------
/retrieve/condition_year.sh:
--------------------------------------------------------------------------------
1 | python retrieve_faiss.py \
2 | --data_path ../data/USPTO_condition_year \
3 | --train_file USPTO_condition_train.csv \
4 | --valid_file USPTO_condition_val.csv \
5 | --test_file USPTO_condition_test.csv \
6 | --field canonical_rxn \
7 | --output_path output/USPTO_condition_year
8 |
--------------------------------------------------------------------------------
/retrieve/convert_format.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | input_path = '/data/scratch/zhengkai/neural-retrieval/retrieved/USPTO_condition_MIT_smiles/test.jsonl'
4 | output_path = 'contrastive/test.json'
5 |
6 | output = []
7 | with open(input_path) as f:
8 | for line in f:
9 | data = json.loads(line)
10 | output.append({
11 | 'id': data['query_id'],
12 | 'nn': [p['docid'] for p in data['negative_passages']]
13 | })
14 |
15 | with open(output_path, 'w') as f:
16 | json.dump(output, f)
17 |
--------------------------------------------------------------------------------
/retrieve/retrieve.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import json
4 | import rdkit
5 | import rdkit.Chem as Chem
6 | import rdkit.Chem.rdChemReactions as rdChemReactions
7 | import rdkit.DataStructs as DataStructs
8 | from tqdm import tqdm
9 | import multiprocessing
10 |
11 |
12 | def reaction_fingerprint(smiles):
13 | rxn = rdChemReactions.ReactionFromSmarts(smiles)
14 | fp = rdChemReactions.CreateDifferenceFingerprintForReaction(rxn)
15 | return fp
16 |
17 |
18 | def reaction_similarity(smiles1=None, smiles2=None, fp1=None, fp2=None):
19 | if fp1 is None and smiles1:
20 | fp1 = reaction_fingerprint(smiles1)
21 | if fp2 is None and smiles2:
22 | fp2 = reaction_fingerprint(smiles2)
23 | assert fp1 is not None
24 | assert fp2 is not None
25 | return DataStructs.TanimotoSimilarity(fp1, fp2)
26 |
27 |
28 | def reaction_similarity_fp_smiles(fp, smiles):
29 | return reaction_similarity(fp1=fp, smiles2=smiles)
30 |
31 |
32 | def compute_reaction_similarities(test_smiles, train_smiles_list, num_workers=64):
33 | test_fp = reaction_fingerprint(test_smiles)
34 | with multiprocessing.Pool(num_workers) as p:
35 | similarities = p.starmap(
36 | reaction_similarity_fp_smiles,
37 | [(test_fp, smiles) for smiles in train_smiles_list],
38 | chunksize=128
39 | )
40 | return similarities
41 |
42 |
43 | if __name__ == '__main__':
44 |
45 | train_df = pd.read_csv('data/USPTO_condition_train.csv')
46 | val_df = pd.read_csv('data/USPTO_condition_val.csv')
47 | test_df = pd.read_csv('data/USPTO_condition_test.csv')
48 |
49 | train_smiles_list = train_df['canonical_rxn']
50 |
51 | results = {}
52 | # with open('test_nn.json') as f:
53 | # results = json.load(f)
54 |
55 | for i, test_row in tqdm(test_df.iterrows()):
56 | if str(i) in results:
57 | continue
58 | test_smiles = test_row['canonical_rxn']
59 | similarities = compute_reaction_similarities(test_smiles, train_smiles_list, num_workers=64)
60 | ranks = np.argsort(similarities)[::-1][:100].tolist()
61 | results[i] = {
62 | 'rank': ranks,
63 | 'similarity': [similarities[j] for j in ranks]
64 | }
65 | if i + 1 == 100:
66 | break
67 |
68 | with open('test_nn.json', 'w') as f:
69 | json.dump(results, f)
70 |
--------------------------------------------------------------------------------
/retrieve/retrieve_faiss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pandas as pd
4 | import numpy as np
5 | import json
6 | import rdkit
7 | import rdkit.Chem as Chem
8 | rdkit.RDLogger.DisableLog('rdApp.*')
9 | import rdkit.Chem.AllChem as AllChem
10 | import rdkit.Chem.rdChemReactions as rdChemReactions
11 | import rdkit.DataStructs as DataStructs
12 | from tqdm import tqdm
13 | import multiprocessing
14 | import faiss
15 | import time
16 |
17 |
18 | def reaction_fingerprint(smiles):
19 | rxn = rdChemReactions.ReactionFromSmarts(smiles)
20 | fp = rdChemReactions.CreateDifferenceFingerprintForReaction(rxn)
21 | return fp
22 |
23 |
24 | def reaction_fingerprint_array(smiles):
25 | fp = reaction_fingerprint(smiles)
26 | array = np.array([x for x in fp])
27 | return array
28 |
29 |
30 | def compute_reaction_fingerprints(smiles_list, num_workers=64):
31 | with multiprocessing.Pool(num_workers) as p:
32 | fps = p.map(reaction_fingerprint_array, smiles_list, chunksize=128)
33 | return np.array(fps)
34 |
35 |
36 | def morgan_fingerprint(smiles):
37 | try:
38 | mol = Chem.MolFromSmiles(smiles)
39 | fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)
40 | array = np.zeros((0,), dtype=np.int8)
41 | DataStructs.ConvertToNumpyArray(fp, array)
42 | except:
43 | return morgan_fingerprint('C')
44 | return array
45 |
46 |
47 | def compute_molecule_fingerprints(smiles_list, num_workers=64):
48 | with multiprocessing.Pool(num_workers) as p:
49 | fps = p.map(morgan_fingerprint, smiles_list, chunksize=64)
50 | return np.array(fps)
51 |
52 |
53 | def compare_condition(row1, row2):
54 | for field in ['catalyst1', 'solvent1', 'solvent2', 'reagent1', 'reagent2']:
55 | if type(row1[field]) is not str and type(row2[field]) is not str:
56 | continue
57 | if row1[field] != row2[field]:
58 | return False
59 | return True
60 |
61 |
62 | def index_and_search(train_fps, query_fps):
63 | print('Faiss build index')
64 | d = train_fps.shape[1]
65 | index = faiss.IndexFlatL2(d)
66 | index.add(train_fps)
67 |
68 | print('Faiss nearest neighbor search')
69 | start = time.time()
70 | k = 20
71 | distance, rank = index.search(query_fps, k)
72 | end = time.time()
73 | print(f"{end - start:.2f} s")
74 | return rank
75 |
76 |
77 | if __name__ == '__main__':
78 |
79 | parser = argparse.ArgumentParser()
80 | parser.add_argument('--data_path', type=str, default=None, required=True)
81 | parser.add_argument('--train_file', type=str, default=None, required=True)
82 | parser.add_argument('--valid_file', type=str, default=None, required=True)
83 | parser.add_argument('--test_file', type=str, default=None, required=True)
84 | parser.add_argument('--field', type=str, default='canonical_rxn')
85 | parser.add_argument('--before', type=int, default=-1)
86 | parser.add_argument('--output_path', type=str, default=None, required=True)
87 | args = parser.parse_args()
88 |
89 | train_df = pd.read_csv(os.path.join(args.data_path, args.train_file), keep_default_na=False)
90 | val_df = pd.read_csv(os.path.join(args.data_path, args.valid_file), keep_default_na=False)
91 | test_df = pd.read_csv(os.path.join(args.data_path, args.test_file), keep_default_na=False)
92 |
93 | if args.field == 'canonical_rxn':
94 | print('Reaction fingerprint')
95 | fingerprint_fn = compute_reaction_fingerprints
96 | else:
97 | print('Molecule fingerprint')
98 | fingerprint_fn = compute_molecule_fingerprints
99 |
100 | train_fp_file = os.path.join(args.output_path, 'train_fp.pkl')
101 | if not os.path.exists(train_fp_file):
102 | if args.before != -1:
103 | train_df = train_df[train_df['year'] < args.before].reset_index(drop=True)
104 | train_fps = fingerprint_fn(train_df[args.field])
105 | os.makedirs(args.output_path, exist_ok=True)
106 | with open(train_fp_file, 'wb') as f:
107 | np.save(f, train_fps)
108 | else:
109 | with open(train_fp_file, 'rb') as f:
110 | train_fps = np.load(f)
111 |
112 | train_id = train_df['id']
113 |
114 | query_fps, query_id = train_fps, train_df['id']
115 | rank = index_and_search(train_fps, query_fps)
116 | result = [{'id': query_id[i], 'nn': [train_id[n] for n in nn]} for i, nn in enumerate(rank)]
117 | with open(os.path.join(args.output_path, 'train.json'), 'w') as f:
118 | json.dump(result, f)
119 |
120 | query_fps, query_id = fingerprint_fn(val_df[args.field]), val_df['id']
121 | rank = index_and_search(train_fps, query_fps)
122 | result = [{'id': query_id[i], 'nn': [train_id[n] for n in nn]} for i, nn in enumerate(rank)]
123 | with open(os.path.join(args.output_path, 'val.json'), 'w') as f:
124 | json.dump(result, f)
125 |
126 | query_fps, query_id = fingerprint_fn(test_df[args.field]), test_df['id']
127 | rank = index_and_search(train_fps, query_fps)
128 | result = [{'id': query_id[i], 'nn': [train_id[n] for n in nn]} for i, nn in enumerate(rank)]
129 | with open(os.path.join(args.output_path, 'test.json'), 'w') as f:
130 | json.dump(result, f)
131 |
132 | if args.field == 'canonical_rxn':
133 | cnt = {x: 0 for x in [1, 3, 5, 10, 15]}
134 | for i, nn in enumerate(rank):
135 | test_row = test_df.iloc[i]
136 | train_rows = [train_df.iloc[n] for n in nn]
137 | hit_map = [compare_condition(test_row, train_row) for train_row in train_rows]
138 | for x in cnt:
139 | cnt[x] += np.any(hit_map[:x])
140 |
141 | print(cnt, len(test_df))
142 | for x in cnt:
143 | print(f"Top-{x}: {cnt[x] / len(test_df):.4f}", end=' ')
144 | print()
145 |
--------------------------------------------------------------------------------
/retrieve/retro.sh:
--------------------------------------------------------------------------------
1 | python retrieve_faiss.py \
2 | --data_path ../data/USPTO_50K/matched1 \
3 | --train_file train.csv \
4 | --valid_file valid.csv \
5 | --test_file test.csv \
6 | --field product_smiles \
7 | --output_path output/USPTO_50K
8 |
9 | #python retrieve_faiss.py \
10 | # --data_path ../data/USPTO_50K/matched1 \
11 | # --train_file ../../USPTO_rxn_smiles.csv \
12 | # --valid_file valid.csv \
13 | # --test_file test.csv \
14 | # --field product_smiles \
15 | # --output_path output/USPTO_50K/full
16 |
--------------------------------------------------------------------------------
/retrieve/retro_year.sh:
--------------------------------------------------------------------------------
1 | #python retrieve_faiss.py \
2 | # --data_path ../data/USPTO_50K_year \
3 | # --train_file ../USPTO_rxn_smiles.csv \
4 | # --valid_file valid.csv \
5 | # --test_file test.csv \
6 | # --field product_smiles \
7 | # --output_path output/USPTO_50K_year/full
8 |
9 | python retrieve_faiss.py \
10 | --data_path ../data/USPTO_50K_year \
11 | --train_file ../USPTO_rxn_smiles.csv \
12 | --before 2012 \
13 | --valid_file valid.csv \
14 | --test_file test.csv \
15 | --field product_smiles \
16 | --output_path output/USPTO_50K_year/corpus_before_2012
17 |
18 |
--------------------------------------------------------------------------------
/scripts/train_RCR.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NUM_GPUS_PER_NODE=4
4 | BATCH_SIZE=128
5 | ACCUM_STEP=1
6 |
7 | SAVE_PATH=output/RCR_textreact
8 | NN_PATH=data/Tevatron_output/RCR/
9 |
10 | mkdir -p ${SAVE_PATH}
11 |
12 | NCCL_P2P_DISABLE=1 python main.py \
13 | --task condition \
14 | --encoder allenai/scibert_scivocab_uncased \
15 | --decoder textreact/configs/bert_l6.json \
16 | --encoder_pretrained \
17 | --data_path data/RCR/ \
18 | --train_file train.csv \
19 | --valid_file val.csv \
20 | --test_file test.csv \
21 | --vocab_file textreact/vocab/vocab_condition.txt \
22 | --corpus_file data/USPTO_rxn_corpus.csv \
23 | --nn_path ${NN_PATH} \
24 | --train_nn_file train_rank.json \
25 | --valid_nn_file val_rank.json \
26 | --test_nn_file test_rank.json \
27 | --num_neighbors 3 \
28 | --use_gold_neighbor \
29 | --save_path ${SAVE_PATH} \
30 | --max_length 512 \
31 | --shuffle_smiles \
32 | --mlm --mlm_ratio 0.15 --mlm_layer mlp --mlm_lambda 0.1 \
33 | --lr 1e-4 \
34 | --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \
35 | --gradient_accumulation_steps ${ACCUM_STEP} \
36 | --epochs 20 \
37 | --warmup 0.02 \
38 | --do_train --do_valid --do_test \
39 | --num_beams 15 \
40 | --precision 16-mixed \
41 | --gpus ${NUM_GPUS_PER_NODE}
42 |
--------------------------------------------------------------------------------
/scripts/train_RCR_TS.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NUM_GPUS_PER_NODE=4
4 | BATCH_SIZE=128
5 | ACCUM_STEP=1
6 |
7 | SAVE_PATH=output/RCR_TS_textreact
8 | NN_PATH=data/Tevatron_output/RCR_TS/
9 |
10 | mkdir -p ${SAVE_PATH}
11 |
12 | NCCL_P2P_DISABLE=1 python main.py \
13 | --task condition \
14 | --encoder allenai/scibert_scivocab_uncased \
15 | --decoder textreact/configs/bert_l6.json \
16 | --encoder_pretrained \
17 | --data_path data/RCR_TS/ \
18 | --train_file train.csv \
19 | --valid_file val.csv \
20 | --test_file test.csv \
21 | --vocab_file textreact/vocab/vocab_condition.txt \
22 | --corpus_file data/USPTO_rxn_corpus.csv \
23 | --nn_path ${NN_PATH} \
24 | --train_nn_file train_rank.json \
25 | --valid_nn_file val_rank_full.json \
26 | --test_nn_file test_rank_full.json \
27 | --num_neighbors 3 \
28 | --use_gold_neighbor \
29 | --save_path ${SAVE_PATH} \
30 | --max_length 512 \
31 | --shuffle_smiles \
32 | --mlm --mlm_ratio 0.15 --mlm_layer mlp --mlm_lambda 0.1 \
33 | --lr 1e-4 \
34 | --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \
35 | --gradient_accumulation_steps ${ACCUM_STEP} \
36 | --epochs 20 \
37 | --warmup 0.02 \
38 | --do_train --do_valid --do_test \
39 | --num_beams 15 \
40 | --precision 16-mixed \
41 | --gpus ${NUM_GPUS_PER_NODE}
42 |
--------------------------------------------------------------------------------
/scripts/train_RetroSyn_tb.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NUM_GPUS_PER_NODE=4
4 | BATCH_SIZE=128
5 | ACCUM_STEP=1
6 |
7 | SAVE_PATH=output/RetroSyn_tb_textreact
8 | NN_PATH=data/Tevatron_output/RetroSyn/
9 |
10 | mkdir -p ${SAVE_PATH}
11 |
12 | NCCL_P2P_DISABLE=1 python main.py \
13 | --task retro \
14 | --template_based \
15 | --shuffle_smiles \
16 | --encoder allenai/scibert_scivocab_uncased \
17 | --encoder_pretrained \
18 | --encoder_tokenizer smiles_text \
19 | --vocab_file textreact/vocab/vocab_smiles.txt \
20 | --data_path data/RetroSyn/ \
21 | --template_path data/RetroSyn/template_based \
22 | --train_file train.csv \
23 | --valid_file valid.csv \
24 | --test_file test.csv \
25 | --corpus_file data/USPTO_rxn_corpus.csv \
26 | --nn_path ${NN_PATH} \
27 | --train_nn_file train_rank.json \
28 | --valid_nn_file valid_rank.json \
29 | --test_nn_file test_rank.json \
30 | --num_neighbors 3 \
31 | --use_gold_neighbor \
32 | --random_neighbor_ratio 0.2 \
33 | --save_path ${SAVE_PATH} \
34 | --load_ckpt best.ckpt \
35 | --max_length 512 \
36 | --max_dec_length 160 \
37 | --mlm --mlm_ratio 0.15 --mlm_layer mlp --mlm_lambda 0.1 \
38 | --lr 1e-4 \
39 | --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \
40 | --gradient_accumulation_steps ${ACCUM_STEP} \
41 | --test_batch_size 32 \
42 | --epochs 200 \
43 | --eval_per_epoch 10 \
44 | --warmup 0.02 \
45 | --do_train --do_valid --do_test \
46 | --num_beams 20 \
47 | --precision 16 \
48 | --gpus ${NUM_GPUS_PER_NODE}
49 |
--------------------------------------------------------------------------------
/scripts/train_RetroSyn_tb_TS.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NUM_GPUS_PER_NODE=4
4 | BATCH_SIZE=128
5 | ACCUM_STEP=1
6 |
7 | SAVE_PATH=output/RetroSyn_tb_TS_textreact
8 | NN_PATH=data/Tevatron_output/RetroSyn_TS/
9 |
10 | mkdir -p ${SAVE_PATH}
11 |
12 | NCCL_P2P_DISABLE=1 python main.py \
13 | --task retro \
14 | --template_based \
15 | --shuffle_smiles \
16 | --encoder allenai/scibert_scivocab_uncased \
17 | --encoder_pretrained \
18 | --encoder_tokenizer smiles_text \
19 | --vocab_file textreact/vocab/vocab_smiles.txt \
20 | --data_path data/RetroSyn_TS/ \
21 | --template_path data/RetroSyn_TS/template_based \
22 | --train_file train.csv \
23 | --valid_file valid.csv \
24 | --test_file test.csv \
25 | --corpus_file data/USPTO_rxn_corpus.csv \
26 | --nn_path ${NN_PATH} \
27 | --train_nn_file train_rank.json \
28 | --valid_nn_file valid_rank_full.json \
29 | --test_nn_file test_rank_full.json \
30 | --num_neighbors 3 \
31 | --use_gold_neighbor \
32 | --random_neighbor_ratio 0.2 \
33 | --save_path ${SAVE_PATH} \
34 | --load_ckpt best.ckpt \
35 | --max_length 512 \
36 | --max_dec_length 160 \
37 | --mlm --mlm_ratio 0.15 --mlm_layer mlp --mlm_lambda 0.1 \
38 | --lr 1e-4 \
39 | --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \
40 | --gradient_accumulation_steps ${ACCUM_STEP} \
41 | --test_batch_size 32 \
42 | --epochs 200 \
43 | --eval_per_epoch 10 \
44 | --warmup 0.02 \
45 | --do_train --do_valid --do_test \
46 | --num_beams 20 \
47 | --precision 16 \
48 | --gpus ${NUM_GPUS_PER_NODE}
49 |
--------------------------------------------------------------------------------
/scripts/train_RetroSyn_tf.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NUM_GPUS_PER_NODE=4
4 | BATCH_SIZE=128
5 | ACCUM_STEP=1
6 |
7 | SAVE_PATH=output/RetroSyn_tf_textreact
8 | NN_PATH=data/Tevatron_output/RetroSyn/
9 |
10 | mkdir -p ${SAVE_PATH}
11 |
12 | NCCL_P2P_DISABLE=1 python main.py \
13 | --task retro \
14 | --encoder allenai/scibert_scivocab_uncased \
15 | --decoder textreact/configs/bert_l6.json \
16 | --encoder_pretrained \
17 | --vocab_file textreact/vocab/vocab_smiles.txt \
18 | --data_path data/RetroSyn/ \
19 | --train_file train.csv \
20 | --valid_file valid.csv \
21 | --test_file test.csv \
22 | --corpus_file data/USPTO_rxn_corpus.csv \
23 | --nn_path ${NN_PATH} \
24 | --train_nn_file train_rank.json \
25 | --valid_nn_file valid_rank.json \
26 | --test_nn_file test_rank.json \
27 | --num_neighbors 3 \
28 | --use_gold_neighbor \
29 | --random_neighbor_ratio 0.2 \
30 | --save_path ${SAVE_PATH} \
31 | --load_ckpt best.ckpt \
32 | --max_length 512 \
33 | --max_dec_length 160 \
34 | --mlm --mlm_ratio 0.15 --mlm_layer mlp --mlm_lambda 0.1 \
35 | --lr 1e-4 \
36 | --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \
37 | --gradient_accumulation_steps ${ACCUM_STEP} \
38 | --test_batch_size 32 \
39 | --epochs 200 \
40 | --eval_per_epoch 25 \
41 | --warmup 0.02 \
42 | --do_train --do_valid --do_test \
43 | --num_beams 20 \
44 | --precision 16 \
45 | --gpus ${NUM_GPUS_PER_NODE}
46 |
--------------------------------------------------------------------------------
/scripts/train_RetroSyn_tf_TS.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NUM_GPUS_PER_NODE=4
4 | BATCH_SIZE=128
5 | ACCUM_STEP=1
6 |
7 | SAVE_PATH=output/RetroSyn_tf_TS_textreact
8 | NN_PATH=data/Tevatron_output/RetroSyn_TS/
9 |
10 | mkdir -p ${SAVE_PATH}
11 |
12 | NCCL_P2P_DISABLE=1 python main.py \
13 | --task retro \
14 | --encoder allenai/scibert_scivocab_uncased \
15 | --decoder textreact/configs/bert_l6.json \
16 | --encoder_pretrained \
17 | --vocab_file textreact/vocab/vocab_smiles.txt \
18 | --data_path data/RetroSyn_TS/ \
19 | --train_file train.csv \
20 | --valid_file valid.csv \
21 | --test_file test.csv \
22 | --corpus_file data/USPTO_rxn_corpus.csv \
23 | --nn_path data/Tevatron_output/RetroSyn_TS/ \
24 | --nn_path ${NN_PATH} \
25 | --train_nn_file train_rank.json \
26 | --valid_nn_file valid_rank_full.json \
27 | --test_nn_file test_rank_full.json \
28 | --num_neighbors 3 \
29 | --use_gold_neighbor \
30 | --random_neighbor_ratio 0.2 \
31 | --save_path ${SAVE_PATH} \
32 | --load_ckpt best.ckpt \
33 | --max_length 512 \
34 | --max_dec_length 160 \
35 | --mlm --mlm_ratio 0.15 --mlm_layer mlp --mlm_lambda 0.1 \
36 | --lr 1e-4 \
37 | --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \
38 | --gradient_accumulation_steps ${ACCUM_STEP} \
39 | --test_batch_size 32 \
40 | --epochs 200 \
41 | --eval_per_epoch 25 \
42 | --warmup 0.02 \
43 | --do_train --do_valid --do_test \
44 | --num_beams 20 \
45 | --precision 16 \
46 | --gpus ${NUM_GPUS_PER_NODE}
47 |
--------------------------------------------------------------------------------
/textreact/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thomas0809/textreact/feb62d868a627d293997a56a72f50420377c59b4/textreact/__init__.py
--------------------------------------------------------------------------------
/textreact/configs/bert_l6.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "hidden_dropout_prob": 0.1,
5 | "hidden_size": 768,
6 | "initializer_range": 0.02,
7 | "intermediate_size": 3072,
8 | "layer_norm_eps": 1e-05,
9 | "max_position_embeddings": 512,
10 | "model_type": "roberta",
11 | "num_attention_heads": 12,
12 | "num_hidden_layers": 6,
13 | "type_vocab_size": 1,
14 | "vocab_size": 600,
15 | "bos_token_id": 12,
16 | "eos_token_id": 13,
17 | "pad_token_id": 0
18 | }
--------------------------------------------------------------------------------
/textreact/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import repeat
3 | import numpy as np
4 | import pandas as pd
5 | import multiprocessing
6 |
7 | import rdkit
8 | from rdkit import Chem
9 | rdkit.RDLogger.DisableLog('rdApp.*')
10 |
11 | from .dataset import CONDITION_COLS
12 | from .template_decoder import get_pred_smiles_from_templates
13 |
14 |
15 | def evaluate_reaction_condition(prediction, data_df):
16 | cnt = {x: 0 for x in [1, 3, 5, 10, 15]}
17 | for i, output in prediction.items():
18 | label = data_df.loc[i, CONDITION_COLS].tolist()
19 | hit_map = [pred == label for pred in output['prediction']]
20 | for x in cnt:
21 | cnt[x] += np.any(hit_map[:x])
22 | num_example = len(data_df)
23 | accuracy = {x: cnt[x] / num_example for x in cnt}
24 | return accuracy
25 |
26 |
27 | def canonical_smiles(smiles):
28 | try:
29 | canon_smiles = Chem.CanonSmiles(smiles)
30 | except:
31 | canon_smiles = smiles
32 | return canon_smiles
33 |
34 |
35 | def _compare_pred_and_gold(pred, gold):
36 | pred = [canonical_smiles(smiles) for smiles in pred]
37 | for i, smiles in enumerate(pred):
38 | if smiles == gold:
39 | return i
40 | return 100000
41 |
42 |
43 | def evaluate_retrosynthesis(prediction, data_df, top_k, template_based=False, template_path=None, num_workers=16):
44 | num_example = len(data_df)
45 | with multiprocessing.Pool(num_workers) as p:
46 | gold_list = p.map(canonical_smiles, data_df['reactant_smiles'])
47 | if template_based:
48 | pred_prob_list = [[(*prediction, score)
49 | for prediction, score in zip(prediction[i]['prediction'], prediction[i]['score'])]
50 | for i in range(num_example)]
51 | atom_templates = pd.read_csv(os.path.join(template_path, 'atom_templates.csv'))
52 | bond_templates = pd.read_csv(os.path.join(template_path, 'bond_templates.csv'))
53 | template_infos = pd.read_csv(os.path.join(template_path, 'template_infos.csv'))
54 | atom_templates = {atom_templates['Class'][i]: atom_templates['Template'][i] for i in atom_templates.index}
55 | bond_templates = {bond_templates['Class'][i]: bond_templates['Template'][i] for i in bond_templates.index}
56 | template_infos = {template_infos['Template'][i]: {
57 | 'edit_site': eval(template_infos['edit_site'][i]),
58 | 'change_H': eval(template_infos['change_H'][i]),
59 | 'change_C': eval(template_infos['change_C'][i]),
60 | 'change_S': eval(template_infos['change_S'][i])
61 | } for i in template_infos.index}
62 | pred_list = p.starmap(get_pred_smiles_from_templates,
63 | zip(pred_prob_list, data_df['product_smiles'],
64 | repeat(atom_templates), repeat(bond_templates), repeat(template_infos), repeat(top_k)))
65 | else:
66 | pred_list = [prediction[i]['prediction'] for i in range(num_example)]
67 | indices = p.starmap(_compare_pred_and_gold, [(p, g) for p, g in zip(pred_list, gold_list)])
68 | accuracy = {}
69 | for x in [1, 2, 3, 5, 10, 20]:
70 | accuracy[x] = sum([idx < x for idx in indices]) / num_example
71 | return accuracy
72 |
--------------------------------------------------------------------------------
/textreact/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.rnn import pad_sequence
4 | from transformers import get_scheduler, EncoderDecoderModel, EncoderDecoderConfig, AutoTokenizer, AutoConfig, AutoModel
5 | from transformers.models.bert.modeling_bert import BertLMPredictionHead
6 | from transformers.utils import ModelOutput
7 | from . import utils
8 |
9 |
10 | def get_model(args, enc_tokenizer=None, dec_tokenizer=None):
11 | if args.template_based:
12 | assert args.decoder is None and not args.decoder_pretrained
13 | if args.encoder_pretrained:
14 | encoder = AutoModel.from_pretrained(pretrained_model_name_or_path=args.encoder)
15 | else:
16 | encoder_config = AutoConfig.from_pretrained(args.encoder)
17 | encoder = AutoModel.from_config(encoder_config)
18 | template_head = TemplatePredictionHead(encoder.config.hidden_size, len(dec_tokenizer[0]), len(dec_tokenizer[1]))
19 | model = TemplateBasedModel(encoder, template_head)
20 | else:
21 | if args.encoder_pretrained and args.decoder_pretrained:
22 | model = EncoderDecoderModel.from_encoder_decoder_pretrained(
23 | encoder_pretrained_model_name_or_path=args.encoder, decoder_pretrained_model_name_or_path=args.decoder)
24 | else:
25 | encoder_config = AutoConfig.from_pretrained(args.encoder)
26 | decoder_config = AutoConfig.from_pretrained(args.decoder)
27 | config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
28 | model = EncoderDecoderModel(config=config)
29 | if args.encoder_pretrained:
30 | encoder = AutoModel.from_pretrained(args.encoder)
31 | model.encoder = encoder
32 | encoder = model.encoder
33 | if args.max_length > encoder.config.max_position_embeddings:
34 | utils.expand_position_embeddings(encoder, args.max_length)
35 | if args.encoder_tokenizer == 'smiles_text':
36 | utils.expand_word_embeddings(encoder, len(enc_tokenizer))
37 | return model
38 |
39 |
40 | def get_mlm_head(args, model):
41 | if args.mlm_layer == 'linear':
42 | mlm_head = nn.Linear(model.encoder.config.hidden_size, model.encoder.config.vocab_size)
43 | elif args.mlm_layer == 'mlp':
44 | mlm_head = BertLMPredictionHead(model.encoder.config)
45 | else:
46 | raise NotImplementedError
47 | return mlm_head
48 |
49 |
50 | class TemplateBasedModel(nn.Module):
51 | def __init__(self, encoder, template_head):
52 | super().__init__()
53 | self.encoder = encoder
54 | self.template_head = template_head
55 |
56 | def forward(self, **inputs):
57 | encoder_output = self.encoder(**{k: v for k, v in inputs.items()
58 | if not k.startswith('decoder_') and k not in ['atom_indices', 'bonds']})
59 | atom_hidden_states = []
60 | for hidden_states, atom_indices in zip(encoder_output.last_hidden_state, inputs['atom_indices']):
61 | atom_hidden_states.append(hidden_states[atom_indices])
62 | atom_hidden_states = pad_sequence(atom_hidden_states, batch_first=True)
63 | return ModelOutput(logits=self.template_head(atom_hidden_states), encoder_last_hidden_state=encoder_output.last_hidden_state)
64 |
65 |
66 | class TemplatePredictionHead(nn.Module):
67 | def __init__(self, input_size, num_atom_templates, num_bond_templates):
68 | super().__init__()
69 | self.atom_template_head = nn.Linear(input_size, num_atom_templates + 1)
70 | self.bond_template_head = BondTemplatePredictor(input_size, num_bond_templates)
71 |
72 | def forward(self, input_states):
73 | """
74 | Input: [B x] L x d_in
75 | Output: [B x] L x n_a, [B x] L x L x n_b
76 | """
77 | return self.atom_template_head(input_states), self.bond_template_head(input_states)
78 |
79 |
80 | class BondTemplatePredictor(nn.Module):
81 | def __init__(self, input_size, num_bond_templates):
82 | super().__init__()
83 | self.linear = nn.Linear(2 * input_size, num_bond_templates + 1)
84 |
85 | def forward(self, input_states):
86 | concat_pair_shape = input_states.shape[:-1] + input_states.shape[-2:]
87 | concat_pairs = torch.cat((input_states.unsqueeze(-2).expand(concat_pair_shape),
88 | input_states.unsqueeze(-3).expand(concat_pair_shape)),
89 | dim=-1)
90 | return self.linear(concat_pairs)
91 |
--------------------------------------------------------------------------------
/textreact/template_decoder.py:
--------------------------------------------------------------------------------
1 | """ Adapted from Decode_predictions.py and template_decoder.py """
2 |
3 | import os, re, copy
4 | import pandas as pd
5 | from collections import defaultdict
6 |
7 | import rdkit
8 | from rdkit import Chem, RDLogger
9 | from rdkit.Chem import rdChemReactions
10 | from rdkit.Chem.rdchem import ChiralType
11 | from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
12 |
13 | RDLogger.DisableLog('rdApp.*')
14 |
15 | chiral_type_map = {ChiralType.CHI_UNSPECIFIED : -1, ChiralType.CHI_TETRAHEDRAL_CW: 1, ChiralType.CHI_TETRAHEDRAL_CCW: 2}
16 | chiral_type_map_inv = {v:k for k, v in chiral_type_map.items()}
17 |
18 | a, b = 'a', 'b'
19 |
20 | def get_pred_smiles_from_templates(template_preds, product, atom_templates, bond_templates, template_infos, top_k):
21 | smiles_preds = []
22 | for prediction in template_preds:
23 | mol, pred_site, template, template_info, score = read_prediction(product, prediction, atom_templates, bond_templates, template_infos)
24 | local_template = '>>'.join(['(%s)' % smarts for smarts in template.split('_')[0].split('>>')])
25 | decoded_smiles = decode_localtemplate(mol, pred_site, local_template, template_info)
26 | try:
27 | decoded_smiles = decode_localtemplate(mol, pred_site, local_template, template_info)
28 | if decoded_smiles == None or decoded_smiles in smiles_preds:
29 | continue
30 | except Exception as e:
31 | continue
32 | smiles_preds.append(decoded_smiles)
33 |
34 | if len (smiles_preds) >= top_k:
35 | break
36 |
37 | return smiles_preds
38 |
39 | def get_isomers(smi):
40 | mol = Chem.MolFromSmiles(smi)
41 | isomers = tuple(EnumerateStereoisomers(mol))
42 | isomers_smi = [Chem.MolToSmiles(x, isomericSmiles=True) for x in isomers]
43 | return isomers_smi
44 |
45 | def get_MaxFrag(smiles):
46 | return max(smiles.split('.'), key=len)
47 |
48 | def isomer_match(preds, reac):
49 | reac_isomers = get_isomers(reac)
50 | for k, pred in enumerate(preds):
51 | try:
52 | pred_isomers = get_isomers(pred)
53 | if set(pred_isomers).issubset(set(reac_isomers)) or set(reac_isomers).issubset(set(pred_isomers)):
54 | return k+1
55 | except Exception as e:
56 | pass
57 | return -1
58 |
59 | def get_idx_map(mol):
60 | for atom in mol.GetAtoms():
61 | atom.SetAtomMapNum(atom.GetIdx())
62 | smiles = Chem.MolToSmiles(mol)
63 | num_map = {}
64 | for i, s in enumerate(smiles.split('.')):
65 | m = Chem.MolFromSmiles(s)
66 | for atom in m.GetAtoms():
67 | num_map[atom.GetAtomMapNum()] = atom.GetIdx()
68 | return num_map
69 |
70 | def get_possible_map(pred_site, change_info):
71 | possible_maps = []
72 | if type(pred_site) == type(0):
73 | for edit_type, edits in change_info['edit_site'].items():
74 | if edit_type not in ['A', 'R']:
75 | continue
76 | for edit in edits:
77 | possible_maps.append({edit: pred_site})
78 | else:
79 | for edit_type, edits in change_info['edit_site'].items():
80 | if edit_type not in ['B', 'C']:
81 | continue
82 | for edit in edits:
83 | possible_maps.append({e:p for e, p in zip(edit, pred_site)})
84 | return possible_maps
85 |
86 | def check_idx_match(mols, possible_maps):
87 | matched_maps = []
88 | found_map = {}
89 | for mol in mols:
90 | for atom in mol.GetAtoms():
91 | if atom.HasProp('old_mapno') and atom.HasProp('react_atom_idx'):
92 | found_map[int(atom.GetProp('old_mapno'))] = int(atom.GetProp('react_atom_idx'))
93 | for possible_map in possible_maps:
94 | if possible_map.items() <= found_map.items():
95 | matched_maps.append(found_map)
96 | return matched_maps
97 |
98 | def fix_aromatic(mol):
99 | for atom in mol.GetAtoms():
100 | if not atom.IsInRing() and atom.GetIsAromatic():
101 | atom.SetIsAromatic(False)
102 |
103 | for bond in mol.GetBonds():
104 | if not bond.IsInRing():
105 | bond.SetIsAromatic(False)
106 | if str(bond.GetBondType()) == 'AROMATIC':
107 | bond.SetBondType(Chem.rdchem.BondType.SINGLE)
108 |
109 | def validate_mols(mols):
110 | for mol in mols:
111 | if Chem.MolFromSmiles(Chem.MolToSmiles(mol)) == None:
112 | return False
113 | return True
114 |
115 | def fix_reactant_atoms(product, reactants, matched_map, change_info):
116 | H_change, C_change, S_change = change_info['change_H'], change_info['change_C'], change_info['change_S']
117 | fixed_mols = []
118 | for mol in reactants:
119 | for atom in mol.GetAtoms():
120 | if atom.HasProp('old_mapno'):
121 | mapno = int(atom.GetProp('old_mapno'))
122 | if mapno not in matched_map:
123 | return None
124 | product_atom = product.GetAtomWithIdx(matched_map[mapno])
125 | H_before = product_atom.GetNumExplicitHs() + product_atom.GetNumImplicitHs()
126 | C_before = product_atom.GetFormalCharge()
127 | S_before = chiral_type_map[product_atom.GetChiralTag()]
128 | H_after = H_before + H_change[mapno]
129 | C_after = C_before + C_change[mapno]
130 | S_after = S_change[mapno]
131 | if H_after < 0:
132 | return None
133 | atom.SetNumExplicitHs(H_after)
134 | atom.SetFormalCharge(C_after)
135 | if S_after != 0:
136 | atom.SetChiralTag(chiral_type_map_inv[S_after])
137 | fix_aromatic(mol)
138 | fixed_mols.append(mol)
139 | if validate_mols(fixed_mols):
140 | return tuple(fixed_mols)
141 | else:
142 | return None
143 |
144 | def demap(mols, stereo = True):
145 | if type(mols) == type((0, 0)):
146 | ss = []
147 | for mol in mols:
148 | [atom.SetAtomMapNum(0) for atom in mol.GetAtoms()]
149 | mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol, stereo))
150 | if mol == None:
151 | return None
152 | ss.append(Chem.MolToSmiles(mol))
153 | return '.'.join(sorted(ss))
154 | else:
155 | [atom.SetAtomMapNum(0) for atom in mols.GetAtoms()]
156 | return '.'.join(sorted(Chem.MolToSmiles(mols, stereo).split('.')))
157 |
158 | def read_prediction(smiles, prediction, atom_templates, bond_templates, template_infos):
159 | mol = Chem.MolFromSmiles(smiles)
160 | if len(prediction) == 1:
161 | return mol, None, None, None, 0
162 | else:
163 | edit_type, pred_site, pred_template_class, prediction_score = prediction # (edit_type, pred_site, pred_template_class)
164 | idx_map = get_idx_map(mol)
165 | if edit_type == 'a':
166 | template = atom_templates[pred_template_class]
167 | try:
168 | if len(template.split('>>')[0].split('.')) > 1: pred_site = idx_map[pred_site]
169 | except Exception as e:
170 | raise Exception(f'{smiles}\n{prediction}\n{template}\n{idx_map}\n{pred_site}') from e
171 | else:
172 | template = bond_templates[pred_template_class]
173 | if len(template.split('>>')[0].split('.')) > 1: pred_site= (idx_map[pred_site[0]], idx_map[pred_site[1]])
174 | [atom.SetAtomMapNum(atom.GetIdx()) for atom in mol.GetAtoms()]
175 | if pred_site == None:
176 | return mol, pred_site, short_template, {}, 0
177 | return mol, pred_site, template, template_infos[template], prediction_score
178 |
179 | def decode_localtemplate(product, pred_site, template, template_info):
180 | if pred_site == None:
181 | return None
182 | possible_maps = get_possible_map(pred_site, template_info)
183 | reaction = rdChemReactions.ReactionFromSmarts(template)
184 | reactants = reaction.RunReactants([product])
185 | decodes = []
186 | for output in reactants:
187 | if output == None:
188 | continue
189 | matched_maps = check_idx_match(output, possible_maps)
190 | for matched_map in matched_maps:
191 | decoded = fix_reactant_atoms(product, output, matched_map, template_info)
192 | if decoded == None:
193 | continue
194 | else:
195 | return demap(decoded)
196 | return None
197 |
198 |
--------------------------------------------------------------------------------
/textreact/tokenizer.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import re
4 | from typing import List
5 | import pandas as pd
6 | from transformers import PreTrainedTokenizer, AutoTokenizer, BertTokenizer
7 |
8 |
9 | def load_vocab(vocab_file):
10 | """Loads a vocabulary file into a dictionary."""
11 | vocab = collections.OrderedDict()
12 | with open(vocab_file, "r", encoding="utf-8") as reader:
13 | tokens = reader.readlines()
14 | for index, token in enumerate(tokens):
15 | token = token.rstrip("\n")
16 | vocab[token] = index
17 | return vocab
18 |
19 |
20 | class ReactionConditionTokenizer(PreTrainedTokenizer):
21 |
22 | def __init__(self, vocab_file):
23 | super().__init__(
24 | pad_token='[PAD]',
25 | bos_token='[BOS]',
26 | eos_token='[EOS]',
27 | mask_token='[MASK]',
28 | unk_token='[UNK]',
29 | sep_token='[SEP]'
30 | )
31 | self.vocab = load_vocab(vocab_file)
32 | self.ids_to_tokens = {ids: tok for tok, ids in self.vocab.items()}
33 |
34 | def __len__(self):
35 | return len(self.vocab)
36 |
37 | def __call__(self, conditions, **kwargs):
38 | tokens = self.convert_tokens_to_ids(conditions)
39 | return self.prepare_for_model(tokens, **kwargs)
40 |
41 | def _decode(self, token_ids, skip_special_tokens=False, **kwargs):
42 | tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
43 | return tokens
44 |
45 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
46 | token_ids = token_ids_0
47 | return [self.bos_token_id] + token_ids + [self.eos_token_id]
48 |
49 | def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
50 | token_ids = token_ids_0
51 | return [0] * (len(token_ids) + 2)
52 |
53 | def _convert_token_to_id(self, token):
54 | """Converts a token (str) in an id using the vocab."""
55 | return self.vocab.get(token, self.vocab.get(self.unk_token))
56 |
57 | def _convert_id_to_token(self, index):
58 | """Converts an index (integer) in a token (str) using the vocab."""
59 | return self.ids_to_tokens.get(index, self.unk_token)
60 |
61 |
62 | SMI_REGEX_PATTERN = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#" \
63 | r"|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
64 |
65 |
66 | class SmilesTokenizer(BertTokenizer):
67 | """
68 | Creates the SmilesTokenizer class. The tokenizer heavily inherits from the BertTokenizer
69 | implementation found in Huggingface's transformers library. It runs a WordPiece tokenization
70 | algorithm over SMILES strings using the tokenisation SMILES regex developed by Schwaller et. al.
71 | Please see https://github.com/huggingface/transformers and https://github.com/rxn4chemistry/rxnfp for more details.
72 |
73 | This class requires huggingface's transformers and tokenizers libraries to be installed.
74 | """
75 |
76 | def __init__(self, vocab_file: str = '', **kwargs):
77 | """Constructs a SmilesTokenizer.
78 | Parameters
79 | ----------
80 | vocab_file: str
81 | Path to a SMILES character per line vocabulary file.
82 | Default vocab file is found in deepchem/feat/tests/data/vocab_smiles.txt
83 | """
84 |
85 | super().__init__(vocab_file, bos_token='[CLS]', eos_token='[SEP]', **kwargs)
86 |
87 | if not os.path.isfile(vocab_file):
88 | raise ValueError("Can't find a vocab file at path '{}'.".format(vocab_file))
89 | self.vocab = load_vocab(vocab_file)
90 | self.highest_unused_index = max(
91 | [i for i, v in enumerate(self.vocab.keys()) if v.startswith("[unused")])
92 | self.ids_to_tokens = collections.OrderedDict(
93 | [(ids, tok) for tok, ids in self.vocab.items()])
94 | self.basic_tokenizer = BasicSmilesTokenizer()
95 |
96 | @property
97 | def vocab_size(self):
98 | return len(self.vocab)
99 |
100 | @property
101 | def vocab_list(self):
102 | return list(self.vocab.keys())
103 |
104 | def _tokenize(self, text: str):
105 | """
106 | Tokenize a string into a list of tokens.
107 | Parameters
108 | ----------
109 | text: str
110 | Input string sequence to be tokenized.
111 | """
112 | split_tokens = [token for token in self.basic_tokenizer.tokenize(text)]
113 | return split_tokens
114 |
115 | def _convert_token_to_id(self, token):
116 | """
117 | Converts a token (str/unicode) in an id using the vocab.
118 | Parameters
119 | ----------
120 | token: str
121 | String token from a larger sequence to be converted to a numerical id.
122 | """
123 | return self.vocab.get(token, self.vocab.get(self.unk_token))
124 |
125 | def _convert_id_to_token(self, index):
126 | """
127 | Converts an index (integer) in a token (string/unicode) using the vocab.
128 | Parameters
129 | ----------
130 | index: int
131 | Integer index to be converted back to a string-based token as part of a larger sequence.
132 | """
133 | return self.ids_to_tokens.get(index, self.unk_token)
134 |
135 | def convert_tokens_to_string(self, tokens: List[str]):
136 | """ Converts a sequence of tokens (string) in a single string.
137 | Parameters
138 | ----------
139 | tokens: List[str]
140 | List of tokens for a given string sequence.
141 | Returns
142 | -------
143 | out_string: str
144 | Single string from combined tokens.
145 | """
146 | out_string: str = "".join(tokens).replace(" ##", "").strip()
147 | return out_string
148 |
149 | def add_special_tokens_ids_single_sequence(self, token_ids: List[int]):
150 | """
151 | Adds special tokens to the a sequence for sequence classification tasks.
152 | A BERT sequence has the following format: [CLS] X [SEP]
153 | Parameters
154 | ----------
155 | token_ids: list[int]
156 | list of tokenized input ids. Can be obtained using the encode or encode_plus methods.
157 | """
158 | return [self.cls_token_id] + token_ids + [self.sep_token_id]
159 |
160 | def add_special_tokens_single_sequence(self, tokens: List[str]):
161 | """
162 | Adds special tokens to the a sequence for sequence classification tasks.
163 | A BERT sequence has the following format: [CLS] X [SEP]
164 | Parameters
165 | ----------
166 | tokens: List[str]
167 | List of tokens for a given string sequence.
168 | """
169 | return [self.cls_token] + tokens + [self.sep_token]
170 |
171 | def add_special_tokens_ids_sequence_pair(self, token_ids_0: List[int],
172 | token_ids_1: List[int]) -> List[int]:
173 | """
174 | Adds special tokens to a sequence pair for sequence classification tasks.
175 | A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
176 | Parameters
177 | ----------
178 | token_ids_0: List[int]
179 | List of ids for the first string sequence in the sequence pair (A).
180 | token_ids_1: List[int]
181 | List of tokens for the second string sequence in the sequence pair (B).
182 | """
183 | sep = [self.sep_token_id]
184 | cls = [self.cls_token_id]
185 | return cls + token_ids_0 + sep + token_ids_1 + sep
186 |
187 | def add_padding_tokens(self,
188 | token_ids: List[int],
189 | length: int,
190 | right: bool = True) -> List[int]:
191 | """
192 | Adds padding tokens to return a sequence of length max_length.
193 | By default padding tokens are added to the right of the sequence.
194 | Parameters
195 | ----------
196 | token_ids: list[int]
197 | list of tokenized input ids. Can be obtained using the encode or encode_plus methods.
198 | length: int
199 | right: bool (True by default)
200 | Returns
201 | ----------
202 | token_ids :
203 | list of tokenized input ids. Can be obtained using the encode or encode_plus methods.
204 | padding: int
205 | Integer to be added as padding token
206 | """
207 | padding = [self.pad_token_id] * (length - len(token_ids))
208 |
209 | if right:
210 | return token_ids + padding
211 | else:
212 | return padding + token_ids
213 |
214 |
215 | class BasicSmilesTokenizer(object):
216 | """
217 | Run basic SMILES tokenization using a regex pattern developed by Schwaller et. al. This tokenizer is to be used
218 | when a tokenizer that does not require the transformers library by HuggingFace is required.
219 | """
220 |
221 | def __init__(self, regex_pattern: str = SMI_REGEX_PATTERN):
222 | """ Constructs a BasicSMILESTokenizer. """
223 | self.regex_pattern = regex_pattern
224 | self.regex = re.compile(self.regex_pattern)
225 |
226 | def tokenize(self, text):
227 | """ Basic Tokenization of a SMILES. """
228 | tokens = [token for token in self.regex.findall(text)]
229 | return tokens
230 |
231 |
232 | class SmilesTextTokenizer(PreTrainedTokenizer):
233 |
234 | def __init__(self, text_tokenizer, smiles_tokenizer=None):
235 | super().__init__(
236 | pad_token=text_tokenizer.pad_token,
237 | mask_token=text_tokenizer.mask_token)
238 | if smiles_tokenizer is None:
239 | self.separate = False
240 | self.smiles_tokenizer = text_tokenizer
241 | else:
242 | self.separate = True
243 | self.smiles_tokenizer = smiles_tokenizer
244 | self.text_tokenizer = text_tokenizer
245 |
246 | @property
247 | def smiles_offset(self):
248 | return len(self.text_tokenizer) if self.separate is not None else 0
249 |
250 | def __len__(self):
251 | return len(self.text_tokenizer) + self.smiles_offset
252 |
253 | def __call__(self, text, text_pair, **kwargs):
254 | result = self.smiles_tokenizer(text, **kwargs)
255 | if self.separate:
256 | result['input_ids'] = [v + self.smiles_offset for v in result['input_ids']]
257 | if isinstance(text_pair, str):
258 | result_pair = self.text_tokenizer(text_pair, **kwargs)
259 | for key in result:
260 | result[key] = result[key] + result_pair[key][1:] # skip the CLS token
261 | elif isinstance(text_pair, list):
262 | for t in text_pair:
263 | result_pair = self.text_tokenizer(t, **kwargs)
264 | for key in result:
265 | result[key] = result[key] + result_pair[key][1:] # skip the CLS token
266 | return result
267 |
268 | def _convert_id_to_token(self, index):
269 | if index < len(self.text_tokenizer):
270 | return self.text_tokenizer.convert_ids_to_tokens(index)
271 | else:
272 | return self.smiles_tokenizer.convert_ids_to_tokens(index - len(self.text_tokenizer))
273 |
274 | def _convert_token_to_id(self, token):
275 | return self.text_tokenizer.convert_tokens_to_ids(token)
276 |
277 |
278 | def get_tokenizers(args):
279 | # Encoder
280 | if args.encoder_tokenizer == 'smiles':
281 | enc_tokenizer = SmilesTokenizer(args.vocab_file)
282 | elif args.encoder_tokenizer == 'text':
283 | text_tokenizer = AutoTokenizer.from_pretrained(args.encoder, use_fast=False)
284 | enc_tokenizer = SmilesTextTokenizer(text_tokenizer)
285 | elif args.encoder_tokenizer == 'smiles_text':
286 | smiles_tokenizer = SmilesTokenizer(args.vocab_file)
287 | text_tokenizer = AutoTokenizer.from_pretrained(args.encoder, use_fast=False)
288 | enc_tokenizer = SmilesTextTokenizer(text_tokenizer, smiles_tokenizer)
289 | else:
290 | raise ValueError
291 | if args.template_based:
292 | assert args.encoder_tokenizer.startswith('smiles')
293 | atom_templates = pd.read_csv(os.path.join(args.template_path, 'atom_templates.csv'))['Template']
294 | bond_templates = pd.read_csv(os.path.join(args.template_path, 'bond_templates.csv'))['Template']
295 | dec_tokenizer = atom_templates, bond_templates
296 | else:
297 | assert args.template_path is None
298 | # Decoder
299 | if args.task == 'condition':
300 | dec_tokenizer = ReactionConditionTokenizer(args.vocab_file)
301 | elif args.task == 'retro':
302 | dec_tokenizer = SmilesTokenizer(args.vocab_file)
303 | else:
304 | raise ValueError
305 | return enc_tokenizer, dec_tokenizer
306 |
--------------------------------------------------------------------------------
/textreact/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import logging
8 |
9 | from rdkit import Chem
10 |
11 |
12 | metric_to_mode = {
13 | 'val_loss': 'min',
14 | 'val_acc': 'max'
15 | }
16 |
17 |
18 | def expand_position_embeddings(encoder, max_length):
19 | if encoder.config.model_type in ['bert', 'longformer', 'roberta']:
20 | if max_length <= encoder.config.max_position_embeddings:
21 | return
22 | embeddings = encoder.embeddings
23 | config = encoder.config
24 | old_emb = embeddings.position_embeddings.weight.data.clone()
25 | config.max_position_embeddings = max_length
26 | embeddings.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
27 | embeddings.position_embeddings.weight.data[:old_emb.size(0)] = old_emb
28 | embeddings.register_buffer(
29 | "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
30 | embeddings.register_buffer(
31 | "token_type_ids", torch.zeros((1, config.max_position_embeddings), dtype=torch.long), persistent=False)
32 | else:
33 | raise NotImplementedError
34 |
35 |
36 | def expand_word_embeddings(encoder, vocab_size):
37 | if vocab_size <= encoder.config.vocab_size:
38 | return
39 | embeddings = encoder.embeddings
40 | config = encoder.config
41 | old_emb = embeddings.word_embeddings.weight.data.clone()
42 | config.vocab_size = vocab_size
43 | embeddings.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
44 | embeddings.word_embeddings.weight.data[:old_emb.size(0)] = old_emb
45 |
46 |
47 | def clear_path(path, trainer):
48 | for file in os.listdir(path):
49 | if file.endswith('.ckpt'):
50 | filepath = os.path.join(path, file)
51 | logging.info(f' Remove checkpoint {filepath}')
52 | trainer.strategy.remove_checkpoint(filepath)
53 |
54 |
55 | def gather_prediction_each_neighbor(prediction, num_neighbors):
56 | results = {}
57 | for i, pred in sorted(prediction.items()):
58 | idx = i // num_neighbors
59 | if i % num_neighbors == 0:
60 | results[idx] = pred
61 | else:
62 | for key in results[idx]:
63 | results[idx][key] += pred[key]
64 | return results
65 |
66 |
67 | """ Adapted from localretro_model/get_edit.py """
68 |
69 | def get_id_template(a, class_n, num_atoms, edit_type, python_numbers=False):
70 | edit_idx = a // class_n
71 | template = a % class_n
72 | if edit_type == 'b':
73 | edit_idx = (edit_idx // num_atoms, edit_idx % num_atoms)
74 | if python_numbers:
75 | edit_idx = edit_idx.item() if edit_type == 'a' else (edit_idx[0].item(), edit_idx[1].item())
76 | template = template.item()
77 | return edit_idx, template
78 |
79 | def output2edit(out, top_num, edit_type, bonds=None):
80 | num_atoms, class_n = out.shape[-2:]
81 | readout = out.cpu().detach().numpy()
82 | readout = readout.reshape(-1)
83 | output_rank = np.flip(np.argsort(readout))
84 | filtered_output_rank = []
85 | for r in output_rank:
86 | edit_idx, template = get_id_template(r, class_n, num_atoms, edit_type)
87 | if (bonds is None or edit_idx in bonds) and template != 0:
88 | filtered_output_rank.append(r)
89 | if len(filtered_output_rank) == top_num:
90 | break
91 | selected_edit = [get_id_template(a, class_n, num_atoms, edit_type, python_numbers=True) for a in filtered_output_rank]
92 | selected_proba = [readout[a].item() for a in filtered_output_rank]
93 |
94 | return selected_edit, selected_proba
95 |
96 | def combined_edit(atom_out, bond_out, bonds, top_num=None):
97 | edit_id_a, edit_proba_a = output2edit(atom_out, top_num, edit_type='a')
98 | edit_id_b, edit_proba_b = output2edit(bond_out, top_num, edit_type='b', bonds=bonds)
99 | edit_id_c = edit_id_a + edit_id_b
100 | edit_type_c = ['a'] * len(edit_proba_a) + ['b'] * len(edit_proba_b)
101 | edit_proba_c = edit_proba_a + edit_proba_b
102 | edit_rank_c = np.flip(np.argsort(edit_proba_c))
103 | if top_num is not None:
104 | edit_rank_c = edit_rank_c[:top_num]
105 | edit_preds_c = [(edit_type_c[r], *edit_id_c[r]) for r in edit_rank_c]
106 | edit_proba_c = [edit_proba_c[r] for r in edit_rank_c]
107 |
108 | return edit_preds_c, edit_proba_c
109 |
--------------------------------------------------------------------------------
/textreact/vocab/vocab_condition.txt:
--------------------------------------------------------------------------------
1 | [PAD]
2 | [BOS]
3 | [EOS]
4 | [MASK]
5 | [UNK]
6 | [SEP]
7 |
8 | B
9 | B1C2CCCC1CCC2
10 | Br
11 | BrB(Br)Br
12 | BrBr
13 | Br[Cu]Br
14 | Br[P+](N1CCCC1)(N1CCCC1)N1CCCC1.F[P-](F)(F)(F)(F)F
15 | C
16 | C(=NC1CCCCC1)=NC1CCCCC1
17 | C1=CCCCC1
18 | C1CCC(P(C2CCCCC2)C2CCCCC2)CC1
19 | C1CCC2=NCCCN2CC1
20 | C1CCCCC1
21 | C1CCNC1
22 | C1CCNCC1
23 | C1CCOC1
24 | C1CC[NH2+]CC1.CC(=O)[O-]
25 | C1CN2CCN1CC2
26 | C1COCCN1
27 | C1COCCO1
28 | C1COCCOCCOCCOCCO1
29 | C1COCCOCCOCCOCCOCCO1
30 | CC#N
31 | CC(=O)CC(C)C
32 | CC(=O)Cl
33 | CC(=O)N(C)C
34 | CC(=O)O
35 | CC(=O)OC(C)=O
36 | CC(=O)OC(C)C
37 | CC(=O)OI1(OC(C)=O)(OC(C)=O)OC(=O)c2ccccc21
38 | CC(=O)O[BH-](OC(C)=O)OC(C)=O.[Na+]
39 | CC(=O)O[Pd]OC(C)=O
40 | CC(=O)[O-].[K+]
41 | CC(=O)[O-].[Na+]
42 | CC(C)(C#N)N=NC(C)(C)C#N
43 | CC(C)(C)O
44 | CC(C)(C)OC(=O)N=NC(=O)OC(C)(C)C
45 | CC(C)(C)ON=O
46 | CC(C)(C)O[K]
47 | CC(C)(C)O[Na]
48 | CC(C)(C)[O-].[K+]
49 | CC(C)(C)[P]([Pd][P](C(C)(C)C)(C(C)(C)C)C(C)(C)C)(C(C)(C)C)C(C)(C)C
50 | CC(C)=O
51 | CC(C)CCO
52 | CC(C)CCON=O
53 | CC(C)CO
54 | CC(C)COC(=O)Cl
55 | CC(C)C[Al+]CC(C)C.[H-]
56 | CC(C)C[AlH]CC(C)C
57 | CC(C)N=C=NC(C)C
58 | CC(C)NC(C)C
59 | CC(C)O
60 | CC(C)OC(=O)/N=N/C(=O)OC(C)C
61 | CC(C)OC(=O)N=NC(=O)OC(C)C
62 | CC(C)OC(C)C
63 | CC(C)O[Ti](OC(C)C)(OC(C)C)OC(C)C
64 | CC(C)[Mg]Cl
65 | CC(C)[N-]C(C)C.[Li+]
66 | CC(C)c1cc(C(C)C)c(-c2ccccc2P(C2CCCCC2)C2CCCCC2)c(C(C)C)c1
67 | CC(Cl)Cl
68 | CC(Cl)OC(=O)Cl
69 | CC1(C)C2CCC1(CS(=O)(=O)O)C(=O)C2
70 | CC1(C)CCCC(C)(C)N1
71 | CC1(C)c2cccc(P(c3ccccc3)c3ccccc3)c2Oc2c(P(c3ccccc3)c3ccccc3)cccc21
72 | CC1CCCO1
73 | CC=C(C)C
74 | CCC#N
75 | CCC(C)(C)O
76 | CCC(C)=O
77 | CCC(C)O
78 | CCCCC
79 | CCCCCC
80 | CCCCCCC
81 | CCCCCO
82 | CCCCN(CCCC)CCCC
83 | CCCCO
84 | CCCCP(=CC#N)(CCCC)CCCC
85 | CCCCP(CCCC)CCCC
86 | CCCC[N+](CCCC)(CCCC)CCCC.[F-]
87 | CCCC[SnH](CCCC)CCCC
88 | CCCC[Sn](=O)CCCC
89 | CCCO
90 | CCCP1(=O)OP(=O)(CCC)OP(=O)(CCC)O1
91 | CCN(C(C)C)C(C)C
92 | CCN(CC)CC
93 | CCN=C=NCCCN(C)C
94 | CCNCC
95 | CCO
96 | CCOC(=O)/N=N/C(=O)OCC
97 | CCOC(=O)Cl
98 | CCOC(=O)N=NC(=O)OCC
99 | CCOC(C)=O
100 | CCOCC
101 | CCOCCO
102 | CCOP(=O)(C#N)OCC
103 | CC[Mg]Br
104 | CC[N+](CC)(CC)S(=O)(=O)N=C([O-])OC
105 | CC[SiH](CC)CC
106 | CN
107 | CN(C)C(=N)N(C)C
108 | CN(C)C(N(C)C)=[N+]1N=[N+]([O-])c2ncccc21.F[P-](F)(F)(F)(F)F
109 | CN(C)C(On1nnc2ccccc21)=[N+](C)C.F[B-](F)(F)F
110 | CN(C)C(On1nnc2ccccc21)=[N+](C)C.F[P-](F)(F)(F)(F)F
111 | CN(C)C(On1nnc2cccnc21)=[N+](C)C.F[P-](F)(F)(F)(F)F
112 | CN(C)C=O
113 | CN(C)CCN(C)C
114 | CN(C)P(=O)(N(C)C)N(C)C
115 | CN(C)[P+](On1nnc2ccccc21)(N(C)C)N(C)C.F[P-](F)(F)(F)(F)F
116 | CN(C)c1ccccc1
117 | CN(C)c1ccccn1
118 | CN(C)c1ccncc1
119 | CN1CCCC1=O
120 | CN1CCCN(C)C1=O
121 | CN1CCOCC1
122 | CN1CC[NH+](C)C1Cl.[Cl-]
123 | CNCCNC
124 | CN[C@@H]1CCCC[C@H]1NC
125 | CO
126 | COC(C)(C)C
127 | COCCO
128 | COCCOC
129 | COCCOCCOC
130 | CO[Na]
131 | COc1cccc(OC)c1-c1ccccc1P(C1CCCCC1)C1CCCCC1
132 | COc1ccccc1
133 | COc1nc(OC)nc([N+]2(C)CCOCC2)n1.[Cl-]
134 | CS(=O)(=O)Cl
135 | CS(=O)(=O)O
136 | CS(C)=O
137 | CSC
138 | C[Al](C)C
139 | C[N+](=O)[O-]
140 | C[N+]1([O-])CCOCC1
141 | C[O-].[Na+]
142 | C[Si](C)(C)Br
143 | C[Si](C)(C)C=[N+]=[N-]
144 | C[Si](C)(C)Cl
145 | C[Si](C)(C)I
146 | C[Si](C)(C)[N-][Si](C)(C)C.[K+]
147 | C[Si](C)(C)[N-][Si](C)(C)C.[Li+]
148 | C[Si](C)(C)[N-][Si](C)(C)C.[Na+]
149 | C[n+]1ccccc1Cl.[I-]
150 | Cc1cc(C)c(N2CCN(c3c(C)cc(C)cc3C)C2=[Ru](Cl)(Cl)(=Cc2ccccc2)[P](C2CCCCC2)(C2CCCCC2)C2CCCCC2)c(C)c1
151 | Cc1cc(C)cc(C)c1
152 | Cc1ccc(C)cc1
153 | Cc1ccc(S(=O)(=O)Cl)cc1
154 | Cc1ccc(S(=O)(=O)O)cc1
155 | Cc1ccc(S(=O)(=O)[O-])cc1.c1cc[nH+]cc1
156 | Cc1cccc(C)n1
157 | Cc1ccccc1
158 | Cc1ccccc1C
159 | Cc1ccccc1P(c1ccccc1C)c1ccccc1C
160 | Cc1ccccc1S(=O)(=O)O
161 | Cc1ccccc1[P](c1ccccc1C)(c1ccccc1C)[Pd](Cl)(Cl)[P](c1ccccc1C)(c1ccccc1C)c1ccccc1C
162 | Cl
163 | ClB(Cl)Cl
164 | ClC(Cl)(Cl)Cl
165 | ClC(Cl)Cl
166 | ClCCCl
167 | ClCCl
168 | ClP(Cl)(Cl)(Cl)Cl
169 | ClP(Cl)Cl
170 | Cl[Cu]
171 | Cl[Cu]Cl
172 | Cl[Fe](Cl)Cl
173 | Cl[Hg]Cl
174 | Cl[Ni]Cl
175 | Cl[Pd](Cl)([P](c1ccccc1)(c1ccccc1)c1ccccc1)[P](c1ccccc1)(c1ccccc1)c1ccccc1
176 | Cl[Pd]Cl
177 | Cl[Sn](Cl)(Cl)Cl
178 | Cl[Sn]Cl
179 | Cl[Ti](Cl)(Cl)Cl
180 | Clc1ccccc1
181 | Clc1ccccc1Cl
182 | FB(F)F
183 | I
184 | II
185 | I[Cu]I
186 | N
187 | N#CC1=C(C#N)C(=O)C(Cl)=C(Cl)C1=O
188 | N#N
189 | NCCN
190 | NN
191 | N[C@@H]1CCCC[C@H]1N
192 | O
193 | O=C(CO)O[CH][C@@H](O)CO
194 | O=C(Cl)C(=O)Cl
195 | O=C(N=NC(=O)N1CCCCC1)N1CCCCC1
196 | O=C(O)C(=O)O
197 | O=C(O)C(F)(F)F
198 | O=C(O)CC(O)(CC(=O)O)C(=O)O
199 | O=C(O)O
200 | O=C(O)[C@@H]1CCCN1
201 | O=C(OC(=O)C(F)(F)F)C(F)(F)F
202 | O=C(OO)c1cccc(Cl)c1
203 | O=C(OOC(=O)c1ccccc1)c1ccccc1
204 | O=C([O-])O
205 | O=C([O-])[O-].[Ca+2]
206 | O=C([O-])[O-].[Cs+].[Cs+]
207 | O=C([O-])[O-].[K+].[K+]
208 | O=C([O-])[O-].[Li+]
209 | O=C([O-])[O-].[Na+].[Na+]
210 | O=C(c1ncc[nH]1)c1ncc[nH]1
211 | O=C(n1ccnc1)n1ccnc1
212 | O=C1CCC(=O)N1Br
213 | O=C1OCCN1P(=O)(Cl)N1CCOC1=O
214 | O=C=O
215 | O=CO
216 | O=C[O-].[NH4+]
217 | O=N[O-].[Na+]
218 | O=O
219 | O=P(Cl)(Cl)Cl
220 | O=P(O)(O)O
221 | O=P([O-])([O-])[O-]
222 | O=P([O-])([O-])[O-].[K+].[K+].[K+]
223 | O=P12OP3(=O)OP(=O)(O1)OP(=O)(O2)O3
224 | O=S(=O)(O)C(F)(F)F
225 | O=S(=O)(O)O
226 | O=S(=O)([O-])C(F)(F)F.O=S(=O)([O-])C(F)(F)F.O=S(=O)([O-])C(F)(F)F.[Yb+3]
227 | O=S(=O)([O-])[O-].[Mg+2]
228 | O=S(=O)([O-])[O-].[Na+].[Na+]
229 | O=S(Cl)Cl
230 | O=S([O-])([O-])=S.[Na+].[Na+]
231 | O=S([O-])S(=O)[O-].[Na+].[Na+]
232 | O=S([O-])[O-].[Na+].[Na+]
233 | O=S1(=O)CCCC1
234 | O=[Ag-]
235 | O=[Ag]
236 | O=[Cr](=O)([O-])Cl.c1cc[nH+]cc1
237 | O=[Cr](=O)([O-])O[Cr](=O)(=O)[O-].c1cc[nH+]cc1.c1cc[nH+]cc1
238 | O=[Cr](=O)=O
239 | O=[Cu-]
240 | O=[Cu]
241 | O=[Mn]=O
242 | O=[N+]([O-])c1ccccc1
243 | O=[Os](=O)(=O)=O
244 | O=[Pt]
245 | O=[Pt]=O
246 | OC(=O)[O-].[K+]
247 | OC(=O)[O-].[Na+]
248 | OCC(F)(F)F
249 | OCC(O)CO
250 | OCCO
251 | OCCOCCO
252 | OO
253 | OS(=O)[O-].[Na+]
254 | On1nnc2ccccc21
255 | On1nnc2cccnc21
256 | S=C=S
257 | [Al+3].[Cl-].[Cl-].[Cl-]
258 | [Al+3].[H-].[H-].[H-].[H-].[Li+]
259 | [Al]
260 | [BH3-]C#N.[Na+]
261 | [BH4-].[Li+]
262 | [BH4-].[Na+]
263 | [Br-].[K+]
264 | [C-]#N.[Na+]
265 | [Ca+2].[Cl-].[Cl-]
266 | [Cl-]
267 | [Cl-].[Li+]
268 | [Cl-].[NH4+]
269 | [Cl-].[Na+]
270 | [Cs+].[F-]
271 | [Cu]
272 | [Cu]Br
273 | [Cu]I
274 | [F-].[K+]
275 | [Fe]
276 | [H-].[Na+]
277 | [H][H]
278 | [I-].[K+]
279 | [I-].[Li+]
280 | [I-].[Na+]
281 | [K+]
282 | [K+].[OH-]
283 | [K]
284 | [Li+].[OH-]
285 | [Li]
286 | [Li]C(C)(C)C
287 | [Li]C(C)CC
288 | [Li]CCCC
289 | [Li]O
290 | [Mg]
291 | [NH4+]
292 | [NH4+].[OH-]
293 | [Na+]
294 | [Na+].[O-]Cl
295 | [Na+].[O-][I+3]([O-])([O-])[O-]
296 | [Na+].[OH-]
297 | [Na]
298 | [Ni]
299 | [O-][Cl+3]([O-])([O-])O
300 | [OH-]
301 | [Pd]
302 | [Pt]
303 | [Rh]
304 | [Ru]
305 | [Zn]
306 | c1c[nH]cn1
307 | c1ccc(Oc2ccccc2)cc1
308 | c1ccc(P(c2ccccc2)c2ccc3ccccc3c2-c2c(P(c3ccccc3)c3ccccc3)ccc3ccccc23)cc1
309 | c1ccc(P(c2ccccc2)c2ccccc2)cc1
310 | c1ccc([PH](c2ccccc2)(c2ccccc2)[Pd-4]([PH](c2ccccc2)(c2ccccc2)c2ccccc2)([PH](c2ccccc2)(c2ccccc2)c2ccccc2)[PH](c2ccccc2)(c2ccccc2)c2ccccc2)cc1
311 | c1ccc([P](c2ccccc2)(c2ccccc2)[Pd]([P](c2ccccc2)(c2ccccc2)c2ccccc2)([P](c2ccccc2)(c2ccccc2)c2ccccc2)[P](c2ccccc2)(c2ccccc2)c2ccccc2)cc1
312 | c1ccc2ncccc2c1
313 | c1ccccc1
314 | c1ccncc1
315 | c1cnc2c(c1)ccc1cccnc12
--------------------------------------------------------------------------------
/textreact/vocab/vocab_smiles.txt:
--------------------------------------------------------------------------------
1 | [PAD]
2 | [unused1]
3 | [unused2]
4 | [unused3]
5 | [unused4]
6 | [unused5]
7 | [unused6]
8 | [unused7]
9 | [unused8]
10 | [unused9]
11 | [unused10]
12 | [UNK]
13 | [CLS]
14 | [SEP]
15 | [MASK]
16 | c
17 | C
18 | (
19 | )
20 | O
21 | 1
22 | 2
23 | =
24 | N
25 | .
26 | n
27 | 3
28 | F
29 | Cl
30 | >>
31 | ~
32 | -
33 | 4
34 | [C@H]
35 | S
36 | [C@@H]
37 | [O-]
38 | Br
39 | #
40 | /
41 | [nH]
42 | [N+]
43 | s
44 | 5
45 | o
46 | P
47 | [Na+]
48 | [Si]
49 | I
50 | [Na]
51 | [Pd]
52 | [K+]
53 | [K]
54 | [P]
55 | B
56 | [C@]
57 | [C@@]
58 | [Cl-]
59 | 6
60 | [OH-]
61 | \
62 | [N-]
63 | [Li]
64 | [H]
65 | [2H]
66 | [NH4+]
67 | [c-]
68 | [P-]
69 | [Cs+]
70 | [Li+]
71 | [Cs]
72 | [NaH]
73 | [H-]
74 | [O+]
75 | [BH4-]
76 | [Cu]
77 | 7
78 | [Mg]
79 | [Fe+2]
80 | [n+]
81 | [Sn]
82 | [BH-]
83 | [Pd+2]
84 | [CH]
85 | [I-]
86 | [Br-]
87 | [C-]
88 | [Zn]
89 | [B-]
90 | [F-]
91 | [Al]
92 | [P+]
93 | [BH3-]
94 | [Fe]
95 | [C]
96 | [AlH4]
97 | [Ni]
98 | [SiH]
99 | 8
100 | [Cu+2]
101 | [Mn]
102 | [AlH]
103 | [nH+]
104 | [AlH4-]
105 | [O-2]
106 | [Cr]
107 | [Mg+2]
108 | [NH3+]
109 | [S@]
110 | [Pt]
111 | [Al+3]
112 | [S@@]
113 | [S-]
114 | [Ti]
115 | [Zn+2]
116 | [PH]
117 | [NH2+]
118 | [Ru]
119 | [Ag+]
120 | [S+]
121 | [I+3]
122 | [NH+]
123 | [Ca+2]
124 | [Ag]
125 | 9
126 | [Os]
127 | [Se]
128 | [SiH2]
129 | [Ca]
130 | [Ti+4]
131 | [Ac]
132 | [Cu+]
133 | [S]
134 | [Rh]
135 | [Cl+3]
136 | [cH-]
137 | [Zn+]
138 | [O]
139 | [Cl+]
140 | [SH]
141 | [H+]
142 | [Pd+]
143 | [se]
144 | [PH+]
145 | [I]
146 | [Pt+2]
147 | [C+]
148 | [Mg+]
149 | [Hg]
150 | [W]
151 | [SnH]
152 | [SiH3]
153 | [Fe+3]
154 | [NH]
155 | [Mo]
156 | [CH2+]
157 | %10
158 | [CH2-]
159 | [CH2]
160 | [n-]
161 | [Ce+4]
162 | [NH-]
163 | [Co]
164 | [I+]
165 | [PH2]
166 | [Pt+4]
167 | [Ce]
168 | [B]
169 | [Sn+2]
170 | [Ba+2]
171 | %11
172 | [Fe-3]
173 | [18F]
174 | [SH-]
175 | [Pb+2]
176 | [Os-2]
177 | [Zr+4]
178 | [N]
179 | [Ir]
180 | [Bi]
181 | [Ni+2]
182 | [P@]
183 | [Co+2]
184 | [s+]
185 | [As]
186 | [P+3]
187 | [Hg+2]
188 | [Yb+3]
189 | [CH-]
190 | [Zr+2]
191 | [Mn+2]
192 | [CH+]
193 | [In]
194 | [KH]
195 | [Ce+3]
196 | [Zr]
197 | [AlH2-]
198 | [OH2+]
199 | [Ti+3]
200 | [Rh+2]
201 | [Sb]
202 | [S-2]
203 | %12
204 | [P@@]
205 | [Si@H]
206 | [Mn+4]
207 | p
208 | [Ba]
209 | [NH2-]
210 | [Ge]
211 | [Pb+4]
212 | [Cr+3]
213 | [Au]
214 | [LiH]
215 | [Sc+3]
216 | [o+]
217 | [Rh-3]
218 | %13
219 | [Br]
220 | [Sb-]
221 | [S@+]
222 | [I+2]
223 | [Ar]
224 | [V]
225 | [Cu-]
226 | [Al-]
227 | [Te]
228 | [13c]
229 | [13C]
230 | [Cl]
231 | [PH4+]
232 | [SiH4]
233 | [te]
234 | [CH3-]
235 | [S@@+]
236 | [Rh+3]
237 | [SH+]
238 | [Bi+3]
239 | [Br+2]
240 | [La]
241 | [La+3]
242 | [Pt-2]
243 | [N@@]
244 | [PH3+]
245 | [N@]
246 | [Si+4]
247 | [Sr+2]
248 | [Al+]
249 | [Pb]
250 | [SeH]
251 | [Si-]
252 | [V+5]
253 | [Y+3]
254 | [Re]
255 | [Ru+]
256 | [Sm]
257 | *
258 | [3H]
259 | [NH2]
260 | [Ag-]
261 | [13CH3]
262 | [OH+]
263 | [Ru+3]
264 | [OH]
265 | [Gd+3]
266 | [13CH2]
267 | [In+3]
268 | [Si@@]
269 | [Si@]
270 | [Ti+2]
271 | [Sn+]
272 | [Cl+2]
273 | [AlH-]
274 | [Pd-2]
275 | [SnH3]
276 | [B+3]
277 | [Cu-2]
278 | [Nd+3]
279 | [Pb+3]
280 | [13cH]
281 | [Fe-4]
282 | [Ga]
283 | [Sn+4]
284 | [Hg+]
285 | [11CH3]
286 | [Hf]
287 | [Pr]
288 | [Y]
289 | [S+2]
290 | [Cd]
291 | [Cr+6]
292 | [Zr+3]
293 | [Rh+]
294 | [CH3]
295 | [N-3]
296 | [Hf+2]
297 | [Th]
298 | [Sb+3]
299 | %14
300 | [Cr+2]
301 | [Ru+2]
302 | [Hf+4]
303 | [14C]
304 | [Ta]
305 | [Tl+]
306 | [B+]
307 | [Os+4]
308 | [PdH2]
309 | [Pd-]
310 | [Cd+2]
311 | [Co+3]
312 | [S+4]
313 | [Nb+5]
314 | [123I]
315 | [c+]
316 | [Rb+]
317 | [V+2]
318 | [CH3+]
319 | [Ag+2]
320 | [cH+]
321 | [Mn+3]
322 | [Se-]
323 | [As-]
324 | [Eu+3]
325 | [SH2]
326 | [Sm+3]
327 | [IH+]
328 | %15
329 | [OH3+]
330 | [PH3]
331 | [IH2+]
332 | [SH2+]
333 | [Ir+3]
334 | [AlH3]
335 | [Sc]
336 | [Yb]
337 | [15NH2]
338 | [Lu]
339 | [sH+]
340 | [Gd]
341 | [18F-]
342 | [SH3+]
343 | [SnH4]
344 | [TeH]
345 | [Si@@H]
346 | [Ga+3]
347 | [CaH2]
348 | [Tl]
349 | [Ta+5]
350 | [GeH]
351 | [Br+]
352 | [Sr]
353 | [Tl+3]
354 | [Sm+2]
355 | [PH5]
356 | %16
357 | [N@@+]
358 | [Au+3]
359 | [C-4]
360 | [Nd]
361 | [Ti+]
362 | [IH]
363 | [N@+]
364 | [125I]
365 | [Eu]
366 | [Sn+3]
367 | [Nb]
368 | [Er+3]
369 | [123I-]
370 | [14c]
371 | %17
372 | [SnH2]
373 | [YH]
374 | [Sb+5]
375 | [Pr+3]
376 | [Ir+]
377 | [N+3]
378 | [AlH2]
379 | [19F]
380 | %18
381 | [Tb]
382 | [14CH]
383 | [Mo+4]
384 | [Si+]
385 | [BH]
386 | [Be]
387 | [Rb]
388 | [pH]
389 | %19
390 | %20
391 | [Xe]
392 | [Ir-]
393 | [Be+2]
394 | [C+4]
395 | [RuH2]
396 | [15NH]
397 | [U+2]
398 | [Au-]
399 | %21
400 | %22
401 | [Au+]
402 | [15n]
403 | [Al+2]
404 | [Tb+3]
405 | [15N]
406 | [V+3]
407 | [W+6]
408 | [14CH3]
409 | [Cr+4]
410 | [ClH+]
411 | b
412 | [Ti+6]
413 | [Nd+]
414 | [Zr+]
415 | [PH2+]
416 | [Fm]
417 | [N@H+]
418 | [RuH]
419 | [Dy+3]
420 | %23
421 | [Hf+3]
422 | [W+4]
423 | [11C]
424 | [13CH]
425 | [Er]
426 | [124I]
427 | [LaH]
428 | [F]
429 | [siH]
430 | [Ga+]
431 | [Cm]
432 | [GeH3]
433 | [IH-]
434 | [U+6]
435 | [SeH+]
436 | [32P]
437 | [SeH-]
438 | [Pt-]
439 | [Ir+2]
440 | [se+]
441 | [U]
442 | [F+]
443 | [BH2]
444 | [As+]
445 | [Cf]
446 | [ClH2+]
447 | [Ni+]
448 | [TeH3]
449 | [SbH2]
450 | [Ag+3]
451 | %24
452 | [18O]
453 | [PH4]
454 | [Os+2]
455 | [Na-]
456 | [Sb+2]
457 | [V+4]
458 | [Ho+3]
459 | [68Ga]
460 | [PH-]
461 | [Bi+2]
462 | [Ce+2]
463 | [Pd+3]
464 | [99Tc]
465 | [13C@@H]
466 | [Fe+6]
467 | [c]
468 | [GeH2]
469 | [10B]
470 | [Cu+3]
471 | [Mo+2]
472 | [Cr+]
473 | [Pd+4]
474 | [Dy]
475 | [AsH]
476 | [Ba+]
477 | [SeH2]
478 | [In+]
479 | [TeH2]
480 | [BrH+]
481 | [14cH]
482 | [W+]
483 | [13C@H]
484 | [AsH2]
485 | [In+2]
486 | [N+2]
487 | [N@@H+]
488 | [SbH]
489 | [60Co]
490 | [AsH4+]
491 | [AsH3]
492 | [18OH]
493 | [Ru-2]
494 | [Na-2]
495 | [CuH2]
496 | [31P]
497 | [Ti+5]
498 | [35S]
499 | [P@@H]
500 | [ArH]
501 | [Co+]
502 | [Zr-2]
503 | [BH2-]
504 | [131I]
505 | [SH5]
506 | [VH]
507 | [B+2]
508 | [Yb+2]
509 | [14C@H]
510 | [211At]
511 | [NH3+2]
512 | [IrH]
513 | [IrH2]
514 | [Rh-]
515 | [Cr-]
516 | [Sb+]
517 | [Ni+3]
518 | [TaH3]
519 | [Tl+2]
520 | [64Cu]
521 | [Tc]
522 | [Cd+]
523 | [1H]
524 | [15nH]
525 | [AlH2+]
526 | [FH+2]
527 | [BiH3]
528 | [Ru-]
529 | [Mo+6]
530 | [AsH+]
531 | [BaH2]
532 | [BaH]
533 | [Fe+4]
534 | [229Th]
535 | [Th+4]
536 | [As+3]
537 | [NH+3]
538 | [P@H]
539 | [Li-]
540 | [7NaH]
541 | [Bi+]
542 | [PtH+2]
543 | [p-]
544 | [Re+5]
545 | [NiH]
546 | [Ni-]
547 | [Xe+]
548 | [Ca+]
549 | [11c]
550 | [Rh+4]
551 | [AcH]
552 | [HeH]
553 | [Sc+2]
554 | [Mn+]
555 | [UH]
556 | [14CH2]
557 | [SiH4+]
558 | [18OH2]
559 | [Ac-]
560 | [Re+4]
561 | [118Sn]
562 | [153Sm]
563 | [P+2]
564 | [9CH]
565 | [9CH3]
566 | [Y-]
567 | [NiH2]
568 | [Si+2]
569 | [Mn+6]
570 | [ZrH2]
571 | [C-2]
572 | [Bi+5]
573 | [24NaH]
574 | [Fr]
575 | [15CH]
576 | [Se+]
577 | [At]
578 | [P-3]
579 | [124I-]
580 | [CuH2-]
581 | [Nb+4]
582 | [Nb+3]
583 | [MgH]
584 | [Ir+4]
585 | [67Ga+3]
586 | [67Ga]
587 | [13N]
588 | [15OH2]
589 | [2NH]
590 | [Ho]
591 | [Cn]
--------------------------------------------------------------------------------