├── .gitignore ├── FlowER.png ├── examples ├── samples │ └── unrecognized_pistachio.pkl ├── beam_predict_seed.py └── run_beam_predict_seed.sh ├── requirements.txt ├── scripts ├── search.sh ├── train.sh ├── eval_multiGPU.sh └── search_multiGPU.sh ├── LICENSE ├── run_FlowER_large_newData.sh ├── run_FlowER_large_oldData.sh ├── utils ├── attn_utils.py ├── rounding.py ├── train_utils.py └── data_utils.py ├── settings.py ├── model ├── flow_matching.py └── attn_encoder.py ├── sequence_evaluation.py ├── README.md ├── train.py ├── beam_predict.py ├── beam_predict_multiGPU.py └── eval_multiGPU.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | logs/* 3 | checkpoints/* 4 | results/* 5 | *.pyc -------------------------------------------------------------------------------- /FlowER.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FongMunHong/FlowER/HEAD/FlowER.png -------------------------------------------------------------------------------- /examples/samples/unrecognized_pistachio.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FongMunHong/FlowER/HEAD/examples/samples/unrecognized_pistachio.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.0 2 | numpy==1.26.4 3 | torchdiffeq==0.2.4 4 | rdkit==2024.3.3 5 | iteround==1.0.4 6 | networkx==3.3 7 | matplotlib==3.9.1 -------------------------------------------------------------------------------- /scripts/search.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python beam_predict.py 4 | 5 | # To ensure reproducibility of results in paper in spite of stochasticity, 6 | # we aggregate results from a 100 random seeds 7 | # python examples/beam_predict_seed.py -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | torchrun \ 4 | --node_rank="$NODE_RANK" \ 5 | --nnodes="$NUM_NODES"\ 6 | --nproc_per_node="$NUM_GPUS_PER_NODE" \ 7 | --rdzv-id=456 \ 8 | --rdzv-backend=c10d \ 9 | --rdzv-endpoint="$MASTER_ADDR:$MASTER_PORT" \ 10 | train.py -------------------------------------------------------------------------------- /scripts/eval_multiGPU.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | torchrun \ 4 | --node_rank="$NODE_RANK" \ 5 | --nnodes="$NUM_NODES"\ 6 | --nproc_per_node="$NUM_GPUS_PER_NODE" \ 7 | --rdzv-id=456 \ 8 | --rdzv-backend=c10d \ 9 | --rdzv-endpoint="$MASTER_ADDR:$MASTER_PORT" \ 10 | eval_multiGPU.py 11 | -------------------------------------------------------------------------------- /scripts/search_multiGPU.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | torchrun \ 4 | --node_rank="$NODE_RANK" \ 5 | --nnodes="$NUM_NODES"\ 6 | --nproc_per_node="$NUM_GPUS_PER_NODE" \ 7 | --rdzv-id=456 \ 8 | --rdzv-backend=c10d \ 9 | --rdzv-endpoint="$MASTER_ADDR:$MASTER_PORT" \ 10 | beam_predict_multiGPU.py 11 | 12 | -------------------------------------------------------------------------------- /examples/beam_predict_seed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | import os 4 | import torch 5 | import numpy as np 6 | from beam_predict import Args, setup_logger, log_args, main, log_rank_0 7 | 8 | if __name__ == "__main__": 9 | args = Args 10 | args.local_rank = int(os.environ["LOCAL_RANK"]) if os.environ.get("LOCAL_RANK") else -1 11 | logger = setup_logger(args, "beam") 12 | log_args(args, 'evaluation') 13 | 14 | for i in range(100): 15 | seed = torch.seed() 16 | log_rank_0(f"Current_seed: {seed}") 17 | torch.manual_seed(seed) 18 | 19 | main(args, seed=0) 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 FongMunHong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/run_beam_predict_seed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ## This script is intended for recreating paper results reliabily. 4 | 5 | ### RUN sh examples/run_beam_predict_seed.sh - outside of examples folder 6 | 7 | # export DATA_NAME="flower_dataset" # old dataset 8 | # export EXP_NAME="best_large_hyperparam" 9 | # export EMB_DIM=256 10 | # export RBF_HIGH=18 11 | # export RBF_GAP=0.1 12 | # export SIGMA=0.15 13 | 14 | # export MODEL_NAME="model.2370000_78.pt" # your trained checkpoint here 15 | 16 | export DATA_NAME="flower_new_dataset" 17 | export EXP_NAME="best_large_hyperparam" 18 | export EMB_DIM=256 19 | export RBF_HIGH=12 20 | export RBF_GAP=0.1 21 | export SIGMA=0.15 22 | 23 | export MODEL_NAME="model.2880000_95.pt" # your trained checkpoint here 24 | 25 | 26 | export TRAIN_BATCH_SIZE=4096 27 | export VAL_BATCH_SIZE=4096 28 | export TEST_BATCH_SIZE=4096 29 | 30 | export NUM_WORKERS=4 31 | export CUDA_VISIBLE_DEVICES=0 32 | export NUM_GPUS_PER_NODE=1 33 | 34 | export NUM_NODES=1 35 | export NODE_RANK=0 36 | export MASTER_ADDR=localhost 37 | export MASTER_PORT=1235 38 | 39 | export TRAIN_FILE=$PWD/data/$DATA_NAME/train.txt 40 | export VAL_FILE=$PWD/data/$DATA_NAME/val.txt 41 | export TEST_FILE=$PWD/data/$DATA_NAME/beam.txt 42 | 43 | 44 | export MODEL_PATH=$PWD/checkpoints/$DATA_NAME/$EXP_NAME/ 45 | export RESULT_PATH=$PWD/results/$DATA_NAME/$EXP_NAME/ 46 | 47 | 48 | python examples/beam_predict_seed.py 49 | -------------------------------------------------------------------------------- /run_FlowER_large_newData.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | export DATA_NAME="flower_new_dataset" 5 | export EXP_NAME="best_large_hyperparam" 6 | export EMB_DIM=256 7 | export RBF_HIGH=12 8 | export RBF_GAP=0.1 9 | export SIGMA=0.15 10 | 11 | export MODEL_NAME="model.2880000_95.pt" # your trained checkpoint here 12 | 13 | export TRAIN_BATCH_SIZE=4096 14 | export VAL_BATCH_SIZE=4096 15 | export TEST_BATCH_SIZE=4096 16 | 17 | export NUM_WORKERS=4 18 | export CUDA_VISIBLE_DEVICES=0 19 | export NUM_GPUS_PER_NODE=1 20 | 21 | export NUM_NODES=1 22 | export NODE_RANK=0 23 | export MASTER_ADDR=localhost 24 | export MASTER_PORT=1235 25 | 26 | export TRAIN_FILE=$PWD/data/$DATA_NAME/train.txt 27 | export VAL_FILE=$PWD/data/$DATA_NAME/val.txt 28 | # export TEST_FILE=$PWD/data/$DATA_NAME/test.txt 29 | export TEST_FILE=$PWD/data/$DATA_NAME/beam.txt 30 | 31 | 32 | export MODEL_PATH=$PWD/checkpoints/$DATA_NAME/$EXP_NAME/ 33 | export RESULT_PATH=$PWD/results/$DATA_NAME/$EXP_NAME/ 34 | 35 | 36 | # [ -f $TRAIN_FILE ] || { echo $TRAIN_FILE does not exist; exit; } 37 | # [ -f $VAL_FILE ] || { echo $VAL_FILE does not exist; exit; } 38 | # [ -f $TEST_FILE ] || { echo $TEST_FILE does not exist; exit; } 39 | 40 | 41 | export SCALE=4 # smaller sample size during training validation 42 | # sh scripts/train.sh 43 | 44 | export SCALE=1 # larger sample size during testing 45 | # sh scripts/eval_multiGPU.sh 46 | sh scripts/search.sh 47 | # sh scripts/search_multiGPU.sh -------------------------------------------------------------------------------- /run_FlowER_large_oldData.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Reporting for FlowER-large trained on oldData 4 | 5 | export DATA_NAME="flower_dataset" # old dataset 6 | export EXP_NAME="best_large_hyperparam" 7 | export EMB_DIM=256 8 | export RBF_HIGH=18 9 | export RBF_GAP=0.1 10 | export SIGMA=0.15 11 | 12 | export MODEL_NAME="model.2370000_78.pt" # your trained checkpoint here 13 | 14 | export TRAIN_BATCH_SIZE=4096 15 | export VAL_BATCH_SIZE=4096 16 | export TEST_BATCH_SIZE=4096 17 | 18 | export NUM_WORKERS=4 19 | export CUDA_VISIBLE_DEVICES=0 20 | export NUM_GPUS_PER_NODE=1 21 | 22 | export NUM_NODES=1 23 | export NODE_RANK=0 24 | export MASTER_ADDR=localhost 25 | export MASTER_PORT=1235 26 | 27 | export TRAIN_FILE=$PWD/data/$DATA_NAME/train.txt 28 | export VAL_FILE=$PWD/data/$DATA_NAME/val.txt 29 | # export TEST_FILE=$PWD/data/$DATA_NAME/test.txt 30 | export TEST_FILE=$PWD/data/$DATA_NAME/beam.txt 31 | 32 | 33 | export MODEL_PATH=$PWD/checkpoints/$DATA_NAME/$EXP_NAME/ 34 | export RESULT_PATH=$PWD/results/$DATA_NAME/$EXP_NAME/ 35 | 36 | 37 | [ -f $TRAIN_FILE ] || { echo $TRAIN_FILE does not exist; exit; } 38 | [ -f $VAL_FILE ] || { echo $VAL_FILE does not exist; exit; } 39 | [ -f $TEST_FILE ] || { echo $TEST_FILE does not exist; exit; } 40 | 41 | 42 | export SCALE=4 # smaller sample size during training validation 43 | sh scripts/train.sh 44 | 45 | export SCALE=1 # larger sample size during testing 46 | # sh scripts/eval_multiGPU.sh 47 | # sh scripts/search.sh 48 | # sh scripts/search_multiGPU.sh -------------------------------------------------------------------------------- /utils/attn_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class PositionwiseFeedForward(nn.Module): 5 | """ A two-layer Feed-Forward-Network with residual layer norm. 6 | 7 | Args: 8 | d_model (int): the size of input for the first-layer of the FFN. 9 | d_ff (int): the hidden layer size of the second-layer 10 | of the FNN. 11 | dropout (float): dropout probability in :math:`[0, 1)`. 12 | """ 13 | 14 | def __init__(self, d_model, d_ff, dropout=0.1): 15 | super(PositionwiseFeedForward, self).__init__() 16 | self.w_1 = nn.Linear(d_model, d_ff) 17 | self.w_2 = nn.Linear(d_ff, d_model) 18 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 19 | self.dropout_1 = nn.Dropout(dropout) 20 | self.gelu = nn.GELU() 21 | self.dropout_2 = nn.Dropout(dropout) 22 | 23 | def forward(self, x): 24 | """Layer definition. 25 | 26 | Args: 27 | x: ``(batch_size, input_len, model_dim)`` 28 | 29 | Returns: 30 | (FloatTensor): Output ``(batch_size, input_len, model_dim)``. 31 | """ 32 | 33 | inter = self.dropout_1(self.gelu(self.w_1(self.layer_norm(x)))) 34 | output = self.dropout_2(self.w_2(inter)) 35 | return output + x 36 | 37 | 38 | def sequence_mask(lengths, max_len=None): 39 | """ 40 | Creates a boolean mask from sequence lengths. 41 | """ 42 | batch_size = lengths.numel() 43 | max_len = max_len or lengths.max() 44 | return (torch.arange(0, max_len, device=lengths.device) 45 | .type_as(lengths) 46 | .repeat(batch_size, 1) 47 | .lt(lengths.unsqueeze(1))) 48 | 49 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | DATA_NAME = os.environ.get("DATA_NAME", "USPTO") 5 | EXP_NAME = os.environ.get("EXP_NAME", "") 6 | 7 | SCALE = int(os.environ.get("SCALE", 4)) # train & val 8 | # SCALE = 1 # test 9 | SAMPLE_SIZE = 64 // SCALE 10 | NUM_GPU = int(os.environ.get("NUM_GPUS_PER_NODE", 1)) 11 | 12 | 13 | TRAIN_BATCH_SIZE = int(os.environ.get("TRAIN_BATCH_SIZE", 4096)) 14 | VAL_BATCH_SIZE = int(os.environ.get("VAL_BATCH_SIZE", 4096)) 15 | TEST_BATCH_SIZE = int(os.environ.get("TEST_BATCH_SIZE", 512 * NUM_GPU * SCALE)) 16 | 17 | NUM_NODES = int(os.environ.get("NUM_NODES", 1)) 18 | ACCUMULATION_COUNT = int(os.environ.get("ACCUMULATION_COUNT", 1)) 19 | NUM_WORKERS = int(os.environ.get("NUM_WORKERS", 16)) 20 | 21 | MODEL_NAME = os.environ.get("MODEL_NAME") 22 | 23 | class Args: 24 | # train # 25 | model_name = MODEL_NAME 26 | exp_name = EXP_NAME 27 | train_path = os.environ.get("TRAIN_FILE") 28 | val_path = os.environ.get("VAL_FILE") 29 | test_path = os.environ.get("TEST_FILE") 30 | model_path = os.environ.get("MODEL_PATH") 31 | result_path = os.environ.get("RESULT_PATH") 32 | data_name = f"{DATA_NAME}" 33 | log_file = f"FlowER" 34 | load_from = "" 35 | # resume = True 36 | # load_from = f"{model_path}{MODEL_NAME}" 37 | 38 | backend = "nccl" 39 | num_workers = NUM_WORKERS 40 | emb_dim = int(os.environ.get("EMB_DIM")) 41 | enc_num_layers = 12 42 | post_processing_layers = 1 43 | enc_heads = 32 44 | enc_filter_size = 2048 45 | dropout = 0.0 46 | attn_dropout = 0.0 47 | rel_pos = "emb_only" 48 | shared_attention_layer = 0 49 | sigma = float(os.environ.get("SIGMA")) 50 | train_batch_size = (TRAIN_BATCH_SIZE / ACCUMULATION_COUNT / NUM_GPU / NUM_NODES) 51 | val_batch_size = (VAL_BATCH_SIZE / ACCUMULATION_COUNT / NUM_GPU / NUM_NODES) 52 | test_batch_size = TEST_BATCH_SIZE 53 | batch_type = "tokens_sum" 54 | lr = 0.0001 55 | beta1 = 0.9 56 | beta2 = 0.998 57 | eps = 1e-9 58 | weight_decay = 1e-2 59 | warmup_steps = 30000 60 | clip_norm = 200 61 | 62 | 63 | epoch = int(os.environ.get("EPOCH", 100)) 64 | max_steps = 3000000 65 | accumulation_count = ACCUMULATION_COUNT 66 | save_iter = int(os.environ.get("SAVE_ITER", 30000)) 67 | log_iter = int(os.environ.get("LOG_ITER", 100)) 68 | eval_iter = int(os.environ.get("EVAL_ITER", 30000)) 69 | 70 | 71 | sample_size = SAMPLE_SIZE 72 | rbf_low = 0 73 | rbf_high = float(os.environ.get("RBF_HIGH")) 74 | rbf_gap = float(os.environ.get("RBF_GAP")) 75 | 76 | # validation # 77 | # do_validate = True 78 | # steps2validate = ["1050000", "1320000", "1500000", "930000", "1020000"] 79 | 80 | # inference # 81 | do_validate = False 82 | 83 | # beam-search # 84 | beam_size = 5 85 | nbest = 3 86 | max_depth = 15 87 | chunk_size = 50 -------------------------------------------------------------------------------- /utils/rounding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | def saferound_tensor( 5 | x: torch.Tensor, 6 | places: int, 7 | strategy: str = "difference", 8 | topline: Optional[float] = None 9 | ) -> torch.Tensor: 10 | """ 11 | Round a tensor elementwise to `places` decimal places, adjusting 12 | a minimal number of entries so that the total sum is exactly preserved. 13 | 14 | Args: 15 | x (torch.Tensor): input tensor of floats. 16 | places (int): number of decimal places to round to. 17 | strategy (str): one of {"difference","largest","smallest"}: 18 | - "difference": pick the entries with largest fractional parts first. 19 | - "largest" : pick the largest values first. 20 | - "smallest" : pick the smallest values first. 21 | topline (float, optional): if given, override the target sum 22 | with `topline`. Otherwise target is x.sum(). 23 | 24 | Returns: 25 | torch.Tensor: same shape as `x`, rounded to `places`, but whose 26 | sum exactly equals the rounded(original_sum, places). 27 | """ 28 | assert isinstance(places, int), "places must be integer" 29 | assert strategy in ("difference","largest","smallest"), f"Unknown strategy {strategy}" 30 | 31 | # Flatten for simplicity 32 | orig = x.view(-1).to(dtype=torch.float64) 33 | N = orig.numel() 34 | 35 | # Determine the exact sum we need to hit 36 | total = topline if topline is not None else orig.sum().item() 37 | scale = 10 ** places 38 | target_int = int(round(total * scale)) 39 | 40 | # Scale and take floor/ceil 41 | scaled = orig * scale 42 | low = torch.floor(scaled).to(torch.int64) # integer floors 43 | high = torch.ceil(scaled).to(torch.int64) # integer ceils 44 | sum_low = int(low.sum().item()) 45 | residual = target_int - sum_low # how many +1’s we need 46 | 47 | if residual != 0: 48 | # Depending on strategy, create a sort key 49 | if strategy == "difference": 50 | # fractional part, descending 51 | frac = (scaled - low).cpu() 52 | _, indices = torch.sort(frac, descending=True) 53 | elif strategy == "largest": 54 | # values descending 55 | _, indices = torch.sort(orig.cpu(), descending=True) 56 | else: # "smallest" 57 | # values ascending 58 | _, indices = torch.sort(orig.cpu(), descending=False) 59 | 60 | # Pick exactly `abs(residual)` indices 61 | k = min(abs(residual), N) 62 | chosen = indices[:k] 63 | 64 | # Apply the +1 or -1 65 | if residual > 0: 66 | low[chosen] += 1 67 | else: 68 | # In the very rare case sum_low > target_int, we go back down 69 | low[chosen] -= 1 70 | 71 | # Convert back to float decimals 72 | rounded_flat = low.to(torch.float64).mul_(1.0 / scale) 73 | # reshape back 74 | return rounded_flat.view_as(x).to(dtype=x.dtype) 75 | 76 | -------------------------------------------------------------------------------- /model/flow_matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.data_utils import MATRIX_PAD 4 | 5 | def zero_center_func(x, node_mask): 6 | N = node_mask.sum() 7 | mean = torch.sum(x) / N 8 | x = x - mean * node_mask 9 | return x 10 | 11 | class ConditionalFlowMatcher(nn.Module): 12 | """Base class for conditional flow matching methods. This class implements the independent 13 | conditional flow matching methods from [1] and serves as a parent class for all other flow 14 | matching methods. 15 | 16 | It implements: 17 | - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function 18 | - conditional flow matching ut(x1|x0) = x1 - x0 19 | - score function $\nabla log p_t(x|x0, x1)$ 20 | """ 21 | 22 | def __init__(self, args): 23 | r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$. 24 | 25 | Parameters 26 | ---------- 27 | sigma : float 28 | """ 29 | super().__init__() 30 | self.args = args 31 | self.device = args.device 32 | self.sigma = args.sigma 33 | self.dim = args.emb_dim 34 | 35 | def zero_centered_noise(self, size, node_mask_batch): 36 | rand = torch.randn(size).to(self.device) 37 | x_batch = rand * node_mask_batch 38 | map_zero_center = torch.vmap(zero_center_func) # map on multiple batch 39 | return map_zero_center(x_batch, node_mask_batch).masked_fill(~(node_mask_batch.bool()), 1e-19) 40 | 41 | def sample_be_matrix(self, matrix): 42 | node_mask = (matrix[:, :, 0] != MATRIX_PAD) 43 | masks = (node_mask.unsqueeze(1) * node_mask.unsqueeze(2)).long() 44 | 45 | noise = self.zero_centered_noise(masks.shape, masks) # (n, n, b, d) 46 | noise = 0.5 * (noise + noise.transpose(1, 2)) 47 | matrix = matrix + noise * self.sigma 48 | 49 | return matrix 50 | 51 | def sample_conditional_pt(self, x0, x1, t): 52 | """ 53 | Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. 54 | 55 | Parameters 56 | ---------- 57 | x0 : Tensor, shape (bs, *dim) 58 | represents the source minibatch 59 | x1 : Tensor, shape (bs, *dim) 60 | represents the target minibatch 61 | t : FloatTensor, shape (bs) 62 | 63 | Returns 64 | ------- 65 | xt : Tensor, shape (bs, *dim) 66 | 67 | References 68 | ---------- 69 | [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. 70 | """ 71 | t = t.reshape(-1, *([1] * (x0.dim() - 1))) 72 | mu_t = t * x1 + (1 - t) * x0 73 | return self.sample_be_matrix(mu_t) 74 | 75 | def compute_conditional_vector_field(self, x0, x1): 76 | """ 77 | Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. 78 | 79 | Parameters 80 | ---------- 81 | x0 : Tensor, shape (bs, *dim) 82 | represents the source minibatch 83 | x1 : Tensor, shape (bs, *dim) 84 | represents the target minibatch 85 | 86 | Returns 87 | ------- 88 | ut : conditional vector field ut(x1|x0) = x1 - x0 89 | 90 | References 91 | ---------- 92 | [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. 93 | """ 94 | return x1 - x0 -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import numpy as np 4 | import os 5 | import random 6 | import sys 7 | import torch 8 | import torch.nn as nn 9 | from datetime import datetime 10 | from rdkit import RDLogger 11 | from torch.optim.lr_scheduler import _LRScheduler 12 | 13 | def log_args(args, phase: str): 14 | log_rank_0(f"Logging {phase} arguments") 15 | for k, v in vars(args).items(): 16 | if "__" in k: continue 17 | log_rank_0(f"**** {k} = *{v}*") 18 | 19 | def param_count(model: nn.Module) -> int: 20 | return sum(param.numel() for param in model.parameters() if param.requires_grad) 21 | 22 | 23 | def param_norm(m): 24 | return math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()])) 25 | 26 | 27 | def grad_norm(m): 28 | return math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None])) 29 | 30 | 31 | def get_lr(optimizer): 32 | for param_group in optimizer.param_groups: 33 | return param_group["lr"] 34 | 35 | 36 | def set_seed(seed): 37 | random.seed(seed) 38 | os.environ['PYTHONHASHSEED'] = str(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed_all(seed) 42 | torch.backends.cudnn.benchmark = False 43 | torch.backends.cudnn.deterministic = True 44 | 45 | 46 | def setup_logger(args, phase, warning_off: bool = False): 47 | if warning_off: 48 | RDLogger.DisableLog("rdApp.*") 49 | else: 50 | RDLogger.DisableLog("rdApp.warning") 51 | 52 | os.makedirs(f"./logs/{args.data_name}/{args.exp_name}", exist_ok=True) 53 | dt = datetime.strftime(datetime.now(), "%y%m%d-%H%Mh") 54 | 55 | logger = logging.getLogger() 56 | logger.setLevel(logging.INFO) 57 | fh = logging.FileHandler(f"./logs/{args.data_name}/{args.exp_name}/{phase}_{args.log_file}.{dt}") 58 | sh = logging.StreamHandler(sys.stdout) 59 | fh.setLevel(logging.INFO) 60 | sh.setLevel(logging.INFO) 61 | logger.addHandler(fh) 62 | logger.addHandler(sh) 63 | 64 | return logger 65 | 66 | 67 | def log_rank_0(message): 68 | if torch.distributed.is_initialized(): 69 | if torch.distributed.get_rank() == 0: 70 | logging.info(message) 71 | sys.stdout.flush() 72 | else: 73 | logging.info(message) 74 | sys.stdout.flush() 75 | 76 | 77 | def log_tensor(tensor, tensor_name: str, shape_only=False): 78 | log_rank_0(f"--------------------------{tensor_name}--------------------------") 79 | if not shape_only: 80 | log_rank_0(tensor) 81 | if isinstance(tensor, torch.Tensor): 82 | log_rank_0(tensor.shape) 83 | elif isinstance(tensor, np.ndarray): 84 | log_rank_0(tensor.shape) 85 | elif isinstance(tensor, list): 86 | try: 87 | for item in tensor: 88 | log_rank_0(item.shape) 89 | except Exception as e: 90 | log_rank_0(f"Error: {e}") 91 | log_rank_0("List items are not tensors, skip shape logging.") 92 | 93 | 94 | class NoamLR(_LRScheduler): 95 | """ 96 | Adapted from https://github.com/tugstugi/pytorch-saltnet/blob/master/utils/lr_scheduler.py 97 | 98 | Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate 99 | linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally 100 | to the inverse square root of the step number, scaled by the inverse square root of the 101 | dimensionality of the model. Time will tell if this is just madness or it's actually important. 102 | Parameters 103 | ---------- 104 | warmup_steps: ``int``, required. 105 | The number of steps to linearly increase the learning rate. 106 | """ 107 | def __init__(self, optimizer, model_size, warmup_steps): 108 | self.model_size = model_size 109 | self.warmup_steps = warmup_steps 110 | super().__init__(optimizer) 111 | 112 | def get_lr(self): 113 | step = max(1, self._step_count) 114 | scale = self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup_steps**(-1.5)) 115 | scale *= 1e4 116 | 117 | return [base_lr * scale for base_lr in self.base_lrs] 118 | -------------------------------------------------------------------------------- /sequence_evaluation.py: -------------------------------------------------------------------------------- 1 | 2 | import networkx as nx 3 | import numpy as np 4 | from rdkit import Chem 5 | from collections import defaultdict 6 | from multiprocessing import Pool, cpu_count 7 | from rdkit import RDLogger 8 | RDLogger.DisableLog('rdApp.*') 9 | 10 | 11 | BAD = 1000 12 | 13 | def assign_rank(predictions): 14 | sorted_items = sorted(predictions, key=lambda x: x[1], reverse=True) 15 | 16 | rank_list = [] 17 | current_rank = 0 18 | prev_count = None 19 | 20 | for i, (smi, count, _) in enumerate(sorted_items): 21 | # If the current value is the same as the previous one, assign the same rank 22 | if prev_count == count: 23 | rank_list.append((current_rank, smi, count, _)) 24 | else: 25 | # Otherwise, assign a new rank 26 | current_rank = i 27 | rank_list.append((current_rank, smi, count, _)) 28 | 29 | prev_count = count 30 | 31 | return rank_list 32 | 33 | def clean(smi): 34 | # try: 35 | mol = Chem.MolFromSmiles(smi, sanitize=False) 36 | mol = Chem.RemoveHs(mol) 37 | [atom.SetAtomMapNum(0) for atom in mol.GetAtoms()] 38 | return Chem.MolToSmiles(mol, isomericSmiles=False) 39 | 40 | def process_topk_acc_n_seq_rank(line): 41 | sequence_idx, sequence_preds = line 42 | seq_graph_gt = nx.DiGraph() 43 | 44 | new_sequence_edges = {} 45 | gt_neigh = defaultdict(set) 46 | for pred_info in sequence_preds: 47 | rxn = pred_info["rxn"] 48 | reactant, product = rxn.strip().split('>>') 49 | seq_graph_gt.add_edge(reactant, product) 50 | new_sequence_edges[(reactant, product)] = np.inf 51 | gt_neigh[reactant].add(product) 52 | 53 | 54 | starting_reac = [node for node, in_degree in seq_graph_gt.in_degree() if in_degree == 0] 55 | terminal_prods = list(nx.nodes_with_selfloops(seq_graph_gt)) 56 | 57 | if len(starting_reac) != 1: return BAD, 1, sequence_idx # if starting reactant is not 1 58 | starting_reac = starting_reac[0] 59 | 60 | if len(terminal_prods) == 0: return BAD, 2, sequence_idx # if we have a loop 61 | 62 | # merge predictions 63 | for pred_info in sequence_preds: 64 | reactant, _ = pred_info["rxn"].strip().split('>>') 65 | predictions = pred_info["predictions"] 66 | predictions = assign_rank(predictions) 67 | for rank, pred, pred_count, _ in predictions: 68 | if pred in gt_neigh[reactant]: 69 | cur_rank = new_sequence_edges.get((reactant, pred)) 70 | if cur_rank == np.inf: # and it's the first time 71 | new_sequence_edges[(reactant, pred)] = rank 72 | 73 | seq_graph_pred = nx.DiGraph() 74 | for (reac, prod), rank in new_sequence_edges.items(): 75 | seq_graph_pred.add_edge(reac, prod, weight=rank) 76 | 77 | max_depth = 0 78 | min_sequences_rank = np.inf 79 | for terminal in terminal_prods: 80 | for path in nx.all_simple_paths(seq_graph_pred, source=starting_reac, target=terminal): 81 | max_depth = max(len(path), max_depth) 82 | edges = nx.utils.pairwise(path) 83 | ranks = [seq_graph_pred.get_edge_data(u, v)['weight'] for u, v in edges] 84 | max_topk_within_one_seq = max(ranks) 85 | min_sequences_rank = min(max_topk_within_one_seq, min_sequences_rank) 86 | 87 | terminal_prods = [clean(prod) for prod in terminal_prods] 88 | return min_sequences_rank, 0, sequence_idx, (clean(starting_reac), terminal_prods), max_depth # min of all sequences 89 | 90 | 91 | def remove_atom_map_rxn(line): 92 | ps = Chem.SmilesParserParams() 93 | ps.removeHs = False 94 | ps.sanitize = True 95 | try: 96 | rxn, sequence_idx = line.strip().split("|") 97 | except: 98 | rxn, rxn_class, condition, elem_step, sequence_idx = line.strip().split("|") 99 | 100 | reactant, product = rxn.split(">>") 101 | reac = Chem.MolFromSmiles(reactant, ps) 102 | prod = Chem.MolFromSmiles(product, ps) 103 | 104 | assert reac is not None 105 | assert prod is not None 106 | 107 | [a.ClearProp('molAtomMapNumber') for a in reac.GetAtoms()] 108 | [a.ClearProp('molAtomMapNumber') for a in prod.GetAtoms()] 109 | 110 | reac_smi = Chem.MolToSmiles(reac, isomericSmiles=False) 111 | prod_smi = Chem.MolToSmiles(prod, isomericSmiles=False) 112 | 113 | reac = Chem.MolFromSmiles(reac_smi, ps) 114 | prod = Chem.MolFromSmiles(prod_smi, ps) 115 | reac_smi = Chem.MolToSmiles(reac, isomericSmiles=False) 116 | prod_smi = Chem.MolToSmiles(prod, isomericSmiles=False) 117 | 118 | rxn = f"{reac_smi}>>{prod_smi}|{sequence_idx}" 119 | 120 | return rxn 121 | 122 | def reparse(line): 123 | ps = Chem.SmilesParserParams() 124 | ps.removeHs = False 125 | ps.sanitize = True 126 | metrics, not_sym, predictions = line.strip().split("|") 127 | predictions = eval(predictions) 128 | predictions = sorted(predictions, key=lambda x: x[1], reverse=True) 129 | # new_predictions = [] 130 | pred_dict = defaultdict(int) 131 | for (pred, pred_count, val) in predictions: 132 | pred_mol = Chem.MolFromSmiles(pred, ps) 133 | if pred_mol is None: continue 134 | pred_smi = Chem.MolToSmiles(pred_mol, isomericSmiles=False) 135 | # new_predictions.append((pred_smi, pred_count, val)) 136 | pred_dict[pred_smi] += pred_count 137 | 138 | pred_dict = dict(sorted(pred_dict.items(), key=lambda x: x[1], reverse=True)) 139 | new_predictions = [(pred_smi, prob, True) for pred_smi, prob in pred_dict.items()] 140 | 141 | return f"{metrics}|{not_sym}|{new_predictions}" 142 | 143 | 144 | with open("data/flower_dataset/test.txt") as gt_o, \ 145 | open("results/flower_dataset/best_hyperparam/result-32-1440000_47.txt") as result_o: 146 | 147 | # Preprocessing 148 | result = result_o.readlines() 149 | gt = gt_o.readlines() 150 | 151 | assert len(gt) == len(result) 152 | 153 | print("Ground Truth lines") 154 | p = Pool(cpu_count()) 155 | gt = p.imap(remove_atom_map_rxn, (rxn for rxn in gt)) 156 | gt = list(gt) 157 | 158 | print("Prediction lines") 159 | result = p.imap(reparse, (res for res in result)) 160 | result = list(result) 161 | 162 | nbest = 10 163 | topk_accs = np.zeros([len(gt), nbest], dtype=np.float32) 164 | 165 | invalid = [] 166 | bag_of_vals = defaultdict(list) 167 | reac_prod_rank = {} 168 | for i, (line_res, line_gt) in enumerate(zip(result, gt)): 169 | metrics, not_sym, predictions = line_res.strip().split("|") 170 | metrics, predictions = eval(metrics), eval(predictions) 171 | 172 | invalid.append(metrics[3] / sum(metrics)) 173 | predictions = sorted(predictions, key=lambda x: x[1], reverse=True) 174 | 175 | rxn, sequence_idx = line_gt.strip().split("|") 176 | reactant, product = rxn.split(">>") 177 | if reactant in reac_prod_rank: 178 | extract_rank = reac_prod_rank[(reactant, product)] 179 | topk_accs[i, extract_rank:] = 1 180 | else: 181 | for rank, (pred, pred_count, _) in enumerate(predictions): 182 | if pred == product: 183 | topk_accs[i, rank:] = 1 184 | reac_prod_rank[(reactant, product)] = rank 185 | break 186 | 187 | if sequence_idx in ['PM', 'RS', 'RC', 'PC']: continue 188 | bag_of_vals[sequence_idx].append( 189 | { 190 | "rxn": rxn, 191 | "metrics": metrics, 192 | "predictions": predictions 193 | } 194 | ) 195 | avg_invalid = sum(invalid) / len(invalid) 196 | print(f"Valid percentage: {((1 - avg_invalid) * 100): .2f}%") 197 | 198 | 199 | print("Calculating Topk Step Accuracy") 200 | mean_seq_accuracies = np.mean(topk_accs, axis=0) 201 | for n in range(nbest): 202 | line = f"Top {n+1} step accuracy: {mean_seq_accuracies[n] * 100: .2f} %" 203 | print(line) 204 | 205 | sequence_accs = np.zeros([len(bag_of_vals), nbest], dtype=np.float32) 206 | 207 | print("Calculating Pathway Accuracy") 208 | 209 | no_starting_point = 0 210 | no_starting_point_set = set() 211 | count_no_terminal = 0 212 | count_no_terminal_set = set() 213 | 214 | seq_ranks = p.imap(process_topk_acc_n_seq_rank, ((seq_idx, seq_infos) for seq_idx, seq_infos in bag_of_vals.items())) 215 | for i, (rank, error, seq_idx, (reactant, prod_list), max_depth) in enumerate(seq_ranks): 216 | if error == 1: 217 | no_starting_point += 1 218 | no_starting_point_set.add(seq_idx) 219 | if error == 2: 220 | count_no_terminal += 1 221 | count_no_terminal_set.add(seq_idx) 222 | if rank >= nbest: continue 223 | sequence_accs[i, rank:] = 1 224 | 225 | p.close() 226 | p.join() 227 | 228 | # print('no_starting_point', no_starting_point) 229 | # print('no_starting_point_set', no_starting_point_set) 230 | # print('count_no_terminal', count_no_terminal) 231 | # print('count_no_terminal_set', count_no_terminal_set) 232 | 233 | mean_seq_accuracies = np.mean(sequence_accs, axis=0) 234 | for n in range(nbest): 235 | line = f"Top {n+1} pathway accuracy: {mean_seq_accuracies[n] * 100: .2f} %" 236 | print(line) 237 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlowER: Flow Matching for Electron Redistribution 2 | _Joonyoung F. Joung*, Mun Hong Fong*, Nicholas Casetti, Jordan P. Liles, Ne S. Dassanayake, Connor W. Coley_ 3 | 4 | **NOW published in _Nature_!** 5 | 6 | “Electron flow matching for generative reaction mechanism prediction.” *Nature* **645**, 115–123 (2025). 7 | DOI: [10.1038/s41586-025-09426-9](https://doi.org/10.1038/s41586-025-09426-9) 8 | 9 | ![Alt Text](FlowER.png) 10 | 11 | FlowER uses flow matching to model chemical reaction as a process of electron redistribution, conceptually 12 | aligns with arrow-pushing formalisms. It aims to capture the probabilistic nature of reactions with mass conservation 13 | where multiple outcomes are reached through branching mechanistic networks evolving in time. 14 | 15 | ## Environment Setup 16 | ### System requirements 17 | **Ubuntu**: >= 16.04
18 | **conda**: >= 4.0
19 | **GPU**: at least 25GB Memory with CUDA >= 12.2 20 | 21 | ```bash 22 | $ conda create -n flower python=3.10 23 | $ conda activate flower 24 | $ pip install -r requirements.txt 25 | ``` 26 | 27 | ## Data/Model preparation 28 | FlowER is trained on a combination of subset of USPTO-FULL (Dai et al.), RmechDB and PmechDB (Baldi et al.).
29 | To retrain/reproduce FlowER, download `data.zip` and `checkpoints.zip` folder from [this link](https://figshare.com/articles/dataset/FlowER_-_Mechanistic_datasets_and_model_checkpoint/28359407/3), and unzip them, and place under `FlowER/`
30 | The folder structure for the `data` folder is `data/{DATASET_NAME}/{train,val,test}.txt` and `checkpoints` folder is `checkpoints/{DATASET_NAME}/{EXPERIMENT_NAME}/model.{STEP}_{IDX}.pt` 31 | 32 | ## On how FlowER is structured 33 | The workflow of FlowER revolves mainly around 2 files. `run_FlowER_large_(old|new)Data.sh` and `settings.py`.
34 | The main idea is to use comments `#` to turn on/off configurations when training/validating/inferencing FlowER.
35 | `run_FlowER_large_(old|new)Data.sh` allows user to specify your data folder name, experiment name, gpu configuration and choose which scripts to run.
36 | `settings.py` allows user to specify different configurations for different workflows. 37 | 38 | ## Training Pipeline 39 | ### 1. Train FlowER 40 | Ensure that `data/` folder is populated accordingly and `run_FlowER_large_(old|new)Data.sh` is pointing to the correct files. 41 | ``` 42 | export TRAIN_FILE=$PWD/data/$DATA_NAME/train.txt 43 | export VAL_FILE=$PWD/data/$DATA_NAME/val.txt 44 | ``` 45 | Check `run_FlowER_large_(old|new)Data.sh` has `scripts/train.sh` uncommented. 46 | ```bash 47 | $ sh run_FlowER_large_(old|new)Data.sh 48 | ``` 49 | 50 | ### 2. Validate FlowER 51 | You can validate FlowER on the validation set. Then, in `settings.py`, ensure these are uncommented. 52 | ``` 53 | # validation # 54 | do_validate = True 55 | steps2validate = ["1050000", "1320000", "1500000", "930000", "1020000"] 56 | ``` 57 | `steps2validate` refers to the checkpoints that are selected based on train logs situated at the `/logs` folder.
58 | Check `run_FlowER_large_(old|new)Data.sh` has `scripts/eval.sh` uncommented. 59 | ```bash 60 | $ sh run_FlowER_large_(old|new)Data.sh 61 | ``` 62 | 63 | 64 | ### 3. Test FlowER 65 | You can validate FlowER on the test set. Then, in `settings.py`, specify your checkpoint at `MODEL_NAME` and ensure these are uncommented. 66 | ``` 67 | # inference # 68 | do_validate = False 69 | ``` 70 | Check `run_FlowER_large_(old|new)Data.sh` has `scripts/eval.sh` uncommented. 71 | ```bash 72 | $ sh run_FlowER_large_(old|new)Data.sh 73 | ``` 74 | 75 | #### FlowER train/valid/test input 76 | FlowER takes in atom-mapped reaction as input for training, validation and testing. Each of this elementary reaction steps that is trained on FlowER can be grouped together using sequence index during evaluation when running `sequence_evaluation.py`. \ 77 | \ 78 | An elementary reaction step reaction follows the format of `mapped_reaction|sequence_idx`. Examples are as follows: 79 | ``` 80 | [Cl:1][S:2]([Cl:3])=[O:4].[Cl:5][C:6]1=[N:7][S:8][C:9]([C:10](=[O:11])[O:12][H:15])=[C:13]1[Cl:14]>>[Cl:1][S:2]([Cl:3])([O-:4])[O:11][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O+:12][H:15]|11831 81 | [Cl:1][S:2]([Cl:3])([O-:4])[O:11][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O+:12][H:15]>>[Cl-:1].[S:2]([Cl:3])(=[O:4])[O:11][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O+:12][H:15]|11831 82 | [Cl-:1].[S:2]([Cl:3])(=[O:4])[O:11][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O+:12][H:15]>>[Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)([O:11][S:2]([Cl:3])=[O:4])[O:12][H:15]|11831 83 | [Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)([O:11][S:2]([Cl:3])=[O:4])[O:12][H:15]>>[Cl-:3].[Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O+:12][H:15].[S:2](=[O:4])=[O:11]|11831 84 | [Cl-:3].[Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O+:12][H:15].[S:2](=[O:4])=[O:11]>>[Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O:12].[Cl:3][H:15].[S:2](=[O:4])=[O:11]|11831 85 | [Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O:12].[Cl:3][H:15].[S:2](=[O:4])=[O:11]>>[Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O:12].[Cl:3][H:15].[S:2](=[O:4])=[O:11]|11831 86 | ``` 87 | 88 |
Train/Valid/Test hyperparameters 89 | 90 | ### Model Architecture 91 | - **`emb_dim`** - Embedding dimension size of atom embeddings 92 | - **`enc_num_layers`** - Number of transformer layers to be applied 93 | - **`enc_heads`** - Number of attention heads 94 | - **`enc_filter_size`** - Dimension of Feed-Forward Network in Transformer block 95 | - **`(attn)_dropout`** - Dropout for Transformer block (0.0 empirically works well) 96 | - **`sigma`** - Standard deviation of Gaussian noise added for reparameterizing the bond-electron (BE) matrix 97 | 98 | ### Optimization 99 | - **`lr`** - Learning rate for training (NoamLR) 100 | - **`warmup`** - Warmup steps before LR decay (NoamLR) 101 | - **`clip_norm`** - Gradient clipping threshold to prevent exploding gradients 102 | - **`beta1`**, **`beta2`** - Adam optimizer’s momentum terms 103 | - **`eps`** - Adam optimizer’s denominator term for numerical stability 104 | - **`weight_decay`** - L2 regularization strength to prevent overfitting 105 | 106 | ### Input representation (Bond-Electron matrix) 107 | - **`rbf_low`** - Radial Basis Function (RBF) centers lowest value 108 | - **`rbf_high`** - Radial Basis Function (RBF) centers highest value 109 | - **`rbf_gap`** - Glanularity of RBF centers increment 110 | 111 | ### Inference 112 | - **`do_validate`** - True to trigger validation, False to trigger testing 113 | - **`steps2validate`** - List of checkpoints to run FlowER on for validation 114 | - **`sample_size`** - Number of samples FlowER generates for evaluation 115 | 116 |
117 | 118 | ### 4. Use FlowER for search 119 | FlowER mainly uses beam search to seek for plausible mechanistic pathways. Users can input their smiles at `data/flower_dataset/beam.txt`.
120 | Ensure that in `run_FlowER_large_(old|new)Data.sh`, the `TEST_FILE` variable is pointing towards the correct file. 121 | ``` 122 | export TEST_FILE=$PWD/data/$DATA_NAME/beam.txt 123 | ``` 124 | Ensure that in `settings.py`, beam search configuration are uncommented and specified accordingly. 125 | ``` 126 | test_path = f"data/{DATA_NAME}/beam.txt" 127 | 128 | # beam-search # 129 | beam_size = 5 130 | nbest = 3 131 | max_depth = 15 132 | chunk_size = 50 133 | ``` 134 | Check `run_FlowER_large_(old|new)Data.sh` has `scripts/search.sh` or `sh scripts/search_multiGPU.sh` uncommented. 135 | ```bash 136 | $ sh run_FlowER_large_(old|new)Data.sh 137 | ``` 138 | Visualize your route at `examples/vis_network.ipynb` 139 | 140 | #### FlowER search input 141 | FlowER takes in a non atom-mapped reaction for beam search which can be specified in `beam.txt` 142 | The format of reactants in the file follows `reactant>>product1|product2|...`, where we can specify multiple major and minor products separated by `|` in the file 143 | ``` 144 | CC(=O)CC(=O)C(F)(F)F.NNc1cccc(Br)c1>>Cc1cc(C(F)(F)F)n(-c2cccc(Br)c2)n1 145 | CC(=O)CC(=O)C(F)(F)F.NNc1cccc(Br)c1>>Cc1cc(C(F)(F)F)n(-c2cccc(Br)c2)n1|Cc1cc(C(F)(F)F)nn1-c1cccc(Br)c1 146 | ``` 147 | 148 |
Search hyperparameters 149 | 150 | - **`beam_size`** - Size of top-k selection of candidates (based on cumulative probability) 151 | to be further expanded. Increasing this would make the overall search more comprehensive, but at the 152 | cost of slower runtime. 153 | - **`nbest`** - Cut-off size of the top-k outcomes generated by FlowER after the expan- 154 | sion. This cutoff can filter out unlikely outcomes to be part of the selection. 155 | - **`sample_size`** - Number of samples FlowER generates for evaluation 156 | - **`max_depth`** - Refers to the maximum depth the beam search should explore. 157 | - **`chunk_size`** - Number of reactants sets to be run beam search concurrently. 158 | 159 |
160 | 161 | ## Citation 162 | ```bibtex 163 | @article{joung2025electron, 164 | title={Electron flow matching for generative reaction mechanism prediction obeying conservation laws}, 165 | author={Joung, Joonyoung F and Fong, Mun Hong and Casetti, Nicholas and Liles, Jordan P and Dassanayake, Ne S and Coley, Connor W}, 166 | journal={arXiv preprint arXiv:2502.12979}, 167 | year={2025} 168 | } 169 | ``` 170 | 171 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import datetime 5 | import logging 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.distributed as dist 10 | from model.attn_encoder import AttnEncoderXL 11 | from utils.data_utils import ReactionDataset 12 | from torch.utils.data import DataLoader 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 15 | from torch.utils.data.distributed import DistributedSampler 16 | from settings import Args 17 | from model.flow_matching import ConditionalFlowMatcher 18 | from utils.train_utils import get_lr, grad_norm, log_rank_0, NoamLR, \ 19 | param_count, param_norm, set_seed, setup_logger, log_args 20 | from torch.nn.init import xavier_uniform_ 21 | import torch.optim as optim 22 | 23 | torch.set_printoptions(precision=4, profile="full", sci_mode=False, linewidth=10000) 24 | np.set_printoptions(threshold=sys.maxsize, precision=4, suppress=True, linewidth=500) 25 | 26 | def init_dist(args): 27 | if args.local_rank != -1: 28 | dist.init_process_group(backend=args.backend, 29 | init_method='env://', 30 | timeout=datetime.timedelta(minutes=10)) 31 | torch.cuda.set_device(args.local_rank) 32 | torch.backends.cudnn.benchmark = False 33 | 34 | if dist.is_initialized(): 35 | logging.info(f"Device rank: {dist.get_rank()}") 36 | sys.stdout.flush() 37 | 38 | 39 | def init_model(args): 40 | state = {} 41 | if args.load_from: 42 | log_rank_0(f"Loading pretrained state from {args.load_from}") 43 | state = torch.load(args.load_from, map_location=torch.device("cpu")) 44 | pretrain_args = state["args"] 45 | pretrain_args.local_rank = args.local_rank 46 | 47 | graph_attn_model = AttnEncoderXL(pretrain_args) 48 | pretrain_state_dict = state["state_dict"] 49 | pretrain_state_dict = {k.replace("module.", ""): v for k, v in pretrain_state_dict.items()} 50 | graph_attn_model.load_state_dict(pretrain_state_dict) 51 | log_rank_0("Loaded pretrained model state_dict.") 52 | flow_model = ConditionalFlowMatcher(args) 53 | else: 54 | graph_attn_model = AttnEncoderXL(args) 55 | flow_model = ConditionalFlowMatcher(args) 56 | for p in graph_attn_model.parameters(): 57 | if p.dim() > 1 and p.requires_grad: 58 | xavier_uniform_(p) 59 | 60 | graph_attn_model.to(args.device) 61 | flow_model.to(args.device) 62 | if args.local_rank != -1: 63 | graph_attn_model = DDP( 64 | graph_attn_model, 65 | device_ids=[args.local_rank], 66 | output_device=args.local_rank 67 | ) 68 | log_rank_0("DDP setup finished") 69 | 70 | os.makedirs(args.model_path, exist_ok=True) 71 | 72 | return graph_attn_model, flow_model, state 73 | 74 | def init_loader(args, dataset, batch_size: int, bucket_size: int = 1000, 75 | shuffle: bool = False, epoch: int = None, use_sort: bool =True): 76 | if use_sort: dataset.sort() 77 | if shuffle: dataset.shuffle_in_bucket(bucket_size=bucket_size) 78 | dataset.batch( 79 | batch_type=args.batch_type, 80 | batch_size=batch_size 81 | ) 82 | 83 | if args.local_rank != -1: 84 | sampler = DistributedSampler(dataset, shuffle=shuffle) 85 | if epoch is not None: 86 | sampler.set_epoch(epoch) 87 | else: 88 | sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset) 89 | 90 | loader = DataLoader( 91 | dataset=dataset, 92 | batch_size=1, 93 | sampler=sampler, 94 | num_workers=args.num_workers, 95 | collate_fn=lambda _batch: _batch[0], 96 | pin_memory=True 97 | ) 98 | 99 | return loader 100 | 101 | def get_optimizer_and_scheduler(args, model, state=None): 102 | optimizer = optim.AdamW( 103 | model.parameters(), 104 | lr=args.lr, 105 | betas=(args.beta1, args.beta2), 106 | eps=args.eps, 107 | weight_decay=args.weight_decay 108 | ) 109 | # scheduler = None 110 | scheduler = NoamLR( 111 | optimizer, 112 | model_size=args.emb_dim, 113 | warmup_steps=args.warmup_steps 114 | ) 115 | # scheduler = optim.lr_scheduler.StepLR( 116 | # optimizer, 117 | # step_size=args.eval_iter, gamma=0.99 118 | # ) 119 | 120 | if state and args.resume: 121 | optimizer.load_state_dict(state["optimizer"]) 122 | scheduler.load_state_dict(state["scheduler"]) 123 | log_rank_0("Loaded pretrained optimizer and scheduler state_dicts.") 124 | 125 | return optimizer, scheduler 126 | 127 | def _optimize(args, model, optimizer, scheduler): 128 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) 129 | optimizer.step() 130 | scheduler.step() 131 | g_norm = grad_norm(model) 132 | model.zero_grad(set_to_none=True) 133 | return g_norm 134 | 135 | def main(args): 136 | args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 137 | device = args.device 138 | 139 | init_dist(args) 140 | log_args(args, 'training') 141 | model, flow, state = init_model(args) 142 | total_step = state["total_step"] if state else 0 143 | log_rank_0(f"Number of parameters: {param_count(model)}") 144 | 145 | optimizer, scheduler = get_optimizer_and_scheduler(args, model, state) 146 | 147 | log_rank_0(f"Initializing training ...") 148 | log_rank_0(f"Loading data ...") 149 | with open(args.train_path, 'r') as train_o: 150 | train_smiles_list = train_o.readlines() 151 | with open(args.val_path, 'r') as val_o: 152 | val_smiles_list = val_o.readlines() 153 | 154 | train_dataset = ReactionDataset(args, train_smiles_list) 155 | val_dataset = ReactionDataset(args, val_smiles_list) 156 | 157 | accum = 0 158 | g_norm = 0 159 | losses, accs = [], [] 160 | o_start = time.time() 161 | log_rank_0("Start training") 162 | 163 | accuracy = [] 164 | for epoch in range(args.epoch): 165 | log_rank_0(f"Epoch: {epoch}") 166 | train_loader = init_loader(args, train_dataset, 167 | batch_size=args.train_batch_size, 168 | shuffle=True, 169 | epoch=epoch) 170 | for train_batch in train_loader: 171 | if total_step > args.max_steps: 172 | log_rank_0("Max steps reached, finish training") 173 | exit(0) 174 | 175 | train_batch.to(device) 176 | model.train() 177 | model.zero_grad(set_to_none=True) 178 | 179 | y = train_batch.src_token_ids 180 | y_len = train_batch.src_lens 181 | x0 = train_batch.src_matrices 182 | x1 = train_batch.tgt_matrices 183 | matrix_masks = train_batch.matrix_masks 184 | 185 | 186 | x0_sample = flow.sample_be_matrix(x0) 187 | 188 | t = torch.rand(x0.shape[0]).type_as(x0) 189 | 190 | xt = flow.sample_conditional_pt(x0, x1, t) 191 | ut = flow.compute_conditional_vector_field(x0_sample, x1) 192 | 193 | if hasattr(model, "module"): 194 | model = model.module # unwrap DDP attn_model to enable accessing attn_model func directly 195 | y_emb = model.id2emb(y) 196 | vt = model(y_emb, y_len, xt, t) 197 | 198 | loss = (vt - ut) * matrix_masks 199 | loss = torch.sum((loss) ** 2) / loss.shape[0] 200 | (loss / args.accumulation_count).backward() 201 | losses.append(loss.item()) 202 | 203 | accum += 1 204 | if accum == args.accumulation_count: 205 | g_norm = _optimize(args, model, optimizer, scheduler) 206 | accum = 0 207 | total_step += 1 208 | 209 | if (accum == 0) and (total_step > 0) and (total_step % args.log_iter == 0): 210 | log_rank_0(f"Step {total_step}, loss: {np.mean(losses): .4f}, " 211 | # f"acc: {np.mean(accs): .4f}, 212 | f"p_norm: {param_norm(model): .4f}, g_norm: {g_norm: .4f}, " 213 | f"lr: {get_lr(optimizer): .6f}, " 214 | f"elapsed time: {time.time() - o_start: .0f}") 215 | losses, acc = [], [] 216 | 217 | if (accum == 0) and (total_step > 0) and (total_step % args.eval_iter == 0): 218 | val_count = 50 219 | val_loader = init_loader(args, val_dataset, 220 | batch_size=args.val_batch_size, 221 | shuffle=True, 222 | epoch=epoch) 223 | from eval_multiGPU import get_predictions 224 | metrics = get_predictions(args, model, flow, val_loader, val_count) 225 | if dist.get_rank() == 0: 226 | metrics = np.array(metrics) 227 | log_rank_0(metrics.shape) 228 | topk_accuracies = np.mean(metrics[:, 0].astype(bool)) # correct smiles 229 | log_rank_0(f"Topk accuracies: {(topk_accuracies * 100): .2f}") 230 | model.train() 231 | 232 | # Important: saving only at one node or the ckpt would be corrupted! 233 | if dist.is_initialized() and dist.get_rank() > 0: 234 | continue 235 | 236 | if (accum == 0) and (total_step > 0) and (total_step % args.save_iter == 0): 237 | n_iter = total_step // args.save_iter - 1 238 | log_rank_0(f"Saving at step {total_step}") 239 | if scheduler is not None: 240 | state = { 241 | "args": args, 242 | "total_step": total_step, 243 | "state_dict": model.state_dict(), 244 | "optimizer": optimizer.state_dict(), 245 | "scheduler": scheduler.state_dict() 246 | } 247 | else: 248 | state = { 249 | "args": args, 250 | "total_step": total_step, 251 | "state_dict": model.state_dict(), 252 | "optimizer": optimizer.state_dict(), 253 | } 254 | torch.save(state, os.path.join(args.model_path, f"model.{total_step}_{n_iter}.pt")) 255 | 256 | # lastly 257 | if (args.accumulation_count > 1) and (accum > 0): 258 | _optimize(args, model, optimizer, scheduler) 259 | accum = 0 260 | # total_step += 1 # for partial batch, do not increase total_step 261 | 262 | if args.local_rank != -1: 263 | dist.barrier() 264 | log_rank_0("Epoch ended") 265 | if dist.is_initialized(): 266 | dist.destroy_process_group() 267 | 268 | if __name__ == "__main__": 269 | args = Args 270 | logger = setup_logger(args, "train") 271 | args.local_rank = int(os.environ["LOCAL_RANK"]) if os.environ.get("LOCAL_RANK") else -1 272 | main(args) 273 | -------------------------------------------------------------------------------- /beam_predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import torch 4 | import numpy as np 5 | from rdkit import Chem 6 | from utils.data_utils import ReactionDataset, BEmatrix_to_mol, ps 7 | import torch.distributed as dist 8 | from train import init_model, init_loader 9 | from utils.train_utils import log_rank_0, setup_logger, log_args 10 | from eval_multiGPU import custom_round 11 | from settings import Args 12 | from collections import defaultdict 13 | import networkx as nx 14 | import pickle 15 | from eval_multiGPU import predict_batch 16 | 17 | import warnings 18 | warnings.filterwarnings("ignore", category=FutureWarning) 19 | 20 | def standardize_smiles(mol): 21 | return Chem.MolToSmiles(mol, isomericSmiles=False, allHsExplicit=True) 22 | 23 | def select(args, frontiers_dict, graph_list): 24 | filtered_frontiers_dict = {} 25 | for g_idx, frontiers in frontiers_dict.items(): 26 | graph, root, _ = graph_list[g_idx] 27 | rank_frontiers = {} 28 | for frontier in frontiers: 29 | min_sequences_rank = np.inf 30 | for path in nx.all_simple_paths(graph, root, frontier): 31 | max_depth = max(graph.nodes[root]['depth'], len(path)) 32 | graph.nodes[root]['depth'] = max_depth 33 | edges = list(nx.utils.pairwise(path)) 34 | ranks = [graph.get_edge_data(u, v)['rank'] for u, v in edges] 35 | probs = [graph.get_edge_data(u, v)['count'] / args.sample_size for u, v in edges] 36 | cum_prob = np.prod(probs) 37 | max_topk_within_one_seq = max(ranks) 38 | min_sequences_rank = min(max_topk_within_one_seq, min_sequences_rank) 39 | 40 | # rank_frontiers[frontier] = min_sequences_rank 41 | rank_frontiers[frontier] = -cum_prob 42 | rank_frontiers = sorted(rank_frontiers.items(), key=lambda x:x[1])[:args.beam_size] 43 | # leftover_frontiers = sorted(rank_frontiers.items(), key=lambda x:x[1])[args.beam_size:] 44 | # graph.remove_nodes_from([frontier for frontier, prob in leftover_frontiers]) 45 | 46 | filtered_frontiers_dict[g_idx] = list(dict(rank_frontiers).keys()) 47 | return filtered_frontiers_dict 48 | 49 | 50 | def expand(args, model, flow, data_loader): 51 | sample_size = args.sample_size 52 | 53 | overall_dict = {} 54 | for batch_idx, data_batch in enumerate(data_loader): 55 | # print(data_batch.src_matrices.shape) 56 | data_batch.to(args.device) 57 | src_data_indices = data_batch.src_data_indices 58 | y = data_batch.src_token_ids 59 | y_len = data_batch.src_lens 60 | x0 = data_batch.src_matrices 61 | matrix_masks = data_batch.matrix_masks 62 | src_smiles_list = data_batch.src_smiles_list 63 | 64 | batch_size, n, n = x0.shape 65 | 66 | if (batch_size*n*n) <= 5*360*360: 67 | traj_list = predict_batch(args, batch_idx, data_batch, model, flow, 1) 68 | else: 69 | traj_list = predict_batch(args, batch_idx, data_batch, model, flow, 2) 70 | 71 | 72 | last_step = traj_list[-1] 73 | product_BE_matrices = custom_round(last_step) 74 | product_BE_matrices_batch = torch.split(product_BE_matrices, sample_size) 75 | 76 | for idx in range(batch_size): 77 | reac_smi, product_BE_matrices = \ 78 | src_smiles_list[idx], product_BE_matrices_batch[idx] 79 | 80 | reac_mol = Chem.MolFromSmiles(reac_smi, ps) 81 | matrices, counts = torch.unique(product_BE_matrices, dim=0, return_counts=True) 82 | matrices, counts = matrices.cpu().numpy(), counts.cpu().numpy() 83 | 84 | pred_smis_dict = defaultdict(int) 85 | for i in range(matrices.shape[0]): # all unique matrices 86 | pred_prod_be_matrix, count = matrices[i], counts[i] # predicted product matrix and it's count 87 | num_nodes = y_len[idx] 88 | pred_prod_be_matrix = pred_prod_be_matrix[:num_nodes, :num_nodes] 89 | reac_be_matrix = x0[idx][:num_nodes, :num_nodes].detach().cpu().numpy() 90 | 91 | assert pred_prod_be_matrix.shape == reac_be_matrix.shape, "pred and reac not the same shape" 92 | 93 | try: 94 | pred_mol = BEmatrix_to_mol(reac_mol, pred_prod_be_matrix) 95 | pred_smi = standardize_smiles(pred_mol) 96 | pred_mol = Chem.MolFromSmiles(pred_smi, ps) 97 | pred_smi = standardize_smiles(pred_mol) 98 | pred_smis_dict[pred_smi] += count 99 | except: pass 100 | 101 | pred_smis_tuples = sorted(pred_smis_dict.items(), key=lambda x: x[1], reverse=True) 102 | 103 | pred_smis_dict = dict(pred_smis_tuples[:args.nbest]) 104 | overall_dict[reac_smi] = pred_smis_dict 105 | 106 | return overall_dict 107 | 108 | def reactant_process(smi): 109 | try: 110 | mol = Chem.MolFromSmiles(smi) 111 | mol = Chem.AddHs(mol, explicitOnly=False) 112 | for idx, atom in enumerate(mol.GetAtoms()): 113 | atom.SetAtomMapNum(idx+1) 114 | # src_smi = reactant_process(src_smi) 115 | # print(src_smi) 116 | return Chem.MolToSmiles(mol, isomericSmiles=False, allHsExplicit=True) 117 | except: 118 | print(smi) 119 | raise 120 | 121 | def clean(smi): 122 | # try: 123 | mol = Chem.MolFromSmiles(smi, sanitize=False) 124 | mol = Chem.RemoveHs(mol) 125 | [atom.SetAtomMapNum(0) for atom in mol.GetAtoms()] 126 | return Chem.MolToSmiles(mol, isomericSmiles=False) 127 | 128 | def beam_search(args, model, flow, frontiers_dict, graph_list): 129 | smiles_list = [frontier for frontiers in frontiers_dict.values() for frontier in frontiers] 130 | # print('frontiers', smiles_list) 131 | # print() 132 | if len(smiles_list) == 0: return 133 | log_rank_0(f"Current Depth: {[graph.nodes[root]['depth'] for graph, root, _ in graph_list]}") 134 | exclude_gidx = [g_idx for g_idx, (graph, root, _) in enumerate(graph_list) 135 | if graph.nodes[root]['depth'] >= args.max_depth] 136 | 137 | test_dataset = ReactionDataset(args, smiles_list, reactant_only=True) 138 | try: 139 | test_loader = init_loader(args, test_dataset, 140 | batch_size=args.test_batch_size, 141 | shuffle=False, epoch=None, use_sort=False) 142 | except Exception as e: 143 | print(e) 144 | return 145 | 146 | overall_dict = expand(args, model, flow, test_loader) 147 | new_frontiers_dict = defaultdict(list) 148 | 149 | existing_reactions = {g_idx: {} for g_idx in frontiers_dict.keys()} 150 | for g_idx, frontiers in frontiers_dict.items(): 151 | if g_idx in exclude_gidx: continue 152 | existing_reaction = existing_reactions[g_idx] 153 | graph, _, _ = graph_list[g_idx] 154 | for frontier in frontiers: 155 | clean_frontier = clean(frontier) # --- 156 | try: product_info_dict = overall_dict[frontier] # given reactant, product info 157 | except: continue 158 | for rank, (product, count) in enumerate(product_info_dict.items()): 159 | try: clean_product = clean(product) # -- 160 | except: continue 161 | if (clean_frontier, clean_product) in existing_reaction: 162 | stored_frontier, stored_product = existing_reaction[(clean_frontier, clean_product)] 163 | parent_current = list(graph.predecessors(frontier)) 164 | parent_stored = list(graph.predecessors(stored_frontier)) 165 | 166 | if parent_current == parent_stored: 167 | graph[stored_frontier][stored_product]["count"] += count 168 | 169 | else: 170 | if not graph.has_node(product): 171 | new_frontiers_dict[g_idx].append(product) 172 | graph.add_edge(frontier, product, rank=rank, count=count) 173 | existing_reaction[(clean_frontier, clean_product)] = (frontier, product) 174 | 175 | 176 | filtered_frontiers_dict = select(args, new_frontiers_dict, graph_list) 177 | beam_search(args, model, flow, filtered_frontiers_dict, graph_list) 178 | 179 | def check_if_successful(graph, products): 180 | nodes_with_loops = list(nx.nodes_with_selfloops(graph)) 181 | achieved_products = set() 182 | for node in graph.nodes(): 183 | node_in_products = set(clean(node).split('.')) & set(products) 184 | if node_in_products and node in nodes_with_loops: 185 | achieved_products.update(node_in_products) 186 | return achieved_products 187 | 188 | def main(args, seed=0): 189 | args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 190 | device = args.device 191 | if args.local_rank != -1: 192 | dist.init_process_group(backend=args.backend, init_method='env://', timeout=datetime.timedelta(0, 7200)) 193 | torch.cuda.set_device(args.local_rank) 194 | torch.backends.cudnn.benchmark = True 195 | 196 | with open(args.test_path, 'r') as test_o: 197 | test_smiles_list = test_o.readlines() 198 | 199 | chunk_size = args.chunk_size 200 | chunked_list = [test_smiles_list[i:i + chunk_size] for i in range(0, len(test_smiles_list), chunk_size)] 201 | 202 | for i, chunk in enumerate(chunked_list): 203 | log_rank_0(f"Group Chunk-{i} called:") 204 | checkpoint = os.path.join(args.model_path, args.model_name) 205 | state = torch.load(checkpoint, weights_only=False, map_location=device) 206 | pretrain_args = state["args"] 207 | pretrain_args.load_from = None 208 | pretrain_args.device = device 209 | 210 | pretrain_state_dict = state["state_dict"] 211 | pretrain_args.local_rank = args.local_rank 212 | 213 | attn_model, flow, state = init_model(pretrain_args) 214 | if hasattr(attn_model, "module"): 215 | attn_model = attn_model.module # unwrap DDP attn_model to enable accessing attn_model func directly 216 | 217 | pretrain_state_dict = {k.replace("module.", ""): v for k, v in pretrain_state_dict.items()} 218 | attn_model.load_state_dict(pretrain_state_dict) 219 | log_rank_0(f"Loaded pretrained state_dict from {checkpoint}") 220 | 221 | graph_list = [] 222 | frontiers_dict = defaultdict(list) 223 | for idx, line in enumerate(chunk): 224 | if ">>" in line: 225 | ori_reactant = line.strip().split(">>")[0] 226 | products = line.strip().split(">>")[1].split("|") # major products 227 | products = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in products] 228 | else: 229 | ori_reactant = line.strip() 230 | products = [] 231 | reactant = reactant_process(ori_reactant) 232 | graph = nx.DiGraph() 233 | graph.add_node(reactant, depth=1) 234 | graph_list.append((graph, reactant, (ori_reactant, products))) 235 | frontiers_dict[idx].append(reactant) 236 | 237 | beam_search(args, attn_model, flow, frontiers_dict, graph_list) 238 | 239 | all_results = [] 240 | for beam_idx, (graph, root, (reactant, products)) in enumerate(graph_list): 241 | # print(output_chunk_idx, reaction) 242 | check = check_if_successful(graph, products) 243 | log_rank_0(f"Beam Search Results {beam_idx}: {len(check)}/{len(products)} - {check}") 244 | all_results.append((graph, root, (reactant, products), check)) 245 | 246 | os.makedirs(args.result_path, exist_ok=True) 247 | saving_file = os.path.join(args.result_path, f'result_chunk_{i}_s{seed}.pickle') 248 | with open(saving_file, "wb") as f_out: 249 | pickle.dump(all_results, f_out) 250 | 251 | 252 | if __name__ == "__main__": 253 | args = Args 254 | args.local_rank = int(os.environ["LOCAL_RANK"]) if os.environ.get("LOCAL_RANK") else -1 255 | logger = setup_logger(args, "beam") 256 | log_args(args, 'evaluation') 257 | main(args) 258 | -------------------------------------------------------------------------------- /beam_predict_multiGPU.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from rdkit import Chem 5 | from utils.data_utils import ReactionDataset, BEmatrix_to_mol, ps 6 | import torch.distributed as dist 7 | from train import init_model, init_loader 8 | from utils.train_utils import log_rank_0, setup_logger, log_args 9 | from eval_multiGPU import custom_round 10 | from settings import Args 11 | from collections import defaultdict 12 | import networkx as nx 13 | import pickle 14 | import torch.multiprocessing as mp 15 | import time 16 | from eval_multiGPU import predict_batch 17 | 18 | import warnings 19 | warnings.filterwarnings("ignore", category=FutureWarning) 20 | 21 | def standardize_smiles(mol): 22 | return Chem.MolToSmiles(mol, isomericSmiles=False, allHsExplicit=True) 23 | 24 | def select(args, frontiers_dict, graph_list): 25 | filtered_frontiers_dict = {} 26 | for g_idx, frontiers in frontiers_dict.items(): 27 | graph, root, _ = graph_list[g_idx] 28 | rank_frontiers = {} 29 | for frontier in frontiers: 30 | min_sequences_rank = np.inf 31 | for path in nx.all_simple_paths(graph, root, frontier): 32 | max_depth = max(graph.nodes[root]['depth'], len(path)) 33 | graph.nodes[root]['depth'] = max_depth 34 | edges = list(nx.utils.pairwise(path)) 35 | ranks = [graph.get_edge_data(u, v)['rank'] for u, v in edges] 36 | probs = [graph.get_edge_data(u, v)['count'] / args.sample_size for u, v in edges] 37 | cum_prob = np.prod(probs) 38 | max_topk_within_one_seq = max(ranks) 39 | min_sequences_rank = min(max_topk_within_one_seq, min_sequences_rank) 40 | 41 | # rank_frontiers[frontier] = min_sequences_rank 42 | rank_frontiers[frontier] = -cum_prob 43 | rank_frontiers = sorted(rank_frontiers.items(), key=lambda x:x[1])[:args.beam_size] 44 | # leftover_frontiers = sorted(rank_frontiers.items(), key=lambda x:x[1])[args.beam_size:] 45 | # graph.remove_nodes_from([frontier for frontier, prob in leftover_frontiers]) 46 | 47 | filtered_frontiers_dict[g_idx] = list(dict(rank_frontiers).keys()) 48 | return filtered_frontiers_dict 49 | 50 | 51 | def expand(args, model, flow, data_loader): 52 | sample_size = args.sample_size 53 | 54 | overall_dict = {} 55 | for batch_idx, data_batch in enumerate(data_loader): 56 | data_batch.to(args.device) 57 | src_data_indices = data_batch.src_data_indices 58 | y = data_batch.src_token_ids 59 | y_len = data_batch.src_lens 60 | x0 = data_batch.src_matrices 61 | matrix_masks = data_batch.matrix_masks 62 | src_smiles_list = data_batch.src_smiles_list 63 | 64 | batch_size, n, n = x0.shape 65 | 66 | if (batch_size*n*n) <= 5*360*360: 67 | traj_list = predict_batch(args, batch_idx, data_batch, model, flow, 1) 68 | else: 69 | traj_list = predict_batch(args, batch_idx, data_batch, model, flow, 2) 70 | 71 | 72 | last_step = traj_list[-1] 73 | product_BE_matrices = custom_round(last_step) 74 | product_BE_matrices_batch = torch.split(product_BE_matrices, sample_size) 75 | 76 | for idx in range(batch_size): 77 | reac_smi, product_BE_matrices = \ 78 | src_smiles_list[idx], product_BE_matrices_batch[idx] 79 | 80 | reac_mol = Chem.MolFromSmiles(reac_smi, ps) 81 | matrices, counts = torch.unique(product_BE_matrices, dim=0, return_counts=True) 82 | matrices, counts = matrices.cpu().numpy(), counts.cpu().numpy() 83 | 84 | pred_smis_dict = defaultdict(int) 85 | for i in range(matrices.shape[0]): # all unique matrices 86 | pred_prod_be_matrix, count = matrices[i], counts[i] # predicted product matrix and it's count 87 | num_nodes = y_len[idx] 88 | pred_prod_be_matrix = pred_prod_be_matrix[:num_nodes, :num_nodes] 89 | reac_be_matrix = x0[idx][:num_nodes, :num_nodes].detach().cpu().numpy() 90 | 91 | assert pred_prod_be_matrix.shape == reac_be_matrix.shape, "pred and reac not the same shape" 92 | 93 | try: 94 | pred_mol = BEmatrix_to_mol(reac_mol, pred_prod_be_matrix) 95 | pred_smi = standardize_smiles(pred_mol) 96 | pred_mol = Chem.MolFromSmiles(pred_smi, ps) 97 | pred_smi = standardize_smiles(pred_mol) 98 | pred_smis_dict[pred_smi] += count 99 | except: pass 100 | 101 | pred_smis_tuples = sorted(pred_smis_dict.items(), key=lambda x: x[1], reverse=True) 102 | 103 | pred_smis_dict = dict(pred_smis_tuples[:args.nbest]) 104 | overall_dict[reac_smi] = pred_smis_dict 105 | 106 | return overall_dict 107 | 108 | def reactant_process(smi): 109 | try: 110 | mol = Chem.MolFromSmiles(smi) 111 | mol = Chem.AddHs(mol, explicitOnly=False) 112 | for idx, atom in enumerate(mol.GetAtoms()): 113 | atom.SetAtomMapNum(idx+1) 114 | return Chem.MolToSmiles(mol, isomericSmiles=False, allHsExplicit=True) 115 | except: 116 | print(smi) 117 | raise 118 | 119 | def clean(smi): 120 | # try: 121 | mol = Chem.MolFromSmiles(smi, sanitize=False) 122 | mol = Chem.RemoveHs(mol) 123 | [atom.SetAtomMapNum(0) for atom in mol.GetAtoms()] 124 | return Chem.MolToSmiles(mol, isomericSmiles=False) 125 | 126 | def beam_search(args, model, flow, frontiers_dict, graph_list): 127 | smiles_list = [frontier for frontiers in frontiers_dict.values() for frontier in frontiers] 128 | # print('frontiers', smiles_list) 129 | # print() 130 | if len(smiles_list) == 0: return 131 | print(f"Current Depth: {[graph.nodes[root]['depth'] for graph, root, _ in graph_list]}") 132 | exclude_gidx = [g_idx for g_idx, (graph, root, _) in enumerate(graph_list) 133 | if graph.nodes[root]['depth'] >= args.max_depth] 134 | 135 | test_dataset = ReactionDataset(args, smiles_list, reactant_only=True) 136 | try: 137 | test_loader = init_loader(args, test_dataset, 138 | batch_size=args.test_batch_size, 139 | shuffle=False, epoch=None, use_sort=False) 140 | except Exception as e: 141 | print(e) 142 | return 143 | 144 | overall_dict = expand(args, model, flow, test_loader) 145 | new_frontiers_dict = defaultdict(list) 146 | 147 | existing_reactions = {g_idx: {} for g_idx in frontiers_dict.keys()} 148 | for g_idx, frontiers in frontiers_dict.items(): 149 | if g_idx in exclude_gidx: continue 150 | existing_reaction = existing_reactions[g_idx] 151 | graph, _, _ = graph_list[g_idx] 152 | for frontier in frontiers: 153 | clean_frontier = clean(frontier) # --- 154 | try: product_info_dict = overall_dict[frontier] # given reactant, product info 155 | except: continue 156 | for rank, (product, count) in enumerate(product_info_dict.items()): 157 | try: clean_product = clean(product) # -- 158 | except: continue 159 | if (clean_frontier, clean_product) in existing_reaction: 160 | stored_frontier, stored_product = existing_reaction[(clean_frontier, clean_product)] 161 | parent_current = list(graph.predecessors(frontier)) 162 | parent_stored = list(graph.predecessors(stored_frontier)) 163 | 164 | if parent_current == parent_stored: 165 | graph[stored_frontier][stored_product]["count"] += count 166 | 167 | else: 168 | if not graph.has_node(product): 169 | new_frontiers_dict[g_idx].append(product) 170 | graph.add_edge(frontier, product, rank=rank, count=count) 171 | existing_reaction[(clean_frontier, clean_product)] = (frontier, product) 172 | 173 | 174 | filtered_frontiers_dict = select(args, new_frontiers_dict, graph_list) 175 | beam_search(args, model, flow, filtered_frontiers_dict, graph_list) 176 | 177 | 178 | def group_lists(lists, group_size): 179 | result = [] 180 | # Process lists in chunks of group_size 181 | for i in range(0, len(lists), group_size): 182 | # Take a slice of size group_size (or remaining elements if less) 183 | chunk = lists[i:i + group_size] 184 | # Convert the chunk to a tuple and add to result 185 | result.append(tuple(chunk)) 186 | 187 | return result 188 | 189 | import signal 190 | import os 191 | 192 | def init_process_killer(): 193 | global processes 194 | processes = [] 195 | 196 | def signal_handler(sig, frame): 197 | print('\nTerminating all processes...') 198 | for p in processes: 199 | if p.is_alive(): 200 | p.terminate() 201 | p.join() # Wait for process to finish 202 | os._exit(0) 203 | 204 | signal.signal(signal.SIGINT, signal_handler) 205 | return processes 206 | 207 | def worker(rank, args, chunk, chunk_idx, lock, queue): 208 | """Worker function that runs on each GPU""" 209 | # Set random seeds for reproducibility 210 | 211 | # Set device for this process 212 | torch.cuda.set_device(rank) 213 | device = torch.device(f'cuda:{rank}') 214 | args.device = device 215 | args.local_rank = -1 # Disable distributed training 216 | 217 | # Load model for this process 218 | checkpoint = os.path.join(args.model_path, args.model_name) 219 | state = torch.load(checkpoint, weights_only=False, map_location=device) 220 | pretrain_args = state["args"] 221 | pretrain_args.load_from = None 222 | pretrain_args.device = device 223 | pretrain_args.local_rank = -1 # Disable distributed training 224 | 225 | # Initialize model without DDP 226 | pretrain_state_dict = state["state_dict"] 227 | attn_model, flow, _ = init_model(pretrain_args) 228 | 229 | # Remove DDP wrapper if present 230 | if hasattr(attn_model, "module"): 231 | attn_model = attn_model.module 232 | 233 | pretrain_state_dict = {k.replace("module.", ""): v for k, v in pretrain_state_dict.items()} 234 | attn_model.load_state_dict(pretrain_state_dict) 235 | 236 | # print(f"GPU {rank} starting processing {len(chunk)} items") 237 | 238 | # Process chunk 239 | graph_list = [] 240 | frontiers_dict = defaultdict(list) 241 | for idx, line in enumerate(chunk): 242 | if ">>" in line: 243 | ori_reactant = line.strip().split(">>")[0] 244 | products = line.strip().split(">>")[1].split("|") 245 | products = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in products] 246 | else: 247 | ori_reactant = line.strip() 248 | products = [] 249 | reactant = reactant_process(ori_reactant) 250 | graph = nx.DiGraph() 251 | graph.add_node(reactant, depth=1) 252 | graph_list.append((graph, reactant, (ori_reactant, products))) 253 | frontiers_dict[idx].append(reactant) 254 | 255 | beam_search(args, attn_model, flow, frontiers_dict, graph_list) 256 | 257 | lock.acquire() 258 | try: 259 | queue.put((rank, chunk_idx, graph_list)) 260 | finally: 261 | lock.release() 262 | 263 | # print(f"GPU {rank} finished processing") 264 | 265 | def check_if_successful(graph, products): 266 | nodes_with_loops = list(nx.nodes_with_selfloops(graph)) 267 | achieved_products = set() 268 | for node in graph.nodes(): 269 | node_in_products = set(clean(node).split('.')) & set(products) 270 | if node_in_products and node in nodes_with_loops: 271 | achieved_products.update(node_in_products) 272 | return achieved_products 273 | 274 | def main_multi_gpu(args): 275 | start = time.time() 276 | global processes 277 | processes = init_process_killer() 278 | 279 | # Get number of available GPUs 280 | world_size = torch.cuda.device_count() 281 | log_rank_0(f"Found {world_size} GPUs") 282 | 283 | # Read all test smiles 284 | with open(args.test_path, 'r') as test_o: 285 | test_smiles_list = test_o.readlines() 286 | 287 | # Calculate chunk size and create chunks 288 | # chunk_size = math.ceil(len(test_smiles_list) / world_size) 289 | chunk_size = args.chunk_size // world_size 290 | chunks = [test_smiles_list[i:i + chunk_size] for i in range(0, len(test_smiles_list), chunk_size)] 291 | 292 | group_chunks = group_lists(chunks, world_size) 293 | log_rank_0(f"Number of group chunks: {len(group_chunks)}") 294 | 295 | os.makedirs(args.result_path, exist_ok=True) 296 | # Start processes 297 | lock = mp.Lock() 298 | q = mp.Queue() 299 | chunk_idx = 0 300 | for group_chunk_id, group_chunk in enumerate(group_chunks): 301 | log_rank_0(f"Group Chunk-{group_chunk_id} called:") 302 | all_results = [] 303 | processes = [] 304 | for gpu_idx, chunk in enumerate(group_chunk): 305 | p = mp.Process(target=worker, args=(gpu_idx, args, chunk, chunk_idx, lock, q)) 306 | p.start() 307 | processes.append(p) 308 | time.sleep(1) # Add small delay between process starts 309 | chunk_idx += 1 310 | 311 | outputs = [] 312 | for _ in processes: 313 | output = q.get(timeout=1000) 314 | if output is None: continue 315 | outputs.append(output) 316 | outputs = sorted(outputs, key=lambda x:x[0]) 317 | for output in outputs: 318 | _, output_chunk_idx, graph_list = output 319 | for beam_idx, (graph, root, (reactant, products)) in enumerate(graph_list): 320 | check = check_if_successful(graph, products) 321 | log_rank_0(f"Beam Search Results {beam_idx}: {len(check)}/{len(products)} - {check}") 322 | all_results.append((graph, root, (reactant, products), check)) 323 | 324 | # Wait for all processes to complete 325 | for p in processes: 326 | p.join() 327 | 328 | saving_file = os.path.join(args.result_path, f'result_chunk_{group_chunk_id}.pickle') 329 | with open(saving_file, "wb") as f_out: 330 | pickle.dump(all_results, f_out) 331 | 332 | log_rank_0(f"---- Time used: {(time.time() - start):.2f}s ----") 333 | 334 | 335 | log_rank_0("Done!") 336 | 337 | if __name__ == "__main__": 338 | # Ensure clean startup 339 | mp.set_start_method('spawn') 340 | 341 | args = Args 342 | logger = setup_logger(args, "beam") 343 | log_args(args, 'evaluation') 344 | main_multi_gpu(args) -------------------------------------------------------------------------------- /eval_multiGPU.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import glob 4 | import datetime 5 | import torch 6 | import numpy as np 7 | from rdkit import Chem 8 | import torchdiffeq 9 | from utils.data_utils import ReactionDataset, BEmatrix_to_mol, ps 10 | from utils.rounding import saferound_tensor 11 | import torch.distributed as dist 12 | from train import init_model, init_loader 13 | from utils.train_utils import log_rank_0, setup_logger, log_args 14 | from settings import Args 15 | from collections import defaultdict 16 | import time 17 | import iteround 18 | 19 | ps = Chem.SmilesParserParams() 20 | ps.removeHs = False 21 | ps.sanitize = True 22 | 23 | def is_sym(a): 24 | return (a.transpose(1, 0) == a).all() 25 | 26 | def redist_fix(pred_matrix, reac_smi, reac_be_matrix): 27 | pred_electron_sum = np.zeros([len(pred_matrix)]) 28 | for i in range(len(pred_matrix)): 29 | pred_electron_sum[i] = \ 30 | np.sum(pred_matrix[i, :]) + np.sum(pred_matrix[:, i]) - pred_matrix[i, i] 31 | 32 | reac_electron_sum = np.zeros([len(reac_be_matrix)]) 33 | for i in range(len(reac_be_matrix)): 34 | reac_electron_sum[i] = \ 35 | np.sum(reac_be_matrix[i, :]) + np.sum(reac_be_matrix[:, i]) - reac_be_matrix[i, i] 36 | 37 | diff = reac_electron_sum - pred_electron_sum 38 | 39 | if np.sum(diff) == 0: 40 | pred_matrix[np.diag_indices_from(pred_matrix)] += diff 41 | 42 | return pred_matrix 43 | 44 | # # old implementation uses CPU 45 | # def redistribute_round(x): 46 | # rounded_diff = iteround.saferound(x.flatten().cpu().numpy().tolist(), 0) 47 | # rounded_diff = torch.as_tensor(rounded_diff, dtype=torch.float).view(*x.shape) 48 | # return rounded_diff.to(x) 49 | 50 | # new implementation uses GPU 51 | def redistribute_round(x): 52 | rounded = saferound_tensor(x, places=0, strategy="difference") 53 | return rounded 54 | 55 | def custom_round(x): 56 | output = [] 57 | for i in range(x.shape[0]): 58 | try: output.append(redistribute_round(x[i])) 59 | except: output.append(torch.round(x[i])) 60 | return torch.stack(output) 61 | 62 | def standardize_smiles(mol): 63 | [a.SetAtomMapNum(0) for a in mol.GetAtoms()] 64 | return Chem.MolToSmiles(mol, isomericSmiles=False, allHsExplicit=True) 65 | 66 | def split_number(number, num_parts): 67 | if number % num_parts != 0: 68 | raise ValueError("The number cannot be evenly divided into the specified number of parts.") 69 | return [number // num_parts] * num_parts 70 | 71 | start = time.time() 72 | def predict_batch(args, batch_idx, data_batch, model, flow, split, rand_matrix=None): 73 | src_data_indices = data_batch.src_data_indices 74 | y = data_batch.src_token_ids 75 | y_len = data_batch.src_lens 76 | x0 = data_batch.src_matrices 77 | # x1 = data_batch.tgt_matrices 78 | matrix_masks = data_batch.matrix_masks 79 | 80 | batch_size, n, n = x0.shape 81 | 82 | log_rank_0(f"Batch idx: {batch_idx}, batch_shape {batch_size, n, n} {(time.time() - start): .2f}s") 83 | # --------ODE inference--------------# 84 | SAMPLE_BATCH = args.sample_size 85 | # split_sample_batches = split_number(SAMPLE_BATCH, 2) if n >= 400 else split_number(SAMPLE_BATCH, 1) 86 | # split_sample_batches = split_number(SAMPLE_BATCH, 1) 87 | split_sample_batches = split_number(SAMPLE_BATCH, split) 88 | 89 | big_traj_list = [] 90 | for sample_size in split_sample_batches: 91 | src_data_indices = src_data_indices.repeat_interleave(sample_size, dim=0) 92 | x0_repeated = x0.repeat_interleave(sample_size, dim=0) 93 | x0_sample_repeated = flow.sample_be_matrix(x0_repeated) 94 | 95 | matrix_masks_repeated = matrix_masks.repeat_interleave(sample_size, dim=0) 96 | x0_sample_repeated = x0_sample_repeated.masked_fill(~(matrix_masks_repeated.bool()), 0) # ode initial step has RMS norm thus padding nan has to be swap to 0 97 | 98 | del matrix_masks_repeated 99 | torch.cuda.empty_cache() 100 | 101 | y_repeated = y.repeat_interleave(sample_size, dim=0) 102 | y_emb_repeated = model.id2emb(y_repeated) 103 | y_len_batch_repeated = y_len.repeat_interleave(sample_size, dim=0) 104 | 105 | traj_list = torchdiffeq.odeint_adjoint( 106 | lambda t, x: model.forward(y_emb_repeated, y_len_batch_repeated, x, t), 107 | x0_sample_repeated, 108 | torch.linspace(0, 1, 2).to(args.device), 109 | atol=1e-4, 110 | rtol=1e-4, 111 | method="dopri5", 112 | adjoint_params=() 113 | ) 114 | big_traj_list.append((traj_list.transpose(0, 1).detach().cpu(), sample_size)) 115 | 116 | # merging 117 | all_traj_list = [] 118 | for bs in range(batch_size): 119 | for traj_list, sample_size in big_traj_list: 120 | all_traj_list.append(traj_list[bs*sample_size:(bs+1)*sample_size].transpose(0, 1)) 121 | traj_list = torch.concat(all_traj_list, dim=1) # concat on sampling dimension 122 | # ------------------------------------# 123 | return traj_list 124 | 125 | def get_predictions(args, model, flow, data_loader, iter_count=np.inf, write_o=None): 126 | accuracy = [] 127 | model.eval() 128 | with torch.no_grad(): 129 | log_rank_0('Start ODE Prediction...') 130 | if dist.get_rank() == 0: 131 | inferenced_indexes = set() 132 | 133 | for batch_idx, data_batch in enumerate(data_loader): 134 | if batch_idx >= iter_count: break 135 | data_batch.to(args.device) 136 | 137 | src_data_indices = data_batch.src_data_indices 138 | x0 = data_batch.src_matrices 139 | y_len = data_batch.src_lens 140 | batch_size, n, n = x0.shape 141 | src_smiles_list = data_batch.src_smiles_list 142 | tgt_smiles_list = data_batch.tgt_smiles_list 143 | 144 | 145 | # if (batch_size*n*n) <= 5*360*360: 146 | if (batch_size*n*n) <= 15*130*130: 147 | traj_list = predict_batch(args, batch_idx, data_batch, model, flow, 1) 148 | else: 149 | traj_list = predict_batch(args, batch_idx, data_batch, model, flow, 2) 150 | 151 | if torch.distributed.is_initialized() and dist.get_world_size() > 1: 152 | gathered_results = [None for _ in range(dist.get_world_size())] 153 | dist.gather_object( 154 | (src_data_indices, traj_list, x0, y_len, src_smiles_list, tgt_smiles_list), 155 | gathered_results if dist.get_rank() == 0 else None, 156 | dst=0 157 | ) 158 | else: 159 | gathered_results = [(src_data_indices, traj_list, x0, y_len, src_smiles_list, tgt_smiles_list)] 160 | 161 | if dist.get_rank() > 0: 162 | continue 163 | 164 | for result in gathered_results: 165 | src_data_indices, traj_list, x0, y_len, src_smiles_list, tgt_smiles_list = result 166 | batch_size, n, n = x0.shape 167 | 168 | last_step = traj_list[-1] 169 | product_BE_matrices = custom_round(last_step) 170 | product_BE_matrices_batch = torch.split(product_BE_matrices, args.sample_size) 171 | 172 | for idx in range(batch_size): 173 | reac_smi, product_smi, product_BE_matrices = \ 174 | src_smiles_list[idx], tgt_smiles_list[idx], product_BE_matrices_batch[idx] 175 | 176 | data_idx = int(src_data_indices[idx].detach().cpu()) 177 | if data_idx in inferenced_indexes: continue 178 | else: inferenced_indexes.add(data_idx) 179 | 180 | reac_mol = Chem.MolFromSmiles(reac_smi, ps) 181 | prod_mol = Chem.MolFromSmiles(product_smi, ps) 182 | 183 | tgt_smiles = standardize_smiles(prod_mol) 184 | 185 | matrices, counts = torch.unique(product_BE_matrices, dim=0, return_counts=True) 186 | matrices, counts = matrices.cpu().numpy(), counts.cpu().numpy() 187 | 188 | not_sym = 0 189 | 190 | correct = wrong_smi_conserved = wrong_smi_non_conserved = 0 191 | no_smi_conserved = no_smi_non_conserved = 0 192 | 193 | pred_smi_dict = defaultdict(int) 194 | pred_conserved_dict = defaultdict(bool) 195 | # Evaluation on unique predicted BE matrices 196 | for i in range(matrices.shape[0]): 197 | pred_prod_be_matrix, count = matrices[i], counts[i] # predicted product matrix and it's count 198 | num_nodes = y_len[idx] 199 | pred_prod_be_matrix = pred_prod_be_matrix[:num_nodes, :num_nodes] 200 | reac_be_matrix = x0[idx][:num_nodes, :num_nodes].detach().cpu().numpy() 201 | 202 | # print(f"Matrix{i} - {count}") 203 | pred_prod_be_matrix = redist_fix(pred_prod_be_matrix, reac_smi, reac_be_matrix) 204 | 205 | assert pred_prod_be_matrix.shape == reac_be_matrix.shape, "pred and reac not the same shape" 206 | 207 | if not is_sym(pred_prod_be_matrix): 208 | not_sym += 1 209 | 210 | try: 211 | pred_mol = BEmatrix_to_mol(reac_mol, pred_prod_be_matrix) 212 | pred_smi = standardize_smiles(pred_mol) 213 | 214 | pred_mol = Chem.MolFromSmiles(pred_smi, ps) 215 | pred_smi = standardize_smiles(pred_mol) 216 | tgt_mol = Chem.MolFromSmiles(tgt_smiles, ps) 217 | tgt_smiles = standardize_smiles(tgt_mol) 218 | 219 | 220 | if pred_smi == tgt_smiles and pred_prod_be_matrix.sum() == reac_be_matrix.sum(): 221 | correct += count 222 | pred_smi_dict[pred_smi] += count 223 | pred_conserved_dict[pred_smi] = True 224 | elif pred_prod_be_matrix.sum() == reac_be_matrix.sum(): # conserve electron, gives wrong smiles 225 | wrong_smi_conserved += count 226 | pred_smi_dict[pred_smi] += count 227 | pred_conserved_dict[pred_smi] = True 228 | else: # Gives SMILES but does not conserve electron 229 | wrong_smi_non_conserved += count ########### This is added metric 230 | except: 231 | if pred_prod_be_matrix.sum() == reac_be_matrix.sum(): 232 | no_smi_conserved += count 233 | else: 234 | no_smi_non_conserved += count 235 | 236 | metric = [correct, wrong_smi_conserved, wrong_smi_non_conserved, no_smi_conserved, no_smi_non_conserved] 237 | predictions = [(smi, pred_smi_dict[smi], pred_conserved_dict[smi]) for smi in pred_smi_dict] 238 | if write_o is not None: 239 | write_o.write(f"{metric}|{not_sym}|{predictions}\n") 240 | write_o.flush() 241 | accuracy.append(metric) 242 | 243 | return accuracy 244 | 245 | 246 | def main(args): 247 | args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 248 | device = args.device 249 | if args.local_rank != -1: 250 | dist.init_process_group(backend=args.backend, init_method='env://', timeout=datetime.timedelta(0, 7200)) 251 | torch.cuda.set_device(args.local_rank) 252 | torch.backends.cudnn.benchmark = True 253 | 254 | if args.do_validate: 255 | phase = "valid" 256 | checkpoints = glob.glob(os.path.join(args.model_path, "*.pt")) 257 | checkpoints = sorted( 258 | checkpoints, 259 | key=lambda ckpt: int(ckpt.split(".")[-2].split("_")[-1]), 260 | reverse=True 261 | ) 262 | assert len(args.steps2validate) > 1, "Nothing to validate on" 263 | checkpoints = [ckpt for ckpt in checkpoints 264 | if ckpt.split(".")[-2].split("_")[0] in args.steps2validate] # lr0.001 265 | else: 266 | phase = "test" 267 | checkpoints = [os.path.join(args.model_path, args.model_name)] 268 | 269 | 270 | for ckpt_i, checkpoint in enumerate(checkpoints): 271 | state = torch.load(checkpoint, weights_only=False, map_location=device) 272 | pretrain_args = state["args"] 273 | pretrain_args.load_from = None 274 | pretrain_args.device = device 275 | 276 | pretrain_state_dict = state["state_dict"] 277 | pretrain_args.local_rank = args.local_rank 278 | 279 | attn_model, flow, state = init_model(pretrain_args) 280 | if hasattr(attn_model, "module"): 281 | attn_model = attn_model.module # unwrap DDP attn_model to enable accessing attn_model func directly 282 | 283 | pretrain_state_dict = {k.replace("module.", ""): v for k, v in pretrain_state_dict.items()} 284 | attn_model.load_state_dict(pretrain_state_dict) 285 | log_rank_0(f"Loaded pretrained state_dict from {checkpoint}") 286 | 287 | os.makedirs(args.result_path, exist_ok=True) 288 | results_path = os.path.join(args.result_path, f'{phase}-{args.sample_size}-{checkpoint.split(".")[-2]}.txt') 289 | if os.path.isfile(results_path): 290 | with open(results_path, 'r') as fp: 291 | n_lines = len(fp.readlines()) 292 | file_mod = 'a' 293 | start = n_lines 294 | log_rank_0(f"Continuing previous runs at reaction {start}...") 295 | else: 296 | log_rank_0("Starting new run...") 297 | file_mod = 'w' 298 | start = 0 299 | 300 | if args.do_validate: 301 | with open(args.val_path, 'r') as test_o: 302 | test_smiles_list = test_o.readlines()[start:] 303 | else: 304 | with open(args.test_path, 'r') as test_o: 305 | test_smiles_list = test_o.readlines()[start:] 306 | 307 | assert len(test_smiles_list) > 0, "Nothing to do inference" 308 | 309 | test_dataset = ReactionDataset(args, test_smiles_list) 310 | test_loader = init_loader(args, test_dataset, 311 | batch_size=args.test_batch_size, 312 | shuffle=False, epoch=None, use_sort=False) 313 | 314 | with open(results_path, file_mod) as result_o: 315 | metrics = get_predictions(args, attn_model, flow, test_loader, write_o=result_o) 316 | if dist.get_rank() == 0: 317 | metrics = np.array(metrics) 318 | topk_accuracies = np.mean(metrics[:, 0].astype(bool)) # correct smiles 319 | log_rank_0(f"Topk accuracies: {(topk_accuracies * 100): .2f}") 320 | 321 | 322 | if __name__ == "__main__": 323 | args = Args 324 | args.local_rank = int(os.environ["LOCAL_RANK"]) if os.environ.get("LOCAL_RANK") else -1 325 | logger = setup_logger(args, "eval") 326 | log_args(args, 'evaluation') 327 | main(args) 328 | -------------------------------------------------------------------------------- /model/attn_encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from utils.attn_utils import PositionwiseFeedForward 5 | from utils.attn_utils import sequence_mask 6 | from utils.data_utils import ELEM_LIST 7 | from model.flow_matching import zero_center_func 8 | 9 | def timestep_embedding(timesteps, dim, max_period=10000): 10 | """Create sinusoidal timestep embeddings. 11 | 12 | :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. 13 | :param dim: the dimension of the output. 14 | :param max_period: controls the minimum frequency of the embeddings. 15 | :return: an [N x dim] Tensor of positional embeddings. 16 | """ 17 | half = dim // 2 18 | freqs = torch.exp( 19 | -math.log(max_period) 20 | * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) 21 | / half 22 | ) 23 | args = timesteps[:, None].float() * freqs[None] 24 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 25 | if dim % 2: 26 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 27 | return embedding 28 | 29 | def zero_center_output(x_batch, ori_node_mask_batch): 30 | x_batch = x_batch.masked_fill(ori_node_mask_batch, 1e-19) 31 | node_mask_batch = (~ori_node_mask_batch).long() 32 | map_zero_center = torch.vmap(zero_center_func) 33 | return map_zero_center(x_batch, node_mask_batch).masked_fill(~(node_mask_batch.bool()), 1e-19) 34 | 35 | class RBFExpansion(nn.Module): 36 | def __init__(self, args): 37 | """ 38 | Adapted from Schnet. 39 | https://github.com/atomistic-machine-learning/SchNet/blob/master/src/schnet/nn/layers/rbf.py 40 | """ 41 | super().__init__() 42 | self.args = args 43 | self.device = args.device 44 | self.low = args.rbf_low 45 | self.high = args.rbf_high 46 | self.gap = args.rbf_gap 47 | 48 | self.xrange = self.high - self.low 49 | 50 | self.centers = torch.linspace(self.low, self.high, 51 | int(torch.ceil(torch.tensor(self.xrange / self.gap)))).to(self.device) 52 | self.dim = len(self.centers) 53 | 54 | def forward(self, matrix, matrix_mask): 55 | matrix = matrix.masked_fill(matrix_mask, 1e9) 56 | matrix = matrix.unsqueeze(-1) # Add a new dimension at the end 57 | # Compute the RBF 58 | matrix = matrix - self.centers 59 | rbf = torch.exp(-(matrix ** 2) / self.gap) 60 | return rbf 61 | 62 | class MultiHeadedRelAttention(nn.Module): 63 | def __init__(self, args, head_count, model_dim, dropout, u, v): 64 | super().__init__() 65 | self.args = args 66 | 67 | assert model_dim % head_count == 0 68 | self.dim_per_head = model_dim // head_count 69 | self.model_dim = model_dim 70 | self.head_count = head_count 71 | 72 | self.linear_keys = nn.Linear(model_dim, model_dim) 73 | self.linear_values = nn.Linear(model_dim, model_dim) 74 | self.linear_query = nn.Linear(model_dim, model_dim) 75 | 76 | self.softmax = nn.Softmax(dim=-1) 77 | self.dropout = nn.Dropout(dropout) 78 | self.final_linear = nn.Linear(model_dim, model_dim) 79 | 80 | self.u = u if u is not None else \ 81 | nn.Parameter(torch.randn(self.d_model), requires_grad=True) 82 | self.v = v if v is not None else \ 83 | nn.Parameter(torch.randn(self.d_model), requires_grad=True) 84 | 85 | def forward(self, inputs, mask, rel_emb): 86 | """ 87 | Compute the context vector and the attention vectors. 88 | 89 | Args: 90 | inputs (FloatTensor): set of `key_len` 91 | key vectors ``(batch, key_len, dim)`` 92 | mask: binary mask 1/0 indicating which keys have 93 | zero / non-zero attention ``(batch, query_len, key_len)`` 94 | rel_emb: graph distance matrix (BUCKETED), ``(batch, key_len, key_len)`` 95 | Returns: 96 | (FloatTensor, FloatTensor): 97 | 98 | * output context vectors ``(batch, query_len, dim)`` 99 | * Attention vector in heads ``(batch, head, query_len, key_len)``. 100 | """ 101 | 102 | batch_size = inputs.size(0) 103 | dim_per_head = self.dim_per_head 104 | head_count = self.head_count 105 | 106 | def shape(x): 107 | """Projection.""" 108 | return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2) 109 | 110 | def unshape(x): 111 | """Compute context.""" 112 | return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head) 113 | 114 | # 1) Project key, value, and query. Seems that we don't need layer_cache here 115 | query = self.linear_query(inputs) 116 | key = self.linear_keys(inputs) 117 | value = self.linear_values(inputs) 118 | 119 | key = shape(key) # (b, t_k, h) -> (b, head, t_k, h/head) 120 | value = shape(value) 121 | query = shape(query) # (b, t_q, h) -> (b, head, t_q, h/head) 122 | 123 | key_len = key.size(2) 124 | query_len = query.size(2) 125 | 126 | # 2) Calculate and scale scores. 127 | query = query / math.sqrt(dim_per_head) 128 | 129 | if rel_emb is None: 130 | scores = torch.matmul( 131 | query, key.transpose(2, 3)) # (b, head, t_q, t_k) 132 | # scores = scores + rel_emb.unsqueeze(1) 133 | else: 134 | # a + c 135 | u = self.u.reshape(1, head_count, 1, dim_per_head) 136 | a_c = torch.matmul(query + u, key.transpose(2, 3)) 137 | 138 | # rel_emb = self.relative_pe(rel_emb) # (b, t_q, t_k) -> (b, t_q, t_k, h) 139 | rel_emb = rel_emb.reshape( # (b, t_q, t_k, h) -> (b, t_q, t_k, head, h/head) 140 | batch_size, query_len, key_len, head_count, dim_per_head) 141 | 142 | # b + d 143 | query = query.unsqueeze(-2) # (b, head, t_q, h/head) -> (b, head, t_q, 1, h/head) 144 | rel_emb_t = rel_emb.permute(0, 3, 1, 4, 2) # (b, t_q, t_k, head, h/head) -> (b, head, t_q, h/head, t_k) 145 | 146 | v = self.v.reshape(1, head_count, 1, 1, dim_per_head) 147 | b_d = torch.matmul(query + v, rel_emb_t 148 | ).squeeze(-2) # (b, head, t_q, 1, t_k) -> (b, head, t_q, t_k) 149 | 150 | scores = a_c + b_d 151 | 152 | scores = scores.float() 153 | 154 | mask = mask.unsqueeze(1) # (B, 1, 1, T_values) 155 | scores = scores.masked_fill(mask, -1e18) 156 | 157 | # 3) Apply attention dropout and compute context vectors. 158 | attn = self.softmax(scores) 159 | drop_attn = self.dropout(attn) 160 | 161 | context_original = torch.matmul(drop_attn, value) # -> (b, head, t_q, h/head) 162 | context = unshape(context_original) # -> (b, t_q, h) 163 | 164 | output = self.final_linear(context) 165 | attns = attn.view(batch_size, head_count, query_len, key_len) 166 | 167 | return output, attns 168 | 169 | 170 | class SALayerXL(nn.Module): 171 | """ 172 | A single layer of the self-attention encoder. 173 | 174 | Args: 175 | d_model (int): the dimension of keys/values/queries in 176 | MultiHeadedAttention, also the input size of 177 | the first-layer of the PositionwiseFeedForward. 178 | heads (int): the number of head for MultiHeadedAttention. 179 | d_ff (int): the second-layer of the PositionwiseFeedForward. 180 | dropout: dropout probability(0-1.0). 181 | """ 182 | 183 | def __init__(self, args, d_model, heads, d_ff, dropout, attention_dropout, u, v): 184 | super().__init__() 185 | 186 | self.self_attn = MultiHeadedRelAttention( 187 | args, 188 | heads, d_model, dropout=attention_dropout, 189 | u=u, 190 | v=v 191 | ) 192 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 193 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 194 | self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) 195 | self.dropout = nn.Dropout(dropout) 196 | 197 | def forward(self, inputs, mask, rel_emb): 198 | """ 199 | Args: 200 | inputs (FloatTensor): ``(batch_size, src_len, model_dim)`` 201 | mask (LongTensor): ``(batch_size, 1, src_len)`` 202 | rel_emb (LongTensor): ``(batch_size, src_len, src_len)`` 203 | 204 | Returns: 205 | (FloatTensor): 206 | 207 | * outputs ``(batch_size, src_len, model_dim)`` 208 | """ 209 | normed_inputs = self.layer_norm(inputs) 210 | context, _ = self.self_attn(normed_inputs, mask=mask, rel_emb=rel_emb) 211 | out = self.dropout(context) + inputs 212 | return self.feed_forward(self.layer_norm_2(out)) + out 213 | 214 | class Block(nn.Module): 215 | def __init__(self, size: int): 216 | super().__init__() 217 | 218 | self.ff = nn.Linear(size, size) 219 | self.act = nn.GELU() 220 | self.layer_norm = nn.LayerNorm(size, eps=1e-6) 221 | 222 | def forward(self, x: torch.Tensor): 223 | return x + self.layer_norm(self.act(self.ff(x))) 224 | 225 | class AttnEncoderXL(nn.Module): 226 | def __init__(self, args): 227 | super().__init__() 228 | self.args = args 229 | self.device = args.device 230 | self.num_layers = args.enc_num_layers 231 | self.post_processing_layers = args.post_processing_layers 232 | self.d_model = args.emb_dim 233 | self.heads = args.enc_heads 234 | self.d_ff = args.enc_filter_size 235 | self.attention_dropout = args.attn_dropout 236 | 237 | self.atom_embedding = nn.Embedding(len(ELEM_LIST) , self.d_model, padding_idx=0) 238 | self.rbf = RBFExpansion(args) 239 | 240 | self.time_dim = self.d_model - self.rbf.dim 241 | self.time_embed = nn.Sequential( 242 | nn.Linear(self.time_dim, self.time_dim), 243 | nn.SiLU(), 244 | nn.Linear(self.time_dim, self.time_dim), 245 | ) 246 | 247 | self.dropout = nn.Dropout(p=args.dropout) 248 | if args.rel_pos in ["enc_only", "emb_only"]: 249 | self.u = nn.Parameter(torch.randn(self.d_model), requires_grad=True) 250 | self.v = nn.Parameter(torch.randn(self.d_model), requires_grad=True) 251 | else: 252 | self.u = None 253 | self.v = None 254 | 255 | if args.shared_attention_layer == 1: 256 | self.attention_layer = SALayerXL( 257 | args, self.d_model, self.heads, self.d_ff, args.dropout, self.attention_dropout, 258 | self.u, self.v) 259 | else: 260 | self.attention_layers = nn.ModuleList( 261 | [SALayerXL( 262 | args, self.d_model, self.heads, self.d_ff, args.dropout, self.attention_dropout, 263 | self.u, self.v) 264 | for i in range(self.num_layers)]) 265 | self.layer_norm = nn.LayerNorm(self.d_model, eps=1e-6) 266 | 267 | self.query_w = torch.nn.Sequential( 268 | *[Block(self.d_model) for _ in range(self.post_processing_layers)] 269 | ) 270 | self.key_w = torch.nn.Sequential( 271 | *[Block(self.d_model) for _ in range(self.post_processing_layers)] 272 | ) 273 | 274 | self.query_diag_w = torch.nn.Sequential( 275 | *[Block(self.d_model) for _ in range(self.post_processing_layers)] 276 | ) 277 | self.key_diag_w = torch.nn.Sequential( 278 | *[Block(self.d_model) for _ in range(self.post_processing_layers)] 279 | ) 280 | self.value_diag_w = torch.nn.Sequential( 281 | *[Block(self.d_model) for _ in range(self.post_processing_layers)] 282 | ) 283 | self.final_diag_w = torch.nn.Linear(self.d_model, 1) 284 | 285 | self.rel_emb_w = torch.nn.Sequential( 286 | *[*[Block(self.d_model) for _ in range(self.post_processing_layers)], 287 | torch.nn.Linear(self.d_model, 1)] 288 | ) 289 | self.softmax = nn.Softmax(dim=-1) 290 | 291 | rbf_layers = [] 292 | # rbf_layers.append(torch.nn.Linear(self.rbf.dim+1, self.rbf.dim)) 293 | for _ in range(self.post_processing_layers): 294 | rbf_layers.append(Block(self.rbf.dim)) 295 | self.rbf_linear = torch.nn.Sequential(*rbf_layers) 296 | self.rbf_final_linear = torch.nn.Linear(self.rbf.dim, 1) 297 | 298 | def id2emb(self, src_token_id): 299 | return self.atom_embedding(src_token_id) 300 | 301 | def forward(self, src, lengths, bond_matrix, timestep): 302 | """adapt from onmt TransformerEncoder 303 | src_token_id: (b, t, h) 304 | lengths: (b,) 305 | 306 | NEW on Jan'23: return: (b, t, h) 307 | """ 308 | if timestep.dim() == 0: 309 | timestep = timestep.repeat(lengths.shape[0]) 310 | 311 | b, n, _ = bond_matrix.shape 312 | timestep = self.time_embed(timestep_embedding(timestep, self.time_dim)) 313 | timestep = timestep.unsqueeze(1).unsqueeze(1) # unsqueeze to match bond n x n 314 | timestep = timestep.repeat(1, n, n, 1) # unsqueeze to match bond n x n 315 | 316 | mask = ~sequence_mask(lengths).unsqueeze(1) 317 | 318 | matrix_masks = ~(~mask * ~mask.transpose(1, 2)).bool() 319 | rbf_bond_matrix = self.rbf(bond_matrix, matrix_masks) 320 | rbf_bond_matrix = self.rbf_linear(rbf_bond_matrix) # b, n, n, 1 -> b, n, n, rbf-dim 321 | 322 | rel_emb = torch.cat((rbf_bond_matrix, timestep), dim=-1) # b, n, n, d 323 | 324 | # src = self.atom_embedding(src_token_id) 325 | # h_place = (src_token_id == 1).float().unsqueeze(-1).repeat(1, 1, src.shape[-1]) 326 | 327 | b, n, d = src.shape 328 | 329 | # a_i - raw atom embeddings 330 | a_i = src * math.sqrt(self.d_model) 331 | a_i = self.dropout(a_i) 332 | 333 | if self.args.shared_attention_layer == 1: 334 | layer = self.attention_layer 335 | for i in range(self.num_layers): 336 | a_i = layer(a_i, mask, rel_emb) 337 | else: 338 | for layer in self.attention_layers: 339 | a_i = layer(a_i, mask, rel_emb) 340 | a_i = self.layer_norm(a_i) # b,n,d 341 | 342 | # a_i - atom embeddings after multiheaded attention on atom embeddings + rbf expansion 343 | 344 | # diagonal prediction 345 | query_diag = self.query_diag_w(a_i) # b,n,d @ d,d -> b,n,d 346 | key_diag = self.key_diag_w(a_i) # b,n,d @ d,d -> b,n,d 347 | value_diag = self.value_diag_w(a_i) # b,n,d @ d,d -> b,n,d 348 | 349 | diag_scores = torch.matmul(query_diag, key_diag.transpose(1, 2)) # b,n,d @ b,d,n -> b,n,n 350 | diag_scores = diag_scores.masked_fill(matrix_masks, 1e-9) 351 | diag_scores = self.softmax(diag_scores) / math.sqrt(self.d_model) 352 | context = torch.matmul(diag_scores, value_diag) # b,n,n @ b,n,d -> b,n,d 353 | diag = self.final_diag_w(context).view(b, n) # b,n,d @ d,1 -> b,n,1 -> b,n 354 | 355 | # non diagonal prediction 356 | query = self.query_w(a_i) # b,n,d @ d,d -> b,n,d 357 | key = self.key_w(a_i) 358 | 359 | scores = torch.matmul(query, key.transpose(1, 2)) # b,n,d @ b,d,n -> b,n,n 360 | a_ij = scores / math.sqrt(self.d_model) 361 | 362 | rbfw_ij = self.rel_emb_w(rel_emb).view(b, n, n) # b,n,n,d @ d,1 -> b,n,n,1 -> b,n,n 363 | out = a_ij + rbfw_ij 364 | 365 | for i in range(b): 366 | indices = torch.arange(n) 367 | out[i, indices, indices] = 0 368 | out[i].diagonal().add_(diag[i]) 369 | 370 | out = zero_center_output(out, matrix_masks) 371 | out = (out + out.transpose(1, 2)) 372 | 373 | return out 374 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from rdkit import Chem 4 | import numpy as np 5 | import sys 6 | from rdkit import RDLogger 7 | RDLogger.DisableLog('rdApp.*') 8 | from multiprocessing import Pool, cpu_count 9 | 10 | np.set_printoptions(threshold=sys.maxsize, linewidth=500) 11 | torch.set_printoptions(profile="full") 12 | 13 | ELEM_LIST = ['PAD', 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Na', 'Mg', 'Al', 'Si', \ 14 | 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', \ 15 | 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Sr', 'Y', 'Zr', 'Mo', 'Tc', 'Ru', \ 16 | 'Rh', 'Pd', 'Ag', 'In', 'Sn', 'Sb', 'Te', 'I', 'Cs', 'Ba', 'La', 'Ce', 'Eu', \ 17 | 'Yb', 'Ta', 'W', 'Os', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'V', 'Sm'] 18 | 19 | MATRIX_PAD = -30 20 | bt_to_electron = {Chem.rdchem.BondType.SINGLE: 2, 21 | Chem.rdchem.BondType.DOUBLE: 4, 22 | Chem.rdchem.BondType.TRIPLE: 6, 23 | Chem.rdchem.BondType.AROMATIC: 3} 24 | 25 | tbl = Chem.GetPeriodicTable() 26 | 27 | def bond_features(bond): 28 | bt = bond.GetBondType() 29 | 30 | return bt_to_electron[bt] 31 | 32 | def count_lone_pairs(a): 33 | v=tbl.GetNOuterElecs(a.GetAtomicNum()) 34 | c=a.GetFormalCharge() 35 | b=sum([bond.GetBondTypeAsDouble() for bond in a.GetBonds()]) 36 | h=a.GetTotalNumHs() 37 | return v-c-b-h 38 | 39 | ps = Chem.SmilesParserParams() 40 | ps.removeHs = False 41 | ps.sanitize = True 42 | def get_BE_matrix(r): 43 | rmol = Chem.MolFromSmiles(r, ps) 44 | Chem.Kekulize(rmol) 45 | max_natoms = len(rmol.GetAtoms()) 46 | f = np.zeros((max_natoms,max_natoms)) 47 | 48 | for atom in rmol.GetAtoms(): 49 | lone_pair = count_lone_pairs(atom) 50 | f[atom.GetIntProp('molAtomMapNumber') - 1, atom.GetIntProp('molAtomMapNumber') - 1] = lone_pair 51 | 52 | for bond in rmol.GetBonds(): 53 | a1 = bond.GetBeginAtom().GetIntProp('molAtomMapNumber') - 1 54 | a2 = bond.GetEndAtom().GetIntProp('molAtomMapNumber') - 1 55 | f[(a1,a2)] = f[(a2,a1)] = bond_features(bond)/2 # so that bond electron diff matrix sums up to 0 56 | 57 | return f 58 | 59 | electron_to_bo = {val:key for key, val in bt_to_electron.items()} 60 | 61 | def get_formal_charge(a, electron): 62 | v=tbl.GetNOuterElecs(a.GetAtomicNum()) 63 | b=sum([bond.GetBondTypeAsDouble() for bond in a.GetBonds()]) 64 | h=a.GetTotalNumHs() 65 | f =v - electron - b - h 66 | return f 67 | 68 | def mol_prop_compute(matrix): 69 | """ 70 | vectorized way of computing atom dict and bond dict from matrix 71 | """ 72 | n = matrix.shape[0] 73 | 74 | # 1) Compute symmetric bond sums once: 75 | Mplus = matrix + matrix.T 76 | 77 | # 2) Extract all off-diagonal i>') 132 | 133 | error = "" 134 | try: 135 | _ = get_BE_matrix(src_smi) 136 | _ = get_BE_matrix(tgt_smi) 137 | src_vocab_id_list, src_len = smi2vocabid(src_smi) 138 | tgt_vocab_id_list, tgt_len = smi2vocabid(tgt_smi) 139 | assert (src_vocab_id_list == tgt_vocab_id_list).all() 140 | except Exception as e: 141 | error = e 142 | src_smi, tgt_smi = '', '' 143 | src_vocab_id_list, src_len = [], 0 144 | tgt_vocab_id_list, tgt_len = [], 0 145 | 146 | # Return a tuple of results for this smiles pair 147 | return { 148 | 'src_smi': src_smi, 149 | 'tgt_smi': tgt_smi, 150 | 'src_vocab_id_list': src_vocab_id_list, 151 | 'tgt_vocab_id_list': tgt_vocab_id_list, 152 | 'src_len': src_len, 153 | 'tgt_len': tgt_len 154 | # "error": error 155 | } 156 | 157 | class ReactionBatch: 158 | def __init__(self, 159 | src_data_indices: torch.Tensor, 160 | src_token_ids: torch.Tensor, 161 | src_lens: torch.Tensor, 162 | src_matrices: torch.Tensor, 163 | tgt_matrices: torch.Tensor, 164 | matrix_masks: torch.Tensor, 165 | src_smiles_list: list, 166 | tgt_smiles_list: list, 167 | ): 168 | self.src_data_indices = src_data_indices 169 | self.src_token_ids = src_token_ids 170 | self.src_lens = src_lens 171 | self.src_matrices = src_matrices 172 | self.tgt_matrices = tgt_matrices 173 | self.matrix_masks = matrix_masks 174 | self.src_smiles_list = src_smiles_list 175 | self.tgt_smiles_list = tgt_smiles_list 176 | 177 | def to(self, device): 178 | self.src_data_indices = self.src_data_indices.to(device) 179 | self.src_token_ids = self.src_token_ids.to(device) 180 | self.src_lens = self.src_lens.to(device) 181 | self.src_matrices = self.src_matrices.to(device) 182 | self.tgt_matrices = self.tgt_matrices.to(device) 183 | self.matrix_masks = self.matrix_masks.to(device) 184 | 185 | def pin_memory(self): 186 | self.src_data_indices = self.src_data_indices.pin_memory() 187 | self.src_token_ids = self.src_token_ids.pin_memory() 188 | self.src_lens = self.src_lens.pin_memory() 189 | self.src_matrices = self.src_matrices.pin_memory() 190 | self.tgt_matrices = self.tgt_matrices.pin_memory() 191 | self.matrix_masks = self.matrix_masks.pin_memory() 192 | 193 | return self 194 | 195 | class ReactionDataset(Dataset): 196 | def __init__(self, args, smiles_list, parallel=True, reactant_only=False): 197 | self.args = args 198 | self.device = args.device 199 | self.reactant_only = reactant_only 200 | self.smiles_list = smiles_list 201 | self.src_smis = [] 202 | self.tgt_smis = [] 203 | 204 | self.src_token_ids = [] 205 | self.tgt_token_ids = [] 206 | 207 | self.src_lens = [] 208 | self.tgt_lens = [] 209 | 210 | if reactant_only: 211 | self.parse_reactant_only() 212 | else: 213 | if parallel: 214 | self.parse_data_parallel() 215 | else: 216 | self.parse_data() 217 | 218 | self.src_lens = np.asarray(self.src_lens) 219 | 220 | self.data_size = len(self.src_smis) 221 | self.data_indices = np.arange(self.data_size) 222 | 223 | def parse_reactant_only(self): 224 | for src_smi in self.smiles_list: 225 | src_smi = src_smi.strip() 226 | try: 227 | _ = get_BE_matrix(src_smi) 228 | src_vocab_id_list, src_len = smi2vocabid(src_smi) 229 | except Exception as e: 230 | print(e) 231 | continue 232 | 233 | self.src_smis.append(src_smi) 234 | self.src_token_ids.append(src_vocab_id_list) 235 | self.src_lens.append(src_len) 236 | self.tgt_lens.append(src_len) 237 | 238 | assert len(self.src_smis) > 0, "Empty Data" 239 | 240 | def parse_data(self): 241 | for smiles in self.smiles_list: 242 | src_smi, tgt_smi = smiles.strip().split('|')[0].split('>>') 243 | 244 | try: 245 | src_matrix = get_BE_matrix(src_smi) 246 | tgt_matrix = get_BE_matrix(tgt_smi) 247 | src_vocab_id_list, src_len = smi2vocabid(src_smi) 248 | tgt_vocab_id_list, tgt_len = smi2vocabid(tgt_smi) 249 | assert (src_vocab_id_list == tgt_vocab_id_list).all() 250 | assert src_len == tgt_len, "src len and tgt len should be the same" 251 | except Exception as e: 252 | print(e) 253 | continue 254 | 255 | self.src_smis.append(src_smi) 256 | self.tgt_smis.append(tgt_smi) 257 | self.src_token_ids.append(src_vocab_id_list) 258 | self.tgt_token_ids.append(tgt_vocab_id_list) 259 | self.src_lens.append(src_len) 260 | self.tgt_lens.append(tgt_len) 261 | 262 | assert len(self.src_smis) == len(self.tgt_smis) == len(self.src_lens) == len(self.tgt_lens) \ 263 | == len(self.tgt_lens) == len(self.src_token_ids) == len(self.tgt_token_ids) 264 | 265 | def parse_data_parallel(self): 266 | 267 | p = Pool(cpu_count()) 268 | results = p.imap(process_smiles, ((smiles) for smiles in self.smiles_list)) 269 | p.close() 270 | p.join() 271 | 272 | # Prepare the final data structures 273 | count = 0 274 | total = 0 275 | for result in results: 276 | total += 1 277 | if result['src_vocab_id_list'] is [] or result['src_len'] == 0: 278 | # print(f"{result['src_smi']}>>{result['tgt_smi']}") 279 | # print(result['error']) 280 | count += 1 281 | continue 282 | self.src_smis.append(result['src_smi']) 283 | self.tgt_smis.append(result['tgt_smi']) 284 | self.src_token_ids.append(result['src_vocab_id_list']) 285 | self.tgt_token_ids.append(result['tgt_vocab_id_list']) 286 | self.src_lens.append(result['src_len']) 287 | self.tgt_lens.append(result['tgt_len']) 288 | 289 | print(f"{count*100/total}% data is unparseable") 290 | 291 | def sort(self): 292 | self.data_indices = np.argsort(self.src_lens) 293 | 294 | def shuffle_in_bucket(self, bucket_size: int): 295 | for i in range(0, self.data_size, bucket_size): 296 | np.random.shuffle(self.data_indices[i:i + bucket_size]) 297 | 298 | def batch(self, batch_type: str, batch_size: int, verbose=False): 299 | 300 | self.batch_sizes = [] 301 | if batch_type.startswith("tokens"): 302 | sample_size = 0 303 | max_batch_src_len = 0 304 | max_batch_tgt_len = 0 305 | 306 | for data_idx in self.data_indices: 307 | src_len = self.src_lens[data_idx] 308 | tgt_len = self.tgt_lens[data_idx] 309 | 310 | max_batch_src_len = max(src_len, max_batch_src_len) 311 | max_batch_tgt_len = max(tgt_len, max_batch_tgt_len) 312 | 313 | if batch_type == "tokens" and \ 314 | max_batch_src_len * (sample_size + 1) <= batch_size: 315 | sample_size += 1 316 | elif batch_type == "tokens_sum" and \ 317 | (max_batch_src_len + max_batch_tgt_len) * (sample_size + 1) <= batch_size: 318 | sample_size += 1 319 | else: 320 | self.batch_sizes.append(sample_size) 321 | 322 | sample_size = 1 323 | max_batch_src_len = src_len 324 | max_batch_tgt_len = tgt_len 325 | 326 | # lastly 327 | self.batch_sizes.append(sample_size) 328 | self.batch_sizes = np.array(self.batch_sizes) 329 | assert np.sum(self.batch_sizes) == self.data_size, \ 330 | f"Size mismatch! Data size: {self.data_size}, sum batch sizes: {np.sum(self.batch_sizes)}" 331 | 332 | self.batch_ends = np.cumsum(self.batch_sizes) 333 | self.batch_starts = np.concatenate([[0], self.batch_ends[:-1]]) 334 | 335 | else: 336 | raise ValueError(f"batch_type {batch_type} not supported!") 337 | 338 | 339 | def __len__(self): 340 | return len(self.batch_sizes) 341 | 342 | def __getitem__(self, idx : int): 343 | batch_index = idx 344 | 345 | batch_start = self.batch_starts[batch_index] 346 | batch_end = self.batch_ends[batch_index] 347 | 348 | data_indices = self.data_indices[batch_start:batch_end] 349 | 350 | # print(self.src_lens[data_indices], data_indices) 351 | max_len = max(self.src_lens[data_indices]) 352 | 353 | src_token_id_batch = [] 354 | src_len_batch = [] 355 | src_matrix_batch = [] 356 | tgt_matrix_batch = [] 357 | src_smiles_batch = [] 358 | tgt_smiles_batch = [] 359 | for data_index in data_indices: 360 | # src_token_id, _ = smi2vocabid(self.src_smis[data_index]) 361 | src_token_id = self.src_token_ids[data_index] 362 | src_len = self.src_lens[data_index] 363 | src_token_id = np.pad(src_token_id, (0, max_len - src_len), 364 | mode='constant', constant_values=0) # constant value 0 based on 'PAD' in ELEM_LIST 365 | src_token_id = torch.as_tensor(src_token_id, dtype=torch.long) 366 | 367 | src_token_id_batch.append(src_token_id) 368 | 369 | src_matrix = get_BE_matrix(self.src_smis[data_index]) 370 | src_matrix = np.pad(src_matrix, ((0, max_len - src_len), (0, max_len - src_len)), 371 | mode='constant', constant_values=MATRIX_PAD) 372 | src_len_batch.append(src_len) 373 | src_matrix_batch.append(src_matrix) 374 | src_smiles_batch.append(self.src_smis[data_index]) 375 | 376 | if not self.reactant_only: 377 | tgt_matrix = get_BE_matrix(self.tgt_smis[data_index]) 378 | tgt_matrix = np.pad(tgt_matrix, ((0, max_len - src_len), (0, max_len - src_len)), 379 | mode='constant', constant_values=MATRIX_PAD) 380 | tgt_matrix_batch.append(tgt_matrix) 381 | tgt_smiles_batch.append(self.tgt_smis[data_index]) 382 | 383 | src_data_indices = torch.as_tensor(data_indices, dtype=torch.long) 384 | src_len_batch = torch.as_tensor(src_len_batch, dtype=torch.long) 385 | src_token_id_batch = torch.stack(src_token_id_batch) 386 | src_matrix_batch = torch.as_tensor(np.stack(src_matrix_batch), dtype=torch.float) 387 | if not self.reactant_only: 388 | tgt_matrix_batch = torch.as_tensor(np.stack(tgt_matrix_batch), dtype=torch.float) 389 | else: tgt_matrix_batch = src_matrix_batch 390 | 391 | node_mask = (src_matrix_batch[:, :, 0] != MATRIX_PAD) 392 | matrix_masks = (node_mask.unsqueeze(1) * node_mask.unsqueeze(2)).long() 393 | 394 | reaction_batch = ReactionBatch( 395 | src_data_indices=src_data_indices, 396 | src_token_ids=src_token_id_batch, 397 | src_lens=src_len_batch, 398 | src_matrices=src_matrix_batch, 399 | tgt_matrices=tgt_matrix_batch, 400 | matrix_masks=matrix_masks, 401 | src_smiles_list=src_smiles_batch, 402 | tgt_smiles_list=tgt_smiles_batch 403 | ) 404 | 405 | return reaction_batch 406 | --------------------------------------------------------------------------------