├── models ├── model_utils.py ├── graphfeat.py ├── dgcn.py ├── VAR │ ├── utils.py │ └── var_dec.py ├── seq2seq.py ├── dgat.py ├── attention_xl.py └── graph2seq_series_rel.py ├── scripts ├── setup.sh ├── preprocess.sh ├── predict.sh ├── validate.sh ├── train_g2s.sh └── download_raw_data.py ├── LICENSE ├── .gitignore ├── utils ├── train_utils.py ├── chem_utils.py ├── rxn_graphs.py ├── parsing.py └── data_utils.py ├── README.md ├── data └── create_1toN_map.ipynb ├── validate.py ├── predict.py ├── preprocess.py └── train.py /models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def index_scatter(sub_data, all_data, index): 5 | d0, d1 = all_data.size() 6 | buf = torch.zeros_like(all_data).scatter_(0, index.repeat(d1, 1).t(), sub_data) 7 | mask = torch.ones(d0, device=all_data.device).scatter_(0, index, 0) 8 | 9 | return all_data * mask.unsqueeze(-1) + buf 10 | 11 | 12 | def index_select_ND(source, dim, index): 13 | index_size = index.size() 14 | suffix_dim = source.size()[1:] 15 | final_size = index_size + suffix_dim 16 | target = source.index_select(dim, index.view(-1)) 17 | 18 | return target.view(final_size) 19 | -------------------------------------------------------------------------------- /scripts/setup.sh: -------------------------------------------------------------------------------- 1 | conda create -y -n retro python=3.6 tqdm 2 | conda activate retro 3 | # CUDA 10.1 4 | conda install -y pytorch=1.6.0 torchvision cudatoolkit=10.1 torchtext -c pytorch 5 | conda install -y rdkit -c conda-forge 6 | 7 | # pip dependencies 8 | pip install gdown OpenNMT-py==1.2.0 networkx==2.5 selfies==1.0.3 9 | 10 | 11 | # # CUDA 11.1 #https://www.zhaoyabo.com/?p=8291 12 | # conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 -c pytorch -c conda-forge 13 | 14 | # # install rdkit 15 | # conda install -y rdkit -c conda-forge 16 | 17 | # # install opennmt 18 | # pip install OpenNMT-py==1.2.0 19 | # # pip install xxx --ignore-installed 20 | 21 | # pip install networkx==2.5 selfies==1.0.3 22 | -------------------------------------------------------------------------------- /scripts/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # DATASET=USPTO_50k 4 | # DATASET=USPTO_DIVERSE 5 | DATASET=$1 6 | MODEL=g2s_series_rel 7 | TASK=retrosynthesis 8 | REPR_START=smiles 9 | REPR_END=smiles 10 | N_WORKERS=8 11 | 12 | PREFIX=${DATASET}_${MODEL}_${REPR_START}_${REPR_END} 13 | 14 | python preprocess.py \ 15 | --model="$MODEL" \ 16 | --data_name="$DATASET" \ 17 | --task="$TASK" \ 18 | --representation_start=$REPR_START \ 19 | --representation_end=$REPR_END \ 20 | --train_src="./data/$DATASET/src-train.txt" \ 21 | --train_tgt="./data/$DATASET/tgt-train.txt" \ 22 | --val_src="./data/$DATASET/src-val.txt" \ 23 | --val_tgt="./data/$DATASET/tgt-val.txt" \ 24 | --test_src="./data/$DATASET/src-test.txt" \ 25 | --test_tgt="./data/$DATASET/tgt-test.txt" \ 26 | --log_file="$PREFIX.preprocess.log" \ 27 | --preprocess_output_path="./preprocessed/$PREFIX/" \ 28 | --seed=42 \ 29 | --max_src_len=1024 \ 30 | --max_tgt_len=1024 \ 31 | --num_workers="$N_WORKERS" 32 | -------------------------------------------------------------------------------- /scripts/predict.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL=g2s_series_rel 4 | 5 | EXP_NO=$1 6 | DATASET=$2 7 | # CKPT=model.175000_34 8 | CKPT=model.200000_39 9 | CHECKPOINT=./checkpoints/${DATASET}_g2s_series_rel_smiles_smiles.$EXP_NO/$CKPT.pt 10 | 11 | BS=30 12 | T=1.0 13 | NBEST=30 14 | MPN_TYPE=dgcn 15 | 16 | REPR_START=smiles 17 | REPR_END=smiles 18 | 19 | PREFIX=${DATASET}_${MODEL}_${REPR_START}_${REPR_END} 20 | 21 | python predict.py \ 22 | --do_predict \ 23 | --do_score \ 24 | --model="$MODEL" \ 25 | --data_name="$DATASET" \ 26 | --test_bin="./preprocessed/$PREFIX/test_0.npz" \ 27 | --test_tgt="./data/$DATASET/tgt-test.txt" \ 28 | --result_file="./results/${DATASET}/$PREFIX.$EXP_NO.$CKPT.result.txt" \ 29 | --log_file="$PREFIX.predict.$EXP_NO.log" \ 30 | --load_from="$CHECKPOINT" \ 31 | --mpn_type="$MPN_TYPE" \ 32 | --rel_pos="$REL_POS" \ 33 | --seed=42 \ 34 | --batch_type=tokens \ 35 | --predict_batch_size=2048 \ 36 | --beam_size="$BS" \ 37 | --n_best="$NBEST" \ 38 | --temperature="$T" \ 39 | --predict_min_len=1 \ 40 | --predict_max_len=512 \ 41 | --log_iter=100 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 MIRA@USTC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /scripts/validate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL=g2s_series_rel 4 | 5 | EXP_NO=$1 6 | DATASET=$2 7 | # BATCH_SIZE=2048 8 | BATCH_SIZE=4096 9 | CHECKPOINT=./checkpoints/${DATASET}_g2s_series_rel_smiles_smiles.$EXP_NO/ 10 | FIRST_STEP=10000 11 | LAST_STEP=200000 12 | 13 | BS=30 14 | T=1.0 15 | NBEST=30 16 | MPN_TYPE=dgcn 17 | 18 | 19 | REPR_START=smiles 20 | REPR_END=smiles 21 | 22 | PREFIX=${DATASET}_${MODEL}_${REPR_START}_${REPR_END} 23 | 24 | python validate.py \ 25 | --model="$MODEL" \ 26 | --data_name="$DATASET" \ 27 | --valid_bin="./preprocessed/$PREFIX/val_0.npz" \ 28 | --val_tgt="./data/$DATASET/tgt-val.txt" \ 29 | --result_file="./results/$PREFIX.$EXP_NO.result.txt" \ 30 | --log_file="$PREFIX.validate.$EXP_NO.log" \ 31 | --load_from="$CHECKPOINT" \ 32 | --checkpoint_step_start="$FIRST_STEP" \ 33 | --checkpoint_step_end="$LAST_STEP" \ 34 | --mpn_type="$MPN_TYPE" \ 35 | --rel_pos="$REL_POS" \ 36 | --seed=42 \ 37 | --batch_type=tokens \ 38 | --predict_batch_size="$BATCH_SIZE" \ 39 | --beam_size="$BS" \ 40 | --n_best="$NBEST" \ 41 | --temperature="$T" \ 42 | --predict_min_len=1 \ 43 | --predict_max_len=512 \ 44 | --log_iter=100 45 | -------------------------------------------------------------------------------- /models/graphfeat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.dgat import DGATEncoder 4 | from models.dgcn import DGCNEncoder 5 | from typing import Tuple 6 | from utils.data_utils import G2SBatch 7 | 8 | 9 | class GraphFeatEncoder(nn.Module): 10 | """ 11 | GraphFeatEncoder encodes molecules by using features of atoms and bonds, 12 | instead of a vocabulary, which is used for generation tasks. 13 | Adapted from Somnath et al. (2020): https://grlplus.github.io/papers/61.pdf 14 | """ 15 | 16 | def __init__(self, args, n_atom_feat: int, n_bond_feat: int): 17 | super().__init__() 18 | self.args = args 19 | 20 | self.n_atom_feat = n_atom_feat 21 | self.n_bond_feat = n_bond_feat 22 | 23 | if args.mpn_type == "dgcn": 24 | MPNClass = DGCNEncoder 25 | elif args.mpn_type == "dgat": 26 | MPNClass = DGATEncoder 27 | else: 28 | raise NotImplemented(f"Unsupported mpn_type: {args.mpn_type}!") 29 | 30 | self.mpn = MPNClass( 31 | args, 32 | input_size=n_atom_feat + n_bond_feat, 33 | node_fdim=n_atom_feat 34 | ) 35 | 36 | def forward(self, reaction_batch: G2SBatch) -> Tuple[torch.Tensor, None]: 37 | """ 38 | Forward pass of the graph encoder. First the feature vectors are extracted, 39 | and then encoded. This has been modified to pass data via the G2SBatch datatype 40 | """ 41 | fnode = reaction_batch.fnode 42 | fmess = reaction_batch.fmess 43 | agraph = reaction_batch.agraph 44 | bgraph = reaction_batch.bgraph 45 | 46 | # embed graph, note that for directed graph, fess[any, 0:2] = u, v 47 | hnode = fnode.clone() 48 | fmess1 = hnode.index_select(index=fmess[:, 0].long(), dim=0) 49 | fmess2 = fmess[:, 2:].clone() 50 | hmess = torch.cat([fmess1, fmess2], dim=-1) # hmess = x = [x_u; x_uv] 51 | 52 | # encode 53 | hatom, _ = self.mpn(hnode, hmess, agraph, bgraph, mask=None) 54 | hmol = None 55 | 56 | return hatom, hmol 57 | -------------------------------------------------------------------------------- /scripts/train_g2s.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # LOAD_FROM="./checkpoints/USPTO_50k_g2s_series_rel_smiles_smiles.cvae-gru-K20/model.195000_38.pt" 4 | LOAD_FROM="" 5 | MODEL=g2s_series_rel 6 | TASK=retrosynthesis 7 | EXP_NO=$1 8 | DATASET=$2 9 | # MPN_TYPE=dgat 10 | MPN_TYPE=dgcn 11 | MAX_REL_POS=4 12 | ACCUM_COUNT=4 13 | ENC_PE=none 14 | ENC_H=256 15 | # BATCH_SIZE=4096 16 | BATCH_SIZE=512 17 | ENC_EMB_SCALE=sqrt 18 | MAX_STEP=5000000 19 | ENC_LAYER=4 20 | BATCH_TYPE=tokens 21 | REL_BUCKETS=10 22 | 23 | K_SIZE=$3 24 | REL_POS=emb_only 25 | ATTN_LAYER=6 26 | LR=4 27 | DROPOUT=0.3 28 | 29 | REPR_START=smiles 30 | REPR_END=smiles 31 | 32 | PREFIX=${DATASET}_${MODEL}_${REPR_START}_${REPR_END} 33 | 34 | python train.py \ 35 | --model="$MODEL" \ 36 | --data_name="$DATASET" \ 37 | --task="$TASK" \ 38 | --representation_end=$REPR_END \ 39 | --load_from="$LOAD_FROM" \ 40 | --train_bin="./preprocessed/$PREFIX/train_0.npz" \ 41 | --valid_bin="./preprocessed/$PREFIX/val_0.npz" \ 42 | --log_file="$PREFIX.train.$EXP_NO.log" \ 43 | --vocab_file="./preprocessed/$PREFIX/vocab_$REPR_END.txt" \ 44 | --save_dir="./checkpoints/$PREFIX.$EXP_NO" \ 45 | --embed_size=256 \ 46 | --mpn_type="$MPN_TYPE" \ 47 | --encoder_num_layers="$ENC_LAYER" \ 48 | --encoder_hidden_size="$ENC_H" \ 49 | --encoder_norm="$ENC_NORM" \ 50 | --encoder_skip_connection="$ENC_SC" \ 51 | --encoder_positional_encoding="$ENC_PE" \ 52 | --encoder_emb_scale="$ENC_EMB_SCALE" \ 53 | --attn_enc_num_layers="$ATTN_LAYER" \ 54 | --attn_enc_hidden_size=256 \ 55 | --attn_enc_heads=8 \ 56 | --attn_enc_filter_size=2048 \ 57 | --rel_pos="$REL_POS" \ 58 | --rel_pos_buckets="$REL_BUCKETS" \ 59 | --decoder_num_layers=6 \ 60 | --decoder_hidden_size=256 \ 61 | --decoder_attn_heads=8 \ 62 | --decoder_filter_size=2048 \ 63 | --dropout="$DROPOUT" \ 64 | --attn_dropout="$DROPOUT" \ 65 | --max_relative_positions="$MAX_REL_POS" \ 66 | --seed=42 \ 67 | --epoch=2000 \ 68 | --max_steps="$MAX_STEP" \ 69 | --warmup_steps=8000 \ 70 | --lr="$LR" \ 71 | --weight_decay=0.0 \ 72 | --clip_norm=20.0 \ 73 | --batch_type="$BATCH_TYPE" \ 74 | --train_batch_size="$BATCH_SIZE" \ 75 | --valid_batch_size="$BATCH_SIZE" \ 76 | --predict_batch_size="$BATCH_SIZE" \ 77 | --accumulation_count="$ACCUM_COUNT" \ 78 | --num_workers=0 \ 79 | --beam_size=5 \ 80 | --predict_min_len=1 \ 81 | --predict_max_len=512 \ 82 | --log_iter=100 \ 83 | --eval_iter=2000 \ 84 | --save_iter=5000 \ 85 | --keep_last_ckpt=10 \ 86 | --compute_graph_distance \ 87 | --variational_num_layers=0 \ 88 | --latent_K="$K_SIZE" \ 89 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # user defined 2 | .idea/ 3 | checkpoints/* 4 | data/USPTO_50k 5 | data/USPTO_DIVERSE 6 | logs/* 7 | results/* 8 | preprocessed/* 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | */*/__pycache__ 13 | */__pycache__ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import numpy as np 4 | import os 5 | import random 6 | import sys 7 | import torch 8 | import torch.nn as nn 9 | from datetime import datetime 10 | from rdkit import RDLogger 11 | from torch.optim.lr_scheduler import _LRScheduler 12 | 13 | 14 | def param_count(model: nn.Module) -> int: 15 | return sum(param.numel() for param in model.parameters() if param.requires_grad) 16 | 17 | 18 | def param_norm(m): 19 | return math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()])) 20 | 21 | 22 | def grad_norm(m): 23 | return math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None])) 24 | 25 | 26 | def get_lr(optimizer): 27 | for param_group in optimizer.param_groups: 28 | return param_group["lr"] 29 | 30 | 31 | def set_seed(seed): 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = False 34 | torch.manual_seed(seed) 35 | torch.cuda.manual_seed_all(seed) 36 | np.random.seed(seed) 37 | random.seed(seed) 38 | 39 | 40 | def setup_logger(args, warning_off: bool = False): 41 | if warning_off: 42 | RDLogger.DisableLog("rdApp.*") 43 | else: 44 | RDLogger.DisableLog("rdApp.warning") 45 | 46 | os.makedirs(f"./logs/{args.data_name}", exist_ok=True) 47 | dt = datetime.strftime(datetime.now(), "%y%m%d-%H%Mh") 48 | 49 | logger = logging.getLogger() 50 | logger.setLevel(logging.INFO) 51 | fh = logging.FileHandler(f"./logs/{args.data_name}/{args.log_file}.{dt}") 52 | sh = logging.StreamHandler(sys.stdout) 53 | fh.setLevel(logging.INFO) 54 | sh.setLevel(logging.INFO) 55 | logger.addHandler(fh) 56 | logger.addHandler(sh) 57 | 58 | return logger 59 | 60 | 61 | def log_tensor(tensor, tensor_name: str): 62 | logging.info(f"--------------------------{tensor_name}--------------------------") 63 | logging.info(tensor) 64 | if isinstance(tensor, torch.Tensor): 65 | logging.info(tensor.shape) 66 | elif isinstance(tensor, np.ndarray): 67 | logging.info(tensor.shape) 68 | elif isinstance(tensor, list): 69 | try: 70 | for item in tensor: 71 | logging.info(item.shape) 72 | except Exception as e: 73 | logging.info(f"Error: {e}") 74 | logging.info("List items are not tensors, skip shape logging.") 75 | 76 | 77 | class NoamLR(_LRScheduler): 78 | """ 79 | Adapted from https://github.com/tugstugi/pytorch-saltnet/blob/master/utils/lr_scheduler.py 80 | 81 | Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate 82 | linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally 83 | to the inverse square root of the step number, scaled by the inverse square root of the 84 | dimensionality of the model. Time will tell if this is just madness or it's actually important. 85 | Parameters 86 | ---------- 87 | warmup_steps: ``int``, required. 88 | The number of steps to linearly increase the learning rate. 89 | """ 90 | def __init__(self, optimizer, model_size, warmup_steps): 91 | self.model_size = model_size 92 | self.warmup_steps = warmup_steps 93 | super().__init__(optimizer) 94 | 95 | def get_lr(self): 96 | step = max(1, self._step_count) 97 | scale = self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup_steps**(-1.5)) 98 | 99 | return [base_lr * scale for base_lr in self.base_lrs] 100 | -------------------------------------------------------------------------------- /scripts/download_raw_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gdown 3 | import os 4 | 5 | 6 | urls_fns_dict = { 7 | "USPTO_DIVERSE": [ 8 | ("https://drive.google.com/uc?id=1lK5r6IdId2khzP8g_ok_TtGNq45eNXva", "src-train.txt"), 9 | ("https://drive.google.com/uc?id=1GA2-GRaaO_G6ZhO-vbgAXjTeBLMtuDrH", "tgt-train.txt"), 10 | ("https://drive.google.com/uc?id=13xB68GfinjeFw58omG0WCFrl8QlTsZtv", "src-val.txt"), 11 | ("https://drive.google.com/uc?id=1Gy5upUwvoJCyftUCycbo7jDFJsW_VqeE", "tgt-val.txt"), 12 | ("https://drive.google.com/uc?id=1fGZudxMy3dftmaJZBK8ZiodojyjKVjYc", "src-test.txt"), 13 | ("https://drive.google.com/uc?id=1UunYH61-qaHJYMHKhehdMCwfqa6doa-J", "tgt-test.txt") 14 | ], 15 | "USPTO_50k": [ 16 | ("https://drive.google.com/uc?id=1pz-qkfeXzeD_drO9XqZVGmZDSn20CEwr", "src-train.txt"), 17 | ("https://drive.google.com/uc?id=1ZmmCJ-9a0nHeQam300NG5i9GJ3k5lnUl", "tgt-train.txt"), 18 | ("https://drive.google.com/uc?id=1NqLI3xpy30kH5fbVC0l8bMsMxLKgO-5n", "src-val.txt"), 19 | ("https://drive.google.com/uc?id=19My9evSNc6dlk9od5OrwkWauBpzL_Qgy", "tgt-val.txt"), 20 | ("https://drive.google.com/uc?id=1l7jSqYfIr0sL5Ad6TUxsythqVFjFudIx", "src-test.txt"), 21 | ("https://drive.google.com/uc?id=17ozyajoqPFeVjfViI59-QpVid1M0zyKN", "tgt-test.txt") 22 | ], 23 | "USPTO_full": [ 24 | ("https://drive.google.com/uc?id=1PbHoIYbm7-69yPOvRA0CrcjojGxVCJCj", "src-train.txt"), 25 | ("https://drive.google.com/uc?id=1RRveZmyXAxufTEix-WRjnfdSq81V9Ud9", "tgt-train.txt"), 26 | ("https://drive.google.com/uc?id=1jOIA-20zFhQ-x9fco1H7Q10R6CfxYeZo", "src-val.txt"), 27 | ("https://drive.google.com/uc?id=19ZNyw7hLJaoyEPot5ntKBxz_o-_R14QP", "tgt-val.txt"), 28 | ("https://drive.google.com/uc?id=1ErtNB29cpSld8o_gr84mKYs51eRat0H9", "src-test.txt"), 29 | ("https://drive.google.com/uc?id=1kV9p1_KJm8EqK6OejSOcqRsO8DwOgjL_", "tgt-test.txt") 30 | ], 31 | "USPTO_480k": [ 32 | ("https://drive.google.com/uc?id=1RysNBvB2rsMP0Ap9XXi02XiiZkEXCrA8", "src-train.txt"), 33 | ("https://drive.google.com/uc?id=1CxxcVqtmOmHE2nhmqPFA6bilavzpcIlb", "tgt-train.txt"), 34 | ("https://drive.google.com/uc?id=1FFN1nz2yB4VwrpWaBuiBDzFzdX3ONBsy", "src-val.txt"), 35 | ("https://drive.google.com/uc?id=1pYCjWkYvgp1ZQ78EKQBArOvt_2P1KnmI", "tgt-val.txt"), 36 | ("https://drive.google.com/uc?id=10t6pHj9yR8Tp3kDvG0KMHl7Bt_TUbQ8W", "src-test.txt"), 37 | ("https://drive.google.com/uc?id=1FeGuiGuz0chVBRgePMu0pGJA4FVReA-b", "tgt-test.txt") 38 | ], 39 | "USPTO_STEREO": [ 40 | ("https://drive.google.com/uc?id=1r3_7WMEor7-CgN34Foj-ET-uFco0fURU", "src-train.txt"), 41 | ("https://drive.google.com/uc?id=1HUBLDtqEQc6MQ-FZQqNhh2YBtdc63xdG", "tgt-train.txt"), 42 | ("https://drive.google.com/uc?id=1WwCH8ASgBM1yOmZe0cJ46bj6kPSYYIRc", "src-val.txt"), 43 | ("https://drive.google.com/uc?id=19OsSpXxWJ-XWuDwfG04VTYzcKAJ28MTw", "tgt-val.txt"), 44 | ("https://drive.google.com/uc?id=1FcbWZnyixhptaO6DIVjCjm_CeTomiCQJ", "src-test.txt"), 45 | ("https://drive.google.com/uc?id=1rVWvbmoVC90jyGml_t-r3NhaoWVVSKLe", "tgt-test.txt") 46 | ] 47 | } 48 | 49 | 50 | def parse_args(): 51 | parser = argparse.ArgumentParser("download_raw_data.py", conflict_handler="resolve") 52 | parser.add_argument("--data_name", help="data name", type=str, default="", 53 | choices=["USPTO_50k", "USPTO_full", "USPTO_480k", "USPTO_STEREO", "USPTO_DIVERSE"]) 54 | 55 | return parser.parse_args() 56 | 57 | 58 | def main(): 59 | args = parse_args() 60 | data_path = os.path.join("./data", args.data_name) 61 | 62 | os.makedirs(data_path, exist_ok=True) 63 | 64 | for url, fn in urls_fns_dict[args.data_name]: 65 | ofn = os.path.join(data_path, fn) 66 | if not os.path.exists(ofn): 67 | gdown.download(url, ofn, quiet=False) 68 | assert os.path.exists(ofn) 69 | else: 70 | print(f"{ofn} exists, skip downloading") 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Modeling Diverse Chemical Reactions for Single-step Retrosynthesis via Discrete Latent Variables 2 | This is the code of paper 3 | **Modeling Diverse Chemical Reactions for Single-step Retrosynthesis via Discrete Latent Variables**. 4 | Huarui He, Jie Wang, Yunfei Liu, Feng Wu. CIKM 2022. 5 | 6 | ## Reproduce the Results 7 | ### 1. Environmental setup 8 | Please ensure that conda has been properly initialized, i.e. **conda activate** is runnable. Then 9 | ``` 10 | bash -i scripts/setup.sh 11 | conda activate retro 12 | ``` 13 | 14 | ### 2. Data preparation 15 | Download the raw (cleaned and tokenized) data from Google Drive by 16 | ``` 17 | python scripts/download_raw_data.py --data_name=$DATASET 18 | ``` 19 | where DATASET is one of [**USPTO_50k**, **USPTO_DIVERSE**]
20 | Run **create_1toN_map.ipynb** in **data/** to derive 1-to-N answer dict for the each dataset. 21 | It is okay to only download the dataset you want. 22 | 23 | Then run the preprocessing script by 24 | ``` 25 | sh scripts/preprocess.sh $DATASET 26 | ``` 27 | 28 | ### 3. Model training and validation 29 | Run the training script by 30 | ``` 31 | export CUDA_VISIBLE_DEVICES=7 32 | sh scripts/train_g2s.sh "cvae-gru-K20" "USPTO_DIVERSE" "20" 33 | ``` 34 | 35 | Optionally, run the evaluation script by 36 | ``` 37 | sh scripts/validate.sh "cvae-gru-K20" "USPTO_DIVERSE" 38 | ``` 39 | Note: the evaluation process performs beam search over the whole val sets for all checkpoints. 40 | It can take tens of hours. 41 | 42 | 43 | ### 4. Testing 44 | Then run the testing script by 45 | ``` 46 | sh scripts/predict.sh "cvae-gru-K20" "USPTO_DIVERSE" 47 | ``` 48 | which will first run beam search to generate the results for all the test inputs, 49 | and then computes the average top-k accuracies. 50 | 51 | 65 | 66 | 67 | ## File tree 68 | ``` 69 | RetroDCVAE 70 | ├─ README.md 71 | ├─ data 72 | │ ├─ USPTO_50k 73 | │ │ ├─ src-train.txt 74 | │ │ ├─ ... 75 | │ │ └─ tgt-test.txt 76 | │ ├─ USPTO_DIVERSE 77 | │ │ ├─ src-train.txt 78 | │ │ ├─ ... 79 | │ │ └─ tgt-test.txt 80 | │ └─ create_1toN_map.ipynb 81 | ├─ models 82 | │ ├─ VAR 83 | │ │ ├─ utils.py 84 | │ │ └─ var_dec.py 85 | │ ├─ attention_xl.py 86 | │ ├─ dgat.py 87 | │ ├─ dgcn.py 88 | │ ├─ graph2seq_series_rel.py 89 | │ ├─ graphfeat.py 90 | │ ├─ model_utils.py 91 | │ └─ seq2seq.py 92 | ├─ predict.py 93 | ├─ preprocess.py 94 | ├─ scripts 95 | │ ├─ download_checkpoints.py 96 | │ ├─ download_raw_data.py 97 | │ ├─ predict.sh 98 | │ ├─ preprocess.sh 99 | │ ├─ setup.sh 100 | │ ├─ train_g2s.sh 101 | │ └─ validate.sh 102 | ├─ train.py 103 | ├─ utils 104 | │ ├─ chem_utils.py 105 | │ ├─ data_utils.py 106 | │ ├─ parsing.py 107 | │ ├─ rxn_graphs.py 108 | │ └─ train_utils.py 109 | └─ validate.py 110 | ``` 111 | 112 | ## Citation 113 | If you find this code useful, please consider citing the following paper. 114 | ``` 115 | @inproceedings{CIKM22_RetroDCVAE, 116 | author={Huarui He and Jie Wang and Yunfei Liu and Feng Wu}, 117 | booktitle={Proc. of CIKM}, 118 | title={Modeling Diverse Chemical Reactions for Single-step Retrosynthesis via Discrete Latent Variables}, 119 | year={2022} 120 | } 121 | ``` 122 | 123 | ## Acknowledgement 124 | We refer to the code of [Graph2SMILES](https://github.com/coleygroup/Graph2SMILES). Thanks for their contributions. 125 | 126 | -------------------------------------------------------------------------------- /models/dgcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.model_utils import index_select_ND 4 | from typing import Tuple 5 | 6 | 7 | class DGCNGRU(nn.Module): 8 | """GRU Message Passing layer.""" 9 | def __init__(self, args, input_size: int, h_size: int, depth: int): 10 | super().__init__() 11 | self.args = args 12 | 13 | self.input_size = input_size 14 | self.h_size = h_size 15 | self.depth = depth 16 | 17 | self._build_layer_components() 18 | 19 | def _build_layer_components(self) -> None: 20 | """Build layer components.""" 21 | self.W_z = nn.Linear(self.input_size + self.h_size, self.h_size) 22 | self.W_r = nn.Linear(self.input_size, self.h_size, bias=False) 23 | self.U_r = nn.Linear(self.h_size, self.h_size) 24 | self.W_h = nn.Linear(self.input_size + self.h_size, self.h_size) 25 | 26 | def GRU(self, x: torch.Tensor, h_nei: torch.Tensor) -> torch.Tensor: 27 | """Implements the GRU gating equations. 28 | 29 | Parameters 30 | ---------- 31 | x: torch.Tensor, input tensor 32 | h_nei: torch.Tensor, hidden states of the neighbors 33 | """ 34 | sum_h = h_nei.sum(dim=1) # (9) 35 | z_input = torch.cat([x, sum_h], dim=1) # x = [x_u; x_uv] 36 | z = torch.sigmoid(self.W_z(z_input)) # (10) 37 | 38 | r_1 = self.W_r(x).view(-1, 1, self.h_size) 39 | r_2 = self.U_r(h_nei) 40 | r = torch.sigmoid(r_1 + r_2) # (11) r_ku = f_r(x; m_ku) = W_r(x) + U_r(m_ku) 41 | 42 | gated_h = r * h_nei 43 | sum_gated_h = gated_h.sum(dim=1) # (12) 44 | h_input = torch.cat([x, sum_gated_h], dim=1) 45 | pre_h = torch.tanh(self.W_h(h_input)) # (13) 46 | new_h = (1.0 - z) * sum_h + z * pre_h # (14) 47 | 48 | return new_h 49 | 50 | def forward(self, fmess: torch.Tensor, bgraph: torch.Tensor) -> torch.Tensor: 51 | """Forward pass of the RNN 52 | 53 | Parameters 54 | ---------- 55 | fmess: torch.Tensor, contains the initial features passed as messages 56 | bgraph: torch.Tensor, bond graph tensor. Contains who passes messages to whom. 57 | """ 58 | h = torch.zeros(fmess.size()[0], self.h_size, device=fmess.device) 59 | mask = torch.ones(h.size()[0], 1, device=h.device) 60 | mask[0, 0] = 0 # first message is padding 61 | 62 | for i in range(self.depth): 63 | h_nei = index_select_ND(h, 0, bgraph) 64 | h = self.GRU(fmess, h_nei) 65 | h = h * mask 66 | return h 67 | 68 | 69 | class DGCNEncoder(nn.Module): 70 | """MessagePassing Network based encoder. Messages are updated using an RNN 71 | and the final message is used to update atom embeddings.""" 72 | def __init__(self, args, input_size: int, node_fdim: int): 73 | super().__init__() 74 | self.args = args 75 | 76 | self.h_size = args.encoder_hidden_size 77 | self.depth = args.encoder_num_layers 78 | self.input_size = input_size 79 | self.node_fdim = node_fdim 80 | 81 | self._build_layers() 82 | 83 | def _build_layers(self) -> None: 84 | """Build layers associated with the MPNEncoder.""" 85 | self.W_o = nn.Sequential(nn.Linear(self.node_fdim + self.h_size, self.h_size), nn.GELU()) 86 | self.rnn = DGCNGRU(self.args, self.input_size, self.h_size, self.depth) 87 | 88 | def forward(self, fnode: torch.Tensor, fmess: torch.Tensor, 89 | agraph: torch.Tensor, bgraph: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, ...]: 90 | """Forward pass of the MPNEncoder. 91 | 92 | Parameters 93 | ---------- 94 | fnode: torch.Tensor, node feature tensor 95 | fmess: torch.Tensor, message features 96 | agraph: torch.Tensor, neighborhood of an atom 97 | bgraph: torch.Tensor, neighborhood of a bond, 98 | except the directed bond from the destination node to the source node 99 | mask: torch.Tensor, masks on nodes 100 | """ 101 | h = self.rnn(fmess, bgraph) 102 | nei_message = index_select_ND(h, 0, agraph) 103 | nei_message = nei_message.sum(dim=1) 104 | node_hiddens = torch.cat([fnode, nei_message], dim=1) 105 | node_hiddens = self.W_o(node_hiddens) 106 | 107 | if mask is None: 108 | mask = torch.ones(node_hiddens.size(0), 1, device=fnode.device) 109 | mask[0, 0] = 0 # first node is padding 110 | 111 | return node_hiddens * mask, h 112 | -------------------------------------------------------------------------------- /data/create_1toN_map.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Make Answer JSON from TXT" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "5007it [00:00, 633529.82it/s]\n" 20 | ] 21 | }, 22 | { 23 | "name": "stdout", 24 | "output_type": "stream", 25 | "text": [ 26 | "JSON Done!\n" 27 | ] 28 | }, 29 | { 30 | "name": "stderr", 31 | "output_type": "stream", 32 | "text": [ 33 | "5001it [00:00, 619831.40it/s]" 34 | ] 35 | }, 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "JSON Done!\n", 41 | "2022-08-07 04:46:10\n" 42 | ] 43 | }, 44 | { 45 | "name": "stderr", 46 | "output_type": "stream", 47 | "text": [ 48 | "\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "import csv\n", 54 | "import os\n", 55 | "from tqdm import tqdm\n", 56 | "import sys\n", 57 | "import torch\n", 58 | "import time\n", 59 | "\n", 60 | "\n", 61 | "def creatJSON(data_path, src_file, tgt_file):\n", 62 | " with open(src_file, 'r') as sf, open(tgt_file, 'r') as tf:\n", 63 | " src = sf.readlines()\n", 64 | " tgt = tf.readlines()\n", 65 | " answer = {}\n", 66 | " last_p = None\n", 67 | " for idx, row in tqdm(enumerate(src)):\n", 68 | " smi_p = row.strip()\n", 69 | " smi_r = tgt[idx].strip()\n", 70 | " if smi_p != last_p:\n", 71 | " last_rs = []\n", 72 | " if last_p == smi_p and smi_r in last_rs:\n", 73 | " continue\n", 74 | " if last_p != smi_p:\n", 75 | " answer[smi_p] = {'product':smi_p, 'reaction':0, 'reactant':[]}\n", 76 | " answer[smi_p]['reactant'].append(smi_r)\n", 77 | " answer[smi_p]['reaction']+=1\n", 78 | " last_p = smi_p\n", 79 | " last_rs.append(smi_r)\n", 80 | " import json\n", 81 | " json_str = json.dumps(answer, indent=4)\n", 82 | " if 'val' in src_file:\n", 83 | " fn = os.path.join(data_path, \"raw_val\")\n", 84 | " elif 'test' in src_file:\n", 85 | " fn = os.path.join(data_path, \"raw_test\")\n", 86 | " elif 'train' in src_file:\n", 87 | " fn = os.path.join(data_path, \"raw_train\")\n", 88 | " with open(fn+'.json', 'w') as json_file:\n", 89 | " json_file.write(json_str)\n", 90 | " print(f'JSON Done!')\n", 91 | "\n", 92 | "def main():\n", 93 | " # fp = \"./USPTO_DIVERSE\"\n", 94 | " fp = \"./USPTO_50k\"\n", 95 | " for phase in [\"test\",\"val\"]:\n", 96 | " src_file, tgt_file = os.path.join(fp, f\"src-{phase}.txt\"), os.path.join(fp, f\"tgt-{phase}.txt\")\n", 97 | " creatJSON(fp, src_file, tgt_file)\n", 98 | "\n", 99 | "if __name__ == \"__main__\":\n", 100 | " main()\n", 101 | " print(time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()))" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 4, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "CC\n", 114 | "C C\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "smi = 'O=C(Cl)C(=O)Cl.O=C(O)c1ccc2cncc(Br)c2n1>>[Cl-]'\n", 120 | "smi_r, _, smi_p = smi.split(\">\")\n", 121 | "a=canonicalize_smiles(smi_p, remove_atom_number=True)\n", 122 | "b=tokenize_smiles(canonicalize_smiles(smi_p, remove_atom_number=True))\n", 123 | "print(canonicalize_smiles(smi_p, remove_atom_number=True))\n", 124 | "print(tokenize_smiles(canonicalize_smiles(smi_p, remove_atom_number=True)))\n" 125 | ] 126 | } 127 | ], 128 | "metadata": { 129 | "kernelspec": { 130 | "display_name": "Python 3.6.13 ('graph2seq')", 131 | "language": "python", 132 | "name": "python3" 133 | }, 134 | "language_info": { 135 | "codemirror_mode": { 136 | "name": "ipython", 137 | "version": 3 138 | }, 139 | "file_extension": ".py", 140 | "mimetype": "text/x-python", 141 | "name": "python", 142 | "nbconvert_exporter": "python", 143 | "pygments_lexer": "ipython3", 144 | "version": "3.6.13" 145 | }, 146 | "orig_nbformat": 4, 147 | "vscode": { 148 | "interpreter": { 149 | "hash": "e4a186ed28707f3ae002f1b6f68caf91bf4d4c31b0d65fe2b2ab9c5a49f5e9dc" 150 | } 151 | } 152 | }, 153 | "nbformat": 4, 154 | "nbformat_minor": 2 155 | } 156 | -------------------------------------------------------------------------------- /utils/chem_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | from typing import List 4 | 5 | 6 | # Symbols for different atoms 7 | ATOM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 8 | 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 9 | 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 10 | 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Ta', 'Tc', 'Ba', 'Bi', 'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir', 11 | 'Ce', 'Gd', 'Ga', 'Cs', '*', 'unk'] 12 | ATOM_DICT = {symbol: i for i, symbol in enumerate(ATOM_LIST)} 13 | 14 | MAX_NB = 10 15 | DEGREES = list(range(MAX_NB)) 16 | HYBRIDIZATION = [Chem.rdchem.HybridizationType.SP, 17 | Chem.rdchem.HybridizationType.SP2, 18 | Chem.rdchem.HybridizationType.SP3, 19 | Chem.rdchem.HybridizationType.SP3D, 20 | Chem.rdchem.HybridizationType.SP3D2] 21 | HYBRIDIZATION_DICT = {hb: i for i, hb in enumerate(HYBRIDIZATION)} 22 | 23 | FORMAL_CHARGE = [-1, -2, 1, 2, 0] 24 | FC_DICT = {fc: i for i, fc in enumerate(FORMAL_CHARGE)} 25 | 26 | VALENCE = [0, 1, 2, 3, 4, 5, 6] 27 | VALENCE_DICT = {vl: i for i, vl in enumerate(VALENCE)} 28 | 29 | NUM_Hs = [0, 1, 3, 4, 5] 30 | NUM_Hs_DICT = {nH: i for i, nH in enumerate(NUM_Hs)} 31 | 32 | CHIRAL_TAG = [Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, 33 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, 34 | Chem.rdchem.ChiralType.CHI_UNSPECIFIED] 35 | CHIRAL_TAG_DICT = {ct: i for i, ct in enumerate(CHIRAL_TAG)} 36 | 37 | RS_TAG = ["R", "S", "None"] 38 | RS_TAG_DICT = {rs: i for i, rs in enumerate(RS_TAG)} 39 | 40 | BOND_TYPES = [None, 41 | Chem.rdchem.BondType.SINGLE, 42 | Chem.rdchem.BondType.DOUBLE, 43 | Chem.rdchem.BondType.TRIPLE, 44 | Chem.rdchem.BondType.AROMATIC] 45 | BOND_FLOAT_TO_TYPE = { 46 | 0.0: BOND_TYPES[0], 47 | 1.0: BOND_TYPES[1], 48 | 2.0: BOND_TYPES[2], 49 | 3.0: BOND_TYPES[3], 50 | 1.5: BOND_TYPES[4], 51 | } 52 | 53 | BOND_STEREO = [Chem.rdchem.BondStereo.STEREOE, 54 | Chem.rdchem.BondStereo.STEREOZ, 55 | Chem.rdchem.BondStereo.STEREONONE] 56 | 57 | BOND_DELTAS = {-3: 0, -2: 1, -1.5: 2, -1: 3, -0.5: 4, 0: 5, 0.5: 6, 1: 7, 1.5: 8, 2: 9, 3: 10} 58 | BOND_FLOATS = [0.0, 1.0, 2.0, 3.0, 1.5] 59 | 60 | RXN_CLASSES = list(range(10)) 61 | 62 | # ATOM_FDIM = len(ATOM_LIST) + len(DEGREES) + len(FORMAL_CHARGE) + len(HYBRIDIZATION) \ 63 | # + len(VALENCE) + len(NUM_Hs) + 1 64 | ATOM_FDIM = [len(ATOM_LIST), len(DEGREES), len(FORMAL_CHARGE), len(HYBRIDIZATION), len(VALENCE), 65 | len(NUM_Hs), len(CHIRAL_TAG), len(RS_TAG), 2] 66 | # BOND_FDIM = 6 67 | BOND_FDIM = 9 68 | BINARY_FDIM = 5 + BOND_FDIM 69 | INVALID_BOND = -1 70 | 71 | 72 | def get_atom_features_sparse(atom: Chem.Atom, rxn_class: int = None, use_rxn_class: bool = False) -> List[int]: 73 | """Get atom features as sparse idx. 74 | 75 | Parameters 76 | ---------- 77 | atom: Chem.Atom, 78 | Atom object from RDKit 79 | rxn_class: int, None 80 | Reaction class the molecule was part of 81 | use_rxn_class: bool, default False, 82 | Whether to use reaction class as additional input 83 | """ 84 | feature_array = [] 85 | symbol = atom.GetSymbol() 86 | symbol_id = ATOM_DICT.get(symbol, ATOM_DICT["unk"]) 87 | feature_array.append(symbol_id) 88 | 89 | if symbol in ["*", "unk"]: 90 | padding = [999999999] * len(ATOM_FDIM) if use_rxn_class else [999999999] * (len(ATOM_FDIM) - 1) 91 | feature_array.extend(padding) 92 | 93 | else: 94 | degree_id = atom.GetDegree() 95 | if degree_id not in DEGREES: 96 | degree_id = 9 97 | formal_charge_id = FC_DICT.get(atom.GetFormalCharge(), 4) 98 | hybridization_id = HYBRIDIZATION_DICT.get(atom.GetHybridization(), 4) 99 | valence_id = VALENCE_DICT.get(atom.GetTotalValence(), 6) 100 | num_h_id = NUM_Hs_DICT.get(atom.GetTotalNumHs(), 4) 101 | chiral_tag_id = CHIRAL_TAG_DICT.get(atom.GetChiralTag(), 2) 102 | 103 | rs_tag = atom.GetPropsAsDict().get("_CIPCode", "None") 104 | rs_tag_id = RS_TAG_DICT.get(rs_tag, 2) 105 | 106 | is_aromatic = int(atom.GetIsAromatic()) 107 | feature_array.extend([degree_id, formal_charge_id, hybridization_id, 108 | valence_id, num_h_id, chiral_tag_id, rs_tag_id, is_aromatic]) 109 | 110 | if use_rxn_class: 111 | feature_array.append(rxn_class) 112 | 113 | return feature_array 114 | 115 | 116 | def get_bond_features(bond: Chem.Bond) -> List[int]: 117 | """Get bond features. 118 | 119 | Parameters 120 | ---------- 121 | bond: Chem.Bond, 122 | bond object 123 | """ 124 | bt = bond.GetBondType() 125 | bond_features = [int(bt == bond_type) for bond_type in BOND_TYPES[1:]] 126 | bs = bond.GetStereo() 127 | bond_features.extend([int(bs == bond_stereo) for bond_stereo in BOND_STEREO]) 128 | bond_features.extend([int(bond.GetIsConjugated()), int(bond.IsInRing())]) 129 | 130 | return bond_features 131 | -------------------------------------------------------------------------------- /models/VAR/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def gaussian_kld(recog_mu, recog_std, prior_mu, prior_std): 7 | kld = -0.5 * torch.sum(torch.log(torch.div(torch.pow(recog_std, 2), torch.pow(prior_std, 2)) + 1e-6) 8 | - torch.div(torch.pow(recog_std, 2), torch.pow(prior_std, 2)) 9 | - torch.div(torch.pow(recog_mu - prior_mu, 2), torch.pow(prior_std, 2)) + 1, dim=-1) 10 | return kld 11 | 12 | 13 | def onehot_from_logits(logits, eps=0.0): 14 | """ 15 | Given batch of logits, return one-hot sample using epsilon greedy strategy 16 | (based on given epsilon) 17 | """ 18 | # get best (according to current policy) actions in one-hot form 19 | argmax_acs = (logits == logits.max(1, keepdim=True)[0]).float() 20 | if eps == 0.0: 21 | return argmax_acs 22 | 23 | 24 | class layerN(nn.Module): 25 | """ 26 | Layer Normalization class 27 | """ 28 | 29 | def __init__(self, features, eps=1e-6): 30 | super(layerN, self).__init__() 31 | self.a_2 = torch.ones(features) 32 | self.b_2 = torch.zeros(features) 33 | self.eps = eps 34 | 35 | def forward(self, x): 36 | device = x.device 37 | self.a_2, self.b_2 = self.a_2.to(device=device), self.b_2.to(device=device) 38 | mean = x.mean(-1, keepdim=True) 39 | std = x.std(-1, keepdim=True) 40 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 41 | 42 | 43 | class Con_Layer_Norm(nn.Module): 44 | """ 45 | Conditional Layer Normalization class 46 | """ 47 | 48 | def __init__(self, features, d_size, eps=1e-6): 49 | super(Con_Layer_Norm, self).__init__() 50 | self.a_2 = nn.Parameter(torch.ones(features)) 51 | self.b_2 = nn.Parameter(torch.zeros(features)) 52 | self.eps = eps 53 | self.emb_g = nn.Linear(d_size, features, bias=False) 54 | self.emb_b = nn.Linear(d_size, features, bias=False) 55 | nn.init.zeros_(self.emb_g.weight) 56 | nn.init.zeros_(self.emb_b.weight) 57 | 58 | def forward(self, x, condition=None): 59 | mean = x.mean(-1, keepdim=True) 60 | std = x.std(-1, keepdim=True) 61 | if condition is not None: 62 | if condition.size() != x.size(): 63 | condition = torch.mean(condition, dim=1, keepdim=True) 64 | gamma = self.emb_g(condition) 65 | beta = self.emb_b(condition) 66 | return (self.a_2 + gamma) * (x - mean) / (std + self.eps) + (self.b_2 + beta) 67 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 68 | 69 | 70 | class VarFeedForward(nn.Module): 71 | def __init__(self, input_depth, filter_size, output_depth, layer_config='ll', dropout=0.0): 72 | super(VarFeedForward, self).__init__() 73 | 74 | layers = [] 75 | sizes = ([(input_depth, filter_size)] + 76 | [(filter_size, filter_size)] * (len(layer_config) - 2) + 77 | [(filter_size, output_depth)]) 78 | 79 | for lc, s in zip(list(layer_config), sizes): 80 | if lc == 'l': 81 | layers.append(nn.Linear(*s)) 82 | else: 83 | raise ValueError("Unknown layer type {}".format(lc)) 84 | 85 | self.layers = nn.ModuleList(layers) 86 | self.tanh = nn.Tanh() 87 | self.drop = nn.Dropout(dropout) 88 | 89 | def forward(self, inputs): 90 | x = inputs 91 | for i, layer in enumerate(self.layers): 92 | x = layer(x) 93 | if i < len(self.layers) - 1: 94 | x = self.tanh(x) 95 | x = self.drop(x) 96 | return x 97 | 98 | 99 | class FeedForwardNet(nn.Module): 100 | def __init__(self, d_model, d_ff, d_out, dropout=0.1): 101 | super(FeedForwardNet, self).__init__() 102 | self.w_1 = nn.Linear(d_model, d_ff) 103 | self.w_2 = nn.Linear(d_ff, d_out) 104 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 105 | self.dropout_1 = nn.Dropout(dropout) 106 | self.relu = nn.ReLU() 107 | self.tanh = nn.Tanh() 108 | self.dropout_2 = nn.Dropout(dropout) 109 | 110 | def forward(self, x): 111 | # inter = self.dropout_1(self.tanh(self.w_1(self.layer_norm(x)))) 112 | inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) 113 | output = self.dropout_2(self.w_2(inter)) 114 | return output 115 | 116 | 117 | # borrow from onmt 118 | def tile(x, count, dim=0): 119 | """ 120 | Tiles x on dimension dim count times. 121 | """ 122 | perm = list(range(len(x.size()))) 123 | if dim != 0: 124 | perm[0], perm[dim] = perm[dim], perm[0] 125 | x = x.permute(perm).contiguous() 126 | out_size = list(x.size()) 127 | out_size[0] *= count 128 | batch = x.size(0) 129 | x = x.view(batch, -1) \ 130 | .transpose(0, 1) \ 131 | .repeat(count, 1) \ 132 | .transpose(0, 1) \ 133 | .contiguous() \ 134 | .view(*out_size) 135 | if dim != 0: 136 | x = x.permute(perm).contiguous() 137 | return x 138 | 139 | 140 | # borrow from https://github1s.com/bojone/vae/blob/master/vae_keras_cnn_gs.py 141 | def GumbelSoftmax(logits, tau=.8, noise=1e-20): 142 | eps = torch.rand(size=logits.shape, device=logits.device) # uniform distribution on the interval [0, 1) 143 | outputs = logits - torch.log(-torch.log(eps + noise) + noise) 144 | return torch.softmax(outputs / tau, -1) 145 | 146 | 147 | # borrow from https://blog.evjang.com/2016/11/tutorial-categorical-variational.html 148 | def sample_gumbel(shape, device, eps=1e-20): 149 | """Sample from Gumbel(0, 1)""" 150 | U = torch.rand(size=shape, device=device) # uniform distribution on the interval [0, 1) 151 | return -torch.log(-torch.log(U + eps) + eps) 152 | 153 | def gumbel_softmax_sample(logits, temperature=1.0): 154 | """ Draw a sample from the Gumbel-Softmax distribution""" 155 | y = logits + sample_gumbel(logits.shape, logits.device) 156 | return torch.softmax( y / temperature, dim=-1) 157 | 158 | def gumbel_softmax(logits, temperature=1.0, hard=False): 159 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 160 | Args: 161 | logits: [batch_size, n_class] unnormalized log-probs 162 | temperature: non-negative scalar 163 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 164 | Returns: 165 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 166 | If hard=True, then the returned sample will be one-hot, otherwise it will 167 | be a probabilitiy distribution that sums to 1 across classes 168 | """ 169 | y = gumbel_softmax_sample(logits, temperature) 170 | if hard: 171 | y_hard = onehot_from_logits(y) 172 | y = (y_hard - y).detach() + y 173 | return y -------------------------------------------------------------------------------- /utils/rxn_graphs.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from utils.chem_utils import BOND_TYPES 3 | from rdkit import Chem 4 | from typing import List, Tuple, Union 5 | 6 | 7 | def get_sub_mol(mol, sub_atoms): 8 | new_mol = Chem.RWMol() 9 | atom_map = {} 10 | for idx in sub_atoms: 11 | atom = mol.GetAtomWithIdx(idx) 12 | atom_map[idx] = new_mol.AddAtom(atom) 13 | 14 | sub_atoms = set(sub_atoms) 15 | for idx in sub_atoms: 16 | a = mol.GetAtomWithIdx(idx) 17 | for b in a.GetNeighbors(): 18 | if b.GetIdx() not in sub_atoms: 19 | continue 20 | bond = mol.GetBondBetweenAtoms(a.GetIdx(), b.GetIdx()) 21 | bt = bond.GetBondType() 22 | if a.GetIdx() < b.GetIdx(): # each bond is enumerated twice 23 | new_mol.AddBond(atom_map[a.GetIdx()], atom_map[b.GetIdx()], bt) 24 | 25 | return new_mol.GetMol() 26 | 27 | 28 | class RxnGraph: 29 | """ 30 | RxnGraph is an abstract class for storing all elements of a reaction, like 31 | reactants, products and fragments. The edits associated with the reaction 32 | are also captured in edit labels. One can also use h_labels, which keep track 33 | of atoms with hydrogen changes. For reactions with multiple edits, a done 34 | label is also added to account for termination of edits. 35 | """ 36 | 37 | def __init__(self, 38 | prod_mol: Chem.Mol = None, 39 | frag_mol: Chem.Mol = None, 40 | reac_mol: Chem.Mol = None, 41 | rxn_class: int = None) -> None: 42 | """ 43 | Parameters 44 | ---------- 45 | prod_mol: Chem.Mol, 46 | Product molecule 47 | frag_mol: Chem.Mol, default None 48 | Fragment molecule(s) 49 | reac_mol: Chem.Mol, default None 50 | Reactant molecule(s) 51 | rxn_class: int, default None, 52 | Reaction class for this reaction. 53 | """ 54 | if prod_mol is not None: 55 | self.prod_mol = RxnElement(mol=prod_mol, rxn_class=rxn_class) 56 | if frag_mol is not None: 57 | self.frag_mol = MultiElement(mol=frag_mol, rxn_class=rxn_class) 58 | if reac_mol is not None: 59 | self.reac_mol = MultiElement(mol=reac_mol, rxn_class=rxn_class) 60 | self.rxn_class = rxn_class 61 | 62 | def get_attributes(self, mol_attrs: Tuple = ('prod_mol', 'frag_mol', 'reac_mol')) -> Tuple: 63 | """ 64 | Returns the different attributes associated with the reaction graph. 65 | 66 | Parameters 67 | ---------- 68 | mol_attrs: Tuple, 69 | Molecule objects to return 70 | """ 71 | return tuple(getattr(self, attr) for attr in mol_attrs if hasattr(self, attr)) 72 | 73 | 74 | class RxnElement: 75 | """ 76 | RxnElement is an abstract class for dealing with single molecule. The graph 77 | and corresponding molecule attributes are built for the molecule. The constructor 78 | accepts only mol objects, sidestepping the use of SMILES string which may always 79 | not be achievable, especially for a unkekulizable molecule. 80 | """ 81 | 82 | def __init__(self, mol: Chem.Mol, rxn_class: int = None) -> None: 83 | """ 84 | Parameters 85 | ---------- 86 | mol: Chem.Mol, 87 | Molecule 88 | rxn_class: int, default None, 89 | Reaction class for this reaction. 90 | """ 91 | self.mol = mol 92 | self.rxn_class = rxn_class 93 | self._build_mol() 94 | self._build_graph() 95 | 96 | def _build_mol(self) -> None: 97 | """Builds the molecule attributes.""" 98 | self.num_atoms = self.mol.GetNumAtoms() 99 | self.num_bonds = self.mol.GetNumBonds() 100 | self.amap_to_idx = {atom.GetAtomMapNum(): atom.GetIdx() 101 | for atom in self.mol.GetAtoms()} 102 | self.idx_to_amap = {value: key for key, value in self.amap_to_idx.items()} 103 | 104 | def _build_graph(self) -> None: 105 | """Builds the graph attributes.""" 106 | self.G_undir = nx.Graph(Chem.rdmolops.GetAdjacencyMatrix(self.mol)) 107 | self.G_dir = nx.DiGraph(Chem.rdmolops.GetAdjacencyMatrix(self.mol)) 108 | 109 | for atom in self.mol.GetAtoms(): 110 | self.G_undir.nodes[atom.GetIdx()]['label'] = atom.GetSymbol() 111 | self.G_dir.nodes[atom.GetIdx()]['label'] = atom.GetSymbol() 112 | 113 | for bond in self.mol.GetBonds(): 114 | a1 = bond.GetBeginAtom().GetIdx() 115 | a2 = bond.GetEndAtom().GetIdx() 116 | btype = BOND_TYPES.index(bond.GetBondType()) 117 | self.G_undir[a1][a2]['label'] = btype 118 | self.G_dir[a1][a2]['label'] = btype 119 | self.G_dir[a2][a1]['label'] = btype 120 | 121 | self.atom_scope = (0, self.num_atoms) 122 | self.bond_scope = (0, self.num_bonds) 123 | 124 | def update_atom_scope(self, offset: int) -> Union[List, Tuple]: 125 | """Updates the atom indices by the offset. 126 | 127 | Parameters 128 | ---------- 129 | offset: int, 130 | Offset to apply 131 | """ 132 | # Note that the self. reference to atom_scope is dropped to keep self.atom_scope non-dynamic 133 | if isinstance(self.atom_scope, list): 134 | atom_scope = [(st + offset, le) for st, le in self.atom_scope] 135 | else: 136 | st, le = self.atom_scope 137 | atom_scope = (st + offset, le) 138 | 139 | return atom_scope 140 | 141 | def update_bond_scope(self, offset: int) -> Union[List, Tuple]: 142 | """Updates the bond indices by the offset. 143 | 144 | Parameters 145 | ---------- 146 | offset: int, 147 | Offset to apply 148 | """ 149 | # Note that the self. reference to bond_scope is dropped to keep self.bond_scope non-dynamic 150 | if isinstance(self.bond_scope, list): 151 | bond_scope = [(st + offset, le) for st, le in self.bond_scope] 152 | else: 153 | st, le = self.bond_scope 154 | bond_scope = (st + offset, le) 155 | 156 | return bond_scope 157 | 158 | 159 | class MultiElement(RxnElement): 160 | """ 161 | MultiElement is an abstract class for dealing with multiple molecules. The graph 162 | is built with all molecules, but different molecules and their sizes are stored. 163 | The constructor accepts only mol objects, sidestepping the use of SMILES string 164 | which may always not be achievable, especially for an invalid intermediates. 165 | """ 166 | 167 | def _build_graph(self) -> None: 168 | """Builds the graph attributes.""" 169 | self.G_undir = nx.Graph(Chem.rdmolops.GetAdjacencyMatrix(self.mol)) 170 | self.G_dir = nx.DiGraph(Chem.rdmolops.GetAdjacencyMatrix(self.mol)) 171 | 172 | for atom in self.mol.GetAtoms(): 173 | self.G_undir.nodes[atom.GetIdx()]['label'] = atom.GetSymbol() 174 | self.G_dir.nodes[atom.GetIdx()]['label'] = atom.GetSymbol() 175 | 176 | for bond in self.mol.GetBonds(): 177 | a1 = bond.GetBeginAtom().GetIdx() 178 | a2 = bond.GetEndAtom().GetIdx() 179 | btype = BOND_TYPES.index(bond.GetBondType()) 180 | self.G_undir[a1][a2]['label'] = btype 181 | self.G_dir[a1][a2]['label'] = btype 182 | self.G_dir[a2][a1]['label'] = btype 183 | 184 | frag_indices = [c for c in nx.strongly_connected_components(self.G_dir)] 185 | self.mols = [get_sub_mol(self.mol, sub_atoms) for sub_atoms in frag_indices] 186 | 187 | atom_start = 0 188 | bond_start = 0 189 | self.atom_scope = [] 190 | self.bond_scope = [] 191 | 192 | for mol in self.mols: 193 | self.atom_scope.append((atom_start, mol.GetNumAtoms())) 194 | self.bond_scope.append((bond_start, mol.GetNumBonds())) 195 | atom_start += mol.GetNumAtoms() 196 | bond_start += mol.GetNumBonds() 197 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import logging 4 | import numpy as np 5 | import os 6 | import sys 7 | import time 8 | import torch 9 | from models.graph2seq_series_rel import Graph2SeqSeriesRel 10 | from models.seq2seq import Seq2Seq 11 | from torch.utils.data import DataLoader 12 | from utils import parsing 13 | from utils.data_utils import canonicalize_smiles, load_vocab, S2SDataset, G2SDataset, tokenize_smiles 14 | from utils.train_utils import log_tensor, param_count, set_seed, setup_logger 15 | 16 | import json 17 | import warnings 18 | warnings.filterwarnings('ignore') 19 | 20 | def get_predict_parser(): 21 | parser = argparse.ArgumentParser("predict") 22 | parsing.add_common_args(parser) 23 | parsing.add_preprocess_args(parser) 24 | parsing.add_train_args(parser) 25 | parsing.add_predict_args(parser) 26 | 27 | return parser 28 | 29 | 30 | def main(args): 31 | start = time.time() 32 | parsing.log_args(args) 33 | 34 | os.makedirs(os.path.join("./results", args.data_name), exist_ok=True) 35 | 36 | # initialization ----------------- model 37 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 38 | 39 | checkpoints = glob.glob(os.path.join(args.load_from, "*.pt")) 40 | checkpoints = sorted( 41 | checkpoints, 42 | key=lambda ckpt: int(ckpt.split(".")[-2].split("_")[-1]), 43 | reverse=True 44 | ) 45 | checkpoints = [ckpt for ckpt in checkpoints 46 | if (args.checkpoint_step_start <= int(ckpt.split(".")[-2].split("_")[0])) 47 | and (args.checkpoint_step_end >= int(ckpt.split(".")[-2].split("_")[0]))] 48 | 49 | model = None 50 | val_dataset = None 51 | vocab_tokens = None 52 | with open(args.val_tgt.replace('tgt','src'), 'r') as f: 53 | lines = f.readlines() 54 | ans_path = os.path.join("/".join(args.val_tgt.split('/')[:-1]), "raw_val.json") 55 | with open(ans_path, 'r') as jfile: 56 | tgts_dict = json.load(jfile) 57 | for ckpt_i, checkpoint in enumerate(checkpoints): 58 | # ith = int(checkpoint.split('_')[-1].split('.')[0]) 59 | # if ith<38: 60 | # print(f'PASS {checkpoint}!') 61 | # continue 62 | logging.info(f"Loading from {checkpoint}") 63 | state = torch.load(checkpoint) 64 | 65 | pretrain_args = state["args"] 66 | pretrain_state_dict = state["state_dict"] 67 | 68 | if model is None: 69 | # initialization ----------------- model 70 | logging.info(f"Model is None, building model") 71 | logging.info(f"First logging args for training") 72 | parsing.log_args(pretrain_args) 73 | 74 | for attr in ["mpn_type", "rel_pos"]: 75 | try: 76 | getattr(pretrain_args, attr) 77 | except AttributeError: 78 | setattr(pretrain_args, attr, getattr(args, attr)) 79 | 80 | assert args.model == pretrain_args.model, f"Pretrained model is {pretrain_args.model}!" 81 | if args.model == "s2s": 82 | model_class = Seq2Seq 83 | dataset_class = S2SDataset 84 | elif args.model == "g2s_series_rel": 85 | model_class = Graph2SeqSeriesRel 86 | dataset_class = G2SDataset 87 | args.compute_graph_distance = True 88 | assert args.compute_graph_distance 89 | else: 90 | raise ValueError(f"Model {args.model} not supported!") 91 | 92 | # initialization ----------------- vocab 93 | vocab = load_vocab(pretrain_args.vocab_file) 94 | vocab_tokens = [k for k, v in sorted(vocab.items(), key=lambda tup: tup[1])] 95 | 96 | model = model_class(pretrain_args, vocab) 97 | logging.info(model) 98 | logging.info(f"Number of parameters = {param_count(model)}") 99 | 100 | # initialization ----------------- data 101 | val_dataset = dataset_class(pretrain_args, file=args.valid_bin) 102 | val_dataset.batch( 103 | batch_type=args.batch_type, 104 | batch_size=args.predict_batch_size 105 | ) 106 | with open(args.val_tgt, "r") as f: 107 | total = sum(1 for _ in f) 108 | 109 | model.load_state_dict(pretrain_state_dict, strict=False) 110 | logging.info(f"Loaded pretrained state_dict from {checkpoint}") 111 | 112 | model.to(device) 113 | model.eval() 114 | 115 | val_loader = DataLoader( 116 | dataset=val_dataset, 117 | batch_size=1, 118 | shuffle=False, 119 | collate_fn=lambda _batch: _batch[0], 120 | num_workers=16, 121 | pin_memory=True 122 | ) 123 | 124 | # prediction 125 | all_predictions = [] 126 | with torch.no_grad(): 127 | for val_idx, val_batch in enumerate(val_loader): 128 | if val_idx % args.log_iter == 0: 129 | logging.info(f"Doing inference on val step {val_idx}, time: {time.time() - start: .2f} s") 130 | sys.stdout.flush() 131 | 132 | val_batch.to(device) 133 | results = model.predict_step( 134 | reaction_batch=val_batch, 135 | batch_size=val_batch.size, 136 | beam_size=args.beam_size, 137 | n_best=args.n_best, 138 | temperature=args.temperature, 139 | min_length=args.predict_min_len, 140 | max_length=args.predict_max_len 141 | )[0] 142 | 143 | for i, predictions in enumerate(results["predictions"]): 144 | smis = [] 145 | for prediction in predictions: 146 | predicted_idx = prediction.detach().cpu().numpy() 147 | predicted_tokens = [vocab_tokens[idx] for idx in predicted_idx[:-1]] 148 | smi = " ".join(predicted_tokens) 149 | smis.append(smi) 150 | smis = ",".join(smis) 151 | all_predictions.append(f"{smis}\n") 152 | 153 | # saving prediction results 154 | result_file = f"{args.result_file}.{ckpt_i}" 155 | result_stat_file = f"{args.result_file}.stat.{ckpt_i}" 156 | with open(result_file, "w") as of: 157 | of.writelines(all_predictions) 158 | 159 | # scoring 160 | invalid = 0 161 | accuracies = np.zeros([total, args.n_best], dtype=np.float32) 162 | 163 | with open(result_file, "r") as f_predict: 164 | for i, line_predict in enumerate(f_predict): 165 | line_predict = "".join(line_predict.split()) 166 | smis_predict = line_predict.split(",") 167 | smis_predict = [canonicalize_smiles(smi, trim=False, suppress_warning=True) for smi in smis_predict] 168 | if not smis_predict[0]: 169 | invalid += 1 # top-1 invalid 170 | smis_predict = [smi for smi in smis_predict if smi and not smi == "CC"] # delete invalid prediction 171 | 172 | #++++++++++++++++++++ for exact matching 173 | product = lines[i].strip() 174 | for j, smi in enumerate(smis_predict): 175 | if tokenize_smiles(smi) in tgts_dict[product]['reactant']: 176 | accuracies[i,j:] = 1.0 177 | break 178 | 179 | with open(result_stat_file, "w") as of: 180 | line = f"Total: {total}, top 1 invalid: {invalid / total * 100: .2f} %" 181 | logging.info(line) 182 | of.write(f"{line}\n") 183 | 184 | mean_accuracies = np.mean(accuracies, axis=0) 185 | for n in range(args.n_best): 186 | line = f"Top {n+1} accuracy: {mean_accuracies[n] * 100: .2f} %" 187 | logging.info(line) 188 | of.write(f"{line}\n") 189 | 190 | torch.cuda.empty_cache() 191 | 192 | 193 | if __name__ == "__main__": 194 | predict_parser = get_predict_parser() 195 | args = predict_parser.parse_args() 196 | 197 | # set random seed (just in case) 198 | set_seed(args.seed) 199 | 200 | # logger setup 201 | logger = setup_logger(args, warning_off=True) 202 | 203 | torch.set_printoptions(profile="full") 204 | main(args) 205 | -------------------------------------------------------------------------------- /utils/parsing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | 5 | def log_args(args): 6 | logging.info(f"Logging arguments") 7 | for k, v in vars(args).items(): 8 | logging.info(f"**** {k} = *{v}*") 9 | 10 | 11 | def add_common_args(parser): 12 | group = parser.add_argument_group("Meta") 13 | group.add_argument("--model", help="Model architecture", 14 | choices=["s2s", "g2s", "g2s_series", "g2s_series_rel"], type=str, default="") 15 | group.add_argument("--data_name", help="Data name", type=str, default="") 16 | group.add_argument("--task", help="Task", choices=["reaction_prediction", "retrosynthesis", "autoencoding"], 17 | type=str, default="") 18 | group.add_argument("--representation_end", help="Final string representation to be fed", 19 | choices=["smiles", "selfies"], type=str, default="") 20 | group.add_argument("--seed", help="Random seed", type=int, default=42) 21 | group.add_argument("--max_src_len", help="Max source length", type=int, default=512) 22 | group.add_argument("--max_tgt_len", help="Max target length", type=int, default=512) 23 | group.add_argument("--num_workers", help="No. of workers", type=int, default=1) 24 | group.add_argument("--verbose", help="Whether to enable verbose debugging", action="store_true") 25 | 26 | group = parser.add_argument_group("Paths") 27 | group.add_argument("--log_file", help="Preprocess log file", type=str, default="") 28 | group.add_argument("--vocab_file", help="Vocab file", type=str, default="") 29 | group.add_argument("--preprocess_output_path", help="Path for saving preprocessed outputs", 30 | type=str, default="") 31 | group.add_argument("--save_dir", help="Path for saving checkpoints", type=str, default="") 32 | 33 | 34 | def add_preprocess_args(parser): 35 | group = parser.add_argument_group("Preprocessing options") 36 | # data paths 37 | group.add_argument("--train_src", help="Train source", type=str, default="") 38 | group.add_argument("--train_tgt", help="Train target", type=str, default="") 39 | group.add_argument("--val_src", help="Validation source", type=str, default="") 40 | group.add_argument("--val_tgt", help="Validation target", type=str, default="") 41 | group.add_argument("--test_src", help="Test source", type=str, default="") 42 | group.add_argument("--test_tgt", help="Test target", type=str, default="") 43 | # options 44 | group.add_argument("--representation_start", help="Initial string representation to be fed", 45 | choices=["smiles"], type=str, default="") 46 | group.add_argument("--do_tokenize", help="Whether to tokenize the data files", action="store_true") 47 | group.add_argument("--make_vocab_only", help="Whether to only make vocab", action="store_true") 48 | 49 | 50 | def add_train_args(parser): 51 | group = parser.add_argument_group("Training options") 52 | # file paths 53 | group.add_argument("--train_bin", help="Train npz", type=str, default="") 54 | group.add_argument("--valid_bin", help="Valid npz", type=str, default="") 55 | group.add_argument("--load_from", help="Checkpoint to load", type=str, default="") 56 | # model params 57 | group.add_argument("--embed_size", help="Decoder embedding size", type=int, default=256) 58 | group.add_argument("--share_embeddings", help="Whether to share encoder/decoder embeddings", action="store_true") 59 | # -------------- mpn encoder --------------- 60 | group.add_argument("--mpn_type", help="Type of MPN", type=str, 61 | choices=["dgcn", "dgat", "dgate", "dgates", "ffn"], default="") 62 | group.add_argument("--encoder_num_layers", help="No. of layers in transformer/mpn encoder", type=int, default=4) 63 | group.add_argument("--encoder_hidden_size", help="Encoder hidden size", type=int, default=256) 64 | group.add_argument("--encoder_attn_heads", help="Encoder no. of attention heads", type=int, default=8) 65 | group.add_argument("--encoder_filter_size", help="Encoder filter size", type=int, default=2048) 66 | group.add_argument("--encoder_norm", help="Encoder norm", type=str, default="none") 67 | group.add_argument("--encoder_skip_connection", help="Encoder skip connection", type=str, default="none") 68 | group.add_argument("--encoder_positional_encoding", help="Encoder positional encoding", type=str, default="") 69 | group.add_argument("--encoder_emb_scale", help="How to scale encoder embedding", type=str, default="") 70 | # -------------- attention encoder --------------- 71 | group.add_argument("--compute_graph_distance", help="Whether to compute graph distance", action="store_true") 72 | group.add_argument("--attn_enc_num_layers", help="No. of layers", type=int, default=4) 73 | group.add_argument("--attn_enc_hidden_size", help="Hidden size", type=int, default=256) 74 | group.add_argument("--attn_enc_heads", help="Hidden size", type=int, default=8) 75 | group.add_argument("--attn_enc_filter_size", help="Filter size", type=int, default=2048) 76 | group.add_argument("--rel_pos", help="type of rel. pos.", type=str, default="none") 77 | group.add_argument("--rel_pos_buckets", help="No. of relative position buckets", type=int, default=10) 78 | #++++++++++++++++++++ Variational decoder 79 | group.add_argument("--latent_K", help="Discrete latent size", type=int, default=30) 80 | group.add_argument("--variational_num_layers", help="No. of layers in variational decoder", type=int, default=1) 81 | group.add_argument("--latent_size", help="Decoder latent size", type=int, default=60) 82 | group.add_argument("--varFFN_hidden_size", help="Decoder VarFeedForward hidden size", type=int, default=128) 83 | # -------------- Transformer decoder --------------- 84 | group.add_argument("--decoder_num_layers", help="No. of layers in transformer decoder", type=int, default=4) 85 | group.add_argument("--decoder_hidden_size", help="Decoder hidden size", type=int, default=256) 86 | group.add_argument("--decoder_attn_heads", help="Decoder no. of attention heads", type=int, default=8) 87 | group.add_argument("--decoder_filter_size", help="Decoder filter size", type=int, default=2048) 88 | group.add_argument("--dropout", help="Hidden dropout", type=float, default=0.0) 89 | group.add_argument("--attn_dropout", help="Attention dropout", type=float, default=0.0) 90 | group.add_argument("--max_relative_positions", help="Max relative positions", type=int, default=0) 91 | # training params 92 | group.add_argument("--enable_amp", help="Whether to enable mixed precision training", action="store_true") 93 | group.add_argument("--epoch", help="Number of training epochs", type=int, default=300) 94 | group.add_argument("--max_steps", help="Number of max total steps", type=int, default=1000000) 95 | group.add_argument("--warmup_steps", help="Number of warmup steps", type=int, default=8000) 96 | group.add_argument("--lr", help="Learning rate", type=float, default=0.0) 97 | group.add_argument("--beta1", help="Adam beta 1", type=float, default=0.9) 98 | group.add_argument("--beta2", help="Adam beta 2", type=float, default=0.998) 99 | group.add_argument("--eps", help="Adam epsilon", type=float, default=1e-9) 100 | group.add_argument("--weight_decay", help="Adam weight decay", type=float, default=1e-2) 101 | group.add_argument("--clip_norm", help="Max norm for gradient clipping", type=float, default=20.0) 102 | group.add_argument("--batch_type", help="batch type", type=str, default="tokens") 103 | group.add_argument("--train_batch_size", help="Batch size for train", type=int, default=4096) 104 | group.add_argument("--valid_batch_size", help="Batch size for valid", type=int, default=4096) 105 | group.add_argument("--accumulation_count", help="No. of batches for gradient accumulation", type=int, default=1) 106 | group.add_argument("--log_iter", help="No. of steps per logging", type=int, default=100) 107 | group.add_argument("--eval_iter", help="No. of steps per evaluation", type=int, default=100) 108 | group.add_argument("--save_iter", help="No. of steps per saving", type=int, default=100) 109 | group.add_argument("--keep_last_ckpt", help="No. of steps remaining", type=int, default=10) 110 | # debug params 111 | group.add_argument("--do_profile", help="Whether to do profiling", action="store_true") 112 | group.add_argument("--record_shapes", help="Whether to record tensor shapes for profiling", action="store_true") 113 | 114 | return parser 115 | 116 | 117 | def add_predict_args(parser): 118 | group = parser.add_argument_group("Prediction options") 119 | group.add_argument("--do_predict", help="Whether to do prediction", action="store_true") 120 | group.add_argument("--do_score", help="Whether to score predictions", action="store_true") 121 | group.add_argument("--checkpoint_step_start", help="First checkpoint step", type=int) 122 | group.add_argument("--checkpoint_step_end", help="Last checkpoint step", type=int) 123 | group.add_argument("--predict_batch_size", help="Batch size for prediction", type=int, default=4096) 124 | # decoding params 125 | group.add_argument("--test_bin", help="Test npz", type=str, default="") 126 | group.add_argument("--result_file", help="Result file", type=str, default="") 127 | group.add_argument("--beam_size", help="Beam size for decoding", type=int, default=0) 128 | group.add_argument("--n_best", help="Number of best results to be retained", type=int, default=10) 129 | group.add_argument("--temperature", help="Beam search temperature", type=float, default=1.0) 130 | group.add_argument("--predict_min_len", help="Min length for prediction", type=int, default=1) 131 | group.add_argument("--predict_max_len", help="Max length for prediction", type=int, default=512) 132 | -------------------------------------------------------------------------------- /models/seq2seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.data_utils import S2SBatch 5 | from onmt.encoders.transformer import TransformerEncoder 6 | from onmt.decoders import TransformerDecoder 7 | from onmt.modules.embeddings import Embeddings 8 | from onmt.translate import BeamSearch, GNMTGlobalScorer, GreedySearch 9 | from typing import Dict 10 | 11 | 12 | class Seq2Seq(nn.Module): 13 | def __init__(self, args, vocab: Dict[str, int]): 14 | super().__init__() 15 | self.args = args 16 | self.vocab = vocab 17 | self.vocab_size = len(self.vocab) 18 | 19 | while args.enable_amp and not self.vocab_size % 8 == 0: 20 | self.vocab_size += 1 21 | 22 | self.encoder_embeddings = Embeddings( 23 | word_vec_size=args.embed_size, 24 | word_vocab_size=self.vocab_size, 25 | word_padding_idx=self.vocab["_PAD"], 26 | position_encoding=True, 27 | dropout=args.dropout 28 | ) 29 | 30 | self.decoder_embeddings = Embeddings( 31 | word_vec_size=args.embed_size, 32 | word_vocab_size=self.vocab_size, 33 | word_padding_idx=self.vocab["_PAD"], 34 | position_encoding=True, 35 | dropout=args.dropout 36 | ) 37 | 38 | if args.share_embeddings: 39 | self.decoder_embeddings.word_lut.weight = self.encoder_embeddings.word_lut.weight 40 | 41 | self.encoder = TransformerEncoder( 42 | num_layers=args.decoder_num_layers, 43 | d_model=args.decoder_hidden_size, 44 | heads=args.decoder_attn_heads, 45 | d_ff=args.decoder_filter_size, 46 | dropout=args.dropout, 47 | attention_dropout=args.attn_dropout, 48 | embeddings=self.encoder_embeddings, 49 | max_relative_positions=0 50 | ) 51 | 52 | self.decoder = TransformerDecoder( 53 | num_layers=args.decoder_num_layers, 54 | d_model=args.decoder_hidden_size, 55 | heads=args.decoder_attn_heads, 56 | d_ff=args.decoder_filter_size, 57 | copy_attn=False, 58 | self_attn_type="scaled-dot", 59 | dropout=args.dropout, 60 | attention_dropout=args.attn_dropout, 61 | embeddings=self.decoder_embeddings, 62 | max_relative_positions=0, 63 | aan_useffn=False, 64 | full_context_alignment=False, 65 | alignment_layer=-3, 66 | alignment_heads=0 67 | ) 68 | 69 | self.output_layer = nn.Linear(args.decoder_hidden_size, self.vocab_size, bias=True) 70 | 71 | self.criterion = nn.CrossEntropyLoss( 72 | ignore_index=self.vocab["_PAD"], 73 | reduction="mean" 74 | ) 75 | 76 | def encode_and_reshape(self, reaction_batch: S2SBatch): 77 | src = reaction_batch.src_token_ids 78 | lengths = reaction_batch.src_lengths 79 | 80 | src = src.transpose(0, 1).contiguous().unsqueeze(-1) # [b, src_t] => [src_t, b, 1] 81 | emb, out, length = self.encoder(src, lengths) 82 | self.decoder.init_state( 83 | src=src, 84 | memory_bank=out, 85 | enc_hidden=emb 86 | ) 87 | 88 | return out, length 89 | 90 | def forward(self, reaction_batch: S2SBatch): 91 | padded_memory_bank, memory_lengths = self.encode_and_reshape(reaction_batch) 92 | 93 | dec_in = reaction_batch.tgt_token_ids[:, :-1] # pop last and insert SOS for decoder input 94 | m = nn.ConstantPad1d((1, 0), self.vocab["_SOS"]) 95 | dec_in = m(dec_in) 96 | dec_in = dec_in.transpose(0, 1).unsqueeze(-1) # [b, tgt_t] => [tgt_t, b, 1] 97 | 98 | dec_outs, _ = self.decoder( 99 | tgt=dec_in, 100 | memory_bank=padded_memory_bank, 101 | memory_lengths=memory_lengths 102 | ) 103 | 104 | dec_outs = self.output_layer(dec_outs) # [t, b, h] => [t, b, v] 105 | dec_outs = dec_outs.permute(1, 2, 0) # [t, b, v] => [b, v, t] 106 | 107 | loss = self.criterion( 108 | input=dec_outs, 109 | target=reaction_batch.tgt_token_ids 110 | ) 111 | 112 | predictions = torch.argmax(dec_outs, dim=1) # [b, t] 113 | mask = (reaction_batch.tgt_token_ids != self.vocab["_PAD"]).long() 114 | accs = (predictions == reaction_batch.tgt_token_ids).float() 115 | accs = accs * mask 116 | # acc = accs.mean() omg this is so stupid 117 | acc = accs.sum() / mask.sum() 118 | 119 | return loss, acc 120 | 121 | def predict_step(self, reaction_batch: S2SBatch, 122 | batch_size: int, beam_size: int, min_length: int, max_length: int): 123 | if beam_size == 1: 124 | decode_strategy = GreedySearch( 125 | pad=self.vocab["_PAD"], 126 | bos=self.vocab["_SOS"], 127 | eos=self.vocab["_EOS"], 128 | batch_size=batch_size, 129 | min_length=min_length, 130 | max_length=max_length, 131 | block_ngram_repeat=0, 132 | exclusion_tokens=set(), 133 | return_attention=False, 134 | sampling_temp=0.0, 135 | keep_topk=1 136 | ) 137 | else: 138 | global_scorer = GNMTGlobalScorer(alpha=0.0, 139 | beta=0.0, 140 | length_penalty="none", 141 | coverage_penalty="none") 142 | decode_strategy = BeamSearch( 143 | beam_size=beam_size, 144 | batch_size=batch_size, 145 | pad=self.vocab["_PAD"], 146 | bos=self.vocab["_SOS"], 147 | eos=self.vocab["_EOS"], 148 | n_best=1, # TODO: this is hard-coded to return top 1 right now 149 | global_scorer=global_scorer, 150 | min_length=min_length, 151 | max_length=max_length, 152 | return_attention=False, 153 | block_ngram_repeat=0, 154 | exclusion_tokens=set(), 155 | stepwise_penalty=None, 156 | ratio=0.0 157 | ) 158 | 159 | padded_memory_bank, memory_lengths = self.encode_and_reshape(reaction_batch) 160 | 161 | # adapted from onmt.translate.translator 162 | results = { 163 | "predictions": None, 164 | "scores": None, 165 | "attention": None 166 | } 167 | 168 | # (2) prep decode_strategy. Possibly repeat src objects. 169 | src_map = None 170 | target_prefix = None 171 | fn_map_state, memory_bank, memory_lengths, src_map = decode_strategy.initialize( 172 | memory_bank=padded_memory_bank, 173 | src_lengths=memory_lengths, 174 | src_map=src_map, 175 | target_prefix=target_prefix 176 | ) 177 | 178 | # (3) Begin decoding step by step: 179 | for step in range(decode_strategy.max_length): 180 | decoder_input = decode_strategy.current_predictions.view(1, -1, 1) 181 | 182 | dec_out, dec_attn = self.decoder( 183 | tgt=decoder_input, 184 | memory_bank=memory_bank, 185 | memory_lengths=memory_lengths, 186 | step=step 187 | ) 188 | 189 | if "std" in dec_attn: 190 | attn = dec_attn["std"] 191 | else: 192 | attn = None 193 | 194 | dec_out = self.output_layer(dec_out) # [t, b, h] => [t, b, v] 195 | dec_out = dec_out.squeeze(0) # [t, b, v] => [b, v] 196 | log_probs = F.log_softmax(dec_out, dim=-1) 197 | 198 | # log_probs = self.model.generator(dec_out.squeeze(0)) 199 | 200 | decode_strategy.advance(log_probs, attn) 201 | any_finished = decode_strategy.is_finished.any() 202 | if any_finished: 203 | decode_strategy.update_finished() 204 | if decode_strategy.done: 205 | break 206 | 207 | select_indices = decode_strategy.select_indices 208 | 209 | if any_finished: 210 | # Reorder states. 211 | if isinstance(memory_bank, tuple): 212 | memory_bank = tuple(x.index_select(1, select_indices) 213 | for x in memory_bank) 214 | else: 215 | memory_bank = memory_bank.index_select(1, select_indices) 216 | 217 | memory_lengths = memory_lengths.index_select(0, select_indices) 218 | 219 | if src_map is not None: 220 | src_map = src_map.index_select(1, select_indices) 221 | 222 | if any_finished: 223 | self.map_state( 224 | lambda state, dim: state.index_select(dim, select_indices)) 225 | 226 | results["scores"] = decode_strategy.scores 227 | results["predictions"] = decode_strategy.predictions 228 | results["attention"] = decode_strategy.attention 229 | results["alignment"] = [[] for _ in range(self.args.predict_batch_size)] 230 | 231 | return results 232 | 233 | # adapted from onmt.decoders.transformer 234 | def map_state(self, fn): 235 | def _recursive_map(struct, batch_dim=0): 236 | for k, v in struct.items(): 237 | if v is not None: 238 | if isinstance(v, dict): 239 | _recursive_map(v) 240 | else: 241 | struct[k] = fn(v, batch_dim) 242 | 243 | # self.decoder.state["src"] = fn(self.decoder.state["src"], 1) 244 | # => self.state["src"] = self.state["src"].index_select(1, select_indices) 245 | 246 | if self.decoder.state["cache"] is not None: 247 | _recursive_map(self.decoder.state["cache"]) 248 | -------------------------------------------------------------------------------- /models/dgat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.model_utils import index_select_ND 4 | from typing import Tuple 5 | 6 | 7 | class DGATGRU(nn.Module): 8 | """GRU Message Passing layer.""" 9 | def __init__(self, args, input_size: int, h_size: int, depth: int): 10 | super().__init__() 11 | self.args = args 12 | 13 | self.input_size = input_size 14 | self.h_size = h_size 15 | self.depth = depth 16 | 17 | self._build_layer_components() 18 | self._build_attention() 19 | 20 | def _build_layer_components(self) -> None: 21 | """Build layer components.""" 22 | self.W_z = nn.Linear(self.input_size + self.h_size, self.h_size) 23 | self.W_r = nn.Linear(self.input_size, self.h_size, bias=False) 24 | self.U_r = nn.Linear(self.h_size, self.h_size) 25 | self.W_h = nn.Linear(self.input_size + self.h_size, self.h_size) 26 | 27 | def _build_attention(self) -> None: 28 | self.leaky_relu = nn.LeakyReLU() 29 | self.head_count = self.args.encoder_attn_heads 30 | self.dim_per_head = self.h_size // self.head_count 31 | 32 | self.attn_alpha = nn.Parameter( 33 | torch.Tensor(1, 1, self.head_count, 2 * self.dim_per_head), requires_grad=True) 34 | self.attn_bias = nn.Parameter( 35 | torch.Tensor(1, 1, self.head_count), requires_grad=True) 36 | 37 | self.attn_W_q = nn.Linear(self.input_size, self.h_size, bias=True) 38 | self.attn_W_k = nn.Linear(self.h_size, self.h_size, bias=True) 39 | self.attn_W_v = nn.Linear(self.h_size, self.h_size, bias=True) 40 | 41 | self.softmax = nn.Softmax(dim=1) 42 | self.dropout = nn.Dropout(self.args.dropout) 43 | self.attn_dropout = nn.Dropout(self.args.attn_dropout) 44 | 45 | def GRU(self, x: torch.Tensor, h_nei: torch.Tensor) -> torch.Tensor: 46 | """Implements the GRU gating equations. 47 | 48 | Parameters 49 | ---------- 50 | x: torch.Tensor, input tensor 51 | h_nei: torch.Tensor, hidden states of the neighbors 52 | """ 53 | # attention-based aggregation 54 | n_node, max_nn, h_size = h_nei.size() 55 | head_count = self.head_count 56 | dim_per_head = self.dim_per_head 57 | 58 | q = self.attn_W_q(x) # (n_node, input) -> (n_node, h) 59 | q = q.unsqueeze(1).repeat(1, max_nn, 1) # -> (n_node, max_nn, h) 60 | q = q.reshape( 61 | n_node, max_nn, head_count, dim_per_head) # -> (n_node, max_nn, head, h/head) 62 | 63 | k = self.attn_W_k(h_nei) # (n_node, max_nn, h) 64 | k = k.reshape( 65 | n_node, max_nn, head_count, dim_per_head) # -> (n_node, max_nn, head, h/head) 66 | 67 | v = self.attn_W_v(h_nei) # (n_node, max_nn, h) 68 | v = v.reshape( 69 | n_node, max_nn, head_count, dim_per_head) # -> (n_node, max_nn, head, h/head) 70 | 71 | qk = torch.cat([q, k], dim=-1) # -> (n_node, max_nn, head, 2*h/head) 72 | qk = self.leaky_relu(qk) 73 | 74 | attn_score = qk * self.attn_alpha # (n_node, max_nn, head, 2*h/head) 75 | attn_score = torch.sum(attn_score, dim=-1) # (n_node, max_nn, head, 2*h/head) -> (n_node, max_nn, head) 76 | attn_score = attn_score + self.attn_bias # (n_node, max_nn, head) 77 | 78 | attn_mask = (h_nei.sum(dim=2) == 0 79 | ).unsqueeze(2) # (n_node, max_nn, h) -> (n_node, max_nn, 1) 80 | attn_score = attn_score.masked_fill(attn_mask, -1e18) 81 | 82 | attn_weight = self.softmax(attn_score) # (n_node, max_nn, head), softmax over dim=1 83 | attn_weight = attn_weight.unsqueeze(3) # -> (n_node, max_nn, head, 1) 84 | 85 | attn_context = attn_weight * v # -> (n_node, max_nn, head, h/head) 86 | attn_context = attn_context.reshape( 87 | n_node, max_nn, h_size) # -> (n_node, max_nn, h) 88 | 89 | sum_h = attn_context.sum(dim=1) # -> (n_node, h) 90 | 91 | # GRU 92 | z_input = torch.cat([x, sum_h], dim=1) # x = [x_u; x_uv] 93 | z = torch.sigmoid(self.W_z(z_input)) # (10) 94 | 95 | r_1 = self.W_r(x) # (n_node, h) -> (n_node, h) 96 | r_2 = self.U_r(sum_h) # (n_node, h) -> (n_node, h) 97 | r = torch.sigmoid(r_1 + r_2) # (11) r_ku = f_r(x; m_ku) = W_r(x) + U_r(m_ku) 98 | 99 | sum_gated_h = r * sum_h # (n_node, h) 100 | h_input = torch.cat([x, sum_gated_h], dim=1) 101 | pre_h = torch.tanh(self.W_h(h_input)) # (13) 102 | new_h = (1.0 - z) * sum_h + z * pre_h # (14) 103 | 104 | return new_h 105 | 106 | def forward(self, fmess: torch.Tensor, bgraph: torch.Tensor) -> torch.Tensor: 107 | """Forward pass of the RNN 108 | 109 | Parameters 110 | ---------- 111 | fmess: torch.Tensor, contains the initial features passed as messages 112 | bgraph: torch.Tensor, bond graph tensor. Contains who passes messages to whom. 113 | """ 114 | h = torch.zeros(fmess.size()[0], self.h_size, device=fmess.device) 115 | mask = torch.ones(h.size()[0], 1, device=h.device) 116 | mask[0, 0] = 0 # first message is padding 117 | 118 | for i in range(self.depth): 119 | h_nei = index_select_ND(h, 0, bgraph) 120 | h = self.GRU(fmess, h_nei) 121 | h = h * mask 122 | return h 123 | 124 | 125 | class DGATEncoder(nn.Module): 126 | """MessagePassing Network based encoder. Messages are updated using an RNN 127 | and the final message is used to update atom embeddings.""" 128 | def __init__(self, args, input_size: int, node_fdim: int): 129 | super().__init__() 130 | self.args = args 131 | 132 | self.h_size = args.encoder_hidden_size 133 | self.depth = args.encoder_num_layers 134 | self.input_size = input_size 135 | self.node_fdim = node_fdim 136 | self.head_count = args.encoder_attn_heads 137 | self.dim_per_head = self.h_size // self.head_count 138 | 139 | self.leaky_relu = nn.LeakyReLU() 140 | 141 | self._build_layers() 142 | self._build_attention() 143 | 144 | def _build_layers(self) -> None: 145 | """Build layers associated with the MPNEncoder.""" 146 | self.W_o = nn.Sequential(nn.Linear(self.node_fdim + self.h_size, self.h_size), nn.GELU()) 147 | self.rnn = DGATGRU(self.args, self.input_size, self.h_size, self.depth) 148 | 149 | def _build_attention(self) -> None: 150 | self.attn_alpha = nn.Parameter( 151 | torch.Tensor(1, 1, self.head_count, 2 * self.dim_per_head), requires_grad=True) 152 | self.attn_bias = nn.Parameter( 153 | torch.Tensor(1, 1, self.head_count), requires_grad=True) 154 | 155 | self.attn_W_q = nn.Linear(self.node_fdim, self.h_size, bias=True) 156 | self.attn_W_k = nn.Linear(self.h_size, self.h_size, bias=True) 157 | self.attn_W_v = nn.Linear(self.h_size, self.h_size, bias=True) 158 | 159 | self.softmax = nn.Softmax(dim=1) 160 | self.dropout = nn.Dropout(self.args.dropout) 161 | self.attn_dropout = nn.Dropout(self.args.attn_dropout) 162 | 163 | def forward(self, fnode: torch.Tensor, fmess: torch.Tensor, 164 | agraph: torch.Tensor, bgraph: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, ...]: 165 | """Forward pass of the MPNEncoder. 166 | 167 | Parameters 168 | ---------- 169 | fnode: torch.Tensor, node feature tensor 170 | fmess: torch.Tensor, message features 171 | agraph: torch.Tensor, neighborhood of an atom 172 | bgraph: torch.Tensor, neighborhood of a bond, 173 | except the directed bond from the destination node to the source node 174 | mask: torch.Tensor, masks on nodes 175 | """ 176 | h = self.rnn(fmess, bgraph) 177 | nei_message = index_select_ND(h, 0, agraph) 178 | 179 | # attention-based aggregation 180 | n_node, max_nn, h_size = nei_message.size() 181 | head_count = self.head_count 182 | dim_per_head = self.dim_per_head 183 | 184 | q = self.attn_W_q(fnode) # (n_node, h) 185 | q = q.unsqueeze(1).repeat(1, max_nn, 1) # -> (n_node, max_nn, h) 186 | q = q.reshape( 187 | n_node, max_nn, head_count, dim_per_head) # (n_node, max_nn, h) -> (n_node, max_nn, head, h/head) 188 | 189 | k = self.attn_W_k(nei_message) # (n_node, max_nn, h) 190 | k = k.reshape( 191 | n_node, max_nn, head_count, dim_per_head) # -> (n_node, max_nn, head, h/head) 192 | 193 | v = self.attn_W_v(nei_message) # (n_node, max_nn, h) 194 | v = v.reshape( 195 | n_node, max_nn, head_count, dim_per_head) # -> (n_node, max_nn, head, h/head) 196 | 197 | qk = torch.cat([q, k], dim=-1) # -> (n_node, max_nn, head, 2*h/head) 198 | qk = self.leaky_relu(qk) 199 | 200 | attn_score = qk * self.attn_alpha # (n_node, max_nn, head, 2*h/head) 201 | attn_score = torch.sum(attn_score, dim=-1) # (n_node, max_nn, head, 2*h/head) -> (n_node, max_nn, head) 202 | attn_score = attn_score + self.attn_bias # (n_node, max_nn, head) 203 | 204 | attn_mask = (nei_message.sum(dim=2) == 0 205 | ).unsqueeze(2) # (n_node, max_nn, h) -> (n_node, max_nn, 1) 206 | attn_score = attn_score.masked_fill(attn_mask, -1e18) 207 | 208 | attn_weight = self.softmax(attn_score) # (n_node, max_nn, head), softmax over dim=1 209 | attn_weight = attn_weight.unsqueeze(3) # -> (n_node, max_nn, head, 1) 210 | 211 | attn_context = attn_weight * v # -> (n_node, max_nn, head, h/head) 212 | attn_context = attn_context.reshape( 213 | n_node, max_nn, h_size) # -> (n_node, max_nn, h) 214 | 215 | nei_message = attn_context.sum(dim=1) # -> (n_node, h) 216 | 217 | # readout 218 | node_hiddens = torch.cat([fnode, nei_message], dim=1) 219 | node_hiddens = self.W_o(node_hiddens) 220 | 221 | if mask is None: 222 | mask = torch.ones(node_hiddens.size(0), 1, device=fnode.device) 223 | mask[0, 0] = 0 # first node is padding 224 | 225 | return node_hiddens * mask, h 226 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import numpy as np 4 | import os 5 | import sys 6 | import torch 7 | from models.graph2seq_series_rel import Graph2SeqSeriesRel 8 | from models.seq2seq import Seq2Seq 9 | from torch.utils.data import DataLoader 10 | from utils import parsing 11 | from utils.data_utils import canonicalize_smiles, load_vocab, S2SDataset, G2SDataset, tokenize_smiles 12 | from utils.train_utils import log_tensor, param_count, set_seed, setup_logger 13 | 14 | import warnings 15 | warnings.filterwarnings('ignore') 16 | 17 | 18 | def get_predict_parser(): 19 | parser = argparse.ArgumentParser("predict") 20 | parsing.add_common_args(parser) 21 | parsing.add_preprocess_args(parser) 22 | parsing.add_train_args(parser) 23 | parsing.add_predict_args(parser) 24 | 25 | return parser 26 | 27 | 28 | def main(args): 29 | parsing.log_args(args) 30 | 31 | if args.do_predict and os.path.exists(args.result_file): 32 | logging.info(f"Result file found at {args.result_file}, skipping prediction") 33 | logging.info(f"Loaded pretrained state_dict from {args.load_from}") 34 | 35 | elif args.do_predict and not os.path.exists(args.result_file): 36 | 37 | # initialization ----------------- model 38 | assert os.path.exists(args.load_from), f"{args.load_from} does not exist!" 39 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 40 | 41 | state = torch.load(args.load_from) 42 | pretrain_args = state["args"] 43 | pretrain_state_dict = state["state_dict"] 44 | 45 | for attr in ["mpn_type", "rel_pos"]: 46 | try: 47 | getattr(pretrain_args, attr) 48 | except AttributeError: 49 | setattr(pretrain_args, attr, getattr(args, attr)) 50 | 51 | assert args.model == pretrain_args.model, f"Pretrained model is {pretrain_args.model}!" 52 | if args.model == "s2s": 53 | model_class = Seq2Seq 54 | dataset_class = S2SDataset 55 | elif args.model == "g2s_series_rel": 56 | model_class = Graph2SeqSeriesRel 57 | dataset_class = G2SDataset 58 | args.compute_graph_distance = True 59 | assert args.compute_graph_distance 60 | else: 61 | raise ValueError(f"Model {args.model} not supported!") 62 | 63 | # initialization ----------------- vocab 64 | vocab = load_vocab(pretrain_args.vocab_file) 65 | vocab_tokens = [k for k, v in sorted(vocab.items(), key=lambda tup: tup[1])] 66 | 67 | model = model_class(pretrain_args, vocab) 68 | model.load_state_dict(pretrain_state_dict, strict=False) 69 | logging.info(f"Loaded pretrained state_dict from {args.load_from}") 70 | 71 | model.to(device) 72 | model.eval() 73 | 74 | logging.info(model) 75 | logging.info(f"Number of parameters = {param_count(model)}") 76 | logging.info(f"Loaded pretrained state_dict from {args.load_from}") 77 | 78 | # initialization ----------------- data 79 | test_dataset = dataset_class(pretrain_args, file=args.test_bin) 80 | test_dataset.batch( 81 | batch_type=args.batch_type, 82 | batch_size=args.predict_batch_size 83 | ) 84 | test_loader = DataLoader( 85 | dataset=test_dataset, 86 | batch_size=1, 87 | shuffle=False, 88 | collate_fn=lambda _batch: _batch[0], 89 | num_workers=16, 90 | pin_memory=True 91 | ) 92 | 93 | all_predictions = [] 94 | supp_info = [] 95 | 96 | with torch.no_grad(): 97 | with open(args.test_tgt.replace('tgt','src'), 'r') as f: 98 | lines = f.readlines() 99 | import json 100 | ans_path = os.path.join("/".join(args.test_tgt.split('/')[:-1]), "raw_test.json") 101 | with open(ans_path, 'r') as jfile: 102 | ans = json.load(jfile) 103 | for test_idx, test_batch in enumerate(test_loader): 104 | if test_idx % args.log_iter == 0: 105 | logging.info(f"Doing inference on test step {test_idx}") 106 | sys.stdout.flush() 107 | 108 | test_batch.to(device) 109 | if hasattr(model.decoder, 'cvae'): 110 | results, sample_z = model.predict_step( ## sample_z.shape=torch.Size([51, 30]) [batch_size, beam_size] 111 | reaction_batch=test_batch, 112 | batch_size=test_batch.size, 113 | beam_size=args.beam_size, 114 | n_best=args.n_best, 115 | temperature=args.temperature, 116 | min_length=args.predict_min_len, 117 | max_length=args.predict_max_len 118 | ) 119 | else: 120 | results = model.predict_step( 121 | reaction_batch=test_batch, 122 | batch_size=test_batch.size, 123 | beam_size=args.beam_size, 124 | n_best=args.n_best, 125 | temperature=args.temperature, 126 | min_length=args.predict_min_len, 127 | max_length=args.predict_max_len 128 | ) 129 | 130 | for i, predictions in enumerate(results["predictions"]): 131 | idx = test_batch.data_indice[i] 132 | if idx.cpu() not in test_dataset.ptr: 133 | continue 134 | product = lines[idx].strip() 135 | smis = [] 136 | for prediction in predictions: 137 | predicted_idx = prediction.detach().cpu().numpy() 138 | predicted_tokens = [vocab_tokens[idx] for idx in predicted_idx[:-1]] 139 | smi = " ".join(predicted_tokens) 140 | if smi in ans[product]['reactant']: #['C O C ( = O ) c 1 c c ( Cl ) c c c 1 Br'] 141 | if 'cover' in ans[product].keys(): 142 | ans[product]['cover'] += 1 143 | else: 144 | ans[product]['cover'] = 1 145 | ans[product]['reactant'].remove(smi) 146 | smis.append("".join(predicted_tokens)) 147 | 148 | smis = ",".join(smis) 149 | all_predictions.append(f"{smis}\n") 150 | if hasattr(model.decoder, 'cvae'): 151 | supp_info.append(f"1-to-N: {ans[product]['reaction']}; Product id {idx}: {''.join(product.split())}; Discrete Z: {sample_z[i].tolist()}\n") 152 | else: 153 | supp_info.append(f"1-to-N: {ans[product]['reaction']}; Product id {idx}: {''.join(product.split())}\n") 154 | 155 | save_dir = os.path.join(*args.result_file.split('/')[:-1]) 156 | if not os.path.exists(save_dir): 157 | os.makedirs(save_dir) 158 | metric = {} 159 | for k in ans.keys(): 160 | if ans[k]['reaction'] not in metric.keys(): 161 | metric[ans[k]['reaction']] = {} 162 | metric[ans[k]['reaction']]['num'] = 0 163 | metric[ans[k]['reaction']]['covers'] = [] 164 | metric[ans[k]['reaction']]['num'] += 1 165 | metric[ans[k]['reaction']]['covers'].append(1. - len(ans[k]['reactant']) * 1. / ans[k]['reaction']) 166 | for k in metric.keys(): 167 | metric[k]['coverage'] = np.mean(metric[k]['covers']) 168 | import json 169 | json_str = json.dumps(metric) 170 | with open(args.result_file+'.json', 'w') as json_file: 171 | json_file.write(json_str) 172 | logging.info(json_str) 173 | with open(args.result_file, "w") as of: 174 | of.writelines(all_predictions) 175 | with open(args.result_file[:-4]+'_info.txt', "w") as of: 176 | of.writelines(supp_info) 177 | 178 | if args.do_score: 179 | invalid = 0 180 | 181 | total = len(test_dataset.ptr) - 1 182 | accuracies = np.zeros([total, args.n_best], dtype=np.float32) 183 | with open(args.test_tgt, "r") as f_tgt, open(args.result_file, "r") as f_predict: 184 | targets = f_tgt.readlines() 185 | for i, line_predict in enumerate(f_predict): 186 | line_predict = "".join(line_predict.split()) 187 | smis_predict = line_predict.split(",") 188 | smis_predict = [canonicalize_smiles(smi, trim=False) for smi in smis_predict] 189 | if not smis_predict[0]: 190 | invalid += 1 191 | smis_predict = [smi for smi in smis_predict if smi and not smi == "CC"] 192 | smis_predict = list(dict.fromkeys(smis_predict)) ## Deduplication 193 | 194 | tgt_pre = test_dataset.ptr[i].item() 195 | tgt_post = test_dataset.ptr[i+1].item() 196 | line_tgts = [s.strip() for s in targets[tgt_pre:tgt_post]] 197 | smi_tgts = [] 198 | for line_tgt in line_tgts: 199 | smi_tgt = "".join(line_tgt.split()) 200 | smi_tgt = canonicalize_smiles(smi_tgt, trim=False) 201 | if not smi_tgt or smi_tgt == "CC": 202 | continue 203 | smi_tgts.append(smi_tgt) 204 | 205 | for j, smi in enumerate(smis_predict): 206 | if smi in smi_tgts: 207 | accuracies[i, j:] = 1.0 208 | break 209 | 210 | 211 | logging.info(f"Total: {total}, " 212 | f"top 1 invalid: {invalid / total * 100: .2f} %") 213 | 214 | mean_accuracies = np.mean(accuracies, axis=0) 215 | for n in range(args.n_best): 216 | logging.info(f"Top {n+1} accuracy: {mean_accuracies[n] * 100: .2f} %") 217 | logging.info(f"Loaded pretrained state_dict from {args.load_from}") 218 | logging.info(f"Top 1,3,5,10 accuracy: {mean_accuracies[0] * 100: .2f},{mean_accuracies[2] * 100: .2f}," 219 | f"{mean_accuracies[4] * 100: .2f},{mean_accuracies[9] * 100: .2f}") 220 | 221 | 222 | if __name__ == "__main__": 223 | predict_parser = get_predict_parser() 224 | args = predict_parser.parse_args() 225 | 226 | # set random seed (just in case) 227 | set_seed(args.seed) 228 | 229 | # logger setup 230 | logger = setup_logger(args, warning_off=True) 231 | 232 | torch.set_printoptions(profile="full") 233 | main(args) 234 | -------------------------------------------------------------------------------- /models/attention_xl.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from onmt.modules.embeddings import PositionalEncoding 5 | from onmt.modules.position_ffn import PositionwiseFeedForward 6 | from onmt.utils.misc import sequence_mask 7 | 8 | 9 | def get_sin_encodings(rel_pos_buckets, model_dim) -> torch.Tensor: 10 | pe = torch.zeros(rel_pos_buckets + 1, model_dim) 11 | position = torch.arange(0, rel_pos_buckets).unsqueeze(1) 12 | div_term = torch.exp((torch.arange(0, model_dim, 2, dtype=torch.float) * 13 | -(math.log(10000.0) / model_dim))) 14 | pe[:-1, 0::2] = torch.sin(position.float() * div_term) # leaving last "position" as padding 15 | pe[:-1, 1::2] = torch.cos(position.float() * div_term) 16 | 17 | return pe 18 | 19 | 20 | class MultiHeadedRelAttention(nn.Module): 21 | def __init__(self, args, head_count, model_dim, dropout, rel_pos_buckets, u, v): 22 | super().__init__() 23 | self.args = args 24 | 25 | assert model_dim % head_count == 0 26 | self.dim_per_head = model_dim // head_count 27 | self.model_dim = model_dim 28 | self.head_count = head_count 29 | 30 | self.linear_keys = nn.Linear(model_dim, model_dim) 31 | self.linear_values = nn.Linear(model_dim, model_dim) 32 | self.linear_query = nn.Linear(model_dim, model_dim) 33 | 34 | self.softmax = nn.Softmax(dim=-1) 35 | self.dropout = nn.Dropout(dropout) 36 | self.final_linear = nn.Linear(model_dim, model_dim) 37 | 38 | self.rel_pos_buckets = rel_pos_buckets 39 | 40 | if args.rel_pos == "enc_only": 41 | self.relative_pe = nn.Embedding.from_pretrained( 42 | embeddings=get_sin_encodings(rel_pos_buckets, model_dim), 43 | freeze=True, 44 | padding_idx=rel_pos_buckets 45 | ) 46 | # self.W_kR = nn.Parameter( 47 | # torch.Tensor(self.head_count, self.dim_per_head, self.dim_per_head), requires_grad=True) 48 | # self.b_kR = nn.Parameter( 49 | # torch.Tensor(self.head_count, self.dim_per_head), requires_grad=True) 50 | 51 | elif args.rel_pos == "emb_only": 52 | self.relative_pe = nn.Embedding( 53 | rel_pos_buckets + 1, 54 | model_dim, 55 | padding_idx=rel_pos_buckets 56 | ) 57 | # self.W_kR = nn.Parameter( 58 | # torch.Tensor(self.head_count, self.dim_per_head, self.dim_per_head), requires_grad=True) 59 | # self.b_kR = nn.Parameter( 60 | # torch.Tensor(self.head_count, self.dim_per_head), requires_grad=True) 61 | 62 | else: 63 | self.relative_pe = None 64 | self.W_kR = None 65 | self.b_kR = None 66 | 67 | self.u = u 68 | self.v = v 69 | 70 | def forward(self, inputs, mask, distances): 71 | """ 72 | Compute the context vector and the attention vectors. 73 | 74 | Args: 75 | inputs (FloatTensor): set of `key_len` 76 | key vectors ``(batch, key_len, dim)`` 77 | mask: binary mask 1/0 indicating which keys have 78 | zero / non-zero attention ``(batch, query_len, key_len)`` 79 | distances: graph distance matrix (BUCKETED), ``(batch, key_len, key_len)`` 80 | Returns: 81 | (FloatTensor, FloatTensor): 82 | 83 | * output context vectors ``(batch, query_len, dim)`` 84 | * Attention vector in heads ``(batch, head, query_len, key_len)``. 85 | """ 86 | 87 | batch_size = inputs.size(0) 88 | dim_per_head = self.dim_per_head 89 | head_count = self.head_count 90 | 91 | def shape(x): 92 | """Projection.""" 93 | return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2) 94 | 95 | def unshape(x): 96 | """Compute context.""" 97 | return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head) 98 | 99 | # 1) Project key, value, and query. Seems that we don't need layer_cache here 100 | query = self.linear_query(inputs) 101 | key = self.linear_keys(inputs) 102 | value = self.linear_values(inputs) 103 | 104 | key = shape(key) # (b, t_k, h) -> (b, head, t_k, h/head) 105 | value = shape(value) 106 | query = shape(query) # (b, t_q, h) -> (b, head, t_q, h/head) 107 | 108 | key_len = key.size(2) 109 | query_len = query.size(2) 110 | 111 | # 2) Calculate and scale scores. 112 | query = query / math.sqrt(dim_per_head) 113 | 114 | if self.relative_pe is None: 115 | scores = torch.matmul( 116 | query, key.transpose(2, 3)) # (b, head, t_q, t_k) 117 | 118 | else: 119 | # a + c 120 | u = self.u.reshape(1, head_count, 1, dim_per_head) 121 | a_c = torch.matmul(query + u, key.transpose(2, 3)) 122 | 123 | rel_emb = self.relative_pe(distances) # (b, t_q, t_k) -> (b, t_q, t_k, h) 124 | rel_emb = rel_emb.reshape( # (b, t_q, t_k, h) -> (b, t_q, t_k, head, h/head) 125 | batch_size, query_len, key_len, head_count, dim_per_head) 126 | 127 | # W_kR = self.W_kR.reshape(1, 1, 1, head_count, dim_per_head, dim_per_head) 128 | # rel_emb = torch.matmul(rel_emb, W_kR) # (b, t_q, t_k, head, 1, h/head) 129 | # rel_emb = rel_emb.squeeze(-2) # (b, t_q, t_k, head, h/head) 130 | # 131 | # b_kR = self.b_kR.reshape(1, 1, 1, head_count, dim_per_head) 132 | # rel_emb = rel_emb + b_kR # (b, t_q, t_k, head, h/head) 133 | 134 | # b + d 135 | query = query.unsqueeze(-2) # (b, head, t_q, h/head) -> (b, head, t_q, 1, h/head) 136 | rel_emb_t = rel_emb.permute(0, 3, 1, 4, 2) # (b, t_q, t_k, head, h/head) -> (b, head, t_q, h/head, t_k) 137 | 138 | v = self.v.reshape(1, head_count, 1, 1, dim_per_head) 139 | b_d = torch.matmul(query + v, rel_emb_t 140 | ).squeeze(-2) # (b, head, t_q, 1, t_k) -> (b, head, t_q, t_k) 141 | 142 | scores = a_c + b_d 143 | 144 | scores = scores.float() 145 | 146 | mask = mask.unsqueeze(1) # (B, 1, 1, T_values) 147 | scores = scores.masked_fill(mask, -1e18) 148 | 149 | # 3) Apply attention dropout and compute context vectors. 150 | attn = self.softmax(scores) 151 | drop_attn = self.dropout(attn) 152 | 153 | context_original = torch.matmul(drop_attn, value) # -> (b, head, t_q, h/head) 154 | context = unshape(context_original) # -> (b, t_q, h) 155 | 156 | output = self.final_linear(context) 157 | attns = attn.view(batch_size, head_count, query_len, key_len) 158 | 159 | return output, attns 160 | 161 | 162 | class SALayerXL(nn.Module): 163 | """ 164 | A single layer of the self-attention encoder. 165 | 166 | Args: 167 | d_model (int): the dimension of keys/values/queries in 168 | MultiHeadedAttention, also the input size of 169 | the first-layer of the PositionwiseFeedForward. 170 | heads (int): the number of head for MultiHeadedAttention. 171 | d_ff (int): the second-layer of the PositionwiseFeedForward. 172 | dropout: dropout probability(0-1.0). 173 | """ 174 | 175 | def __init__(self, args, d_model, heads, d_ff, dropout, attention_dropout, rel_pos_buckets: int, u, v): 176 | super().__init__() 177 | 178 | self.self_attn = MultiHeadedRelAttention( 179 | args, 180 | heads, d_model, dropout=attention_dropout, 181 | rel_pos_buckets=rel_pos_buckets, 182 | u=u, 183 | v=v 184 | ) 185 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 186 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 187 | self.dropout = nn.Dropout(dropout) 188 | 189 | def forward(self, inputs, mask, distances): 190 | """ 191 | Args: 192 | inputs (FloatTensor): ``(batch_size, src_len, model_dim)`` 193 | mask (LongTensor): ``(batch_size, 1, src_len)`` 194 | distances (LongTensor): ``(batch_size, src_len, src_len)`` 195 | 196 | Returns: 197 | (FloatTensor): 198 | 199 | * outputs ``(batch_size, src_len, model_dim)`` 200 | """ 201 | input_norm = self.layer_norm(inputs) 202 | context, _ = self.self_attn(input_norm, mask=mask, distances=distances) 203 | out = self.dropout(context) + inputs 204 | 205 | return self.feed_forward(out) 206 | 207 | 208 | class AttnEncoderXL(nn.Module): 209 | def __init__(self, args): 210 | super().__init__() 211 | self.args = args 212 | 213 | self.num_layers = args.attn_enc_num_layers 214 | self.d_model = args.attn_enc_hidden_size 215 | self.heads = args.attn_enc_heads 216 | self.d_ff = args.attn_enc_filter_size 217 | self.attention_dropout = args.attn_dropout 218 | self.rel_pos_buckets = args.rel_pos_buckets 219 | 220 | self.encoder_pe = None 221 | if args.encoder_positional_encoding == "transformer": 222 | self.encoder_pe = PositionalEncoding( 223 | dropout=args.dropout, 224 | dim=self.d_model, 225 | max_len=1024 # temporary hard-code. Seems that onmt fix the denominator as 10000.0 226 | ) 227 | else: 228 | self.dropout = nn.Dropout(p=args.dropout) 229 | 230 | if args.rel_pos in ["enc_only", "emb_only"]: 231 | self.u = nn.Parameter(torch.randn(self.d_model), requires_grad=True) 232 | self.v = nn.Parameter(torch.randn(self.d_model), requires_grad=True) 233 | else: 234 | self.u = None 235 | self.v = None 236 | 237 | self.attention_layers = nn.ModuleList( 238 | [SALayerXL( 239 | args, self.d_model, self.heads, self.d_ff, args.dropout, self.attention_dropout, 240 | self.rel_pos_buckets, self.u, self.v) 241 | for i in range(self.num_layers)]) 242 | self.layer_norm = nn.LayerNorm(self.d_model, eps=1e-6) 243 | 244 | def forward(self, src, lengths, distances): 245 | """adapt from onmt TransformerEncoder 246 | src: (t, b, h) 247 | lengths: (b,) 248 | distances: (b, t, t) 249 | """ 250 | 251 | if self.encoder_pe is not None: 252 | emb = self.encoder_pe(src) 253 | out = emb.transpose(0, 1).contiguous() 254 | else: 255 | out = src.transpose(0, 1).contiguous() 256 | if self.args.encoder_emb_scale == "sqrt": 257 | out = out * math.sqrt(self.d_model) 258 | out = self.dropout(out) 259 | 260 | mask = ~sequence_mask(lengths).unsqueeze(1) 261 | 262 | for layer in self.attention_layers: 263 | out = layer(out, mask, distances) 264 | out = self.layer_norm(out) 265 | 266 | return out.transpose(0, 1).contiguous() 267 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import numpy as np 4 | import os 5 | import sys 6 | import time 7 | import torch 8 | from multiprocessing import Pool 9 | from typing import Dict, List, Tuple 10 | from utils import parsing 11 | from utils.data_utils import get_graph_features_from_smi, load_vocab, make_vocab, \ 12 | tokenize_selfies_from_smiles, tokenize_smiles 13 | from utils.train_utils import log_tensor, set_seed, setup_logger 14 | 15 | 16 | def get_preprocess_parser(): 17 | parser = argparse.ArgumentParser("preprocess") 18 | parsing.add_common_args(parser) 19 | parsing.add_preprocess_args(parser) 20 | 21 | return parser 22 | 23 | 24 | def tokenize(fns: Dict[str, List[Tuple[str, str]]], output_path: str, repr_start: str, repr_end: str): 25 | assert repr_start == "smiles", f"{repr_start} input provided. Only smiles inputs are supported!" 26 | 27 | if repr_end == "smiles": 28 | tokenize_line = tokenize_smiles 29 | elif repr_end == "selfies": 30 | tokenize_line = tokenize_selfies_from_smiles 31 | else: 32 | raise ValueError(f"{repr_end} output required. Only smiles and selfies outputs are supported!") 33 | 34 | ofns = {} 35 | 36 | for phase, file_list in fns.items(): 37 | ofns[phase] = [] 38 | 39 | for src_file, tgt_file in file_list: 40 | src_output = os.path.join(output_path, f"{repr_end}_tokenized_{os.path.basename(src_file)}") 41 | tgt_output = os.path.join(output_path, f"{repr_end}_tokenized_{os.path.basename(tgt_file)}") 42 | 43 | for fn, ofn in [(src_file, src_output), 44 | (tgt_file, tgt_output)]: 45 | if os.path.exists(ofn): 46 | logging.info(f"Found {ofn}, skipping tokenization.") 47 | continue 48 | 49 | with open(fn, "r") as f, open(ofn, "w") as of: 50 | logging.info(f"Tokenizing input {fn} into {ofn}") 51 | for i, line in enumerate(f): 52 | line = "".join(line.strip().split()) 53 | newline = tokenize_line(line) 54 | of.write(f"{newline}\n") 55 | logging.info(f"Done, total lines: {i + 1}") 56 | 57 | ofns[phase].append((src_output, tgt_output)) 58 | 59 | return ofns 60 | 61 | 62 | def get_token_ids(tokens: list, vocab: Dict[str, int], max_len: int) -> Tuple[List, int]: 63 | # token_ids = [vocab["_SOS"]] # shouldn't really need this 64 | token_ids = [] 65 | token_ids.extend([vocab[token] for token in tokens]) 66 | token_ids = token_ids[:max_len-1] 67 | token_ids.append(vocab["_EOS"]) 68 | 69 | lens = len(token_ids) 70 | while len(token_ids) < max_len: 71 | token_ids.append(vocab["_PAD"]) 72 | 73 | return token_ids, lens 74 | 75 | 76 | def get_seq_features_from_line(_args) -> Tuple[np.ndarray, int, np.ndarray, int]: 77 | i, src_line, tgt_line, max_src_len, max_tgt_len = _args 78 | assert isinstance(src_line, str) and isinstance(tgt_line, str) 79 | if i > 0 and i % 10000 == 0: 80 | logging.info(f"Processing {i}th SMILES") 81 | 82 | global G_vocab 83 | 84 | src_tokens = src_line.strip().split() 85 | if not src_tokens: 86 | src_tokens = ["C", "C"] # hardcode to ignore 87 | tgt_tokens = tgt_line.strip().split() 88 | src_token_ids, src_lens = get_token_ids(src_tokens, G_vocab, max_len=max_src_len) 89 | tgt_token_ids, tgt_lens = get_token_ids(tgt_tokens, G_vocab, max_len=max_tgt_len) 90 | 91 | src_token_ids = np.array(src_token_ids, dtype=np.int32) 92 | tgt_token_ids = np.array(tgt_token_ids, dtype=np.int32) 93 | 94 | return src_token_ids, src_lens, tgt_token_ids, tgt_lens 95 | 96 | 97 | def binarize_s2s(src_file: str, tgt_file: str, prefix: str, output_path: str, 98 | max_src_len: int, max_tgt_len: int, num_workers: int = 1): 99 | output_file = os.path.join(output_path, f"{prefix}.npz") 100 | logging.info(f"Binarizing (s2s) src {src_file} and tgt {tgt_file}, saving to {output_file}") 101 | 102 | with open(src_file, "r") as f: 103 | src_lines = f.readlines() 104 | 105 | with open(tgt_file, "r") as f: 106 | tgt_lines = f.readlines() 107 | 108 | logging.info("Getting seq features") 109 | start = time.time() 110 | 111 | p = Pool(num_workers) 112 | seq_features_and_lengths = p.imap( 113 | get_seq_features_from_line, 114 | ((i, src_line, tgt_line, max_src_len, max_tgt_len) 115 | for i, (src_line, tgt_line) in enumerate(zip(src_lines, tgt_lines))) 116 | ) 117 | 118 | p.close() 119 | p.join() 120 | 121 | seq_features_and_lengths = list(seq_features_and_lengths) 122 | 123 | logging.info(f"Done seq featurization, time: {time.time() - start}. Collating") 124 | src_token_ids, src_lens, tgt_token_ids, tgt_lens = zip(*seq_features_and_lengths) 125 | 126 | src_token_ids = np.stack(src_token_ids, axis=0) 127 | src_lens = np.array(src_lens, dtype=np.int32) 128 | tgt_token_ids = np.stack(tgt_token_ids, axis=0) 129 | tgt_lens = np.array(tgt_lens, dtype=np.int32) 130 | 131 | np.savez( 132 | output_file, 133 | src_token_ids=src_token_ids, 134 | src_lens=src_lens, 135 | tgt_token_ids=tgt_token_ids, 136 | tgt_lens=tgt_lens 137 | ) 138 | 139 | 140 | def binarize_g2s(src_file: str, tgt_file: str, prefix: str, output_path: str, 141 | max_src_len: int, max_tgt_len: int, num_workers: int = 1): 142 | output_file = os.path.join(output_path, f"{prefix}.npz") 143 | logging.info(f"Binarizing (g2s) src {src_file} and tgt {tgt_file}, saving to {output_file}") 144 | 145 | with open(src_file, "r") as f: 146 | # lines = f.readlines()[164104:164106] 147 | src_lines = f.readlines() 148 | 149 | with open(tgt_file, "r") as f: 150 | tgt_lines = f.readlines() 151 | 152 | logging.info("Getting seq features") 153 | start = time.time() 154 | 155 | p = Pool(num_workers) 156 | seq_features_and_lengths = p.imap( 157 | get_seq_features_from_line, 158 | ((i, src_line, tgt_line, max_src_len, max_tgt_len) 159 | for i, (src_line, tgt_line) in enumerate(zip(src_lines, tgt_lines))) 160 | ) 161 | 162 | p.close() 163 | p.join() 164 | 165 | seq_features_and_lengths = list(seq_features_and_lengths) 166 | 167 | logging.info(f"Done seq featurization, time: {time.time() - start}. Collating") 168 | src_token_ids, src_lens, tgt_token_ids, tgt_lens = zip(*seq_features_and_lengths) 169 | 170 | src_token_ids = np.stack(src_token_ids, axis=0) 171 | src_lens = np.array(src_lens, dtype=np.int32) 172 | tgt_token_ids = np.stack(tgt_token_ids, axis=0) 173 | tgt_lens = np.array(tgt_lens, dtype=np.int32) 174 | 175 | logging.info("Getting graph features") 176 | start = time.time() 177 | 178 | p = Pool(num_workers) 179 | graph_features_and_lengths = p.imap( 180 | get_graph_features_from_smi, 181 | ((i, "".join(line.split()), False) for i, line in enumerate(src_lines)) 182 | ) 183 | 184 | p.close() 185 | p.join() 186 | 187 | graph_features_and_lengths = list(graph_features_and_lengths) 188 | logging.info(f"Done graph featurization, time: {time.time() - start}. Collating and saving...") 189 | a_scopes, a_scopes_lens, b_scopes, b_scopes_lens, a_features, a_features_lens, \ 190 | b_features, b_features_lens, a_graphs, b_graphs = zip(*graph_features_and_lengths) 191 | 192 | a_scopes = np.concatenate(a_scopes, axis=0) 193 | b_scopes = np.concatenate(b_scopes, axis=0) 194 | a_features = np.concatenate(a_features, axis=0) 195 | b_features = np.concatenate(b_features, axis=0) 196 | a_graphs = np.concatenate(a_graphs, axis=0) 197 | b_graphs = np.concatenate(b_graphs, axis=0) 198 | 199 | a_scopes_lens = np.array(a_scopes_lens, dtype=np.int32) 200 | b_scopes_lens = np.array(b_scopes_lens, dtype=np.int32) 201 | a_features_lens = np.array(a_features_lens, dtype=np.int32) 202 | b_features_lens = np.array(b_features_lens, dtype=np.int32) 203 | 204 | np.savez( 205 | output_file, 206 | src_token_ids=src_token_ids, 207 | src_lens=src_lens, 208 | tgt_token_ids=tgt_token_ids, 209 | tgt_lens=tgt_lens, 210 | a_scopes=a_scopes, 211 | b_scopes=b_scopes, 212 | a_features=a_features, 213 | b_features=b_features, 214 | a_graphs=a_graphs, 215 | b_graphs=b_graphs, 216 | a_scopes_lens=a_scopes_lens, 217 | b_scopes_lens=b_scopes_lens, 218 | a_features_lens=a_features_lens, 219 | b_features_lens=b_features_lens 220 | ) 221 | 222 | 223 | def preprocess_main(args): 224 | parsing.log_args(args) 225 | 226 | os.makedirs(args.preprocess_output_path, exist_ok=True) 227 | 228 | fns = { 229 | "train": [(args.train_src, args.train_tgt)], 230 | "val": [(args.val_src, args.val_tgt)], 231 | "test": [(args.test_src, args.test_tgt)] 232 | } 233 | 234 | if not args.representation_start == args.representation_end: 235 | assert args.do_tokenize, f"Different representations, start: {args.representation_start}, " \ 236 | f"end: {args.representation_end}. Please set '--do_tokenize'" 237 | 238 | if args.do_tokenize: 239 | ofns = tokenize(fns=fns, 240 | output_path=args.preprocess_output_path, 241 | repr_start=args.representation_start, 242 | repr_end=args.representation_end) 243 | fns = ofns # just pass the handle of tokenized files 244 | 245 | vocab_file = os.path.join(args.preprocess_output_path, 246 | f"vocab_{args.representation_end}.txt") 247 | if not os.path.exists(vocab_file): 248 | make_vocab( 249 | fns=fns, 250 | vocab_file=vocab_file, 251 | tokenized=True 252 | ) 253 | 254 | if args.make_vocab_only: 255 | logging.info(f"--make_vocab_only flag detected. Skipping featurization") 256 | exit(0) 257 | 258 | global G_vocab 259 | G_vocab = load_vocab(vocab_file) 260 | 261 | if args.model == "s2s": 262 | binarize = binarize_s2s 263 | elif args.model.startswith("g2s"): 264 | binarize = binarize_g2s 265 | else: 266 | raise ValueError(f"Model {args.model} not supported!") 267 | 268 | for phase, file_list in fns.items(): 269 | for i, (src_file, tgt_file) in enumerate(file_list): 270 | binarize( 271 | src_file=src_file, 272 | tgt_file=tgt_file, 273 | prefix=f"{phase}_{i}", 274 | output_path=args.preprocess_output_path, 275 | max_src_len=args.max_src_len, 276 | max_tgt_len=args.max_tgt_len, 277 | num_workers=args.num_workers 278 | ) 279 | 280 | 281 | if __name__ == "__main__": 282 | preprocess_parser = get_preprocess_parser() 283 | args = preprocess_parser.parse_args() 284 | 285 | # set random seed 286 | set_seed(args.seed) 287 | 288 | # logger setup 289 | logger = setup_logger(args) 290 | 291 | np.set_printoptions(threshold=sys.maxsize) 292 | torch.set_printoptions(profile="full") 293 | 294 | G_vocab = {} # global vocab 295 | 296 | preprocess_main(args) 297 | -------------------------------------------------------------------------------- /models/graph2seq_series_rel.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from utils.chem_utils import ATOM_FDIM, BOND_FDIM 7 | from utils.data_utils import G2SBatch 8 | from utils.train_utils import log_tensor 9 | from models.attention_xl import AttnEncoderXL 10 | from models.graphfeat import GraphFeatEncoder 11 | from onmt.decoders import TransformerDecoder 12 | from onmt.modules.embeddings import Embeddings 13 | from onmt.translate import BeamSearch, GNMTGlobalScorer, GreedySearch 14 | from typing import Any, Dict 15 | 16 | from models.VAR.var_dec import VariationalDecoder 17 | from models.VAR.utils import tile 18 | 19 | 20 | class Graph2SeqSeriesRel(nn.Module): 21 | def __init__(self, args, vocab: Dict[str, int]): 22 | super().__init__() 23 | self.args = args 24 | self.vocab = vocab 25 | self.vocab_size = len(self.vocab) 26 | 27 | while args.enable_amp and not self.vocab_size % 8 == 0: 28 | self.vocab_size += 1 29 | 30 | self.encoder = GraphFeatEncoder( 31 | args, 32 | n_atom_feat=sum(ATOM_FDIM), 33 | n_bond_feat=BOND_FDIM 34 | ) 35 | 36 | if args.attn_enc_num_layers > 0: 37 | self.attention_encoder = AttnEncoderXL(args) 38 | else: 39 | self.attention_encoder = None 40 | 41 | self.decoder_embeddings = Embeddings( 42 | word_vec_size=args.embed_size, 43 | word_vocab_size=self.vocab_size, 44 | word_padding_idx=self.vocab["_PAD"], 45 | position_encoding=True, 46 | dropout=args.dropout 47 | ) 48 | 49 | args.vocab_size = self.vocab_size 50 | self.decoder = VariationalDecoder(args, 51 | # self.decoder = TransformerDecoder( 52 | num_layers=args.decoder_num_layers, 53 | d_model=args.decoder_hidden_size, 54 | heads=args.decoder_attn_heads, 55 | d_ff=args.decoder_filter_size, 56 | copy_attn=False, 57 | self_attn_type="scaled-dot", 58 | dropout=args.dropout, 59 | attention_dropout=args.attn_dropout, 60 | embeddings=self.decoder_embeddings, 61 | max_relative_positions=args.max_relative_positions, 62 | aan_useffn=False, 63 | full_context_alignment=False, 64 | alignment_layer=-3, 65 | alignment_heads=0 66 | ) 67 | 68 | if not args.attn_enc_hidden_size == args.decoder_hidden_size: 69 | self.bridge_layer = nn.Linear(args.attn_enc_hidden_size, args.decoder_hidden_size, bias=True) 70 | 71 | self.output_layer = nn.Linear(args.decoder_hidden_size, self.vocab_size, bias=True) 72 | 73 | self.criterion = nn.CrossEntropyLoss( 74 | ignore_index=self.vocab["_PAD"], 75 | reduction="mean" 76 | ) 77 | 78 | def encode_and_reshape(self, reaction_batch: G2SBatch): 79 | hatom, _ = self.encoder(reaction_batch) # (n_atoms, h) 80 | if not self.args.attn_enc_hidden_size == self.args.decoder_hidden_size: 81 | hatom = self.bridge_layer(hatom) # bridging 82 | 83 | # hatom reshaping into [t, b, h] 84 | atom_scope = reaction_batch.atom_scope # list of b (n_components, 2) 85 | 86 | memory_lengths = [scope[-1][0] + scope[-1][1] - scope[0][0] 87 | for scope in atom_scope] # (b, ) 88 | 89 | # the 1+ corresponds to Atom(*) 90 | assert 1 + sum(memory_lengths) == hatom.size(0), \ 91 | f"Memory lengths calculation error, encoder output: {hatom.size(0)}, memory_lengths: {memory_lengths}" 92 | 93 | memory_bank = torch.split(hatom, [1] + memory_lengths, dim=0) # [n_atoms, h] => 1+b tup of (t, h) 94 | padded_memory_bank = [] 95 | max_length = max(memory_lengths) 96 | 97 | for length, h in zip(memory_lengths, memory_bank[1:]): 98 | m = nn.ZeroPad2d((0, 0, 0, max_length - length)) 99 | padded_memory_bank.append(m(h)) 100 | 101 | padded_memory_bank = torch.stack(padded_memory_bank, dim=1) # list of b (max_t, h) => [max_t, b, h] 102 | product_emb = padded_memory_bank 103 | 104 | memory_lengths = torch.tensor(memory_lengths, 105 | dtype=torch.long, 106 | device=padded_memory_bank.device) 107 | 108 | if self.attention_encoder is not None: 109 | padded_memory_bank = self.attention_encoder( 110 | padded_memory_bank, 111 | memory_lengths, 112 | reaction_batch.distances 113 | ) 114 | 115 | self.decoder.state["src"] = np.zeros(max_length) # TODO: this is hardcoded to make transformer decoder work 116 | 117 | return padded_memory_bank, memory_lengths, product_emb 118 | 119 | def forward(self, reaction_batch: G2SBatch): 120 | padded_memory_bank, memory_lengths, product_emb = self.encode_and_reshape(reaction_batch) 121 | 122 | # adapted from onmt.models 123 | dec_in = reaction_batch.tgt_token_ids[:, :-1] # pop last, insert SOS for decoder input 124 | m = nn.ConstantPad1d((1, 0), self.vocab["_SOS"]) 125 | dec_in = m(dec_in) 126 | dec_in = dec_in.transpose(0, 1).unsqueeze(-1) # [b, max_tgt_t] => [max_tgt_t, b, 1] 127 | 128 | dec_outs, _, aux_loss, kld_loss = self.decoder(product_emb, 129 | tgt=dec_in, 130 | memory_bank=padded_memory_bank, 131 | memory_lengths=memory_lengths) 132 | 133 | dec_outs = self.output_layer(dec_outs) # [t, b, h] => [t, b, v] 134 | dec_outs = dec_outs.permute(1, 2, 0) # [t, b, v] => [b, v, t] 135 | 136 | loss = self.criterion( 137 | input=dec_outs, 138 | target=reaction_batch.tgt_token_ids 139 | ) 140 | 141 | predictions = torch.argmax(dec_outs, dim=1) # [b, t] 142 | mask = (reaction_batch.tgt_token_ids != self.vocab["_PAD"]).long() 143 | accs = (predictions == reaction_batch.tgt_token_ids).float() 144 | accs = accs * mask 145 | acc = accs.sum() / mask.sum() 146 | 147 | return (loss, aux_loss, kld_loss), acc 148 | 149 | def predict_step(self, reaction_batch: G2SBatch, 150 | batch_size: int, beam_size: int, n_best: int, temperature: float, 151 | min_length: int, max_length: int) -> Dict[str, Any]: 152 | if beam_size == 1: 153 | decode_strategy = GreedySearch( 154 | pad=self.vocab["_PAD"], 155 | bos=self.vocab["_SOS"], 156 | eos=self.vocab["_EOS"], 157 | batch_size=batch_size, 158 | min_length=min_length, 159 | max_length=max_length, 160 | block_ngram_repeat=0, 161 | exclusion_tokens=set(), 162 | return_attention=False, 163 | sampling_temp=0.0, 164 | keep_topk=1 165 | ) 166 | else: 167 | global_scorer = GNMTGlobalScorer(alpha=0.0, 168 | beta=0.0, 169 | length_penalty="none", 170 | coverage_penalty="none") 171 | decode_strategy = BeamSearch( 172 | beam_size=beam_size, 173 | batch_size=batch_size, 174 | pad=self.vocab["_PAD"], 175 | bos=self.vocab["_SOS"], 176 | eos=self.vocab["_EOS"], 177 | n_best=n_best, 178 | global_scorer=global_scorer, 179 | min_length=min_length, 180 | max_length=max_length, 181 | return_attention=False, 182 | block_ngram_repeat=0, 183 | exclusion_tokens=set(), 184 | stepwise_penalty=None, 185 | ratio=0.0 186 | ) 187 | 188 | padded_memory_bank, memory_lengths, product_emb = self.encode_and_reshape(reaction_batch=reaction_batch) 189 | # adapted from onmt.translate.translator 190 | results = { 191 | "predictions": None, 192 | "scores": None, 193 | "attention": None 194 | } 195 | 196 | # (2) prep decode_strategy. Possibly repeat src objects. 197 | src_map = None 198 | target_prefix = None 199 | fn_map_state, memory_bank, memory_lengths, src_map = decode_strategy.initialize( 200 | memory_bank=padded_memory_bank, 201 | src_lengths=memory_lengths, 202 | src_map=src_map, 203 | target_prefix=target_prefix 204 | ) 205 | product_emb = tile(product_emb, beam_size, dim=1) 206 | 207 | # (3) Begin decoding step by step: 208 | for step in range(decode_strategy.max_length): 209 | decoder_input = decode_strategy.current_predictions.view(1, -1, 1) 210 | dec_out, dec_attn = self.decoder(product_emb, 211 | tgt=decoder_input, 212 | memory_bank=memory_bank, 213 | memory_lengths=memory_lengths, 214 | step=step) 215 | 216 | if "std" in dec_attn: 217 | attn = dec_attn["std"] 218 | else: 219 | attn = None 220 | 221 | dec_out = self.output_layer(dec_out) # [t, b, h] => [t, b, v] 222 | dec_out = dec_out / temperature 223 | dec_out = dec_out.squeeze(0) # [t, b, v] => [b, v] 224 | log_probs = F.log_softmax(dec_out, dim=-1) 225 | 226 | decode_strategy.advance(log_probs, attn) 227 | any_finished = decode_strategy.is_finished.any() 228 | if any_finished: 229 | decode_strategy.update_finished() 230 | if decode_strategy.done: 231 | break 232 | 233 | select_indices = decode_strategy.select_indices 234 | 235 | if any_finished: 236 | # Reorder states. 237 | if isinstance(memory_bank, tuple): 238 | memory_bank = tuple(x.index_select(1, select_indices) 239 | for x in memory_bank) 240 | else: 241 | memory_bank = memory_bank.index_select(1, select_indices) 242 | product_emb = product_emb.index_select(1, select_indices) 243 | 244 | memory_lengths = memory_lengths.index_select(0, select_indices) 245 | 246 | if src_map is not None: 247 | src_map = src_map.index_select(1, select_indices) 248 | 249 | if any_finished: 250 | self.map_state( 251 | lambda state, dim: state.index_select(dim, select_indices)) 252 | 253 | results["scores"] = decode_strategy.scores 254 | results["predictions"] = decode_strategy.predictions 255 | results["attention"] = decode_strategy.attention 256 | results["alignment"] = [[] for _ in range(self.args.predict_batch_size)] 257 | 258 | if hasattr(self.decoder, 'y_sample'): 259 | return results, self.decoder.y_sample.argmax(-1).reshape(-1, beam_size) 260 | else: 261 | return results 262 | 263 | # adapted from onmt.decoders.transformer 264 | def map_state(self, fn): 265 | def _recursive_map(struct, batch_dim=0): 266 | for k, v in struct.items(): 267 | if v is not None: 268 | if isinstance(v, dict): 269 | _recursive_map(v) 270 | else: 271 | struct[k] = fn(v, batch_dim) 272 | 273 | # self.decoder.state["src"] = fn(self.decoder.state["src"], 1) 274 | # => self.state["src"] = self.state["src"].index_select(1, select_indices) 275 | 276 | if self.decoder.state["cache"] is not None: 277 | _recursive_map(self.decoder.state["cache"]) 278 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import numpy as np 4 | import os 5 | import sys 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from models.graph2seq_series_rel import Graph2SeqSeriesRel 11 | from models.seq2seq import Seq2Seq 12 | from torch.nn.init import xavier_uniform_ 13 | from torch.utils.data import DataLoader 14 | from utils import parsing 15 | from utils.data_utils import load_vocab, S2SDataset, G2SDataset 16 | from utils.train_utils import get_lr, grad_norm, NoamLR, param_count, param_norm, set_seed, setup_logger 17 | 18 | import math 19 | import shutil 20 | import warnings 21 | warnings.filterwarnings('ignore') 22 | 23 | 24 | def get_train_parser(): 25 | parser = argparse.ArgumentParser("train") 26 | parsing.add_common_args(parser) 27 | parsing.add_train_args(parser) 28 | parsing.add_predict_args(parser) 29 | 30 | return parser 31 | 32 | 33 | def main(args): 34 | parsing.log_args(args) 35 | 36 | # initialization ----------------- vocab 37 | if not os.path.exists(args.vocab_file): 38 | raise ValueError(f"Vocab file {args.vocab_file} not found!") 39 | vocab = load_vocab(args.vocab_file) 40 | vocab_tokens = [k for k, v in sorted(vocab.items(), key=lambda tup: tup[1])] 41 | 42 | # initialization ----------------- model 43 | os.makedirs(args.save_dir, exist_ok=True) 44 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 45 | 46 | if args.model == "s2s": 47 | model_class = Seq2Seq 48 | dataset_class = S2SDataset 49 | elif args.model == "g2s_series_rel": 50 | model_class = Graph2SeqSeriesRel 51 | dataset_class = G2SDataset 52 | assert args.compute_graph_distance 53 | else: 54 | raise ValueError(f"Model {args.model} not supported!") 55 | 56 | model = model_class(args, vocab) 57 | for p in model.parameters(): 58 | if p.dim() > 1 and p.requires_grad: 59 | xavier_uniform_(p) 60 | 61 | if args.load_from: 62 | state = torch.load(args.load_from) 63 | pretrain_args = state["args"] 64 | pretrain_state_dict = state["state_dict"] 65 | logging.info(f"Loaded pretrained state_dict from {args.load_from}") 66 | model.load_state_dict(pretrain_state_dict, strict=False) 67 | 68 | model.to(device) 69 | model.train() 70 | 71 | logging.info(model) 72 | logging.info(f"Number of parameters = {param_count(model)}") 73 | 74 | # initialization ----------------- optimizer 75 | optimizer = optim.AdamW( 76 | model.parameters(), 77 | lr=args.lr, 78 | betas=(args.beta1, args.beta2), 79 | eps=args.eps, 80 | weight_decay=args.weight_decay 81 | ) 82 | scheduler = NoamLR( 83 | optimizer, 84 | model_size=args.decoder_hidden_size, 85 | warmup_steps=args.warmup_steps 86 | ) 87 | 88 | # initialization ----------------- data 89 | train_dataset = dataset_class(args, file=args.train_bin) 90 | valid_dataset = dataset_class(args, file=args.valid_bin) 91 | 92 | total_step = 0 93 | accum = 0 94 | losses, accs = [], [] 95 | 96 | # Creates a GradScaler once at the beginning of training. 97 | scaler = torch.cuda.amp.GradScaler(enabled=args.enable_amp) 98 | 99 | o_start = time.time() 100 | 101 | logging.info("Start training") 102 | logging.info(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 103 | for epoch in range(args.epoch): 104 | model.zero_grad() 105 | 106 | train_dataset.sort() 107 | train_dataset.shuffle_in_bucket(bucket_size=1000) 108 | train_dataset.batch( 109 | batch_type=args.batch_type, 110 | batch_size=args.train_batch_size 111 | ) 112 | train_loader = DataLoader( 113 | dataset=train_dataset, 114 | batch_size=1, 115 | shuffle=True, 116 | collate_fn=lambda _batch: _batch[0], 117 | num_workers=16, 118 | pin_memory=True 119 | ) 120 | 121 | for batch_idx, batch in enumerate(train_loader): 122 | if total_step > args.max_steps: 123 | logging.info("Max steps reached, finish training") 124 | exit(0) 125 | 126 | batch.to(device) 127 | with torch.autograd.profiler.profile(enabled=args.do_profile, 128 | record_shapes=args.record_shapes, 129 | use_cuda=torch.cuda.is_available()) as prof: 130 | 131 | # Enables autocasting for the forward pass (model + loss) 132 | with torch.cuda.amp.autocast(enabled=args.enable_amp): 133 | loss_all, acc = model(batch) 134 | old_loss, aux_loss, kld_loss = loss_all 135 | kla_coef = min(math.tanh(2. * total_step / args.max_steps - 3) + 1, 1) * 0.03 136 | loss = old_loss + 0.1 * aux_loss + kla_coef * kld_loss 137 | 138 | # Exits the context manager before backward() 139 | # Scales loss. Calls backward() on scaled loss to create scaled gradients. 140 | scaler.scale(loss).backward() 141 | 142 | losses.append(loss.item()) 143 | accs.append(acc.item() * 100) 144 | 145 | accum += 1 146 | 147 | if accum == args.accumulation_count: 148 | # Unscales the gradients of optimizer's assigned params in-place 149 | scaler.unscale_(optimizer) 150 | 151 | # Since the gradients of optimizer's assigned params are unscaled, clips as usual: 152 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) 153 | 154 | # optimizer's gradients are already unscaled, so scaler.step does not unscale them, 155 | scaler.step(optimizer) 156 | 157 | # Updates the scale for next iteration. 158 | scaler.update() 159 | 160 | scheduler.step() 161 | 162 | g_norm = grad_norm(model) 163 | model.zero_grad() 164 | accum = 0 165 | total_step += 1 166 | 167 | if args.do_profile: 168 | logging.info(prof 169 | .key_averages(group_by_input_shape=args.record_shapes) 170 | .table(sort_by="cuda_time_total")) 171 | sys.stdout.flush() 172 | 173 | if (accum == 0) and (total_step > 0) and (total_step % args.log_iter == 0): 174 | logging.info(f"Step {total_step}, loss: {np.mean(losses)}, acc: {np.mean(accs)}, " 175 | f"p_norm: {param_norm(model)}, g_norm: {g_norm}, " 176 | f"lr: {get_lr(optimizer): .6f}, elapsed time: {time.time() - o_start: .0f}") 177 | sys.stdout.flush() 178 | losses, accs = [], [] 179 | 180 | if (accum == 0) and (total_step > 0) and (total_step % args.eval_iter == 0): 181 | model.eval() 182 | eval_count = 100 183 | eval_meters = [0.0, 0.0] 184 | 185 | valid_dataset.sort() 186 | valid_dataset.shuffle_in_bucket(bucket_size=1000) 187 | valid_dataset.batch( 188 | batch_type=args.batch_type, 189 | batch_size=args.valid_batch_size 190 | ) 191 | valid_loader = DataLoader( 192 | dataset=valid_dataset, 193 | batch_size=1, 194 | shuffle=True, 195 | collate_fn=lambda _batch: _batch[0], 196 | num_workers=16, 197 | pin_memory=True 198 | ) 199 | 200 | with torch.no_grad(): 201 | for eval_idx, eval_batch in enumerate(valid_loader): 202 | if eval_idx >= eval_count: 203 | break 204 | eval_batch.to(device) 205 | 206 | eval_loss_all, eval_acc = model(eval_batch) 207 | old_loss, aux_loss, kld_loss = eval_loss_all 208 | kla_coef = min(math.tanh(2. * total_step / args.max_steps - 3) + 1, 1) * 0.03 209 | eval_loss = old_loss + 0.1 * aux_loss + kla_coef * kld_loss 210 | eval_meters[0] += eval_loss.item() / eval_count 211 | eval_meters[1] += eval_acc * 100 / eval_count 212 | 213 | logging.info(f"Evaluation (with teacher) at step {total_step}, eval loss: {eval_meters[0]}, " 214 | f"eval acc: {eval_meters[1]}") 215 | logging.info(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 216 | sys.stdout.flush() 217 | 218 | model.train() 219 | 220 | if (accum == 0) and (total_step > 0) and (total_step % args.save_iter == 0): 221 | n_iter = total_step // args.save_iter - 1 222 | 223 | model.eval() 224 | eval_count = 100 225 | 226 | valid_dataset.sort() 227 | valid_dataset.shuffle_in_bucket(bucket_size=1000) 228 | valid_dataset.batch( 229 | batch_type=args.batch_type, 230 | batch_size=args.valid_batch_size 231 | ) 232 | valid_loader = DataLoader( 233 | dataset=valid_dataset, 234 | batch_size=1, 235 | shuffle=True, 236 | collate_fn=lambda _batch: _batch[0], 237 | num_workers=16, 238 | pin_memory=True 239 | ) 240 | 241 | accs_token = [] 242 | accs_seq = [] 243 | 244 | with torch.no_grad(): 245 | for eval_idx, eval_batch in enumerate(valid_loader): 246 | if eval_idx >= eval_count: 247 | break 248 | 249 | eval_batch.to(device) 250 | results = model.predict_step( 251 | reaction_batch=eval_batch, 252 | batch_size=eval_batch.size, 253 | beam_size=args.beam_size, 254 | n_best=1, 255 | temperature=1.0, 256 | min_length=args.predict_min_len, 257 | max_length=args.predict_max_len 258 | )[0] 259 | predictions = [t[0].cpu().numpy() for t in results["predictions"]] 260 | 261 | for i, prediction in enumerate(predictions): 262 | acc_seq, acc_token = 0, 0 263 | target_ids = None 264 | for j in torch.arange(eval_batch.pres[i], eval_batch.posts[i]): 265 | tgt_length = valid_dataset.tgt_lens[j] 266 | tgt_token_ids = valid_dataset.tgt_token_ids[j][:tgt_length] 267 | acc_seq = max(np.array_equal(tgt_token_ids, prediction[:tgt_length]), acc_seq) 268 | while len(prediction) < tgt_length: 269 | prediction = np.append(prediction, vocab["_PAD"]) 270 | if np.mean(tgt_token_ids == prediction[:tgt_length])>=acc_token: 271 | acc_token = np.mean(tgt_token_ids == prediction[:tgt_length]) 272 | target_ids = tgt_token_ids 273 | 274 | accs_token.append(acc_token) 275 | accs_seq.append(acc_seq) 276 | 277 | if eval_idx % 20 == 0 and i == 0: 278 | logging.info(f"Target text: {' '.join([vocab_tokens[idx] for idx in target_ids])}") 279 | logging.info(f"Predicted text: {' '.join([vocab_tokens[idx] for idx in prediction])}") 280 | logging.info(f"acc_token: {acc_token}, acc_seq: {acc_seq}\n") 281 | 282 | logging.info(f"Evaluation (without teacher) at step {total_step}, " 283 | f"eval acc (token): {np.mean(accs_token)}, " 284 | f"eval acc (sequence): {np.mean(accs_seq)}") 285 | logging.info(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 286 | sys.stdout.flush() 287 | 288 | model.train() 289 | 290 | logging.info(f"Saving at step {total_step}") 291 | sys.stdout.flush() 292 | 293 | state = { 294 | "args": args, 295 | "state_dict": model.state_dict() 296 | } 297 | if not os.path.exists(args.save_dir): 298 | os.makedirs(args.save_dir) 299 | 300 | torch.save(state, os.path.join(args.save_dir, f"model.{total_step}_{n_iter}.pt")) 301 | if n_iter >= args.keep_last_ckpt-1: 302 | old_iter = n_iter - args.keep_last_ckpt 303 | old_path = os.path.join(args.save_dir, f"model_{(old_iter+1) * args.save_iter}_{old_iter}.pt") 304 | shutil.rmtree(old_path) 305 | 306 | # lastly 307 | if (args.accumulation_count > 1) and (accum > 0): 308 | scaler.unscale_(optimizer) 309 | 310 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) 311 | 312 | scaler.step(optimizer) 313 | scaler.update() 314 | 315 | scheduler.step() 316 | 317 | model.zero_grad() 318 | accum = 0 319 | 320 | 321 | if __name__ == "__main__": 322 | train_parser = get_train_parser() 323 | args = train_parser.parse_args() 324 | 325 | # set random seed 326 | set_seed(args.seed) 327 | 328 | # logger setup 329 | logger = setup_logger(args) 330 | 331 | torch.set_printoptions(profile="full") 332 | main(args) 333 | -------------------------------------------------------------------------------- /models/VAR/var_dec.py: -------------------------------------------------------------------------------- 1 | from numpy.core.fromnumeric import var 2 | import torch 3 | import torch.nn as nn 4 | 5 | from onmt.decoders.decoder import DecoderBase 6 | from onmt.modules import MultiHeadedAttention, AverageAttention 7 | from onmt.modules.position_ffn import PositionwiseFeedForward 8 | from onmt.utils.misc import sequence_mask 9 | 10 | 11 | from models.VAR.utils import gaussian_kld, Con_Layer_Norm, VarFeedForward, FeedForwardNet, gumbel_softmax, tile 12 | import torch.nn.functional as F 13 | 14 | class TransformerDecoderLayer(nn.Module): 15 | """Transformer Decoder layer block in Pre-Norm style. 16 | Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style, 17 | providing better converge speed and performance. This is also the actual 18 | implementation in tensor2tensor and also avalable in fairseq. 19 | See https://tunz.kr/post/4 and :cite:`DeeperTransformer`. 20 | 21 | .. mermaid:: 22 | 23 | graph LR 24 | %% "*SubLayer" can be self-attn, src-attn or feed forward block 25 | A(input) --> B[Norm] 26 | B --> C["*SubLayer"] 27 | C --> D[Drop] 28 | D --> E((+)) 29 | A --> E 30 | E --> F(out) 31 | 32 | 33 | Args: 34 | d_model (int): the dimension of keys/values/queries in 35 | :class:`MultiHeadedAttention`, also the input size of 36 | the first-layer of the :class:`PositionwiseFeedForward`. 37 | heads (int): the number of heads for MultiHeadedAttention. 38 | d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`. 39 | dropout (float): dropout in residual, self-attn(dot) and feed-forward 40 | attention_dropout (float): dropout in context_attn (and self-attn(avg)) 41 | self_attn_type (string): type of self-attention scaled-dot, average 42 | max_relative_positions (int): 43 | Max distance between inputs in relative positions representations 44 | aan_useffn (bool): Turn on the FFN layer in the AAN decoder 45 | full_context_alignment (bool): 46 | whether enable an extra full context decoder forward for alignment 47 | alignment_heads (int): 48 | N. of cross attention heads to use for alignment guiding 49 | """ 50 | 51 | def __init__(self, d_model, heads, d_ff, dropout, attention_dropout, 52 | self_attn_type="scaled-dot", max_relative_positions=0, 53 | aan_useffn=False, full_context_alignment=False, 54 | alignment_heads=0): 55 | super(TransformerDecoderLayer, self).__init__() 56 | 57 | if self_attn_type == "scaled-dot": 58 | self.self_attn = MultiHeadedAttention( 59 | heads, d_model, dropout=attention_dropout, 60 | max_relative_positions=max_relative_positions) 61 | elif self_attn_type == "average": 62 | self.self_attn = AverageAttention(d_model, 63 | dropout=attention_dropout, 64 | aan_useffn=aan_useffn) 65 | 66 | self.context_attn = MultiHeadedAttention( 67 | heads, d_model, dropout=attention_dropout) 68 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 69 | self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) 70 | self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) 71 | self.drop = nn.Dropout(dropout) 72 | self.full_context_alignment = full_context_alignment 73 | self.alignment_heads = alignment_heads 74 | 75 | def forward(self, *args, **kwargs): 76 | """ Extend `_forward` for (possibly) multiple decoder pass: 77 | Always a default (future masked) decoder forward pass, 78 | Possibly a second future aware decoder pass for joint learn 79 | full context alignement, :cite:`garg2019jointly`. 80 | 81 | Args: 82 | * All arguments of _forward. 83 | with_align (bool): whether return alignment attention. 84 | 85 | Returns: 86 | (FloatTensor, FloatTensor, FloatTensor or None): 87 | 88 | * output ``(batch_size, T, model_dim)`` 89 | * top_attn ``(batch_size, T, src_len)`` 90 | * attn_align ``(batch_size, T, src_len)`` or None 91 | """ 92 | with_align = kwargs.pop('with_align', False) 93 | output, attns = self._forward(*args, **kwargs) 94 | top_attn = attns[:, 0, :, :].contiguous() 95 | attn_align = None 96 | if with_align: 97 | if self.full_context_alignment: 98 | # return _, (B, Q_len, K_len) 99 | _, attns = self._forward(*args, **kwargs, future=True) 100 | 101 | if self.alignment_heads > 0: 102 | attns = attns[:, :self.alignment_heads, :, :].contiguous() 103 | # layer average attention across heads, get ``(B, Q, K)`` 104 | # Case 1: no full_context, no align heads -> layer avg baseline 105 | # Case 2: no full_context, 1 align heads -> guided align 106 | # Case 3: full_context, 1 align heads -> full cte guided align 107 | attn_align = attns.mean(dim=1) 108 | return output, top_attn, attn_align 109 | 110 | def _forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, 111 | layer_cache=None, step=None, condition=None, future=False): 112 | """ A naive forward pass for transformer decoder. 113 | 114 | # T: could be 1 in the case of stepwise decoding or tgt_len 115 | 116 | Args: 117 | inputs (FloatTensor): ``(batch_size, T, model_dim)`` 118 | memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)`` 119 | src_pad_mask (bool): ``(batch_size, 1, src_len)`` 120 | tgt_pad_mask (bool): ``(batch_size, 1, T)`` 121 | layer_cache (dict or None): cached layer info when stepwise decode 122 | step (int or None): stepwise decoding counter 123 | future (bool): If set True, do not apply future_mask. 124 | 125 | Returns: 126 | (FloatTensor, FloatTensor): 127 | 128 | * output ``(batch_size, T, model_dim)`` 129 | * attns ``(batch_size, head, T, src_len)`` 130 | 131 | """ 132 | dec_mask = None 133 | 134 | if step is None: 135 | tgt_len = tgt_pad_mask.size(-1) 136 | if not future: # apply future_mask, result mask in (B, T, T) 137 | future_mask = torch.ones( 138 | [tgt_len, tgt_len], 139 | device=tgt_pad_mask.device, 140 | dtype=torch.uint8) 141 | future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) 142 | # BoolTensor was introduced in pytorch 1.2 143 | try: 144 | future_mask = future_mask.bool() 145 | except AttributeError: 146 | pass 147 | dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) 148 | else: # only mask padding, result mask in (B, 1, T) 149 | dec_mask = tgt_pad_mask 150 | 151 | input_norm = self.layer_norm_1(inputs) 152 | 153 | if isinstance(self.self_attn, MultiHeadedAttention): 154 | query, _ = self.self_attn(input_norm, input_norm, input_norm, 155 | mask=dec_mask, 156 | layer_cache=layer_cache, 157 | attn_type="self") 158 | elif isinstance(self.self_attn, AverageAttention): 159 | query, _ = self.self_attn(input_norm, mask=dec_mask, 160 | layer_cache=layer_cache, step=step) 161 | 162 | query = self.drop(query) + inputs 163 | 164 | query_norm = self.layer_norm_2(query) 165 | mid, attns = self.context_attn(memory_bank, memory_bank, query_norm, 166 | mask=src_pad_mask, 167 | layer_cache=layer_cache, 168 | attn_type="context") 169 | output = self.feed_forward(self.drop(mid) + query) 170 | 171 | return output, attns 172 | 173 | def update_dropout(self, dropout, attention_dropout): 174 | self.self_attn.update_dropout(attention_dropout) 175 | self.context_attn.update_dropout(attention_dropout) 176 | self.feed_forward.update_dropout(dropout) 177 | self.drop.p = dropout 178 | 179 | 180 | class VariationalDecoderLayer(nn.Module): 181 | def __init__(self, args, d_model, heads, d_ff, dropout, attention_dropout, 182 | self_attn_type="scaled-dot", max_relative_positions=0, 183 | aan_useffn=False, full_context_alignment=False, 184 | alignment_heads=0): 185 | super(VariationalDecoderLayer, self).__init__() 186 | 187 | if self_attn_type == "scaled-dot": 188 | self.self_attn = MultiHeadedAttention( 189 | heads, d_model, dropout=attention_dropout, 190 | max_relative_positions=max_relative_positions) 191 | elif self_attn_type == "average": 192 | self.self_attn = AverageAttention(d_model, 193 | dropout=attention_dropout, 194 | aan_useffn=aan_useffn) 195 | 196 | self.context_attn = MultiHeadedAttention( 197 | heads, d_model, dropout=attention_dropout) 198 | # self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 199 | self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) 200 | self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) 201 | hid_size = args.varFFN_hidden_size # hid_size = 128 202 | latent_size = args.latent_size # latent_size = 60 203 | vocab_size = args.vocab_size # vocab_size = 75 204 | condition_size = d_model 205 | 206 | self.layer_norm_3 = nn.LayerNorm(d_model, eps=1e-6) 207 | self.prior = VarFeedForward(d_model, hid_size, 2 * latent_size, 208 | layer_config='lll', dropout=dropout) 209 | self.recog = VarFeedForward(d_model * 2, hid_size, 2 * latent_size, 210 | layer_config='lll', dropout=dropout) 211 | self.criterion = nn.NLLLoss(ignore_index=1) 212 | # self.z_score2 = nn.Sequential(nn.Linear(latent_size, vocab_size), nn.LogSoftmax(dim=-1)) 213 | # self.feed_forward = FeedForwardNet(d_model+latent_size, d_ff, d_model, dropout) # var version d_in != d_out 214 | self.feed_forward = FeedForwardNet(d_model+latent_size, d_ff, d_model, dropout=0) # var version for z d_in != d_out 215 | 216 | self.conln_1 = Con_Layer_Norm(d_model, condition_size) #(d_model, d_model) 217 | self.conln_2 = Con_Layer_Norm(d_model, condition_size) 218 | self.drop = nn.Dropout(dropout) 219 | self.full_context_alignment = full_context_alignment 220 | self.alignment_heads = alignment_heads 221 | 222 | def forward(self, *args, **kwargs): 223 | with_align = kwargs.pop('with_align', False) 224 | output, attns, z, kld_loss = self._forward(*args, **kwargs) 225 | top_attn = attns[:, 0, :, :].contiguous() 226 | attn_align = None 227 | if with_align: 228 | if self.full_context_alignment: 229 | # return _, (B, Q_len, K_len) 230 | _, attns, z = self._forward(*args, **kwargs, future=True) 231 | 232 | if self.alignment_heads > 0: 233 | attns = attns[:, :self.alignment_heads, :, :].contiguous() 234 | # layer average attention across heads, get ``(B, Q, K)`` 235 | # Case 1: no full_context, no align heads -> layer avg baseline 236 | # Case 2: no full_context, 1 align heads -> guided align 237 | # Case 3: full_context, 1 align heads -> full cte guided align 238 | attn_align = attns.mean(dim=1) 239 | return output, top_attn, z, kld_loss, attn_align 240 | 241 | def _forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, 242 | layer_cache=None, step=None, condition=None, future=False): 243 | train = True if layer_cache is None else False 244 | dec_mask = None 245 | 246 | if step is None: 247 | tgt_len = tgt_pad_mask.size(-1) 248 | if not future: # apply future_mask, result mask in (B, T, T) 249 | future_mask = torch.ones( 250 | [tgt_len, tgt_len], 251 | device=tgt_pad_mask.device, 252 | dtype=torch.uint8) 253 | future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) 254 | # BoolTensor was introduced in pytorch 1.2 255 | try: 256 | future_mask = future_mask.bool() 257 | except AttributeError: 258 | pass 259 | dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) 260 | else: # only mask padding, result mask in (B, 1, T) 261 | dec_mask = tgt_pad_mask 262 | 263 | input_norm = self.layer_norm_1(inputs) 264 | 265 | if isinstance(self.self_attn, MultiHeadedAttention): 266 | query, _ = self.self_attn(input_norm, input_norm, input_norm, 267 | mask=dec_mask, 268 | layer_cache=layer_cache, 269 | attn_type="self") 270 | elif isinstance(self.self_attn, AverageAttention): 271 | query, _ = self.self_attn(input_norm, mask=dec_mask, 272 | layer_cache=layer_cache, step=step) 273 | 274 | query = self.drop(query) + inputs 275 | 276 | query_norm = self.conln_1(query, memory_bank) 277 | mid, attns = self.context_attn(memory_bank, memory_bank, query_norm, 278 | mask=src_pad_mask, 279 | layer_cache=layer_cache, 280 | attn_type="context") 281 | mid_norm = self.conln_2(mid, memory_bank) # type_emb g_rep self.conln(mid, g_rep) 282 | prior = {"mean": [], "std": []} 283 | post = {"mean": [], "std": []} 284 | # Prior net for testing 285 | mu, log_var = self.prior(mid_norm).chunk(2, dim=-1) 286 | std = torch.exp(0.5 * log_var) 287 | prior["mean"] = mu 288 | prior["std"] = std 289 | # Posterior net for training 290 | kld_loss = 0 291 | if train: 292 | mu, log_var = self.recog(torch.cat([mid_norm, inputs], dim=-1)).chunk(2, dim=-1) 293 | 294 | std = torch.exp(0.5 * log_var) 295 | post["mean"] = mu 296 | post["std"] = std 297 | kld_loss = gaussian_kld(post["mean"], post["std"], prior["mean"], prior["std"]) # print(kld_loss[:3, :4]) 298 | kld_loss = torch.mean(kld_loss) # original 299 | 300 | # reparameterize 301 | eps = torch.randn(size=mu.size(), device=mu.device) 302 | z = eps * std + mu 303 | 304 | # Positionwise Feedforward 305 | output = self.feed_forward(torch.cat([z, mid_norm], dim=-1)) # torch.Size([640, 1, 256]) 306 | output = self.drop(output) + mid # torch.Size([141, 47, 256]) 307 | 308 | return output, attns, z, kld_loss 309 | 310 | def update_dropout(self, dropout, attention_dropout): 311 | self.self_attn.update_dropout(attention_dropout) 312 | self.context_attn.update_dropout(attention_dropout) 313 | # self.feed_forward.update_dropout(dropout) 314 | self.drop.p = dropout 315 | 316 | 317 | class VariationalDecoder(DecoderBase): 318 | def __init__(self, args, num_layers, d_model, heads, d_ff, 319 | copy_attn, self_attn_type, dropout, attention_dropout, 320 | embeddings, max_relative_positions, aan_useffn, 321 | full_context_alignment, alignment_layer, 322 | alignment_heads): 323 | super(VariationalDecoder, self).__init__() 324 | 325 | self.embeddings = embeddings 326 | 327 | # Decoder State 328 | self.state = {} 329 | 330 | var_layers = args.variational_num_layers 331 | num_layers = num_layers - var_layers 332 | self.transformer_layers = nn.ModuleList() 333 | for i in range(var_layers): 334 | self.transformer_layers.append(VariationalDecoderLayer(args, d_model, heads, d_ff, dropout, 335 | attention_dropout, self_attn_type=self_attn_type, 336 | max_relative_positions=max_relative_positions, 337 | aan_useffn=aan_useffn, 338 | full_context_alignment=full_context_alignment, 339 | alignment_heads=alignment_heads)) 340 | for i in range(num_layers): 341 | self.transformer_layers.append(TransformerDecoderLayer(d_model, heads, d_ff, dropout, 342 | attention_dropout, self_attn_type=self_attn_type, 343 | max_relative_positions=max_relative_positions, 344 | aan_useffn=aan_useffn, 345 | full_context_alignment=full_context_alignment, 346 | alignment_heads=alignment_heads)) 347 | 348 | # previously, there was a GlobalAttention module here for copy 349 | # attention. But it was never actually used -- the "copy" attention 350 | # just reuses the context attention. 351 | self._copy = copy_attn 352 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 353 | 354 | self.alignment_layer = alignment_layer 355 | 356 | # ++++++++++++++++ init GRU 357 | K = args.latent_K 358 | self.latent_emb = nn.Embedding(K, args.embed_size) 359 | self.prior = nn.GRU(d_model, hidden_size=K, num_layers=1, batch_first=True) 360 | self.recog = nn.GRU(d_model, hidden_size=K, num_layers=1, batch_first=True) 361 | self.gru = True 362 | self.cvae = True 363 | self.args = args 364 | 365 | 366 | @classmethod 367 | def from_opt(cls, opt, embeddings): 368 | """Alternate constructor.""" 369 | return cls( 370 | opt.dec_layers, 371 | opt.dec_rnn_size, 372 | opt.heads, 373 | opt.transformer_ff, 374 | opt.copy_attn, 375 | opt.self_attn_type, 376 | opt.dropout[0] if type(opt.dropout) is list else opt.dropout, 377 | opt.attention_dropout[0] if type(opt.attention_dropout) 378 | is list else opt.attention_dropout, 379 | embeddings, 380 | opt.max_relative_positions, 381 | opt.aan_useffn, 382 | opt.full_context_alignment, 383 | opt.alignment_layer, 384 | alignment_heads=opt.alignment_heads) 385 | 386 | def init_state(self, src, memory_bank, enc_hidden): 387 | """Initialize decoder state.""" 388 | self.state["src"] = src 389 | self.state["cache"] = None 390 | 391 | def map_state(self, fn): 392 | def _recursive_map(struct, batch_dim=0): 393 | for k, v in struct.items(): 394 | if v is not None: 395 | if isinstance(v, dict): 396 | _recursive_map(v) 397 | else: 398 | struct[k] = fn(v, batch_dim) 399 | 400 | self.state["src"] = fn(self.state["src"], 1) 401 | if self.state["cache"] is not None: 402 | _recursive_map(self.state["cache"]) 403 | 404 | def detach_state(self): 405 | self.state["src"] = self.state["src"].detach() 406 | 407 | def forward(self, product_emb, tgt, memory_bank, step=None, **kwargs): 408 | """Decode, possibly stepwise.""" 409 | if step == 0: 410 | self._init_cache(memory_bank) 411 | 412 | tgt_words = tgt[:, :, 0].transpose(0, 1) # [max_tgt_t, b, 1] => [b, max_tgt_t] 413 | 414 | emb = self.embeddings(tgt, step=step) 415 | assert emb.dim() == 3 # len x batch x embedding_dim [max_tgt_t, b, h] 416 | 417 | output = emb.transpose(0, 1).contiguous() # [max_tgt_t, b, h] => [b, max_tgt_t, h] 418 | src_memory_bank = memory_bank.transpose(0, 1).contiguous() # [max_src_t, b, h] => [b, max_src_t, h] 419 | product_emb = product_emb.transpose(0, 1).contiguous() # [max_src_t, b, h] => [b, max_src_t, h] 420 | #++++++++++++++++++++ Gumbel-Softmax 421 | kld_loss = 0 422 | ## for CVAE 423 | if hasattr(self, 'cvae') and (step is None or step==0): 424 | if hasattr(self, 'gru'): 425 | logits = self.prior(src_memory_bank, None)[-1][0] # GRU 426 | else: 427 | logits = self.prior(src_memory_bank.mean(dim=1)) 428 | prior = F.log_softmax(logits, dim=-1) 429 | if self.training: 430 | logits = self.recog(torch.cat([src_memory_bank, output], dim=1), None)[-1][0] # GRU 431 | recog = F.softmax(logits, dim=-1) 432 | kld_loss = F.kl_div(prior, recog, reduction="sum") 433 | y_sample = gumbel_softmax(logits, hard=True) 434 | ## for beam search during inference, i.e., perform validate.sh and predict.sh 435 | beam_size = self.args.beam_size 436 | if not self.training and beam_size>5: 437 | y_sample = torch.stack(y_sample.split(beam_size, dim=0), dim=1)[0] # [b, h] 438 | y_sample = tile(y_sample, count=beam_size, dim=0) 439 | latent = y_sample @ self.latent_emb.weight 440 | output[:,0] += latent 441 | self.y_sample = y_sample 442 | 443 | pad_idx = self.embeddings.word_padding_idx 444 | src_lens = kwargs["memory_lengths"] 445 | src_max_len = self.state["src"].shape[0] 446 | src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) 447 | tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] 448 | 449 | with_align = kwargs.pop('with_align', False) 450 | attn_aligns = [] 451 | 452 | for i, layer in enumerate(self.transformer_layers): 453 | layer_cache = self.state["cache"]["layer_{}".format(i)] \ 454 | if step is not None else None 455 | others = layer( 456 | output, 457 | src_memory_bank, 458 | src_pad_mask, 459 | tgt_pad_mask, 460 | layer_cache=layer_cache, 461 | step=step, 462 | with_align=with_align) 463 | if layer._get_name() == 'TransformerDecoderLayer': 464 | output, attn, attn_align = others 465 | elif layer._get_name() == 'VariationalDecoderLayer': 466 | output, attn, z, kld_loss, attn_align = others 467 | if attn_align is not None: 468 | attn_aligns.append(attn_align) 469 | 470 | aux_loss = 0 471 | output = self.layer_norm(output) 472 | dec_outs = output.transpose(0, 1).contiguous() 473 | attn = attn.transpose(0, 1).contiguous() 474 | 475 | attns = {"std": attn} 476 | if self._copy: 477 | attns["copy"] = attn 478 | if with_align: 479 | attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` 480 | # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg 481 | 482 | # TODO change the way attns is returned dict => list or tuple (onnx) 483 | if step is None: ## train or validate 484 | return dec_outs, attns, aux_loss, kld_loss 485 | else: 486 | return dec_outs, attns 487 | 488 | def _init_cache(self, memory_bank): 489 | self.state["cache"] = {} 490 | batch_size = memory_bank.size(1) 491 | depth = memory_bank.size(-1) 492 | 493 | for i, layer in enumerate(self.transformer_layers): 494 | layer_cache = {"memory_keys": None, "memory_values": None} 495 | if isinstance(layer.self_attn, AverageAttention): 496 | layer_cache["prev_g"] = torch.zeros((batch_size, 1, depth), 497 | device=memory_bank.device) 498 | else: 499 | layer_cache["self_keys"] = None 500 | layer_cache["self_values"] = None 501 | self.state["cache"]["layer_{}".format(i)] = layer_cache 502 | 503 | def update_dropout(self, dropout, attention_dropout): 504 | self.embeddings.update_dropout(dropout) 505 | for layer in self.transformer_layers: 506 | layer.update_dropout(dropout, attention_dropout) 507 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path 3 | 4 | import networkx as nx 5 | import numpy as np 6 | import re 7 | import selfies as sf 8 | import sys 9 | import time 10 | import torch 11 | from rdkit import Chem 12 | from torch.utils.data import Dataset 13 | from typing import Dict, List, Tuple 14 | from utils.chem_utils import ATOM_FDIM, BOND_FDIM, get_atom_features_sparse, get_bond_features 15 | from utils.rxn_graphs import RxnGraph 16 | 17 | import torch.nn as nn 18 | 19 | 20 | def tokenize_selfies_from_smiles(smi: str) -> str: 21 | encoded_selfies = sf.encoder(smi) 22 | tokens = list(sf.split_selfies(encoded_selfies)) 23 | assert encoded_selfies == "".join(tokens) 24 | 25 | return " ".join(tokens) 26 | 27 | 28 | def tokenize_smiles(smi: str) -> str: 29 | pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 30 | regex = re.compile(pattern) 31 | tokens = [token for token in regex.findall(smi)] 32 | assert smi == "".join(tokens), f"Tokenization mismatch. smi: {smi}, tokens: {tokens}" 33 | 34 | return " ".join(tokens) 35 | 36 | 37 | def canonicalize_smiles(smiles, remove_atom_number=False, trim=True, suppress_warning=False): 38 | cano_smiles = "" 39 | 40 | mol = Chem.MolFromSmiles(smiles) 41 | 42 | if mol is None: 43 | cano_smiles = "" 44 | 45 | else: 46 | if trim and mol.GetNumHeavyAtoms() < 2: 47 | if not suppress_warning: 48 | logging.info(f"Problematic smiles: {smiles}, setting it to 'CC'") 49 | cano_smiles = "CC" # TODO: hardcode to ignore 50 | else: 51 | if remove_atom_number: 52 | [a.ClearProp('molAtomMapNumber') for a in mol.GetAtoms()] 53 | cano_smiles = Chem.MolToSmiles(mol, isomericSmiles=True) 54 | 55 | return cano_smiles 56 | 57 | 58 | def len2idx(lens) -> np.ndarray: 59 | # end_indices = np.cumsum(np.concatenate(lens, axis=0)) 60 | end_indices = np.cumsum(lens) 61 | start_indices = np.concatenate([[0], end_indices[:-1]], axis=0) 62 | indices = np.stack([start_indices, end_indices], axis=1) 63 | 64 | return indices 65 | 66 | 67 | class S2SBatch: 68 | def __init__(self, 69 | src_token_ids: torch.Tensor, 70 | src_lengths: torch.Tensor, 71 | tgt_token_ids: torch.Tensor, 72 | tgt_lengths: torch.Tensor): 73 | self.src_token_ids = src_token_ids 74 | self.src_lengths = src_lengths 75 | self.tgt_token_ids = tgt_token_ids 76 | self.tgt_lengths = tgt_lengths 77 | 78 | self.size = len(src_lengths) 79 | 80 | def to(self, device): 81 | self.src_token_ids = self.src_token_ids.to(device) 82 | self.src_lengths = self.src_lengths.to(device) 83 | self.tgt_token_ids = self.tgt_token_ids.to(device) 84 | self.tgt_lengths = self.tgt_lengths.to(device) 85 | 86 | def pin_memory(self): 87 | self.src_token_ids = self.src_token_ids.pin_memory() 88 | self.src_lengths = self.src_lengths.pin_memory() 89 | self.tgt_token_ids = self.tgt_token_ids.pin_memory() 90 | self.tgt_lengths = self.tgt_lengths.pin_memory() 91 | 92 | return self 93 | 94 | def log_tensor_shape(self): 95 | logging.info(f"src_token_ids: {self.src_token_ids.shape}, " 96 | f"src_lengths: {self.src_lengths}, " 97 | f"tgt_token_ids: {self.tgt_token_ids.shape}, " 98 | f"tgt_lengths: {self.tgt_lengths}") 99 | 100 | 101 | class S2SDataset(Dataset): 102 | def __init__(self, args, file: str): 103 | self.args = args 104 | 105 | self.src_token_ids = [] 106 | self.src_lens = [] 107 | self.tgt_token_ids = [] 108 | self.tgt_lens = [] 109 | 110 | self.data_indices = [] 111 | self.batch_sizes = [] 112 | self.batch_starts = [] 113 | self.batch_ends = [] 114 | 115 | logging.info(f"Loading preprocessed features from {file}") 116 | feat = np.load(file) 117 | for attr in ["src_token_ids", "src_lens", "tgt_token_ids", "tgt_lens"]: 118 | setattr(self, attr, feat[attr]) 119 | 120 | assert len(self.src_token_ids) == len(self.src_lens) == len(self.tgt_token_ids) == len(self.tgt_lens), \ 121 | f"Lengths of source and target mismatch!" 122 | 123 | self.data_size = len(self.src_token_ids) 124 | self.data_indices = np.arange(self.data_size) 125 | 126 | logging.info(f"Loaded and initialized S2SDataset, size: {self.data_size}") 127 | 128 | def sort(self): 129 | start = time.time() 130 | 131 | logging.info(f"Calling S2SDataset.sort()") 132 | sys.stdout.flush() 133 | self.data_indices = np.argsort(self.src_lens + self.tgt_lens) 134 | 135 | logging.info(f"Done, time: {time.time() - start: .2f} s") 136 | sys.stdout.flush() 137 | 138 | def shuffle_in_bucket(self, bucket_size: int): 139 | start = time.time() 140 | 141 | logging.info(f"Calling S2SDataset.shuffle_in_bucket()") 142 | sys.stdout.flush() 143 | 144 | for i in range(0, self.data_size, bucket_size): 145 | np.random.shuffle(self.data_indices[i:i+bucket_size]) 146 | 147 | logging.info(f"Done, time: {time.time() - start: .2f} s") 148 | sys.stdout.flush() 149 | 150 | def batch(self, batch_type: str, batch_size: int): 151 | start = time.time() 152 | 153 | logging.info(f"Calling S2SDataset.batch()") 154 | sys.stdout.flush() 155 | 156 | self.batch_sizes = [] 157 | 158 | if batch_type == "samples": 159 | raise NotImplementedError 160 | 161 | elif batch_type == "atoms": 162 | raise NotImplementedError 163 | 164 | elif batch_type == "tokens": 165 | sample_size = 0 166 | max_batch_src_len = 0 167 | max_batch_tgt_len = 0 168 | 169 | for data_idx in self.data_indices: 170 | src_len = self.src_lens[data_idx] 171 | tgt_len = self.tgt_lens[data_idx] 172 | 173 | max_batch_src_len = max(src_len, max_batch_src_len) 174 | max_batch_tgt_len = max(tgt_len, max_batch_tgt_len) 175 | while self.args.enable_amp and not max_batch_src_len % 8 == 0: # for amp 176 | max_batch_src_len += 1 177 | while self.args.enable_amp and not max_batch_tgt_len % 8 == 0: # for amp 178 | max_batch_tgt_len += 1 179 | 180 | if (max_batch_src_len + max_batch_tgt_len) * (sample_size + 1) <= batch_size: 181 | sample_size += 1 182 | elif self.args.enable_amp and not sample_size % 8 == 0: 183 | sample_size += 1 184 | else: 185 | self.batch_sizes.append(sample_size) 186 | 187 | sample_size = 1 188 | max_batch_src_len = src_len 189 | max_batch_tgt_len = tgt_len 190 | while self.args.enable_amp and not max_batch_src_len % 8 == 0: # for amp 191 | max_batch_src_len += 1 192 | while self.args.enable_amp and not max_batch_tgt_len % 8 == 0: # for amp 193 | max_batch_tgt_len += 1 194 | 195 | # lastly 196 | self.batch_sizes.append(sample_size) 197 | self.batch_sizes = np.array(self.batch_sizes) 198 | assert np.sum(self.batch_sizes) == self.data_size, \ 199 | f"Size mismatch! Data size: {self.data_size}, sum batch sizes: {np.sum(self.batch_sizes)}" 200 | 201 | self.batch_ends = np.cumsum(self.batch_sizes) 202 | self.batch_starts = np.concatenate([[0], self.batch_ends[:-1]]) 203 | 204 | else: 205 | raise ValueError(f"batch_type {batch_type} not supported!") 206 | 207 | logging.info(f"Done, time: {time.time() - start: .2f} s, total batches: {self.__len__()}") 208 | sys.stdout.flush() 209 | 210 | def __getitem__(self, index: int) -> S2SBatch: 211 | batch_start = self.batch_starts[index] 212 | batch_end = self.batch_ends[index] 213 | 214 | data_indices = self.data_indices[batch_start:batch_end] 215 | 216 | # collating, essentially 217 | src_token_ids = self.src_token_ids[data_indices] 218 | src_lengths = self.src_lens[data_indices] 219 | tgt_token_ids = self.tgt_token_ids[data_indices] 220 | tgt_lengths = self.tgt_lens[data_indices] 221 | 222 | src_token_ids = src_token_ids[:, :max(src_lengths)] 223 | tgt_token_ids = tgt_token_ids[:, :max(tgt_lengths)] 224 | 225 | src_token_ids = torch.as_tensor(src_token_ids, dtype=torch.long) 226 | tgt_token_ids = torch.as_tensor(tgt_token_ids, dtype=torch.long) 227 | src_lengths = torch.tensor(src_lengths, dtype=torch.long) 228 | tgt_lengths = torch.tensor(tgt_lengths, dtype=torch.long) 229 | 230 | s2s_batch = S2SBatch( 231 | src_token_ids=src_token_ids, 232 | src_lengths=src_lengths, 233 | tgt_token_ids=tgt_token_ids, 234 | tgt_lengths=tgt_lengths 235 | ) 236 | # s2s_batch.log_tensor_shape() 237 | return s2s_batch 238 | 239 | def __len__(self): 240 | return len(self.batch_sizes) 241 | 242 | 243 | class G2SBatch: 244 | def __init__(self, 245 | fnode: torch.Tensor, 246 | fmess: torch.Tensor, 247 | agraph: torch.Tensor, 248 | bgraph: torch.Tensor, 249 | atom_scope: List, 250 | bond_scope: List, 251 | tgt_token_ids: torch.Tensor, 252 | tgt_lengths: torch.Tensor, 253 | distances: torch.Tensor = None, 254 | data_indice: torch.Tensor = None, 255 | multi_tgts: torch.Tensor = None, 256 | pres: torch.Tensor = None, 257 | posts: torch.Tensor = None): 258 | # N_tgt: torch.Tensor = None, 259 | # N_tgt_lens: torch.Tensor = None): 260 | self.fnode = fnode 261 | self.fmess = fmess 262 | self.agraph = agraph 263 | self.bgraph = bgraph 264 | self.atom_scope = atom_scope 265 | self.bond_scope = bond_scope 266 | self.tgt_token_ids = tgt_token_ids 267 | self.tgt_lengths = tgt_lengths 268 | self.distances = distances 269 | #++++++++++++++++++++++++ 270 | self.data_indice = data_indice 271 | self.multi_tgts = multi_tgts 272 | # self.N_tgt = N_tgt 273 | # self.N_tgt_lens = N_tgt_lens 274 | self.pres = pres 275 | self.posts = posts 276 | #++++++++++++++++++++++++ 277 | 278 | self.size = len(tgt_lengths) 279 | 280 | def to(self, device): 281 | self.fnode = self.fnode.to(device) 282 | self.fmess = self.fmess.to(device) 283 | self.agraph = self.agraph.to(device) 284 | self.bgraph = self.bgraph.to(device) 285 | self.tgt_token_ids = self.tgt_token_ids.to(device) 286 | self.tgt_lengths = self.tgt_lengths.to(device) 287 | #++++++++++++++++++++++++ 288 | self.data_indice = self.data_indice.to(device) 289 | self.multi_tgts = self.multi_tgts.to(device) 290 | # self.N_tgt = self.N_tgt.to(device) 291 | # self.N_tgt_lens = self.N_tgt_lens.to(device) 292 | self.pres = self.pres.to(device) 293 | self.posts = self.posts.to(device) 294 | #++++++++++++++++++++++++ 295 | 296 | 297 | if self.distances is not None: 298 | self.distances = self.distances.to(device) 299 | 300 | def pin_memory(self): 301 | self.fnode = self.fnode.pin_memory() 302 | self.fmess = self.fmess.pin_memory() 303 | self.agraph = self.agraph.pin_memory() 304 | self.bgraph = self.bgraph.pin_memory() 305 | self.tgt_token_ids = self.tgt_token_ids.pin_memory() 306 | self.tgt_lengths = self.tgt_lengths.pin_memory() 307 | #++++++++++++++++++++++++ 308 | self.data_indice = self.data_indice.pin_memory() 309 | self.multi_tgts = self.multi_tgts.pin_memory() 310 | # self.N_tgt = self.N_tgt.pin_memory() 311 | # self.N_tgt_lens = self.N_tgt_lens.pin_memory() 312 | self.pres = self.pres.pin_memory() 313 | self.posts = self.posts.pin_memory() 314 | #++++++++++++++++++++++++ 315 | 316 | if self.distances is not None: 317 | self.distances = self.distances.pin_memory() 318 | 319 | return self 320 | 321 | def log_tensor_shape(self): 322 | logging.info(f"fnode: {self.fnode.shape}, " 323 | f"fmess: {self.fmess.shape}, " 324 | f"tgt_token_ids: {self.tgt_token_ids.shape}, " 325 | f"tgt_lengths: {self.tgt_lengths}") 326 | 327 | 328 | class G2SDataset(Dataset): 329 | def __init__(self, args, file: str, unique=False): 330 | self.args = args 331 | 332 | self.a_scopes = [] 333 | self.b_scopes = [] 334 | self.a_features = [] 335 | self.b_features = [] 336 | self.a_graphs = [] 337 | self.b_graphs = [] 338 | self.a_scopes_lens = [] 339 | self.b_scopes_lens = [] 340 | self.a_features_lens = [] 341 | self.b_features_lens = [] 342 | 343 | self.src_token_ids = [] # loaded but not batched 344 | self.src_lens = [] 345 | self.tgt_token_ids = [] 346 | self.tgt_lens = [] 347 | 348 | self.data_indices = [] 349 | self.batch_sizes = [] 350 | self.batch_starts = [] 351 | self.batch_ends = [] 352 | 353 | self.vocab = load_vocab(args.vocab_file) 354 | self.vocab_tokens = [k for k, v in sorted(self.vocab.items(), key=lambda tup: tup[1])] 355 | 356 | #++++++++++++++++++++ 357 | save_dir = f"./data/{args.data_name}" # './data/miniFULL' 358 | if 'train' in file: 359 | fp = os.path.join(save_dir, "src-train.txt") 360 | elif 'val' in file: 361 | fp = os.path.join(save_dir, "src-val.txt") 362 | elif 'test' in file: 363 | fp = os.path.join(save_dir, "src-test.txt") 364 | with open(fp, 'r') as f: 365 | ptr = [] 366 | last_sp = None 367 | lines = f.readlines() 368 | for i, sp in enumerate(lines): 369 | if sp != last_sp: 370 | ptr.append(i) 371 | last_sp = sp 372 | ptr.append(i+1) 373 | self.ptr = torch.tensor(ptr, dtype=torch.long) 374 | #++++++++++++++++++++ 375 | 376 | logging.info(f"Loading preprocessed features from {file}") 377 | feat = np.load(file) 378 | for attr in ["a_scopes", "b_scopes", "a_features", "b_features", "a_graphs", "b_graphs", 379 | "a_scopes_lens", "b_scopes_lens", "a_features_lens", "b_features_lens", 380 | "src_token_ids", "src_lens", "tgt_token_ids", "tgt_lens"]: 381 | setattr(self, attr, feat[attr]) 382 | 383 | # mask out chiral tag (as UNSPECIFIED) 384 | self.a_features[:, 6] = 2 385 | 386 | assert len(self.a_scopes_lens) == len(self.b_scopes_lens) == \ 387 | len(self.a_features_lens) == len(self.b_features_lens) == \ 388 | len(self.src_token_ids) == len(self.src_lens) == \ 389 | len(self.tgt_token_ids) == len(self.tgt_lens), \ 390 | f"Lengths of source and target mismatch!" 391 | 392 | self.a_scopes_indices = len2idx(self.a_scopes_lens) 393 | self.b_scopes_indices = len2idx(self.b_scopes_lens) 394 | self.a_features_indices = len2idx(self.a_features_lens) 395 | self.b_features_indices = len2idx(self.b_features_lens) 396 | 397 | del self.a_scopes_lens, self.b_scopes_lens, self.a_features_lens, self.b_features_lens 398 | 399 | self.data_size = len(self.src_token_ids) 400 | self.data_indices = np.arange(self.data_size) 401 | 402 | logging.info(f"Loaded and initialized G2SDataset, size: {self.data_size}") 403 | 404 | 405 | def sort(self): 406 | if self.args.verbose: 407 | start = time.time() 408 | 409 | logging.info(f"Calling G2SDataset.sort()") 410 | sys.stdout.flush() 411 | self.data_indices = np.argsort(self.src_lens) 412 | 413 | logging.info(f"Done, time: {time.time() - start: .2f} s") 414 | sys.stdout.flush() 415 | 416 | else: 417 | self.data_indices = np.argsort(self.src_lens) 418 | 419 | def shuffle_in_bucket(self, bucket_size: int): 420 | if self.args.verbose: 421 | start = time.time() 422 | 423 | logging.info(f"Calling G2SDataset.shuffle_in_bucket()") 424 | sys.stdout.flush() 425 | 426 | for i in range(0, self.data_size, bucket_size): 427 | np.random.shuffle(self.data_indices[i:i+bucket_size]) 428 | 429 | logging.info(f"Done, time: {time.time() - start: .2f} s") 430 | sys.stdout.flush() 431 | 432 | else: 433 | for i in range(0, self.data_size, bucket_size): 434 | np.random.shuffle(self.data_indices[i:i + bucket_size]) 435 | 436 | def batch(self, batch_type: str, batch_size: int): 437 | start = time.time() 438 | 439 | logging.info(f"Calling G2SDataset.batch()") 440 | sys.stdout.flush() 441 | 442 | self.batch_sizes = [] 443 | 444 | if batch_type == "samples": 445 | raise NotImplementedError 446 | 447 | elif batch_type == "atoms": 448 | raise NotImplementedError 449 | 450 | elif batch_type.startswith("tokens"): 451 | sample_size = 0 452 | max_batch_src_len = 0 453 | max_batch_tgt_len = 0 454 | 455 | for data_idx in self.data_indices: 456 | src_len = self.src_lens[data_idx] 457 | tgt_len = self.tgt_lens[data_idx] 458 | 459 | max_batch_src_len = max(src_len, max_batch_src_len) 460 | max_batch_tgt_len = max(tgt_len, max_batch_tgt_len) 461 | while self.args.enable_amp and not max_batch_src_len % 8 == 0: # for amp 462 | max_batch_src_len += 1 463 | while self.args.enable_amp and not max_batch_tgt_len % 8 == 0: # for amp 464 | max_batch_tgt_len += 1 465 | 466 | if batch_type == "tokens" and \ 467 | max_batch_src_len * (sample_size + 1) <= batch_size: 468 | sample_size += 1 469 | elif batch_type == "tokens_sum" and \ 470 | (max_batch_src_len + max_batch_tgt_len) * (sample_size + 1) <= batch_size: 471 | sample_size += 1 472 | elif self.args.enable_amp and not sample_size % 8 == 0: 473 | sample_size += 1 474 | else: 475 | self.batch_sizes.append(sample_size) 476 | 477 | sample_size = 1 478 | max_batch_src_len = src_len 479 | max_batch_tgt_len = tgt_len 480 | while self.args.enable_amp and not max_batch_src_len % 8 == 0: # for amp 481 | max_batch_src_len += 1 482 | while self.args.enable_amp and not max_batch_tgt_len % 8 == 0: # for amp 483 | max_batch_tgt_len += 1 484 | 485 | 486 | # lastly 487 | self.batch_sizes.append(sample_size) 488 | self.batch_sizes = np.array(self.batch_sizes) 489 | assert np.sum(self.batch_sizes) == self.data_size, \ 490 | f"Size mismatch! Data size: {self.data_size}, sum batch sizes: {np.sum(self.batch_sizes)}" 491 | 492 | self.batch_ends = np.cumsum(self.batch_sizes) 493 | self.batch_starts = np.concatenate([[0], self.batch_ends[:-1]]) 494 | 495 | else: 496 | raise ValueError(f"batch_type {batch_type} not supported!") 497 | 498 | logging.info(f"Done, time: {time.time() - start: .2f} s, total batches: {self.__len__()}") 499 | sys.stdout.flush() 500 | 501 | def __getitem__(self, index: int) -> G2SBatch: 502 | batch_index = index 503 | batch_start = self.batch_starts[batch_index] 504 | batch_end = self.batch_ends[batch_index] 505 | 506 | data_indices = self.data_indices[batch_start:batch_end] 507 | 508 | # collating, essentially 509 | # source (graph) 510 | graph_features = [] 511 | a_lengths = [] 512 | for data_index in data_indices: 513 | start, end = self.a_scopes_indices[data_index] 514 | a_scope = self.a_scopes[start:end] 515 | a_length = a_scope[-1][0] + a_scope[-1][1] - a_scope[0][0] 516 | 517 | start, end = self.b_scopes_indices[data_index] 518 | b_scope = self.b_scopes[start:end] 519 | 520 | start, end = self.a_features_indices[data_index] 521 | a_feature = self.a_features[start:end] 522 | a_graph = self.a_graphs[start:end] 523 | 524 | start, end = self.b_features_indices[data_index] 525 | b_feature = self.b_features[start:end] 526 | b_graph = self.b_graphs[start:end] 527 | 528 | graph_feature = (a_scope, b_scope, a_feature, b_feature, a_graph, b_graph) 529 | graph_features.append(graph_feature) 530 | a_lengths.append(a_length) 531 | 532 | fnode, fmess, agraph, bgraph, atom_scope, bond_scope = collate_graph_features(graph_features) 533 | 534 | # target (seq) 535 | tgt_token_ids = self.tgt_token_ids[data_indices] 536 | tgt_lengths = self.tgt_lens[data_indices] 537 | 538 | tgt_token_ids = tgt_token_ids[:, :max(tgt_lengths)] 539 | 540 | tgt_token_ids = torch.as_tensor(tgt_token_ids, dtype=torch.long) 541 | tgt_lengths = torch.tensor(tgt_lengths, dtype=torch.long) 542 | 543 | distances = None 544 | if self.args.compute_graph_distance: 545 | distances = collate_graph_distances(self.args, graph_features, a_lengths) 546 | 547 | """ 548 | logging.info("--------------------src_tokens--------------------") 549 | for data_index in data_indices: 550 | smi = "".join(self.vocab_tokens[src_token_id] for src_token_id in self.src_token_ids[data_index]) 551 | logging.info(smi) 552 | logging.info("--------------------distances--------------------") 553 | logging.info(f"{distances}") 554 | exit(0) 555 | """ 556 | #++++++++++++++++++++++++ 557 | pres, posts = [], [] 558 | for id in data_indices: #torch.as_tensor(data_indices, dtype=torch.long) 559 | post_id = (self.ptr > id).nonzero()[0] 560 | pre = self.ptr[post_id-1] 561 | post = self.ptr[post_id] 562 | pres.append(pre) 563 | posts.append(post) 564 | pres = torch.cat(pres, dim=-1) 565 | posts = torch.cat(posts, dim=-1) 566 | max_len = min(8, (posts-pres).max()) 567 | # max_len = (posts-pres).max() 568 | # TODO: max(tgt_lengths) may not be the truth 569 | tgt_ids = torch.as_tensor(self.tgt_token_ids, dtype=torch.long) 570 | ts=[tgt_ids[torch.arange(pres[i], posts[i]), :max(tgt_lengths)] for i in range(len(data_indices))] ## list, Sequence of unequal lengths 571 | for i, new_ts in enumerate(ts): 572 | old_t = tgt_token_ids[i] 573 | uni_idx = (new_ts!=old_t).float().argmax(-1) # len of len(new_ts) the first index of each row that new&old differs 574 | ind = torch.stack([torch.arange(len(uni_idx)),uni_idx],dim=0) 575 | val = torch.ones(size=(len(uni_idx),)) 576 | s = torch.sparse_coo_tensor(ind, val, (len(uni_idx),old_t.shape[-1])) 577 | mask = s.to_dense()>0 578 | ts[i] = torch.where(mask, new_ts,old_t.unsqueeze(0).expand(new_ts.shape)) 579 | for i,tgts in enumerate(ts): 580 | if len(tgts)<=max_len: 581 | ts[i]=nn.ZeroPad2d((0,0,0,max_len-len(tgts)))(tgts) 582 | else: 583 | ts[i]=tgts[torch.randperm(len(tgts))[:max_len]] 584 | ts = torch.cat(ts, dim=0) 585 | #++++++++++++++++++++++++ 586 | 587 | g2s_batch = G2SBatch( 588 | fnode=fnode, 589 | fmess=fmess, 590 | agraph=agraph, 591 | bgraph=bgraph, 592 | atom_scope=atom_scope, 593 | bond_scope=bond_scope, 594 | tgt_token_ids=tgt_token_ids, 595 | tgt_lengths=tgt_lengths, 596 | distances=distances, 597 | #++++++++++++++++++++++++ 598 | data_indice=torch.as_tensor(data_indices, dtype=torch.long), 599 | multi_tgts=ts, 600 | pres=pres, 601 | posts=posts 602 | # N_tgt=N_tgt, 603 | # N_tgt_lens=N_tgt_lens 604 | #++++++++++++++++++++++++ 605 | ) 606 | # g2s_batch.log_tensor_shape() 607 | 608 | return g2s_batch 609 | 610 | def __len__(self): 611 | return len(self.batch_sizes) 612 | 613 | 614 | def get_graph_from_smiles(smi: str): 615 | mol = Chem.MolFromSmiles(smi) 616 | rxn_graph = RxnGraph(reac_mol=mol) 617 | 618 | return rxn_graph 619 | 620 | 621 | def get_graph_features_from_smi(_args): 622 | i, smi, use_rxn_class = _args 623 | assert isinstance(smi, str) and isinstance(use_rxn_class, bool) 624 | if i > 0 and i % 10000 == 0: 625 | logging.info(f"Processing {i}th SMILES") 626 | 627 | atom_features = [] 628 | bond_features = [] 629 | edge_dict = {} 630 | 631 | if not smi.strip(): 632 | smi = "CC" # hardcode to ignore 633 | 634 | graph = get_graph_from_smiles(smi).reac_mol 635 | 636 | mol = graph.mol 637 | assert mol.GetNumAtoms() == len(graph.G_dir) 638 | 639 | G = nx.convert_node_labels_to_integers(graph.G_dir, first_label=0) 640 | 641 | # node iteration to get sparse atom features 642 | for v, attr in G.nodes(data="label"): 643 | atom_feat = get_atom_features_sparse(mol.GetAtomWithIdx(v), 644 | use_rxn_class=use_rxn_class, 645 | rxn_class=graph.rxn_class) 646 | atom_features.append(atom_feat) 647 | 648 | a_graphs = [[] for _ in range(len(atom_features))] 649 | 650 | # edge iteration to get (dense) bond features 651 | for u, v, attr in G.edges(data='label'): 652 | bond_feat = get_bond_features(mol.GetBondBetweenAtoms(u, v)) 653 | bond_feat = [u, v] + bond_feat 654 | bond_features.append(bond_feat) 655 | 656 | eid = len(edge_dict) 657 | edge_dict[(u, v)] = eid 658 | a_graphs[v].append(eid) 659 | 660 | b_graphs = [[] for _ in range(len(bond_features))] 661 | 662 | # second edge iteration to get neighboring edges (after edge_dict is updated fully) 663 | for bond_feat in bond_features: 664 | u, v = bond_feat[:2] 665 | eid = edge_dict[(u, v)] 666 | 667 | for w in G.predecessors(u): 668 | if not w == v: 669 | b_graphs[eid].append(edge_dict[(w, u)]) 670 | 671 | # padding 672 | for a_graph in a_graphs: 673 | while len(a_graph) < 11: # OH MY GOODNESS... Fe can be bonded to 10... 674 | a_graph.append(1e9) 675 | 676 | for b_graph in b_graphs: 677 | while len(b_graph) < 11: # OH MY GOODNESS... Fe can be bonded to 10... 678 | b_graph.append(1e9) 679 | 680 | a_scopes = np.array(graph.atom_scope, dtype=np.int32) 681 | a_scopes_lens = a_scopes.shape[0] 682 | b_scopes = np.array(graph.bond_scope, dtype=np.int32) 683 | b_scopes_lens = b_scopes.shape[0] 684 | a_features = np.array(atom_features, dtype=np.int32) 685 | a_features_lens = a_features.shape[0] 686 | b_features = np.array(bond_features, dtype=np.int32) 687 | b_features_lens = b_features.shape[0] 688 | a_graphs = np.array(a_graphs, dtype=np.int32) 689 | b_graphs = np.array(b_graphs, dtype=np.int32) 690 | 691 | return a_scopes, a_scopes_lens, b_scopes, b_scopes_lens, \ 692 | a_features, a_features_lens, b_features, b_features_lens, a_graphs, b_graphs 693 | 694 | 695 | def collate_graph_features(graph_features: List[Tuple], directed: bool = True, use_rxn_class: bool = False) \ 696 | -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[np.ndarray], List[np.ndarray]]: 697 | if directed: 698 | padded_features = get_atom_features_sparse(Chem.Atom("*"), use_rxn_class=use_rxn_class, rxn_class=0) 699 | fnode = [np.array(padded_features)] 700 | fmess = [np.zeros(shape=[1, 2 + BOND_FDIM], dtype=np.int32)] 701 | agraph = [np.zeros(shape=[1, 11], dtype=np.int32)] 702 | bgraph = [np.zeros(shape=[1, 11], dtype=np.int32)] 703 | 704 | n_unique_bonds = 1 705 | edge_offset = 1 706 | 707 | atom_scope, bond_scope = [], [] 708 | 709 | for bid, graph_feature in enumerate(graph_features): 710 | a_scope, b_scope, atom_features, bond_features, a_graph, b_graph = graph_feature 711 | 712 | a_scope = a_scope.copy() 713 | b_scope = b_scope.copy() 714 | atom_features = atom_features.copy() 715 | bond_features = bond_features.copy() 716 | a_graph = a_graph.copy() 717 | b_graph = b_graph.copy() 718 | 719 | atom_offset = len(fnode) 720 | bond_offset = n_unique_bonds 721 | n_unique_bonds += int(bond_features.shape[0] / 2) # This should be correct? 722 | 723 | a_scope[:, 0] += atom_offset 724 | b_scope[:, 0] += bond_offset 725 | atom_scope.append(a_scope) 726 | bond_scope.append(b_scope) 727 | 728 | # node iteration is reduced to an extend 729 | fnode.extend(atom_features) 730 | 731 | # edge iteration is reduced to an append 732 | bond_features[:, :2] += atom_offset 733 | fmess.append(bond_features) 734 | 735 | a_graph += edge_offset 736 | a_graph[a_graph >= 999999999] = 0 # resetting padding edge to point towards edge 0 737 | agraph.append(a_graph) 738 | 739 | b_graph += edge_offset 740 | b_graph[b_graph >= 999999999] = 0 # resetting padding edge to point towards edge 0 741 | bgraph.append(b_graph) 742 | 743 | edge_offset += bond_features.shape[0] 744 | 745 | # densification 746 | fnode = np.stack(fnode, axis=0) 747 | fnode_one_hot = np.zeros([fnode.shape[0], sum(ATOM_FDIM)], dtype=np.float32) 748 | 749 | for i in range(len(ATOM_FDIM) - 1): 750 | fnode[:, i+1:] += ATOM_FDIM[i] # cumsum, essentially 751 | 752 | for i, feat in enumerate(fnode): # Looks vectorizable? 753 | # fnode_one_hot[i, feat[feat < 9999]] = 1 754 | fnode_one_hot[i, feat[feat < sum(ATOM_FDIM)]] = 1 755 | 756 | fnode = torch.as_tensor(fnode_one_hot, dtype=torch.float) 757 | fmess = torch.as_tensor(np.concatenate(fmess, axis=0), dtype=torch.float) 758 | 759 | agraph = np.concatenate(agraph, axis=0) 760 | column_idx = np.argwhere(np.all(agraph[..., :] == 0, axis=0)) 761 | agraph = agraph[:, :column_idx[0, 0] + 1] # drop trailing columns of 0, leaving only 1 last column of 0 762 | 763 | bgraph = np.concatenate(bgraph, axis=0) 764 | column_idx = np.argwhere(np.all(bgraph[..., :] == 0, axis=0)) 765 | bgraph = bgraph[:, :column_idx[0, 0] + 1] # drop trailing columns of 0, leaving only 1 last column of 0 766 | 767 | agraph = torch.as_tensor(agraph, dtype=torch.long) 768 | bgraph = torch.as_tensor(bgraph, dtype=torch.long) 769 | 770 | else: 771 | raise NotImplementedError 772 | 773 | return fnode, fmess, agraph, bgraph, atom_scope, bond_scope 774 | 775 | 776 | def collate_graph_distances(args, graph_features: List[Tuple], a_lengths: List[int]) -> torch.Tensor: 777 | max_len = max(a_lengths) 778 | 779 | distances = [] 780 | for bid, (graph_feature, a_length) in enumerate(zip(graph_features, a_lengths)): 781 | _, _, _, bond_features, _, _ = graph_feature 782 | bond_features = bond_features.copy() 783 | 784 | # compute adjacency 785 | adjacency = np.zeros((a_length, a_length), dtype=np.int32) 786 | for bond_feature in bond_features: 787 | u, v = bond_feature[:2] 788 | adjacency[u, v] = 1 789 | 790 | # compute graph distance 791 | distance = adjacency.copy() 792 | shortest_paths = adjacency.copy() 793 | path_length = 2 794 | stop_counter = 0 795 | non_zeros = 0 796 | 797 | while 0 in distance: 798 | shortest_paths = np.matmul(shortest_paths, adjacency) 799 | shortest_paths = path_length * (shortest_paths > 0) 800 | new_distance = distance + (distance == 0) * shortest_paths 801 | 802 | # if np.count_nonzero(new_distance) == np.count_nonzero(distance): 803 | if np.count_nonzero(new_distance) <= non_zeros: 804 | stop_counter += 1 805 | else: 806 | non_zeros = np.count_nonzero(new_distance) 807 | stop_counter = 0 808 | 809 | if args.task == "reaction_prediction" and stop_counter == 3: 810 | break 811 | 812 | distance = new_distance 813 | path_length += 1 814 | 815 | # bucket 816 | distance[(distance > 8) & (distance < 15)] = 8 817 | distance[distance >= 15] = 9 818 | if args.task == "reaction_prediction": 819 | distance[distance == 0] = 10 820 | 821 | # reset diagonal 822 | np.fill_diagonal(distance, 0) 823 | 824 | # padding 825 | if args.task == "reaction_prediction": 826 | padded_distance = np.ones((max_len, max_len), dtype=np.int32) * 11 827 | else: 828 | padded_distance = np.ones((max_len, max_len), dtype=np.int32) * 10 829 | padded_distance[:a_length, :a_length] = distance 830 | 831 | distances.append(padded_distance) 832 | 833 | distances = np.stack(distances) 834 | distances = torch.as_tensor(distances, dtype=torch.long) 835 | 836 | return distances 837 | 838 | 839 | def make_vocab(fns: Dict[str, List[Tuple[str, str]]], vocab_file: str, tokenized=True): 840 | assert tokenized, f"Vocab can only be made from tokenized files" 841 | 842 | logging.info(f"Making vocab from {fns}") 843 | vocab = {} 844 | 845 | for phase, file_list in fns.items(): 846 | for src_file, tgt_file in file_list: 847 | for fn in [src_file, tgt_file]: 848 | with open(fn, "r") as f: 849 | for line in f: 850 | tokens = line.strip().split() 851 | for token in tokens: 852 | if token in vocab: 853 | vocab[token] += 1 854 | else: 855 | vocab[token] = 1 856 | 857 | logging.info(f"Saving vocab into {vocab_file}") 858 | with open(vocab_file, "w") as of: 859 | of.write("_PAD\n_UNK\n_SOS\n_EOS\n") 860 | for token, count in vocab.items(): 861 | of.write(f"{token}\t{count}\n") 862 | 863 | 864 | def load_vocab(vocab_file: str) -> Dict[str, int]: 865 | if os.path.exists(vocab_file): 866 | logging.info(f"Loading vocab from {vocab_file}") 867 | else: 868 | vocab_file = "./preprocessed/default_vocab_smiles.txt" 869 | logging.info(f"Vocab file invalid, loading default vocab from {vocab_file}") 870 | 871 | vocab = {} 872 | with open(vocab_file, "r") as f: 873 | for i, line in enumerate(f): 874 | token = line.strip().split("\t")[0] 875 | vocab[token] = i 876 | 877 | return vocab 878 | 879 | 880 | def data_util_test(): 881 | pass 882 | 883 | 884 | if __name__ == "__main__": 885 | data_util_test() 886 | --------------------------------------------------------------------------------