├── .gitignore
├── FlowER.png
├── examples
├── samples
│ └── unrecognized_pistachio.pkl
├── beam_predict_seed.py
└── run_beam_predict_seed.sh
├── requirements.txt
├── scripts
├── search.sh
├── train.sh
├── eval_multiGPU.sh
└── search_multiGPU.sh
├── LICENSE
├── run_FlowER_large_newData.sh
├── run_FlowER_large_oldData.sh
├── utils
├── attn_utils.py
├── rounding.py
├── train_utils.py
└── data_utils.py
├── settings.py
├── model
├── flow_matching.py
└── attn_encoder.py
├── sequence_evaluation.py
├── README.md
├── train.py
├── beam_predict.py
├── beam_predict_multiGPU.py
└── eval_multiGPU.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/*
2 | logs/*
3 | checkpoints/*
4 | results/*
5 | *.pyc
--------------------------------------------------------------------------------
/FlowER.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FongMunHong/FlowER/HEAD/FlowER.png
--------------------------------------------------------------------------------
/examples/samples/unrecognized_pistachio.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FongMunHong/FlowER/HEAD/examples/samples/unrecognized_pistachio.pkl
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.4.0
2 | numpy==1.26.4
3 | torchdiffeq==0.2.4
4 | rdkit==2024.3.3
5 | iteround==1.0.4
6 | networkx==3.3
7 | matplotlib==3.9.1
--------------------------------------------------------------------------------
/scripts/search.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | python beam_predict.py
4 |
5 | # To ensure reproducibility of results in paper in spite of stochasticity,
6 | # we aggregate results from a 100 random seeds
7 | # python examples/beam_predict_seed.py
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | torchrun \
4 | --node_rank="$NODE_RANK" \
5 | --nnodes="$NUM_NODES"\
6 | --nproc_per_node="$NUM_GPUS_PER_NODE" \
7 | --rdzv-id=456 \
8 | --rdzv-backend=c10d \
9 | --rdzv-endpoint="$MASTER_ADDR:$MASTER_PORT" \
10 | train.py
--------------------------------------------------------------------------------
/scripts/eval_multiGPU.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | torchrun \
4 | --node_rank="$NODE_RANK" \
5 | --nnodes="$NUM_NODES"\
6 | --nproc_per_node="$NUM_GPUS_PER_NODE" \
7 | --rdzv-id=456 \
8 | --rdzv-backend=c10d \
9 | --rdzv-endpoint="$MASTER_ADDR:$MASTER_PORT" \
10 | eval_multiGPU.py
11 |
--------------------------------------------------------------------------------
/scripts/search_multiGPU.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | torchrun \
4 | --node_rank="$NODE_RANK" \
5 | --nnodes="$NUM_NODES"\
6 | --nproc_per_node="$NUM_GPUS_PER_NODE" \
7 | --rdzv-id=456 \
8 | --rdzv-backend=c10d \
9 | --rdzv-endpoint="$MASTER_ADDR:$MASTER_PORT" \
10 | beam_predict_multiGPU.py
11 |
12 |
--------------------------------------------------------------------------------
/examples/beam_predict_seed.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.')
3 | import os
4 | import torch
5 | import numpy as np
6 | from beam_predict import Args, setup_logger, log_args, main, log_rank_0
7 |
8 | if __name__ == "__main__":
9 | args = Args
10 | args.local_rank = int(os.environ["LOCAL_RANK"]) if os.environ.get("LOCAL_RANK") else -1
11 | logger = setup_logger(args, "beam")
12 | log_args(args, 'evaluation')
13 |
14 | for i in range(100):
15 | seed = torch.seed()
16 | log_rank_0(f"Current_seed: {seed}")
17 | torch.manual_seed(seed)
18 |
19 | main(args, seed=0)
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 FongMunHong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/examples/run_beam_predict_seed.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | ## This script is intended for recreating paper results reliabily.
4 |
5 | ### RUN sh examples/run_beam_predict_seed.sh - outside of examples folder
6 |
7 | # export DATA_NAME="flower_dataset" # old dataset
8 | # export EXP_NAME="best_large_hyperparam"
9 | # export EMB_DIM=256
10 | # export RBF_HIGH=18
11 | # export RBF_GAP=0.1
12 | # export SIGMA=0.15
13 |
14 | # export MODEL_NAME="model.2370000_78.pt" # your trained checkpoint here
15 |
16 | export DATA_NAME="flower_new_dataset"
17 | export EXP_NAME="best_large_hyperparam"
18 | export EMB_DIM=256
19 | export RBF_HIGH=12
20 | export RBF_GAP=0.1
21 | export SIGMA=0.15
22 |
23 | export MODEL_NAME="model.2880000_95.pt" # your trained checkpoint here
24 |
25 |
26 | export TRAIN_BATCH_SIZE=4096
27 | export VAL_BATCH_SIZE=4096
28 | export TEST_BATCH_SIZE=4096
29 |
30 | export NUM_WORKERS=4
31 | export CUDA_VISIBLE_DEVICES=0
32 | export NUM_GPUS_PER_NODE=1
33 |
34 | export NUM_NODES=1
35 | export NODE_RANK=0
36 | export MASTER_ADDR=localhost
37 | export MASTER_PORT=1235
38 |
39 | export TRAIN_FILE=$PWD/data/$DATA_NAME/train.txt
40 | export VAL_FILE=$PWD/data/$DATA_NAME/val.txt
41 | export TEST_FILE=$PWD/data/$DATA_NAME/beam.txt
42 |
43 |
44 | export MODEL_PATH=$PWD/checkpoints/$DATA_NAME/$EXP_NAME/
45 | export RESULT_PATH=$PWD/results/$DATA_NAME/$EXP_NAME/
46 |
47 |
48 | python examples/beam_predict_seed.py
49 |
--------------------------------------------------------------------------------
/run_FlowER_large_newData.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 |
4 | export DATA_NAME="flower_new_dataset"
5 | export EXP_NAME="best_large_hyperparam"
6 | export EMB_DIM=256
7 | export RBF_HIGH=12
8 | export RBF_GAP=0.1
9 | export SIGMA=0.15
10 |
11 | export MODEL_NAME="model.2880000_95.pt" # your trained checkpoint here
12 |
13 | export TRAIN_BATCH_SIZE=4096
14 | export VAL_BATCH_SIZE=4096
15 | export TEST_BATCH_SIZE=4096
16 |
17 | export NUM_WORKERS=4
18 | export CUDA_VISIBLE_DEVICES=0
19 | export NUM_GPUS_PER_NODE=1
20 |
21 | export NUM_NODES=1
22 | export NODE_RANK=0
23 | export MASTER_ADDR=localhost
24 | export MASTER_PORT=1235
25 |
26 | export TRAIN_FILE=$PWD/data/$DATA_NAME/train.txt
27 | export VAL_FILE=$PWD/data/$DATA_NAME/val.txt
28 | # export TEST_FILE=$PWD/data/$DATA_NAME/test.txt
29 | export TEST_FILE=$PWD/data/$DATA_NAME/beam.txt
30 |
31 |
32 | export MODEL_PATH=$PWD/checkpoints/$DATA_NAME/$EXP_NAME/
33 | export RESULT_PATH=$PWD/results/$DATA_NAME/$EXP_NAME/
34 |
35 |
36 | # [ -f $TRAIN_FILE ] || { echo $TRAIN_FILE does not exist; exit; }
37 | # [ -f $VAL_FILE ] || { echo $VAL_FILE does not exist; exit; }
38 | # [ -f $TEST_FILE ] || { echo $TEST_FILE does not exist; exit; }
39 |
40 |
41 | export SCALE=4 # smaller sample size during training validation
42 | # sh scripts/train.sh
43 |
44 | export SCALE=1 # larger sample size during testing
45 | # sh scripts/eval_multiGPU.sh
46 | sh scripts/search.sh
47 | # sh scripts/search_multiGPU.sh
--------------------------------------------------------------------------------
/run_FlowER_large_oldData.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # Reporting for FlowER-large trained on oldData
4 |
5 | export DATA_NAME="flower_dataset" # old dataset
6 | export EXP_NAME="best_large_hyperparam"
7 | export EMB_DIM=256
8 | export RBF_HIGH=18
9 | export RBF_GAP=0.1
10 | export SIGMA=0.15
11 |
12 | export MODEL_NAME="model.2370000_78.pt" # your trained checkpoint here
13 |
14 | export TRAIN_BATCH_SIZE=4096
15 | export VAL_BATCH_SIZE=4096
16 | export TEST_BATCH_SIZE=4096
17 |
18 | export NUM_WORKERS=4
19 | export CUDA_VISIBLE_DEVICES=0
20 | export NUM_GPUS_PER_NODE=1
21 |
22 | export NUM_NODES=1
23 | export NODE_RANK=0
24 | export MASTER_ADDR=localhost
25 | export MASTER_PORT=1235
26 |
27 | export TRAIN_FILE=$PWD/data/$DATA_NAME/train.txt
28 | export VAL_FILE=$PWD/data/$DATA_NAME/val.txt
29 | # export TEST_FILE=$PWD/data/$DATA_NAME/test.txt
30 | export TEST_FILE=$PWD/data/$DATA_NAME/beam.txt
31 |
32 |
33 | export MODEL_PATH=$PWD/checkpoints/$DATA_NAME/$EXP_NAME/
34 | export RESULT_PATH=$PWD/results/$DATA_NAME/$EXP_NAME/
35 |
36 |
37 | [ -f $TRAIN_FILE ] || { echo $TRAIN_FILE does not exist; exit; }
38 | [ -f $VAL_FILE ] || { echo $VAL_FILE does not exist; exit; }
39 | [ -f $TEST_FILE ] || { echo $TEST_FILE does not exist; exit; }
40 |
41 |
42 | export SCALE=4 # smaller sample size during training validation
43 | sh scripts/train.sh
44 |
45 | export SCALE=1 # larger sample size during testing
46 | # sh scripts/eval_multiGPU.sh
47 | # sh scripts/search.sh
48 | # sh scripts/search_multiGPU.sh
--------------------------------------------------------------------------------
/utils/attn_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class PositionwiseFeedForward(nn.Module):
5 | """ A two-layer Feed-Forward-Network with residual layer norm.
6 |
7 | Args:
8 | d_model (int): the size of input for the first-layer of the FFN.
9 | d_ff (int): the hidden layer size of the second-layer
10 | of the FNN.
11 | dropout (float): dropout probability in :math:`[0, 1)`.
12 | """
13 |
14 | def __init__(self, d_model, d_ff, dropout=0.1):
15 | super(PositionwiseFeedForward, self).__init__()
16 | self.w_1 = nn.Linear(d_model, d_ff)
17 | self.w_2 = nn.Linear(d_ff, d_model)
18 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
19 | self.dropout_1 = nn.Dropout(dropout)
20 | self.gelu = nn.GELU()
21 | self.dropout_2 = nn.Dropout(dropout)
22 |
23 | def forward(self, x):
24 | """Layer definition.
25 |
26 | Args:
27 | x: ``(batch_size, input_len, model_dim)``
28 |
29 | Returns:
30 | (FloatTensor): Output ``(batch_size, input_len, model_dim)``.
31 | """
32 |
33 | inter = self.dropout_1(self.gelu(self.w_1(self.layer_norm(x))))
34 | output = self.dropout_2(self.w_2(inter))
35 | return output + x
36 |
37 |
38 | def sequence_mask(lengths, max_len=None):
39 | """
40 | Creates a boolean mask from sequence lengths.
41 | """
42 | batch_size = lengths.numel()
43 | max_len = max_len or lengths.max()
44 | return (torch.arange(0, max_len, device=lengths.device)
45 | .type_as(lengths)
46 | .repeat(batch_size, 1)
47 | .lt(lengths.unsqueeze(1)))
48 |
49 |
--------------------------------------------------------------------------------
/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | DATA_NAME = os.environ.get("DATA_NAME", "USPTO")
5 | EXP_NAME = os.environ.get("EXP_NAME", "")
6 |
7 | SCALE = int(os.environ.get("SCALE", 4)) # train & val
8 | # SCALE = 1 # test
9 | SAMPLE_SIZE = 64 // SCALE
10 | NUM_GPU = int(os.environ.get("NUM_GPUS_PER_NODE", 1))
11 |
12 |
13 | TRAIN_BATCH_SIZE = int(os.environ.get("TRAIN_BATCH_SIZE", 4096))
14 | VAL_BATCH_SIZE = int(os.environ.get("VAL_BATCH_SIZE", 4096))
15 | TEST_BATCH_SIZE = int(os.environ.get("TEST_BATCH_SIZE", 512 * NUM_GPU * SCALE))
16 |
17 | NUM_NODES = int(os.environ.get("NUM_NODES", 1))
18 | ACCUMULATION_COUNT = int(os.environ.get("ACCUMULATION_COUNT", 1))
19 | NUM_WORKERS = int(os.environ.get("NUM_WORKERS", 16))
20 |
21 | MODEL_NAME = os.environ.get("MODEL_NAME")
22 |
23 | class Args:
24 | # train #
25 | model_name = MODEL_NAME
26 | exp_name = EXP_NAME
27 | train_path = os.environ.get("TRAIN_FILE")
28 | val_path = os.environ.get("VAL_FILE")
29 | test_path = os.environ.get("TEST_FILE")
30 | model_path = os.environ.get("MODEL_PATH")
31 | result_path = os.environ.get("RESULT_PATH")
32 | data_name = f"{DATA_NAME}"
33 | log_file = f"FlowER"
34 | load_from = ""
35 | # resume = True
36 | # load_from = f"{model_path}{MODEL_NAME}"
37 |
38 | backend = "nccl"
39 | num_workers = NUM_WORKERS
40 | emb_dim = int(os.environ.get("EMB_DIM"))
41 | enc_num_layers = 12
42 | post_processing_layers = 1
43 | enc_heads = 32
44 | enc_filter_size = 2048
45 | dropout = 0.0
46 | attn_dropout = 0.0
47 | rel_pos = "emb_only"
48 | shared_attention_layer = 0
49 | sigma = float(os.environ.get("SIGMA"))
50 | train_batch_size = (TRAIN_BATCH_SIZE / ACCUMULATION_COUNT / NUM_GPU / NUM_NODES)
51 | val_batch_size = (VAL_BATCH_SIZE / ACCUMULATION_COUNT / NUM_GPU / NUM_NODES)
52 | test_batch_size = TEST_BATCH_SIZE
53 | batch_type = "tokens_sum"
54 | lr = 0.0001
55 | beta1 = 0.9
56 | beta2 = 0.998
57 | eps = 1e-9
58 | weight_decay = 1e-2
59 | warmup_steps = 30000
60 | clip_norm = 200
61 |
62 |
63 | epoch = int(os.environ.get("EPOCH", 100))
64 | max_steps = 3000000
65 | accumulation_count = ACCUMULATION_COUNT
66 | save_iter = int(os.environ.get("SAVE_ITER", 30000))
67 | log_iter = int(os.environ.get("LOG_ITER", 100))
68 | eval_iter = int(os.environ.get("EVAL_ITER", 30000))
69 |
70 |
71 | sample_size = SAMPLE_SIZE
72 | rbf_low = 0
73 | rbf_high = float(os.environ.get("RBF_HIGH"))
74 | rbf_gap = float(os.environ.get("RBF_GAP"))
75 |
76 | # validation #
77 | # do_validate = True
78 | # steps2validate = ["1050000", "1320000", "1500000", "930000", "1020000"]
79 |
80 | # inference #
81 | do_validate = False
82 |
83 | # beam-search #
84 | beam_size = 5
85 | nbest = 3
86 | max_depth = 15
87 | chunk_size = 50
--------------------------------------------------------------------------------
/utils/rounding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional
3 |
4 | def saferound_tensor(
5 | x: torch.Tensor,
6 | places: int,
7 | strategy: str = "difference",
8 | topline: Optional[float] = None
9 | ) -> torch.Tensor:
10 | """
11 | Round a tensor elementwise to `places` decimal places, adjusting
12 | a minimal number of entries so that the total sum is exactly preserved.
13 |
14 | Args:
15 | x (torch.Tensor): input tensor of floats.
16 | places (int): number of decimal places to round to.
17 | strategy (str): one of {"difference","largest","smallest"}:
18 | - "difference": pick the entries with largest fractional parts first.
19 | - "largest" : pick the largest values first.
20 | - "smallest" : pick the smallest values first.
21 | topline (float, optional): if given, override the target sum
22 | with `topline`. Otherwise target is x.sum().
23 |
24 | Returns:
25 | torch.Tensor: same shape as `x`, rounded to `places`, but whose
26 | sum exactly equals the rounded(original_sum, places).
27 | """
28 | assert isinstance(places, int), "places must be integer"
29 | assert strategy in ("difference","largest","smallest"), f"Unknown strategy {strategy}"
30 |
31 | # Flatten for simplicity
32 | orig = x.view(-1).to(dtype=torch.float64)
33 | N = orig.numel()
34 |
35 | # Determine the exact sum we need to hit
36 | total = topline if topline is not None else orig.sum().item()
37 | scale = 10 ** places
38 | target_int = int(round(total * scale))
39 |
40 | # Scale and take floor/ceil
41 | scaled = orig * scale
42 | low = torch.floor(scaled).to(torch.int64) # integer floors
43 | high = torch.ceil(scaled).to(torch.int64) # integer ceils
44 | sum_low = int(low.sum().item())
45 | residual = target_int - sum_low # how many +1’s we need
46 |
47 | if residual != 0:
48 | # Depending on strategy, create a sort key
49 | if strategy == "difference":
50 | # fractional part, descending
51 | frac = (scaled - low).cpu()
52 | _, indices = torch.sort(frac, descending=True)
53 | elif strategy == "largest":
54 | # values descending
55 | _, indices = torch.sort(orig.cpu(), descending=True)
56 | else: # "smallest"
57 | # values ascending
58 | _, indices = torch.sort(orig.cpu(), descending=False)
59 |
60 | # Pick exactly `abs(residual)` indices
61 | k = min(abs(residual), N)
62 | chosen = indices[:k]
63 |
64 | # Apply the +1 or -1
65 | if residual > 0:
66 | low[chosen] += 1
67 | else:
68 | # In the very rare case sum_low > target_int, we go back down
69 | low[chosen] -= 1
70 |
71 | # Convert back to float decimals
72 | rounded_flat = low.to(torch.float64).mul_(1.0 / scale)
73 | # reshape back
74 | return rounded_flat.view_as(x).to(dtype=x.dtype)
75 |
76 |
--------------------------------------------------------------------------------
/model/flow_matching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from utils.data_utils import MATRIX_PAD
4 |
5 | def zero_center_func(x, node_mask):
6 | N = node_mask.sum()
7 | mean = torch.sum(x) / N
8 | x = x - mean * node_mask
9 | return x
10 |
11 | class ConditionalFlowMatcher(nn.Module):
12 | """Base class for conditional flow matching methods. This class implements the independent
13 | conditional flow matching methods from [1] and serves as a parent class for all other flow
14 | matching methods.
15 |
16 | It implements:
17 | - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
18 | - conditional flow matching ut(x1|x0) = x1 - x0
19 | - score function $\nabla log p_t(x|x0, x1)$
20 | """
21 |
22 | def __init__(self, args):
23 | r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
24 |
25 | Parameters
26 | ----------
27 | sigma : float
28 | """
29 | super().__init__()
30 | self.args = args
31 | self.device = args.device
32 | self.sigma = args.sigma
33 | self.dim = args.emb_dim
34 |
35 | def zero_centered_noise(self, size, node_mask_batch):
36 | rand = torch.randn(size).to(self.device)
37 | x_batch = rand * node_mask_batch
38 | map_zero_center = torch.vmap(zero_center_func) # map on multiple batch
39 | return map_zero_center(x_batch, node_mask_batch).masked_fill(~(node_mask_batch.bool()), 1e-19)
40 |
41 | def sample_be_matrix(self, matrix):
42 | node_mask = (matrix[:, :, 0] != MATRIX_PAD)
43 | masks = (node_mask.unsqueeze(1) * node_mask.unsqueeze(2)).long()
44 |
45 | noise = self.zero_centered_noise(masks.shape, masks) # (n, n, b, d)
46 | noise = 0.5 * (noise + noise.transpose(1, 2))
47 | matrix = matrix + noise * self.sigma
48 |
49 | return matrix
50 |
51 | def sample_conditional_pt(self, x0, x1, t):
52 | """
53 | Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
54 |
55 | Parameters
56 | ----------
57 | x0 : Tensor, shape (bs, *dim)
58 | represents the source minibatch
59 | x1 : Tensor, shape (bs, *dim)
60 | represents the target minibatch
61 | t : FloatTensor, shape (bs)
62 |
63 | Returns
64 | -------
65 | xt : Tensor, shape (bs, *dim)
66 |
67 | References
68 | ----------
69 | [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
70 | """
71 | t = t.reshape(-1, *([1] * (x0.dim() - 1)))
72 | mu_t = t * x1 + (1 - t) * x0
73 | return self.sample_be_matrix(mu_t)
74 |
75 | def compute_conditional_vector_field(self, x0, x1):
76 | """
77 | Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
78 |
79 | Parameters
80 | ----------
81 | x0 : Tensor, shape (bs, *dim)
82 | represents the source minibatch
83 | x1 : Tensor, shape (bs, *dim)
84 | represents the target minibatch
85 |
86 | Returns
87 | -------
88 | ut : conditional vector field ut(x1|x0) = x1 - x0
89 |
90 | References
91 | ----------
92 | [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
93 | """
94 | return x1 - x0
--------------------------------------------------------------------------------
/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import numpy as np
4 | import os
5 | import random
6 | import sys
7 | import torch
8 | import torch.nn as nn
9 | from datetime import datetime
10 | from rdkit import RDLogger
11 | from torch.optim.lr_scheduler import _LRScheduler
12 |
13 | def log_args(args, phase: str):
14 | log_rank_0(f"Logging {phase} arguments")
15 | for k, v in vars(args).items():
16 | if "__" in k: continue
17 | log_rank_0(f"**** {k} = *{v}*")
18 |
19 | def param_count(model: nn.Module) -> int:
20 | return sum(param.numel() for param in model.parameters() if param.requires_grad)
21 |
22 |
23 | def param_norm(m):
24 | return math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()]))
25 |
26 |
27 | def grad_norm(m):
28 | return math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None]))
29 |
30 |
31 | def get_lr(optimizer):
32 | for param_group in optimizer.param_groups:
33 | return param_group["lr"]
34 |
35 |
36 | def set_seed(seed):
37 | random.seed(seed)
38 | os.environ['PYTHONHASHSEED'] = str(seed)
39 | np.random.seed(seed)
40 | torch.manual_seed(seed)
41 | torch.cuda.manual_seed_all(seed)
42 | torch.backends.cudnn.benchmark = False
43 | torch.backends.cudnn.deterministic = True
44 |
45 |
46 | def setup_logger(args, phase, warning_off: bool = False):
47 | if warning_off:
48 | RDLogger.DisableLog("rdApp.*")
49 | else:
50 | RDLogger.DisableLog("rdApp.warning")
51 |
52 | os.makedirs(f"./logs/{args.data_name}/{args.exp_name}", exist_ok=True)
53 | dt = datetime.strftime(datetime.now(), "%y%m%d-%H%Mh")
54 |
55 | logger = logging.getLogger()
56 | logger.setLevel(logging.INFO)
57 | fh = logging.FileHandler(f"./logs/{args.data_name}/{args.exp_name}/{phase}_{args.log_file}.{dt}")
58 | sh = logging.StreamHandler(sys.stdout)
59 | fh.setLevel(logging.INFO)
60 | sh.setLevel(logging.INFO)
61 | logger.addHandler(fh)
62 | logger.addHandler(sh)
63 |
64 | return logger
65 |
66 |
67 | def log_rank_0(message):
68 | if torch.distributed.is_initialized():
69 | if torch.distributed.get_rank() == 0:
70 | logging.info(message)
71 | sys.stdout.flush()
72 | else:
73 | logging.info(message)
74 | sys.stdout.flush()
75 |
76 |
77 | def log_tensor(tensor, tensor_name: str, shape_only=False):
78 | log_rank_0(f"--------------------------{tensor_name}--------------------------")
79 | if not shape_only:
80 | log_rank_0(tensor)
81 | if isinstance(tensor, torch.Tensor):
82 | log_rank_0(tensor.shape)
83 | elif isinstance(tensor, np.ndarray):
84 | log_rank_0(tensor.shape)
85 | elif isinstance(tensor, list):
86 | try:
87 | for item in tensor:
88 | log_rank_0(item.shape)
89 | except Exception as e:
90 | log_rank_0(f"Error: {e}")
91 | log_rank_0("List items are not tensors, skip shape logging.")
92 |
93 |
94 | class NoamLR(_LRScheduler):
95 | """
96 | Adapted from https://github.com/tugstugi/pytorch-saltnet/blob/master/utils/lr_scheduler.py
97 |
98 | Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate
99 | linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally
100 | to the inverse square root of the step number, scaled by the inverse square root of the
101 | dimensionality of the model. Time will tell if this is just madness or it's actually important.
102 | Parameters
103 | ----------
104 | warmup_steps: ``int``, required.
105 | The number of steps to linearly increase the learning rate.
106 | """
107 | def __init__(self, optimizer, model_size, warmup_steps):
108 | self.model_size = model_size
109 | self.warmup_steps = warmup_steps
110 | super().__init__(optimizer)
111 |
112 | def get_lr(self):
113 | step = max(1, self._step_count)
114 | scale = self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup_steps**(-1.5))
115 | scale *= 1e4
116 |
117 | return [base_lr * scale for base_lr in self.base_lrs]
118 |
--------------------------------------------------------------------------------
/sequence_evaluation.py:
--------------------------------------------------------------------------------
1 |
2 | import networkx as nx
3 | import numpy as np
4 | from rdkit import Chem
5 | from collections import defaultdict
6 | from multiprocessing import Pool, cpu_count
7 | from rdkit import RDLogger
8 | RDLogger.DisableLog('rdApp.*')
9 |
10 |
11 | BAD = 1000
12 |
13 | def assign_rank(predictions):
14 | sorted_items = sorted(predictions, key=lambda x: x[1], reverse=True)
15 |
16 | rank_list = []
17 | current_rank = 0
18 | prev_count = None
19 |
20 | for i, (smi, count, _) in enumerate(sorted_items):
21 | # If the current value is the same as the previous one, assign the same rank
22 | if prev_count == count:
23 | rank_list.append((current_rank, smi, count, _))
24 | else:
25 | # Otherwise, assign a new rank
26 | current_rank = i
27 | rank_list.append((current_rank, smi, count, _))
28 |
29 | prev_count = count
30 |
31 | return rank_list
32 |
33 | def clean(smi):
34 | # try:
35 | mol = Chem.MolFromSmiles(smi, sanitize=False)
36 | mol = Chem.RemoveHs(mol)
37 | [atom.SetAtomMapNum(0) for atom in mol.GetAtoms()]
38 | return Chem.MolToSmiles(mol, isomericSmiles=False)
39 |
40 | def process_topk_acc_n_seq_rank(line):
41 | sequence_idx, sequence_preds = line
42 | seq_graph_gt = nx.DiGraph()
43 |
44 | new_sequence_edges = {}
45 | gt_neigh = defaultdict(set)
46 | for pred_info in sequence_preds:
47 | rxn = pred_info["rxn"]
48 | reactant, product = rxn.strip().split('>>')
49 | seq_graph_gt.add_edge(reactant, product)
50 | new_sequence_edges[(reactant, product)] = np.inf
51 | gt_neigh[reactant].add(product)
52 |
53 |
54 | starting_reac = [node for node, in_degree in seq_graph_gt.in_degree() if in_degree == 0]
55 | terminal_prods = list(nx.nodes_with_selfloops(seq_graph_gt))
56 |
57 | if len(starting_reac) != 1: return BAD, 1, sequence_idx # if starting reactant is not 1
58 | starting_reac = starting_reac[0]
59 |
60 | if len(terminal_prods) == 0: return BAD, 2, sequence_idx # if we have a loop
61 |
62 | # merge predictions
63 | for pred_info in sequence_preds:
64 | reactant, _ = pred_info["rxn"].strip().split('>>')
65 | predictions = pred_info["predictions"]
66 | predictions = assign_rank(predictions)
67 | for rank, pred, pred_count, _ in predictions:
68 | if pred in gt_neigh[reactant]:
69 | cur_rank = new_sequence_edges.get((reactant, pred))
70 | if cur_rank == np.inf: # and it's the first time
71 | new_sequence_edges[(reactant, pred)] = rank
72 |
73 | seq_graph_pred = nx.DiGraph()
74 | for (reac, prod), rank in new_sequence_edges.items():
75 | seq_graph_pred.add_edge(reac, prod, weight=rank)
76 |
77 | max_depth = 0
78 | min_sequences_rank = np.inf
79 | for terminal in terminal_prods:
80 | for path in nx.all_simple_paths(seq_graph_pred, source=starting_reac, target=terminal):
81 | max_depth = max(len(path), max_depth)
82 | edges = nx.utils.pairwise(path)
83 | ranks = [seq_graph_pred.get_edge_data(u, v)['weight'] for u, v in edges]
84 | max_topk_within_one_seq = max(ranks)
85 | min_sequences_rank = min(max_topk_within_one_seq, min_sequences_rank)
86 |
87 | terminal_prods = [clean(prod) for prod in terminal_prods]
88 | return min_sequences_rank, 0, sequence_idx, (clean(starting_reac), terminal_prods), max_depth # min of all sequences
89 |
90 |
91 | def remove_atom_map_rxn(line):
92 | ps = Chem.SmilesParserParams()
93 | ps.removeHs = False
94 | ps.sanitize = True
95 | try:
96 | rxn, sequence_idx = line.strip().split("|")
97 | except:
98 | rxn, rxn_class, condition, elem_step, sequence_idx = line.strip().split("|")
99 |
100 | reactant, product = rxn.split(">>")
101 | reac = Chem.MolFromSmiles(reactant, ps)
102 | prod = Chem.MolFromSmiles(product, ps)
103 |
104 | assert reac is not None
105 | assert prod is not None
106 |
107 | [a.ClearProp('molAtomMapNumber') for a in reac.GetAtoms()]
108 | [a.ClearProp('molAtomMapNumber') for a in prod.GetAtoms()]
109 |
110 | reac_smi = Chem.MolToSmiles(reac, isomericSmiles=False)
111 | prod_smi = Chem.MolToSmiles(prod, isomericSmiles=False)
112 |
113 | reac = Chem.MolFromSmiles(reac_smi, ps)
114 | prod = Chem.MolFromSmiles(prod_smi, ps)
115 | reac_smi = Chem.MolToSmiles(reac, isomericSmiles=False)
116 | prod_smi = Chem.MolToSmiles(prod, isomericSmiles=False)
117 |
118 | rxn = f"{reac_smi}>>{prod_smi}|{sequence_idx}"
119 |
120 | return rxn
121 |
122 | def reparse(line):
123 | ps = Chem.SmilesParserParams()
124 | ps.removeHs = False
125 | ps.sanitize = True
126 | metrics, not_sym, predictions = line.strip().split("|")
127 | predictions = eval(predictions)
128 | predictions = sorted(predictions, key=lambda x: x[1], reverse=True)
129 | # new_predictions = []
130 | pred_dict = defaultdict(int)
131 | for (pred, pred_count, val) in predictions:
132 | pred_mol = Chem.MolFromSmiles(pred, ps)
133 | if pred_mol is None: continue
134 | pred_smi = Chem.MolToSmiles(pred_mol, isomericSmiles=False)
135 | # new_predictions.append((pred_smi, pred_count, val))
136 | pred_dict[pred_smi] += pred_count
137 |
138 | pred_dict = dict(sorted(pred_dict.items(), key=lambda x: x[1], reverse=True))
139 | new_predictions = [(pred_smi, prob, True) for pred_smi, prob in pred_dict.items()]
140 |
141 | return f"{metrics}|{not_sym}|{new_predictions}"
142 |
143 |
144 | with open("data/flower_dataset/test.txt") as gt_o, \
145 | open("results/flower_dataset/best_hyperparam/result-32-1440000_47.txt") as result_o:
146 |
147 | # Preprocessing
148 | result = result_o.readlines()
149 | gt = gt_o.readlines()
150 |
151 | assert len(gt) == len(result)
152 |
153 | print("Ground Truth lines")
154 | p = Pool(cpu_count())
155 | gt = p.imap(remove_atom_map_rxn, (rxn for rxn in gt))
156 | gt = list(gt)
157 |
158 | print("Prediction lines")
159 | result = p.imap(reparse, (res for res in result))
160 | result = list(result)
161 |
162 | nbest = 10
163 | topk_accs = np.zeros([len(gt), nbest], dtype=np.float32)
164 |
165 | invalid = []
166 | bag_of_vals = defaultdict(list)
167 | reac_prod_rank = {}
168 | for i, (line_res, line_gt) in enumerate(zip(result, gt)):
169 | metrics, not_sym, predictions = line_res.strip().split("|")
170 | metrics, predictions = eval(metrics), eval(predictions)
171 |
172 | invalid.append(metrics[3] / sum(metrics))
173 | predictions = sorted(predictions, key=lambda x: x[1], reverse=True)
174 |
175 | rxn, sequence_idx = line_gt.strip().split("|")
176 | reactant, product = rxn.split(">>")
177 | if reactant in reac_prod_rank:
178 | extract_rank = reac_prod_rank[(reactant, product)]
179 | topk_accs[i, extract_rank:] = 1
180 | else:
181 | for rank, (pred, pred_count, _) in enumerate(predictions):
182 | if pred == product:
183 | topk_accs[i, rank:] = 1
184 | reac_prod_rank[(reactant, product)] = rank
185 | break
186 |
187 | if sequence_idx in ['PM', 'RS', 'RC', 'PC']: continue
188 | bag_of_vals[sequence_idx].append(
189 | {
190 | "rxn": rxn,
191 | "metrics": metrics,
192 | "predictions": predictions
193 | }
194 | )
195 | avg_invalid = sum(invalid) / len(invalid)
196 | print(f"Valid percentage: {((1 - avg_invalid) * 100): .2f}%")
197 |
198 |
199 | print("Calculating Topk Step Accuracy")
200 | mean_seq_accuracies = np.mean(topk_accs, axis=0)
201 | for n in range(nbest):
202 | line = f"Top {n+1} step accuracy: {mean_seq_accuracies[n] * 100: .2f} %"
203 | print(line)
204 |
205 | sequence_accs = np.zeros([len(bag_of_vals), nbest], dtype=np.float32)
206 |
207 | print("Calculating Pathway Accuracy")
208 |
209 | no_starting_point = 0
210 | no_starting_point_set = set()
211 | count_no_terminal = 0
212 | count_no_terminal_set = set()
213 |
214 | seq_ranks = p.imap(process_topk_acc_n_seq_rank, ((seq_idx, seq_infos) for seq_idx, seq_infos in bag_of_vals.items()))
215 | for i, (rank, error, seq_idx, (reactant, prod_list), max_depth) in enumerate(seq_ranks):
216 | if error == 1:
217 | no_starting_point += 1
218 | no_starting_point_set.add(seq_idx)
219 | if error == 2:
220 | count_no_terminal += 1
221 | count_no_terminal_set.add(seq_idx)
222 | if rank >= nbest: continue
223 | sequence_accs[i, rank:] = 1
224 |
225 | p.close()
226 | p.join()
227 |
228 | # print('no_starting_point', no_starting_point)
229 | # print('no_starting_point_set', no_starting_point_set)
230 | # print('count_no_terminal', count_no_terminal)
231 | # print('count_no_terminal_set', count_no_terminal_set)
232 |
233 | mean_seq_accuracies = np.mean(sequence_accs, axis=0)
234 | for n in range(nbest):
235 | line = f"Top {n+1} pathway accuracy: {mean_seq_accuracies[n] * 100: .2f} %"
236 | print(line)
237 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FlowER: Flow Matching for Electron Redistribution
2 | _Joonyoung F. Joung*, Mun Hong Fong*, Nicholas Casetti, Jordan P. Liles, Ne S. Dassanayake, Connor W. Coley_
3 |
4 | **NOW published in _Nature_!**
5 |
6 | “Electron flow matching for generative reaction mechanism prediction.” *Nature* **645**, 115–123 (2025).
7 | DOI: [10.1038/s41586-025-09426-9](https://doi.org/10.1038/s41586-025-09426-9)
8 |
9 | 
10 |
11 | FlowER uses flow matching to model chemical reaction as a process of electron redistribution, conceptually
12 | aligns with arrow-pushing formalisms. It aims to capture the probabilistic nature of reactions with mass conservation
13 | where multiple outcomes are reached through branching mechanistic networks evolving in time.
14 |
15 | ## Environment Setup
16 | ### System requirements
17 | **Ubuntu**: >= 16.04
18 | **conda**: >= 4.0
19 | **GPU**: at least 25GB Memory with CUDA >= 12.2
20 |
21 | ```bash
22 | $ conda create -n flower python=3.10
23 | $ conda activate flower
24 | $ pip install -r requirements.txt
25 | ```
26 |
27 | ## Data/Model preparation
28 | FlowER is trained on a combination of subset of USPTO-FULL (Dai et al.), RmechDB and PmechDB (Baldi et al.).
29 | To retrain/reproduce FlowER, download `data.zip` and `checkpoints.zip` folder from [this link](https://figshare.com/articles/dataset/FlowER_-_Mechanistic_datasets_and_model_checkpoint/28359407/3), and unzip them, and place under `FlowER/`
30 | The folder structure for the `data` folder is `data/{DATASET_NAME}/{train,val,test}.txt` and `checkpoints` folder is `checkpoints/{DATASET_NAME}/{EXPERIMENT_NAME}/model.{STEP}_{IDX}.pt`
31 |
32 | ## On how FlowER is structured
33 | The workflow of FlowER revolves mainly around 2 files. `run_FlowER_large_(old|new)Data.sh` and `settings.py`.
34 | The main idea is to use comments `#` to turn on/off configurations when training/validating/inferencing FlowER.
35 | `run_FlowER_large_(old|new)Data.sh` allows user to specify your data folder name, experiment name, gpu configuration and choose which scripts to run.
36 | `settings.py` allows user to specify different configurations for different workflows.
37 |
38 | ## Training Pipeline
39 | ### 1. Train FlowER
40 | Ensure that `data/` folder is populated accordingly and `run_FlowER_large_(old|new)Data.sh` is pointing to the correct files.
41 | ```
42 | export TRAIN_FILE=$PWD/data/$DATA_NAME/train.txt
43 | export VAL_FILE=$PWD/data/$DATA_NAME/val.txt
44 | ```
45 | Check `run_FlowER_large_(old|new)Data.sh` has `scripts/train.sh` uncommented.
46 | ```bash
47 | $ sh run_FlowER_large_(old|new)Data.sh
48 | ```
49 |
50 | ### 2. Validate FlowER
51 | You can validate FlowER on the validation set. Then, in `settings.py`, ensure these are uncommented.
52 | ```
53 | # validation #
54 | do_validate = True
55 | steps2validate = ["1050000", "1320000", "1500000", "930000", "1020000"]
56 | ```
57 | `steps2validate` refers to the checkpoints that are selected based on train logs situated at the `/logs` folder.
58 | Check `run_FlowER_large_(old|new)Data.sh` has `scripts/eval.sh` uncommented.
59 | ```bash
60 | $ sh run_FlowER_large_(old|new)Data.sh
61 | ```
62 |
63 |
64 | ### 3. Test FlowER
65 | You can validate FlowER on the test set. Then, in `settings.py`, specify your checkpoint at `MODEL_NAME` and ensure these are uncommented.
66 | ```
67 | # inference #
68 | do_validate = False
69 | ```
70 | Check `run_FlowER_large_(old|new)Data.sh` has `scripts/eval.sh` uncommented.
71 | ```bash
72 | $ sh run_FlowER_large_(old|new)Data.sh
73 | ```
74 |
75 | #### FlowER train/valid/test input
76 | FlowER takes in atom-mapped reaction as input for training, validation and testing. Each of this elementary reaction steps that is trained on FlowER can be grouped together using sequence index during evaluation when running `sequence_evaluation.py`. \
77 | \
78 | An elementary reaction step reaction follows the format of `mapped_reaction|sequence_idx`. Examples are as follows:
79 | ```
80 | [Cl:1][S:2]([Cl:3])=[O:4].[Cl:5][C:6]1=[N:7][S:8][C:9]([C:10](=[O:11])[O:12][H:15])=[C:13]1[Cl:14]>>[Cl:1][S:2]([Cl:3])([O-:4])[O:11][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O+:12][H:15]|11831
81 | [Cl:1][S:2]([Cl:3])([O-:4])[O:11][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O+:12][H:15]>>[Cl-:1].[S:2]([Cl:3])(=[O:4])[O:11][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O+:12][H:15]|11831
82 | [Cl-:1].[S:2]([Cl:3])(=[O:4])[O:11][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O+:12][H:15]>>[Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)([O:11][S:2]([Cl:3])=[O:4])[O:12][H:15]|11831
83 | [Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)([O:11][S:2]([Cl:3])=[O:4])[O:12][H:15]>>[Cl-:3].[Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O+:12][H:15].[S:2](=[O:4])=[O:11]|11831
84 | [Cl-:3].[Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O+:12][H:15].[S:2](=[O:4])=[O:11]>>[Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O:12].[Cl:3][H:15].[S:2](=[O:4])=[O:11]|11831
85 | [Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O:12].[Cl:3][H:15].[S:2](=[O:4])=[O:11]>>[Cl:1][C:10]([C:9]1=[C:13]([Cl:14])[C:6]([Cl:5])=[N:7][S:8]1)=[O:12].[Cl:3][H:15].[S:2](=[O:4])=[O:11]|11831
86 | ```
87 |
88 | Train/Valid/Test hyperparameters
89 |
90 | ### Model Architecture
91 | - **`emb_dim`** - Embedding dimension size of atom embeddings
92 | - **`enc_num_layers`** - Number of transformer layers to be applied
93 | - **`enc_heads`** - Number of attention heads
94 | - **`enc_filter_size`** - Dimension of Feed-Forward Network in Transformer block
95 | - **`(attn)_dropout`** - Dropout for Transformer block (0.0 empirically works well)
96 | - **`sigma`** - Standard deviation of Gaussian noise added for reparameterizing the bond-electron (BE) matrix
97 |
98 | ### Optimization
99 | - **`lr`** - Learning rate for training (NoamLR)
100 | - **`warmup`** - Warmup steps before LR decay (NoamLR)
101 | - **`clip_norm`** - Gradient clipping threshold to prevent exploding gradients
102 | - **`beta1`**, **`beta2`** - Adam optimizer’s momentum terms
103 | - **`eps`** - Adam optimizer’s denominator term for numerical stability
104 | - **`weight_decay`** - L2 regularization strength to prevent overfitting
105 |
106 | ### Input representation (Bond-Electron matrix)
107 | - **`rbf_low`** - Radial Basis Function (RBF) centers lowest value
108 | - **`rbf_high`** - Radial Basis Function (RBF) centers highest value
109 | - **`rbf_gap`** - Glanularity of RBF centers increment
110 |
111 | ### Inference
112 | - **`do_validate`** - True to trigger validation, False to trigger testing
113 | - **`steps2validate`** - List of checkpoints to run FlowER on for validation
114 | - **`sample_size`** - Number of samples FlowER generates for evaluation
115 |
116 |
117 |
118 | ### 4. Use FlowER for search
119 | FlowER mainly uses beam search to seek for plausible mechanistic pathways. Users can input their smiles at `data/flower_dataset/beam.txt`.
120 | Ensure that in `run_FlowER_large_(old|new)Data.sh`, the `TEST_FILE` variable is pointing towards the correct file.
121 | ```
122 | export TEST_FILE=$PWD/data/$DATA_NAME/beam.txt
123 | ```
124 | Ensure that in `settings.py`, beam search configuration are uncommented and specified accordingly.
125 | ```
126 | test_path = f"data/{DATA_NAME}/beam.txt"
127 |
128 | # beam-search #
129 | beam_size = 5
130 | nbest = 3
131 | max_depth = 15
132 | chunk_size = 50
133 | ```
134 | Check `run_FlowER_large_(old|new)Data.sh` has `scripts/search.sh` or `sh scripts/search_multiGPU.sh` uncommented.
135 | ```bash
136 | $ sh run_FlowER_large_(old|new)Data.sh
137 | ```
138 | Visualize your route at `examples/vis_network.ipynb`
139 |
140 | #### FlowER search input
141 | FlowER takes in a non atom-mapped reaction for beam search which can be specified in `beam.txt`
142 | The format of reactants in the file follows `reactant>>product1|product2|...`, where we can specify multiple major and minor products separated by `|` in the file
143 | ```
144 | CC(=O)CC(=O)C(F)(F)F.NNc1cccc(Br)c1>>Cc1cc(C(F)(F)F)n(-c2cccc(Br)c2)n1
145 | CC(=O)CC(=O)C(F)(F)F.NNc1cccc(Br)c1>>Cc1cc(C(F)(F)F)n(-c2cccc(Br)c2)n1|Cc1cc(C(F)(F)F)nn1-c1cccc(Br)c1
146 | ```
147 |
148 | Search hyperparameters
149 |
150 | - **`beam_size`** - Size of top-k selection of candidates (based on cumulative probability)
151 | to be further expanded. Increasing this would make the overall search more comprehensive, but at the
152 | cost of slower runtime.
153 | - **`nbest`** - Cut-off size of the top-k outcomes generated by FlowER after the expan-
154 | sion. This cutoff can filter out unlikely outcomes to be part of the selection.
155 | - **`sample_size`** - Number of samples FlowER generates for evaluation
156 | - **`max_depth`** - Refers to the maximum depth the beam search should explore.
157 | - **`chunk_size`** - Number of reactants sets to be run beam search concurrently.
158 |
159 |
160 |
161 | ## Citation
162 | ```bibtex
163 | @article{joung2025electron,
164 | title={Electron flow matching for generative reaction mechanism prediction obeying conservation laws},
165 | author={Joung, Joonyoung F and Fong, Mun Hong and Casetti, Nicholas and Liles, Jordan P and Dassanayake, Ne S and Coley, Connor W},
166 | journal={arXiv preprint arXiv:2502.12979},
167 | year={2025}
168 | }
169 | ```
170 |
171 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import datetime
5 | import logging
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.distributed as dist
10 | from model.attn_encoder import AttnEncoderXL
11 | from utils.data_utils import ReactionDataset
12 | from torch.utils.data import DataLoader
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
15 | from torch.utils.data.distributed import DistributedSampler
16 | from settings import Args
17 | from model.flow_matching import ConditionalFlowMatcher
18 | from utils.train_utils import get_lr, grad_norm, log_rank_0, NoamLR, \
19 | param_count, param_norm, set_seed, setup_logger, log_args
20 | from torch.nn.init import xavier_uniform_
21 | import torch.optim as optim
22 |
23 | torch.set_printoptions(precision=4, profile="full", sci_mode=False, linewidth=10000)
24 | np.set_printoptions(threshold=sys.maxsize, precision=4, suppress=True, linewidth=500)
25 |
26 | def init_dist(args):
27 | if args.local_rank != -1:
28 | dist.init_process_group(backend=args.backend,
29 | init_method='env://',
30 | timeout=datetime.timedelta(minutes=10))
31 | torch.cuda.set_device(args.local_rank)
32 | torch.backends.cudnn.benchmark = False
33 |
34 | if dist.is_initialized():
35 | logging.info(f"Device rank: {dist.get_rank()}")
36 | sys.stdout.flush()
37 |
38 |
39 | def init_model(args):
40 | state = {}
41 | if args.load_from:
42 | log_rank_0(f"Loading pretrained state from {args.load_from}")
43 | state = torch.load(args.load_from, map_location=torch.device("cpu"))
44 | pretrain_args = state["args"]
45 | pretrain_args.local_rank = args.local_rank
46 |
47 | graph_attn_model = AttnEncoderXL(pretrain_args)
48 | pretrain_state_dict = state["state_dict"]
49 | pretrain_state_dict = {k.replace("module.", ""): v for k, v in pretrain_state_dict.items()}
50 | graph_attn_model.load_state_dict(pretrain_state_dict)
51 | log_rank_0("Loaded pretrained model state_dict.")
52 | flow_model = ConditionalFlowMatcher(args)
53 | else:
54 | graph_attn_model = AttnEncoderXL(args)
55 | flow_model = ConditionalFlowMatcher(args)
56 | for p in graph_attn_model.parameters():
57 | if p.dim() > 1 and p.requires_grad:
58 | xavier_uniform_(p)
59 |
60 | graph_attn_model.to(args.device)
61 | flow_model.to(args.device)
62 | if args.local_rank != -1:
63 | graph_attn_model = DDP(
64 | graph_attn_model,
65 | device_ids=[args.local_rank],
66 | output_device=args.local_rank
67 | )
68 | log_rank_0("DDP setup finished")
69 |
70 | os.makedirs(args.model_path, exist_ok=True)
71 |
72 | return graph_attn_model, flow_model, state
73 |
74 | def init_loader(args, dataset, batch_size: int, bucket_size: int = 1000,
75 | shuffle: bool = False, epoch: int = None, use_sort: bool =True):
76 | if use_sort: dataset.sort()
77 | if shuffle: dataset.shuffle_in_bucket(bucket_size=bucket_size)
78 | dataset.batch(
79 | batch_type=args.batch_type,
80 | batch_size=batch_size
81 | )
82 |
83 | if args.local_rank != -1:
84 | sampler = DistributedSampler(dataset, shuffle=shuffle)
85 | if epoch is not None:
86 | sampler.set_epoch(epoch)
87 | else:
88 | sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
89 |
90 | loader = DataLoader(
91 | dataset=dataset,
92 | batch_size=1,
93 | sampler=sampler,
94 | num_workers=args.num_workers,
95 | collate_fn=lambda _batch: _batch[0],
96 | pin_memory=True
97 | )
98 |
99 | return loader
100 |
101 | def get_optimizer_and_scheduler(args, model, state=None):
102 | optimizer = optim.AdamW(
103 | model.parameters(),
104 | lr=args.lr,
105 | betas=(args.beta1, args.beta2),
106 | eps=args.eps,
107 | weight_decay=args.weight_decay
108 | )
109 | # scheduler = None
110 | scheduler = NoamLR(
111 | optimizer,
112 | model_size=args.emb_dim,
113 | warmup_steps=args.warmup_steps
114 | )
115 | # scheduler = optim.lr_scheduler.StepLR(
116 | # optimizer,
117 | # step_size=args.eval_iter, gamma=0.99
118 | # )
119 |
120 | if state and args.resume:
121 | optimizer.load_state_dict(state["optimizer"])
122 | scheduler.load_state_dict(state["scheduler"])
123 | log_rank_0("Loaded pretrained optimizer and scheduler state_dicts.")
124 |
125 | return optimizer, scheduler
126 |
127 | def _optimize(args, model, optimizer, scheduler):
128 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
129 | optimizer.step()
130 | scheduler.step()
131 | g_norm = grad_norm(model)
132 | model.zero_grad(set_to_none=True)
133 | return g_norm
134 |
135 | def main(args):
136 | args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
137 | device = args.device
138 |
139 | init_dist(args)
140 | log_args(args, 'training')
141 | model, flow, state = init_model(args)
142 | total_step = state["total_step"] if state else 0
143 | log_rank_0(f"Number of parameters: {param_count(model)}")
144 |
145 | optimizer, scheduler = get_optimizer_and_scheduler(args, model, state)
146 |
147 | log_rank_0(f"Initializing training ...")
148 | log_rank_0(f"Loading data ...")
149 | with open(args.train_path, 'r') as train_o:
150 | train_smiles_list = train_o.readlines()
151 | with open(args.val_path, 'r') as val_o:
152 | val_smiles_list = val_o.readlines()
153 |
154 | train_dataset = ReactionDataset(args, train_smiles_list)
155 | val_dataset = ReactionDataset(args, val_smiles_list)
156 |
157 | accum = 0
158 | g_norm = 0
159 | losses, accs = [], []
160 | o_start = time.time()
161 | log_rank_0("Start training")
162 |
163 | accuracy = []
164 | for epoch in range(args.epoch):
165 | log_rank_0(f"Epoch: {epoch}")
166 | train_loader = init_loader(args, train_dataset,
167 | batch_size=args.train_batch_size,
168 | shuffle=True,
169 | epoch=epoch)
170 | for train_batch in train_loader:
171 | if total_step > args.max_steps:
172 | log_rank_0("Max steps reached, finish training")
173 | exit(0)
174 |
175 | train_batch.to(device)
176 | model.train()
177 | model.zero_grad(set_to_none=True)
178 |
179 | y = train_batch.src_token_ids
180 | y_len = train_batch.src_lens
181 | x0 = train_batch.src_matrices
182 | x1 = train_batch.tgt_matrices
183 | matrix_masks = train_batch.matrix_masks
184 |
185 |
186 | x0_sample = flow.sample_be_matrix(x0)
187 |
188 | t = torch.rand(x0.shape[0]).type_as(x0)
189 |
190 | xt = flow.sample_conditional_pt(x0, x1, t)
191 | ut = flow.compute_conditional_vector_field(x0_sample, x1)
192 |
193 | if hasattr(model, "module"):
194 | model = model.module # unwrap DDP attn_model to enable accessing attn_model func directly
195 | y_emb = model.id2emb(y)
196 | vt = model(y_emb, y_len, xt, t)
197 |
198 | loss = (vt - ut) * matrix_masks
199 | loss = torch.sum((loss) ** 2) / loss.shape[0]
200 | (loss / args.accumulation_count).backward()
201 | losses.append(loss.item())
202 |
203 | accum += 1
204 | if accum == args.accumulation_count:
205 | g_norm = _optimize(args, model, optimizer, scheduler)
206 | accum = 0
207 | total_step += 1
208 |
209 | if (accum == 0) and (total_step > 0) and (total_step % args.log_iter == 0):
210 | log_rank_0(f"Step {total_step}, loss: {np.mean(losses): .4f}, "
211 | # f"acc: {np.mean(accs): .4f},
212 | f"p_norm: {param_norm(model): .4f}, g_norm: {g_norm: .4f}, "
213 | f"lr: {get_lr(optimizer): .6f}, "
214 | f"elapsed time: {time.time() - o_start: .0f}")
215 | losses, acc = [], []
216 |
217 | if (accum == 0) and (total_step > 0) and (total_step % args.eval_iter == 0):
218 | val_count = 50
219 | val_loader = init_loader(args, val_dataset,
220 | batch_size=args.val_batch_size,
221 | shuffle=True,
222 | epoch=epoch)
223 | from eval_multiGPU import get_predictions
224 | metrics = get_predictions(args, model, flow, val_loader, val_count)
225 | if dist.get_rank() == 0:
226 | metrics = np.array(metrics)
227 | log_rank_0(metrics.shape)
228 | topk_accuracies = np.mean(metrics[:, 0].astype(bool)) # correct smiles
229 | log_rank_0(f"Topk accuracies: {(topk_accuracies * 100): .2f}")
230 | model.train()
231 |
232 | # Important: saving only at one node or the ckpt would be corrupted!
233 | if dist.is_initialized() and dist.get_rank() > 0:
234 | continue
235 |
236 | if (accum == 0) and (total_step > 0) and (total_step % args.save_iter == 0):
237 | n_iter = total_step // args.save_iter - 1
238 | log_rank_0(f"Saving at step {total_step}")
239 | if scheduler is not None:
240 | state = {
241 | "args": args,
242 | "total_step": total_step,
243 | "state_dict": model.state_dict(),
244 | "optimizer": optimizer.state_dict(),
245 | "scheduler": scheduler.state_dict()
246 | }
247 | else:
248 | state = {
249 | "args": args,
250 | "total_step": total_step,
251 | "state_dict": model.state_dict(),
252 | "optimizer": optimizer.state_dict(),
253 | }
254 | torch.save(state, os.path.join(args.model_path, f"model.{total_step}_{n_iter}.pt"))
255 |
256 | # lastly
257 | if (args.accumulation_count > 1) and (accum > 0):
258 | _optimize(args, model, optimizer, scheduler)
259 | accum = 0
260 | # total_step += 1 # for partial batch, do not increase total_step
261 |
262 | if args.local_rank != -1:
263 | dist.barrier()
264 | log_rank_0("Epoch ended")
265 | if dist.is_initialized():
266 | dist.destroy_process_group()
267 |
268 | if __name__ == "__main__":
269 | args = Args
270 | logger = setup_logger(args, "train")
271 | args.local_rank = int(os.environ["LOCAL_RANK"]) if os.environ.get("LOCAL_RANK") else -1
272 | main(args)
273 |
--------------------------------------------------------------------------------
/beam_predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import torch
4 | import numpy as np
5 | from rdkit import Chem
6 | from utils.data_utils import ReactionDataset, BEmatrix_to_mol, ps
7 | import torch.distributed as dist
8 | from train import init_model, init_loader
9 | from utils.train_utils import log_rank_0, setup_logger, log_args
10 | from eval_multiGPU import custom_round
11 | from settings import Args
12 | from collections import defaultdict
13 | import networkx as nx
14 | import pickle
15 | from eval_multiGPU import predict_batch
16 |
17 | import warnings
18 | warnings.filterwarnings("ignore", category=FutureWarning)
19 |
20 | def standardize_smiles(mol):
21 | return Chem.MolToSmiles(mol, isomericSmiles=False, allHsExplicit=True)
22 |
23 | def select(args, frontiers_dict, graph_list):
24 | filtered_frontiers_dict = {}
25 | for g_idx, frontiers in frontiers_dict.items():
26 | graph, root, _ = graph_list[g_idx]
27 | rank_frontiers = {}
28 | for frontier in frontiers:
29 | min_sequences_rank = np.inf
30 | for path in nx.all_simple_paths(graph, root, frontier):
31 | max_depth = max(graph.nodes[root]['depth'], len(path))
32 | graph.nodes[root]['depth'] = max_depth
33 | edges = list(nx.utils.pairwise(path))
34 | ranks = [graph.get_edge_data(u, v)['rank'] for u, v in edges]
35 | probs = [graph.get_edge_data(u, v)['count'] / args.sample_size for u, v in edges]
36 | cum_prob = np.prod(probs)
37 | max_topk_within_one_seq = max(ranks)
38 | min_sequences_rank = min(max_topk_within_one_seq, min_sequences_rank)
39 |
40 | # rank_frontiers[frontier] = min_sequences_rank
41 | rank_frontiers[frontier] = -cum_prob
42 | rank_frontiers = sorted(rank_frontiers.items(), key=lambda x:x[1])[:args.beam_size]
43 | # leftover_frontiers = sorted(rank_frontiers.items(), key=lambda x:x[1])[args.beam_size:]
44 | # graph.remove_nodes_from([frontier for frontier, prob in leftover_frontiers])
45 |
46 | filtered_frontiers_dict[g_idx] = list(dict(rank_frontiers).keys())
47 | return filtered_frontiers_dict
48 |
49 |
50 | def expand(args, model, flow, data_loader):
51 | sample_size = args.sample_size
52 |
53 | overall_dict = {}
54 | for batch_idx, data_batch in enumerate(data_loader):
55 | # print(data_batch.src_matrices.shape)
56 | data_batch.to(args.device)
57 | src_data_indices = data_batch.src_data_indices
58 | y = data_batch.src_token_ids
59 | y_len = data_batch.src_lens
60 | x0 = data_batch.src_matrices
61 | matrix_masks = data_batch.matrix_masks
62 | src_smiles_list = data_batch.src_smiles_list
63 |
64 | batch_size, n, n = x0.shape
65 |
66 | if (batch_size*n*n) <= 5*360*360:
67 | traj_list = predict_batch(args, batch_idx, data_batch, model, flow, 1)
68 | else:
69 | traj_list = predict_batch(args, batch_idx, data_batch, model, flow, 2)
70 |
71 |
72 | last_step = traj_list[-1]
73 | product_BE_matrices = custom_round(last_step)
74 | product_BE_matrices_batch = torch.split(product_BE_matrices, sample_size)
75 |
76 | for idx in range(batch_size):
77 | reac_smi, product_BE_matrices = \
78 | src_smiles_list[idx], product_BE_matrices_batch[idx]
79 |
80 | reac_mol = Chem.MolFromSmiles(reac_smi, ps)
81 | matrices, counts = torch.unique(product_BE_matrices, dim=0, return_counts=True)
82 | matrices, counts = matrices.cpu().numpy(), counts.cpu().numpy()
83 |
84 | pred_smis_dict = defaultdict(int)
85 | for i in range(matrices.shape[0]): # all unique matrices
86 | pred_prod_be_matrix, count = matrices[i], counts[i] # predicted product matrix and it's count
87 | num_nodes = y_len[idx]
88 | pred_prod_be_matrix = pred_prod_be_matrix[:num_nodes, :num_nodes]
89 | reac_be_matrix = x0[idx][:num_nodes, :num_nodes].detach().cpu().numpy()
90 |
91 | assert pred_prod_be_matrix.shape == reac_be_matrix.shape, "pred and reac not the same shape"
92 |
93 | try:
94 | pred_mol = BEmatrix_to_mol(reac_mol, pred_prod_be_matrix)
95 | pred_smi = standardize_smiles(pred_mol)
96 | pred_mol = Chem.MolFromSmiles(pred_smi, ps)
97 | pred_smi = standardize_smiles(pred_mol)
98 | pred_smis_dict[pred_smi] += count
99 | except: pass
100 |
101 | pred_smis_tuples = sorted(pred_smis_dict.items(), key=lambda x: x[1], reverse=True)
102 |
103 | pred_smis_dict = dict(pred_smis_tuples[:args.nbest])
104 | overall_dict[reac_smi] = pred_smis_dict
105 |
106 | return overall_dict
107 |
108 | def reactant_process(smi):
109 | try:
110 | mol = Chem.MolFromSmiles(smi)
111 | mol = Chem.AddHs(mol, explicitOnly=False)
112 | for idx, atom in enumerate(mol.GetAtoms()):
113 | atom.SetAtomMapNum(idx+1)
114 | # src_smi = reactant_process(src_smi)
115 | # print(src_smi)
116 | return Chem.MolToSmiles(mol, isomericSmiles=False, allHsExplicit=True)
117 | except:
118 | print(smi)
119 | raise
120 |
121 | def clean(smi):
122 | # try:
123 | mol = Chem.MolFromSmiles(smi, sanitize=False)
124 | mol = Chem.RemoveHs(mol)
125 | [atom.SetAtomMapNum(0) for atom in mol.GetAtoms()]
126 | return Chem.MolToSmiles(mol, isomericSmiles=False)
127 |
128 | def beam_search(args, model, flow, frontiers_dict, graph_list):
129 | smiles_list = [frontier for frontiers in frontiers_dict.values() for frontier in frontiers]
130 | # print('frontiers', smiles_list)
131 | # print()
132 | if len(smiles_list) == 0: return
133 | log_rank_0(f"Current Depth: {[graph.nodes[root]['depth'] for graph, root, _ in graph_list]}")
134 | exclude_gidx = [g_idx for g_idx, (graph, root, _) in enumerate(graph_list)
135 | if graph.nodes[root]['depth'] >= args.max_depth]
136 |
137 | test_dataset = ReactionDataset(args, smiles_list, reactant_only=True)
138 | try:
139 | test_loader = init_loader(args, test_dataset,
140 | batch_size=args.test_batch_size,
141 | shuffle=False, epoch=None, use_sort=False)
142 | except Exception as e:
143 | print(e)
144 | return
145 |
146 | overall_dict = expand(args, model, flow, test_loader)
147 | new_frontiers_dict = defaultdict(list)
148 |
149 | existing_reactions = {g_idx: {} for g_idx in frontiers_dict.keys()}
150 | for g_idx, frontiers in frontiers_dict.items():
151 | if g_idx in exclude_gidx: continue
152 | existing_reaction = existing_reactions[g_idx]
153 | graph, _, _ = graph_list[g_idx]
154 | for frontier in frontiers:
155 | clean_frontier = clean(frontier) # ---
156 | try: product_info_dict = overall_dict[frontier] # given reactant, product info
157 | except: continue
158 | for rank, (product, count) in enumerate(product_info_dict.items()):
159 | try: clean_product = clean(product) # --
160 | except: continue
161 | if (clean_frontier, clean_product) in existing_reaction:
162 | stored_frontier, stored_product = existing_reaction[(clean_frontier, clean_product)]
163 | parent_current = list(graph.predecessors(frontier))
164 | parent_stored = list(graph.predecessors(stored_frontier))
165 |
166 | if parent_current == parent_stored:
167 | graph[stored_frontier][stored_product]["count"] += count
168 |
169 | else:
170 | if not graph.has_node(product):
171 | new_frontiers_dict[g_idx].append(product)
172 | graph.add_edge(frontier, product, rank=rank, count=count)
173 | existing_reaction[(clean_frontier, clean_product)] = (frontier, product)
174 |
175 |
176 | filtered_frontiers_dict = select(args, new_frontiers_dict, graph_list)
177 | beam_search(args, model, flow, filtered_frontiers_dict, graph_list)
178 |
179 | def check_if_successful(graph, products):
180 | nodes_with_loops = list(nx.nodes_with_selfloops(graph))
181 | achieved_products = set()
182 | for node in graph.nodes():
183 | node_in_products = set(clean(node).split('.')) & set(products)
184 | if node_in_products and node in nodes_with_loops:
185 | achieved_products.update(node_in_products)
186 | return achieved_products
187 |
188 | def main(args, seed=0):
189 | args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
190 | device = args.device
191 | if args.local_rank != -1:
192 | dist.init_process_group(backend=args.backend, init_method='env://', timeout=datetime.timedelta(0, 7200))
193 | torch.cuda.set_device(args.local_rank)
194 | torch.backends.cudnn.benchmark = True
195 |
196 | with open(args.test_path, 'r') as test_o:
197 | test_smiles_list = test_o.readlines()
198 |
199 | chunk_size = args.chunk_size
200 | chunked_list = [test_smiles_list[i:i + chunk_size] for i in range(0, len(test_smiles_list), chunk_size)]
201 |
202 | for i, chunk in enumerate(chunked_list):
203 | log_rank_0(f"Group Chunk-{i} called:")
204 | checkpoint = os.path.join(args.model_path, args.model_name)
205 | state = torch.load(checkpoint, weights_only=False, map_location=device)
206 | pretrain_args = state["args"]
207 | pretrain_args.load_from = None
208 | pretrain_args.device = device
209 |
210 | pretrain_state_dict = state["state_dict"]
211 | pretrain_args.local_rank = args.local_rank
212 |
213 | attn_model, flow, state = init_model(pretrain_args)
214 | if hasattr(attn_model, "module"):
215 | attn_model = attn_model.module # unwrap DDP attn_model to enable accessing attn_model func directly
216 |
217 | pretrain_state_dict = {k.replace("module.", ""): v for k, v in pretrain_state_dict.items()}
218 | attn_model.load_state_dict(pretrain_state_dict)
219 | log_rank_0(f"Loaded pretrained state_dict from {checkpoint}")
220 |
221 | graph_list = []
222 | frontiers_dict = defaultdict(list)
223 | for idx, line in enumerate(chunk):
224 | if ">>" in line:
225 | ori_reactant = line.strip().split(">>")[0]
226 | products = line.strip().split(">>")[1].split("|") # major products
227 | products = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in products]
228 | else:
229 | ori_reactant = line.strip()
230 | products = []
231 | reactant = reactant_process(ori_reactant)
232 | graph = nx.DiGraph()
233 | graph.add_node(reactant, depth=1)
234 | graph_list.append((graph, reactant, (ori_reactant, products)))
235 | frontiers_dict[idx].append(reactant)
236 |
237 | beam_search(args, attn_model, flow, frontiers_dict, graph_list)
238 |
239 | all_results = []
240 | for beam_idx, (graph, root, (reactant, products)) in enumerate(graph_list):
241 | # print(output_chunk_idx, reaction)
242 | check = check_if_successful(graph, products)
243 | log_rank_0(f"Beam Search Results {beam_idx}: {len(check)}/{len(products)} - {check}")
244 | all_results.append((graph, root, (reactant, products), check))
245 |
246 | os.makedirs(args.result_path, exist_ok=True)
247 | saving_file = os.path.join(args.result_path, f'result_chunk_{i}_s{seed}.pickle')
248 | with open(saving_file, "wb") as f_out:
249 | pickle.dump(all_results, f_out)
250 |
251 |
252 | if __name__ == "__main__":
253 | args = Args
254 | args.local_rank = int(os.environ["LOCAL_RANK"]) if os.environ.get("LOCAL_RANK") else -1
255 | logger = setup_logger(args, "beam")
256 | log_args(args, 'evaluation')
257 | main(args)
258 |
--------------------------------------------------------------------------------
/beam_predict_multiGPU.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from rdkit import Chem
5 | from utils.data_utils import ReactionDataset, BEmatrix_to_mol, ps
6 | import torch.distributed as dist
7 | from train import init_model, init_loader
8 | from utils.train_utils import log_rank_0, setup_logger, log_args
9 | from eval_multiGPU import custom_round
10 | from settings import Args
11 | from collections import defaultdict
12 | import networkx as nx
13 | import pickle
14 | import torch.multiprocessing as mp
15 | import time
16 | from eval_multiGPU import predict_batch
17 |
18 | import warnings
19 | warnings.filterwarnings("ignore", category=FutureWarning)
20 |
21 | def standardize_smiles(mol):
22 | return Chem.MolToSmiles(mol, isomericSmiles=False, allHsExplicit=True)
23 |
24 | def select(args, frontiers_dict, graph_list):
25 | filtered_frontiers_dict = {}
26 | for g_idx, frontiers in frontiers_dict.items():
27 | graph, root, _ = graph_list[g_idx]
28 | rank_frontiers = {}
29 | for frontier in frontiers:
30 | min_sequences_rank = np.inf
31 | for path in nx.all_simple_paths(graph, root, frontier):
32 | max_depth = max(graph.nodes[root]['depth'], len(path))
33 | graph.nodes[root]['depth'] = max_depth
34 | edges = list(nx.utils.pairwise(path))
35 | ranks = [graph.get_edge_data(u, v)['rank'] for u, v in edges]
36 | probs = [graph.get_edge_data(u, v)['count'] / args.sample_size for u, v in edges]
37 | cum_prob = np.prod(probs)
38 | max_topk_within_one_seq = max(ranks)
39 | min_sequences_rank = min(max_topk_within_one_seq, min_sequences_rank)
40 |
41 | # rank_frontiers[frontier] = min_sequences_rank
42 | rank_frontiers[frontier] = -cum_prob
43 | rank_frontiers = sorted(rank_frontiers.items(), key=lambda x:x[1])[:args.beam_size]
44 | # leftover_frontiers = sorted(rank_frontiers.items(), key=lambda x:x[1])[args.beam_size:]
45 | # graph.remove_nodes_from([frontier for frontier, prob in leftover_frontiers])
46 |
47 | filtered_frontiers_dict[g_idx] = list(dict(rank_frontiers).keys())
48 | return filtered_frontiers_dict
49 |
50 |
51 | def expand(args, model, flow, data_loader):
52 | sample_size = args.sample_size
53 |
54 | overall_dict = {}
55 | for batch_idx, data_batch in enumerate(data_loader):
56 | data_batch.to(args.device)
57 | src_data_indices = data_batch.src_data_indices
58 | y = data_batch.src_token_ids
59 | y_len = data_batch.src_lens
60 | x0 = data_batch.src_matrices
61 | matrix_masks = data_batch.matrix_masks
62 | src_smiles_list = data_batch.src_smiles_list
63 |
64 | batch_size, n, n = x0.shape
65 |
66 | if (batch_size*n*n) <= 5*360*360:
67 | traj_list = predict_batch(args, batch_idx, data_batch, model, flow, 1)
68 | else:
69 | traj_list = predict_batch(args, batch_idx, data_batch, model, flow, 2)
70 |
71 |
72 | last_step = traj_list[-1]
73 | product_BE_matrices = custom_round(last_step)
74 | product_BE_matrices_batch = torch.split(product_BE_matrices, sample_size)
75 |
76 | for idx in range(batch_size):
77 | reac_smi, product_BE_matrices = \
78 | src_smiles_list[idx], product_BE_matrices_batch[idx]
79 |
80 | reac_mol = Chem.MolFromSmiles(reac_smi, ps)
81 | matrices, counts = torch.unique(product_BE_matrices, dim=0, return_counts=True)
82 | matrices, counts = matrices.cpu().numpy(), counts.cpu().numpy()
83 |
84 | pred_smis_dict = defaultdict(int)
85 | for i in range(matrices.shape[0]): # all unique matrices
86 | pred_prod_be_matrix, count = matrices[i], counts[i] # predicted product matrix and it's count
87 | num_nodes = y_len[idx]
88 | pred_prod_be_matrix = pred_prod_be_matrix[:num_nodes, :num_nodes]
89 | reac_be_matrix = x0[idx][:num_nodes, :num_nodes].detach().cpu().numpy()
90 |
91 | assert pred_prod_be_matrix.shape == reac_be_matrix.shape, "pred and reac not the same shape"
92 |
93 | try:
94 | pred_mol = BEmatrix_to_mol(reac_mol, pred_prod_be_matrix)
95 | pred_smi = standardize_smiles(pred_mol)
96 | pred_mol = Chem.MolFromSmiles(pred_smi, ps)
97 | pred_smi = standardize_smiles(pred_mol)
98 | pred_smis_dict[pred_smi] += count
99 | except: pass
100 |
101 | pred_smis_tuples = sorted(pred_smis_dict.items(), key=lambda x: x[1], reverse=True)
102 |
103 | pred_smis_dict = dict(pred_smis_tuples[:args.nbest])
104 | overall_dict[reac_smi] = pred_smis_dict
105 |
106 | return overall_dict
107 |
108 | def reactant_process(smi):
109 | try:
110 | mol = Chem.MolFromSmiles(smi)
111 | mol = Chem.AddHs(mol, explicitOnly=False)
112 | for idx, atom in enumerate(mol.GetAtoms()):
113 | atom.SetAtomMapNum(idx+1)
114 | return Chem.MolToSmiles(mol, isomericSmiles=False, allHsExplicit=True)
115 | except:
116 | print(smi)
117 | raise
118 |
119 | def clean(smi):
120 | # try:
121 | mol = Chem.MolFromSmiles(smi, sanitize=False)
122 | mol = Chem.RemoveHs(mol)
123 | [atom.SetAtomMapNum(0) for atom in mol.GetAtoms()]
124 | return Chem.MolToSmiles(mol, isomericSmiles=False)
125 |
126 | def beam_search(args, model, flow, frontiers_dict, graph_list):
127 | smiles_list = [frontier for frontiers in frontiers_dict.values() for frontier in frontiers]
128 | # print('frontiers', smiles_list)
129 | # print()
130 | if len(smiles_list) == 0: return
131 | print(f"Current Depth: {[graph.nodes[root]['depth'] for graph, root, _ in graph_list]}")
132 | exclude_gidx = [g_idx for g_idx, (graph, root, _) in enumerate(graph_list)
133 | if graph.nodes[root]['depth'] >= args.max_depth]
134 |
135 | test_dataset = ReactionDataset(args, smiles_list, reactant_only=True)
136 | try:
137 | test_loader = init_loader(args, test_dataset,
138 | batch_size=args.test_batch_size,
139 | shuffle=False, epoch=None, use_sort=False)
140 | except Exception as e:
141 | print(e)
142 | return
143 |
144 | overall_dict = expand(args, model, flow, test_loader)
145 | new_frontiers_dict = defaultdict(list)
146 |
147 | existing_reactions = {g_idx: {} for g_idx in frontiers_dict.keys()}
148 | for g_idx, frontiers in frontiers_dict.items():
149 | if g_idx in exclude_gidx: continue
150 | existing_reaction = existing_reactions[g_idx]
151 | graph, _, _ = graph_list[g_idx]
152 | for frontier in frontiers:
153 | clean_frontier = clean(frontier) # ---
154 | try: product_info_dict = overall_dict[frontier] # given reactant, product info
155 | except: continue
156 | for rank, (product, count) in enumerate(product_info_dict.items()):
157 | try: clean_product = clean(product) # --
158 | except: continue
159 | if (clean_frontier, clean_product) in existing_reaction:
160 | stored_frontier, stored_product = existing_reaction[(clean_frontier, clean_product)]
161 | parent_current = list(graph.predecessors(frontier))
162 | parent_stored = list(graph.predecessors(stored_frontier))
163 |
164 | if parent_current == parent_stored:
165 | graph[stored_frontier][stored_product]["count"] += count
166 |
167 | else:
168 | if not graph.has_node(product):
169 | new_frontiers_dict[g_idx].append(product)
170 | graph.add_edge(frontier, product, rank=rank, count=count)
171 | existing_reaction[(clean_frontier, clean_product)] = (frontier, product)
172 |
173 |
174 | filtered_frontiers_dict = select(args, new_frontiers_dict, graph_list)
175 | beam_search(args, model, flow, filtered_frontiers_dict, graph_list)
176 |
177 |
178 | def group_lists(lists, group_size):
179 | result = []
180 | # Process lists in chunks of group_size
181 | for i in range(0, len(lists), group_size):
182 | # Take a slice of size group_size (or remaining elements if less)
183 | chunk = lists[i:i + group_size]
184 | # Convert the chunk to a tuple and add to result
185 | result.append(tuple(chunk))
186 |
187 | return result
188 |
189 | import signal
190 | import os
191 |
192 | def init_process_killer():
193 | global processes
194 | processes = []
195 |
196 | def signal_handler(sig, frame):
197 | print('\nTerminating all processes...')
198 | for p in processes:
199 | if p.is_alive():
200 | p.terminate()
201 | p.join() # Wait for process to finish
202 | os._exit(0)
203 |
204 | signal.signal(signal.SIGINT, signal_handler)
205 | return processes
206 |
207 | def worker(rank, args, chunk, chunk_idx, lock, queue):
208 | """Worker function that runs on each GPU"""
209 | # Set random seeds for reproducibility
210 |
211 | # Set device for this process
212 | torch.cuda.set_device(rank)
213 | device = torch.device(f'cuda:{rank}')
214 | args.device = device
215 | args.local_rank = -1 # Disable distributed training
216 |
217 | # Load model for this process
218 | checkpoint = os.path.join(args.model_path, args.model_name)
219 | state = torch.load(checkpoint, weights_only=False, map_location=device)
220 | pretrain_args = state["args"]
221 | pretrain_args.load_from = None
222 | pretrain_args.device = device
223 | pretrain_args.local_rank = -1 # Disable distributed training
224 |
225 | # Initialize model without DDP
226 | pretrain_state_dict = state["state_dict"]
227 | attn_model, flow, _ = init_model(pretrain_args)
228 |
229 | # Remove DDP wrapper if present
230 | if hasattr(attn_model, "module"):
231 | attn_model = attn_model.module
232 |
233 | pretrain_state_dict = {k.replace("module.", ""): v for k, v in pretrain_state_dict.items()}
234 | attn_model.load_state_dict(pretrain_state_dict)
235 |
236 | # print(f"GPU {rank} starting processing {len(chunk)} items")
237 |
238 | # Process chunk
239 | graph_list = []
240 | frontiers_dict = defaultdict(list)
241 | for idx, line in enumerate(chunk):
242 | if ">>" in line:
243 | ori_reactant = line.strip().split(">>")[0]
244 | products = line.strip().split(">>")[1].split("|")
245 | products = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in products]
246 | else:
247 | ori_reactant = line.strip()
248 | products = []
249 | reactant = reactant_process(ori_reactant)
250 | graph = nx.DiGraph()
251 | graph.add_node(reactant, depth=1)
252 | graph_list.append((graph, reactant, (ori_reactant, products)))
253 | frontiers_dict[idx].append(reactant)
254 |
255 | beam_search(args, attn_model, flow, frontiers_dict, graph_list)
256 |
257 | lock.acquire()
258 | try:
259 | queue.put((rank, chunk_idx, graph_list))
260 | finally:
261 | lock.release()
262 |
263 | # print(f"GPU {rank} finished processing")
264 |
265 | def check_if_successful(graph, products):
266 | nodes_with_loops = list(nx.nodes_with_selfloops(graph))
267 | achieved_products = set()
268 | for node in graph.nodes():
269 | node_in_products = set(clean(node).split('.')) & set(products)
270 | if node_in_products and node in nodes_with_loops:
271 | achieved_products.update(node_in_products)
272 | return achieved_products
273 |
274 | def main_multi_gpu(args):
275 | start = time.time()
276 | global processes
277 | processes = init_process_killer()
278 |
279 | # Get number of available GPUs
280 | world_size = torch.cuda.device_count()
281 | log_rank_0(f"Found {world_size} GPUs")
282 |
283 | # Read all test smiles
284 | with open(args.test_path, 'r') as test_o:
285 | test_smiles_list = test_o.readlines()
286 |
287 | # Calculate chunk size and create chunks
288 | # chunk_size = math.ceil(len(test_smiles_list) / world_size)
289 | chunk_size = args.chunk_size // world_size
290 | chunks = [test_smiles_list[i:i + chunk_size] for i in range(0, len(test_smiles_list), chunk_size)]
291 |
292 | group_chunks = group_lists(chunks, world_size)
293 | log_rank_0(f"Number of group chunks: {len(group_chunks)}")
294 |
295 | os.makedirs(args.result_path, exist_ok=True)
296 | # Start processes
297 | lock = mp.Lock()
298 | q = mp.Queue()
299 | chunk_idx = 0
300 | for group_chunk_id, group_chunk in enumerate(group_chunks):
301 | log_rank_0(f"Group Chunk-{group_chunk_id} called:")
302 | all_results = []
303 | processes = []
304 | for gpu_idx, chunk in enumerate(group_chunk):
305 | p = mp.Process(target=worker, args=(gpu_idx, args, chunk, chunk_idx, lock, q))
306 | p.start()
307 | processes.append(p)
308 | time.sleep(1) # Add small delay between process starts
309 | chunk_idx += 1
310 |
311 | outputs = []
312 | for _ in processes:
313 | output = q.get(timeout=1000)
314 | if output is None: continue
315 | outputs.append(output)
316 | outputs = sorted(outputs, key=lambda x:x[0])
317 | for output in outputs:
318 | _, output_chunk_idx, graph_list = output
319 | for beam_idx, (graph, root, (reactant, products)) in enumerate(graph_list):
320 | check = check_if_successful(graph, products)
321 | log_rank_0(f"Beam Search Results {beam_idx}: {len(check)}/{len(products)} - {check}")
322 | all_results.append((graph, root, (reactant, products), check))
323 |
324 | # Wait for all processes to complete
325 | for p in processes:
326 | p.join()
327 |
328 | saving_file = os.path.join(args.result_path, f'result_chunk_{group_chunk_id}.pickle')
329 | with open(saving_file, "wb") as f_out:
330 | pickle.dump(all_results, f_out)
331 |
332 | log_rank_0(f"---- Time used: {(time.time() - start):.2f}s ----")
333 |
334 |
335 | log_rank_0("Done!")
336 |
337 | if __name__ == "__main__":
338 | # Ensure clean startup
339 | mp.set_start_method('spawn')
340 |
341 | args = Args
342 | logger = setup_logger(args, "beam")
343 | log_args(args, 'evaluation')
344 | main_multi_gpu(args)
--------------------------------------------------------------------------------
/eval_multiGPU.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import glob
4 | import datetime
5 | import torch
6 | import numpy as np
7 | from rdkit import Chem
8 | import torchdiffeq
9 | from utils.data_utils import ReactionDataset, BEmatrix_to_mol, ps
10 | from utils.rounding import saferound_tensor
11 | import torch.distributed as dist
12 | from train import init_model, init_loader
13 | from utils.train_utils import log_rank_0, setup_logger, log_args
14 | from settings import Args
15 | from collections import defaultdict
16 | import time
17 | import iteround
18 |
19 | ps = Chem.SmilesParserParams()
20 | ps.removeHs = False
21 | ps.sanitize = True
22 |
23 | def is_sym(a):
24 | return (a.transpose(1, 0) == a).all()
25 |
26 | def redist_fix(pred_matrix, reac_smi, reac_be_matrix):
27 | pred_electron_sum = np.zeros([len(pred_matrix)])
28 | for i in range(len(pred_matrix)):
29 | pred_electron_sum[i] = \
30 | np.sum(pred_matrix[i, :]) + np.sum(pred_matrix[:, i]) - pred_matrix[i, i]
31 |
32 | reac_electron_sum = np.zeros([len(reac_be_matrix)])
33 | for i in range(len(reac_be_matrix)):
34 | reac_electron_sum[i] = \
35 | np.sum(reac_be_matrix[i, :]) + np.sum(reac_be_matrix[:, i]) - reac_be_matrix[i, i]
36 |
37 | diff = reac_electron_sum - pred_electron_sum
38 |
39 | if np.sum(diff) == 0:
40 | pred_matrix[np.diag_indices_from(pred_matrix)] += diff
41 |
42 | return pred_matrix
43 |
44 | # # old implementation uses CPU
45 | # def redistribute_round(x):
46 | # rounded_diff = iteround.saferound(x.flatten().cpu().numpy().tolist(), 0)
47 | # rounded_diff = torch.as_tensor(rounded_diff, dtype=torch.float).view(*x.shape)
48 | # return rounded_diff.to(x)
49 |
50 | # new implementation uses GPU
51 | def redistribute_round(x):
52 | rounded = saferound_tensor(x, places=0, strategy="difference")
53 | return rounded
54 |
55 | def custom_round(x):
56 | output = []
57 | for i in range(x.shape[0]):
58 | try: output.append(redistribute_round(x[i]))
59 | except: output.append(torch.round(x[i]))
60 | return torch.stack(output)
61 |
62 | def standardize_smiles(mol):
63 | [a.SetAtomMapNum(0) for a in mol.GetAtoms()]
64 | return Chem.MolToSmiles(mol, isomericSmiles=False, allHsExplicit=True)
65 |
66 | def split_number(number, num_parts):
67 | if number % num_parts != 0:
68 | raise ValueError("The number cannot be evenly divided into the specified number of parts.")
69 | return [number // num_parts] * num_parts
70 |
71 | start = time.time()
72 | def predict_batch(args, batch_idx, data_batch, model, flow, split, rand_matrix=None):
73 | src_data_indices = data_batch.src_data_indices
74 | y = data_batch.src_token_ids
75 | y_len = data_batch.src_lens
76 | x0 = data_batch.src_matrices
77 | # x1 = data_batch.tgt_matrices
78 | matrix_masks = data_batch.matrix_masks
79 |
80 | batch_size, n, n = x0.shape
81 |
82 | log_rank_0(f"Batch idx: {batch_idx}, batch_shape {batch_size, n, n} {(time.time() - start): .2f}s")
83 | # --------ODE inference--------------#
84 | SAMPLE_BATCH = args.sample_size
85 | # split_sample_batches = split_number(SAMPLE_BATCH, 2) if n >= 400 else split_number(SAMPLE_BATCH, 1)
86 | # split_sample_batches = split_number(SAMPLE_BATCH, 1)
87 | split_sample_batches = split_number(SAMPLE_BATCH, split)
88 |
89 | big_traj_list = []
90 | for sample_size in split_sample_batches:
91 | src_data_indices = src_data_indices.repeat_interleave(sample_size, dim=0)
92 | x0_repeated = x0.repeat_interleave(sample_size, dim=0)
93 | x0_sample_repeated = flow.sample_be_matrix(x0_repeated)
94 |
95 | matrix_masks_repeated = matrix_masks.repeat_interleave(sample_size, dim=0)
96 | x0_sample_repeated = x0_sample_repeated.masked_fill(~(matrix_masks_repeated.bool()), 0) # ode initial step has RMS norm thus padding nan has to be swap to 0
97 |
98 | del matrix_masks_repeated
99 | torch.cuda.empty_cache()
100 |
101 | y_repeated = y.repeat_interleave(sample_size, dim=0)
102 | y_emb_repeated = model.id2emb(y_repeated)
103 | y_len_batch_repeated = y_len.repeat_interleave(sample_size, dim=0)
104 |
105 | traj_list = torchdiffeq.odeint_adjoint(
106 | lambda t, x: model.forward(y_emb_repeated, y_len_batch_repeated, x, t),
107 | x0_sample_repeated,
108 | torch.linspace(0, 1, 2).to(args.device),
109 | atol=1e-4,
110 | rtol=1e-4,
111 | method="dopri5",
112 | adjoint_params=()
113 | )
114 | big_traj_list.append((traj_list.transpose(0, 1).detach().cpu(), sample_size))
115 |
116 | # merging
117 | all_traj_list = []
118 | for bs in range(batch_size):
119 | for traj_list, sample_size in big_traj_list:
120 | all_traj_list.append(traj_list[bs*sample_size:(bs+1)*sample_size].transpose(0, 1))
121 | traj_list = torch.concat(all_traj_list, dim=1) # concat on sampling dimension
122 | # ------------------------------------#
123 | return traj_list
124 |
125 | def get_predictions(args, model, flow, data_loader, iter_count=np.inf, write_o=None):
126 | accuracy = []
127 | model.eval()
128 | with torch.no_grad():
129 | log_rank_0('Start ODE Prediction...')
130 | if dist.get_rank() == 0:
131 | inferenced_indexes = set()
132 |
133 | for batch_idx, data_batch in enumerate(data_loader):
134 | if batch_idx >= iter_count: break
135 | data_batch.to(args.device)
136 |
137 | src_data_indices = data_batch.src_data_indices
138 | x0 = data_batch.src_matrices
139 | y_len = data_batch.src_lens
140 | batch_size, n, n = x0.shape
141 | src_smiles_list = data_batch.src_smiles_list
142 | tgt_smiles_list = data_batch.tgt_smiles_list
143 |
144 |
145 | # if (batch_size*n*n) <= 5*360*360:
146 | if (batch_size*n*n) <= 15*130*130:
147 | traj_list = predict_batch(args, batch_idx, data_batch, model, flow, 1)
148 | else:
149 | traj_list = predict_batch(args, batch_idx, data_batch, model, flow, 2)
150 |
151 | if torch.distributed.is_initialized() and dist.get_world_size() > 1:
152 | gathered_results = [None for _ in range(dist.get_world_size())]
153 | dist.gather_object(
154 | (src_data_indices, traj_list, x0, y_len, src_smiles_list, tgt_smiles_list),
155 | gathered_results if dist.get_rank() == 0 else None,
156 | dst=0
157 | )
158 | else:
159 | gathered_results = [(src_data_indices, traj_list, x0, y_len, src_smiles_list, tgt_smiles_list)]
160 |
161 | if dist.get_rank() > 0:
162 | continue
163 |
164 | for result in gathered_results:
165 | src_data_indices, traj_list, x0, y_len, src_smiles_list, tgt_smiles_list = result
166 | batch_size, n, n = x0.shape
167 |
168 | last_step = traj_list[-1]
169 | product_BE_matrices = custom_round(last_step)
170 | product_BE_matrices_batch = torch.split(product_BE_matrices, args.sample_size)
171 |
172 | for idx in range(batch_size):
173 | reac_smi, product_smi, product_BE_matrices = \
174 | src_smiles_list[idx], tgt_smiles_list[idx], product_BE_matrices_batch[idx]
175 |
176 | data_idx = int(src_data_indices[idx].detach().cpu())
177 | if data_idx in inferenced_indexes: continue
178 | else: inferenced_indexes.add(data_idx)
179 |
180 | reac_mol = Chem.MolFromSmiles(reac_smi, ps)
181 | prod_mol = Chem.MolFromSmiles(product_smi, ps)
182 |
183 | tgt_smiles = standardize_smiles(prod_mol)
184 |
185 | matrices, counts = torch.unique(product_BE_matrices, dim=0, return_counts=True)
186 | matrices, counts = matrices.cpu().numpy(), counts.cpu().numpy()
187 |
188 | not_sym = 0
189 |
190 | correct = wrong_smi_conserved = wrong_smi_non_conserved = 0
191 | no_smi_conserved = no_smi_non_conserved = 0
192 |
193 | pred_smi_dict = defaultdict(int)
194 | pred_conserved_dict = defaultdict(bool)
195 | # Evaluation on unique predicted BE matrices
196 | for i in range(matrices.shape[0]):
197 | pred_prod_be_matrix, count = matrices[i], counts[i] # predicted product matrix and it's count
198 | num_nodes = y_len[idx]
199 | pred_prod_be_matrix = pred_prod_be_matrix[:num_nodes, :num_nodes]
200 | reac_be_matrix = x0[idx][:num_nodes, :num_nodes].detach().cpu().numpy()
201 |
202 | # print(f"Matrix{i} - {count}")
203 | pred_prod_be_matrix = redist_fix(pred_prod_be_matrix, reac_smi, reac_be_matrix)
204 |
205 | assert pred_prod_be_matrix.shape == reac_be_matrix.shape, "pred and reac not the same shape"
206 |
207 | if not is_sym(pred_prod_be_matrix):
208 | not_sym += 1
209 |
210 | try:
211 | pred_mol = BEmatrix_to_mol(reac_mol, pred_prod_be_matrix)
212 | pred_smi = standardize_smiles(pred_mol)
213 |
214 | pred_mol = Chem.MolFromSmiles(pred_smi, ps)
215 | pred_smi = standardize_smiles(pred_mol)
216 | tgt_mol = Chem.MolFromSmiles(tgt_smiles, ps)
217 | tgt_smiles = standardize_smiles(tgt_mol)
218 |
219 |
220 | if pred_smi == tgt_smiles and pred_prod_be_matrix.sum() == reac_be_matrix.sum():
221 | correct += count
222 | pred_smi_dict[pred_smi] += count
223 | pred_conserved_dict[pred_smi] = True
224 | elif pred_prod_be_matrix.sum() == reac_be_matrix.sum(): # conserve electron, gives wrong smiles
225 | wrong_smi_conserved += count
226 | pred_smi_dict[pred_smi] += count
227 | pred_conserved_dict[pred_smi] = True
228 | else: # Gives SMILES but does not conserve electron
229 | wrong_smi_non_conserved += count ########### This is added metric
230 | except:
231 | if pred_prod_be_matrix.sum() == reac_be_matrix.sum():
232 | no_smi_conserved += count
233 | else:
234 | no_smi_non_conserved += count
235 |
236 | metric = [correct, wrong_smi_conserved, wrong_smi_non_conserved, no_smi_conserved, no_smi_non_conserved]
237 | predictions = [(smi, pred_smi_dict[smi], pred_conserved_dict[smi]) for smi in pred_smi_dict]
238 | if write_o is not None:
239 | write_o.write(f"{metric}|{not_sym}|{predictions}\n")
240 | write_o.flush()
241 | accuracy.append(metric)
242 |
243 | return accuracy
244 |
245 |
246 | def main(args):
247 | args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
248 | device = args.device
249 | if args.local_rank != -1:
250 | dist.init_process_group(backend=args.backend, init_method='env://', timeout=datetime.timedelta(0, 7200))
251 | torch.cuda.set_device(args.local_rank)
252 | torch.backends.cudnn.benchmark = True
253 |
254 | if args.do_validate:
255 | phase = "valid"
256 | checkpoints = glob.glob(os.path.join(args.model_path, "*.pt"))
257 | checkpoints = sorted(
258 | checkpoints,
259 | key=lambda ckpt: int(ckpt.split(".")[-2].split("_")[-1]),
260 | reverse=True
261 | )
262 | assert len(args.steps2validate) > 1, "Nothing to validate on"
263 | checkpoints = [ckpt for ckpt in checkpoints
264 | if ckpt.split(".")[-2].split("_")[0] in args.steps2validate] # lr0.001
265 | else:
266 | phase = "test"
267 | checkpoints = [os.path.join(args.model_path, args.model_name)]
268 |
269 |
270 | for ckpt_i, checkpoint in enumerate(checkpoints):
271 | state = torch.load(checkpoint, weights_only=False, map_location=device)
272 | pretrain_args = state["args"]
273 | pretrain_args.load_from = None
274 | pretrain_args.device = device
275 |
276 | pretrain_state_dict = state["state_dict"]
277 | pretrain_args.local_rank = args.local_rank
278 |
279 | attn_model, flow, state = init_model(pretrain_args)
280 | if hasattr(attn_model, "module"):
281 | attn_model = attn_model.module # unwrap DDP attn_model to enable accessing attn_model func directly
282 |
283 | pretrain_state_dict = {k.replace("module.", ""): v for k, v in pretrain_state_dict.items()}
284 | attn_model.load_state_dict(pretrain_state_dict)
285 | log_rank_0(f"Loaded pretrained state_dict from {checkpoint}")
286 |
287 | os.makedirs(args.result_path, exist_ok=True)
288 | results_path = os.path.join(args.result_path, f'{phase}-{args.sample_size}-{checkpoint.split(".")[-2]}.txt')
289 | if os.path.isfile(results_path):
290 | with open(results_path, 'r') as fp:
291 | n_lines = len(fp.readlines())
292 | file_mod = 'a'
293 | start = n_lines
294 | log_rank_0(f"Continuing previous runs at reaction {start}...")
295 | else:
296 | log_rank_0("Starting new run...")
297 | file_mod = 'w'
298 | start = 0
299 |
300 | if args.do_validate:
301 | with open(args.val_path, 'r') as test_o:
302 | test_smiles_list = test_o.readlines()[start:]
303 | else:
304 | with open(args.test_path, 'r') as test_o:
305 | test_smiles_list = test_o.readlines()[start:]
306 |
307 | assert len(test_smiles_list) > 0, "Nothing to do inference"
308 |
309 | test_dataset = ReactionDataset(args, test_smiles_list)
310 | test_loader = init_loader(args, test_dataset,
311 | batch_size=args.test_batch_size,
312 | shuffle=False, epoch=None, use_sort=False)
313 |
314 | with open(results_path, file_mod) as result_o:
315 | metrics = get_predictions(args, attn_model, flow, test_loader, write_o=result_o)
316 | if dist.get_rank() == 0:
317 | metrics = np.array(metrics)
318 | topk_accuracies = np.mean(metrics[:, 0].astype(bool)) # correct smiles
319 | log_rank_0(f"Topk accuracies: {(topk_accuracies * 100): .2f}")
320 |
321 |
322 | if __name__ == "__main__":
323 | args = Args
324 | args.local_rank = int(os.environ["LOCAL_RANK"]) if os.environ.get("LOCAL_RANK") else -1
325 | logger = setup_logger(args, "eval")
326 | log_args(args, 'evaluation')
327 | main(args)
328 |
--------------------------------------------------------------------------------
/model/attn_encoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from utils.attn_utils import PositionwiseFeedForward
5 | from utils.attn_utils import sequence_mask
6 | from utils.data_utils import ELEM_LIST
7 | from model.flow_matching import zero_center_func
8 |
9 | def timestep_embedding(timesteps, dim, max_period=10000):
10 | """Create sinusoidal timestep embeddings.
11 |
12 | :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional.
13 | :param dim: the dimension of the output.
14 | :param max_period: controls the minimum frequency of the embeddings.
15 | :return: an [N x dim] Tensor of positional embeddings.
16 | """
17 | half = dim // 2
18 | freqs = torch.exp(
19 | -math.log(max_period)
20 | * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
21 | / half
22 | )
23 | args = timesteps[:, None].float() * freqs[None]
24 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
25 | if dim % 2:
26 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
27 | return embedding
28 |
29 | def zero_center_output(x_batch, ori_node_mask_batch):
30 | x_batch = x_batch.masked_fill(ori_node_mask_batch, 1e-19)
31 | node_mask_batch = (~ori_node_mask_batch).long()
32 | map_zero_center = torch.vmap(zero_center_func)
33 | return map_zero_center(x_batch, node_mask_batch).masked_fill(~(node_mask_batch.bool()), 1e-19)
34 |
35 | class RBFExpansion(nn.Module):
36 | def __init__(self, args):
37 | """
38 | Adapted from Schnet.
39 | https://github.com/atomistic-machine-learning/SchNet/blob/master/src/schnet/nn/layers/rbf.py
40 | """
41 | super().__init__()
42 | self.args = args
43 | self.device = args.device
44 | self.low = args.rbf_low
45 | self.high = args.rbf_high
46 | self.gap = args.rbf_gap
47 |
48 | self.xrange = self.high - self.low
49 |
50 | self.centers = torch.linspace(self.low, self.high,
51 | int(torch.ceil(torch.tensor(self.xrange / self.gap)))).to(self.device)
52 | self.dim = len(self.centers)
53 |
54 | def forward(self, matrix, matrix_mask):
55 | matrix = matrix.masked_fill(matrix_mask, 1e9)
56 | matrix = matrix.unsqueeze(-1) # Add a new dimension at the end
57 | # Compute the RBF
58 | matrix = matrix - self.centers
59 | rbf = torch.exp(-(matrix ** 2) / self.gap)
60 | return rbf
61 |
62 | class MultiHeadedRelAttention(nn.Module):
63 | def __init__(self, args, head_count, model_dim, dropout, u, v):
64 | super().__init__()
65 | self.args = args
66 |
67 | assert model_dim % head_count == 0
68 | self.dim_per_head = model_dim // head_count
69 | self.model_dim = model_dim
70 | self.head_count = head_count
71 |
72 | self.linear_keys = nn.Linear(model_dim, model_dim)
73 | self.linear_values = nn.Linear(model_dim, model_dim)
74 | self.linear_query = nn.Linear(model_dim, model_dim)
75 |
76 | self.softmax = nn.Softmax(dim=-1)
77 | self.dropout = nn.Dropout(dropout)
78 | self.final_linear = nn.Linear(model_dim, model_dim)
79 |
80 | self.u = u if u is not None else \
81 | nn.Parameter(torch.randn(self.d_model), requires_grad=True)
82 | self.v = v if v is not None else \
83 | nn.Parameter(torch.randn(self.d_model), requires_grad=True)
84 |
85 | def forward(self, inputs, mask, rel_emb):
86 | """
87 | Compute the context vector and the attention vectors.
88 |
89 | Args:
90 | inputs (FloatTensor): set of `key_len`
91 | key vectors ``(batch, key_len, dim)``
92 | mask: binary mask 1/0 indicating which keys have
93 | zero / non-zero attention ``(batch, query_len, key_len)``
94 | rel_emb: graph distance matrix (BUCKETED), ``(batch, key_len, key_len)``
95 | Returns:
96 | (FloatTensor, FloatTensor):
97 |
98 | * output context vectors ``(batch, query_len, dim)``
99 | * Attention vector in heads ``(batch, head, query_len, key_len)``.
100 | """
101 |
102 | batch_size = inputs.size(0)
103 | dim_per_head = self.dim_per_head
104 | head_count = self.head_count
105 |
106 | def shape(x):
107 | """Projection."""
108 | return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2)
109 |
110 | def unshape(x):
111 | """Compute context."""
112 | return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head)
113 |
114 | # 1) Project key, value, and query. Seems that we don't need layer_cache here
115 | query = self.linear_query(inputs)
116 | key = self.linear_keys(inputs)
117 | value = self.linear_values(inputs)
118 |
119 | key = shape(key) # (b, t_k, h) -> (b, head, t_k, h/head)
120 | value = shape(value)
121 | query = shape(query) # (b, t_q, h) -> (b, head, t_q, h/head)
122 |
123 | key_len = key.size(2)
124 | query_len = query.size(2)
125 |
126 | # 2) Calculate and scale scores.
127 | query = query / math.sqrt(dim_per_head)
128 |
129 | if rel_emb is None:
130 | scores = torch.matmul(
131 | query, key.transpose(2, 3)) # (b, head, t_q, t_k)
132 | # scores = scores + rel_emb.unsqueeze(1)
133 | else:
134 | # a + c
135 | u = self.u.reshape(1, head_count, 1, dim_per_head)
136 | a_c = torch.matmul(query + u, key.transpose(2, 3))
137 |
138 | # rel_emb = self.relative_pe(rel_emb) # (b, t_q, t_k) -> (b, t_q, t_k, h)
139 | rel_emb = rel_emb.reshape( # (b, t_q, t_k, h) -> (b, t_q, t_k, head, h/head)
140 | batch_size, query_len, key_len, head_count, dim_per_head)
141 |
142 | # b + d
143 | query = query.unsqueeze(-2) # (b, head, t_q, h/head) -> (b, head, t_q, 1, h/head)
144 | rel_emb_t = rel_emb.permute(0, 3, 1, 4, 2) # (b, t_q, t_k, head, h/head) -> (b, head, t_q, h/head, t_k)
145 |
146 | v = self.v.reshape(1, head_count, 1, 1, dim_per_head)
147 | b_d = torch.matmul(query + v, rel_emb_t
148 | ).squeeze(-2) # (b, head, t_q, 1, t_k) -> (b, head, t_q, t_k)
149 |
150 | scores = a_c + b_d
151 |
152 | scores = scores.float()
153 |
154 | mask = mask.unsqueeze(1) # (B, 1, 1, T_values)
155 | scores = scores.masked_fill(mask, -1e18)
156 |
157 | # 3) Apply attention dropout and compute context vectors.
158 | attn = self.softmax(scores)
159 | drop_attn = self.dropout(attn)
160 |
161 | context_original = torch.matmul(drop_attn, value) # -> (b, head, t_q, h/head)
162 | context = unshape(context_original) # -> (b, t_q, h)
163 |
164 | output = self.final_linear(context)
165 | attns = attn.view(batch_size, head_count, query_len, key_len)
166 |
167 | return output, attns
168 |
169 |
170 | class SALayerXL(nn.Module):
171 | """
172 | A single layer of the self-attention encoder.
173 |
174 | Args:
175 | d_model (int): the dimension of keys/values/queries in
176 | MultiHeadedAttention, also the input size of
177 | the first-layer of the PositionwiseFeedForward.
178 | heads (int): the number of head for MultiHeadedAttention.
179 | d_ff (int): the second-layer of the PositionwiseFeedForward.
180 | dropout: dropout probability(0-1.0).
181 | """
182 |
183 | def __init__(self, args, d_model, heads, d_ff, dropout, attention_dropout, u, v):
184 | super().__init__()
185 |
186 | self.self_attn = MultiHeadedRelAttention(
187 | args,
188 | heads, d_model, dropout=attention_dropout,
189 | u=u,
190 | v=v
191 | )
192 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
193 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
194 | self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
195 | self.dropout = nn.Dropout(dropout)
196 |
197 | def forward(self, inputs, mask, rel_emb):
198 | """
199 | Args:
200 | inputs (FloatTensor): ``(batch_size, src_len, model_dim)``
201 | mask (LongTensor): ``(batch_size, 1, src_len)``
202 | rel_emb (LongTensor): ``(batch_size, src_len, src_len)``
203 |
204 | Returns:
205 | (FloatTensor):
206 |
207 | * outputs ``(batch_size, src_len, model_dim)``
208 | """
209 | normed_inputs = self.layer_norm(inputs)
210 | context, _ = self.self_attn(normed_inputs, mask=mask, rel_emb=rel_emb)
211 | out = self.dropout(context) + inputs
212 | return self.feed_forward(self.layer_norm_2(out)) + out
213 |
214 | class Block(nn.Module):
215 | def __init__(self, size: int):
216 | super().__init__()
217 |
218 | self.ff = nn.Linear(size, size)
219 | self.act = nn.GELU()
220 | self.layer_norm = nn.LayerNorm(size, eps=1e-6)
221 |
222 | def forward(self, x: torch.Tensor):
223 | return x + self.layer_norm(self.act(self.ff(x)))
224 |
225 | class AttnEncoderXL(nn.Module):
226 | def __init__(self, args):
227 | super().__init__()
228 | self.args = args
229 | self.device = args.device
230 | self.num_layers = args.enc_num_layers
231 | self.post_processing_layers = args.post_processing_layers
232 | self.d_model = args.emb_dim
233 | self.heads = args.enc_heads
234 | self.d_ff = args.enc_filter_size
235 | self.attention_dropout = args.attn_dropout
236 |
237 | self.atom_embedding = nn.Embedding(len(ELEM_LIST) , self.d_model, padding_idx=0)
238 | self.rbf = RBFExpansion(args)
239 |
240 | self.time_dim = self.d_model - self.rbf.dim
241 | self.time_embed = nn.Sequential(
242 | nn.Linear(self.time_dim, self.time_dim),
243 | nn.SiLU(),
244 | nn.Linear(self.time_dim, self.time_dim),
245 | )
246 |
247 | self.dropout = nn.Dropout(p=args.dropout)
248 | if args.rel_pos in ["enc_only", "emb_only"]:
249 | self.u = nn.Parameter(torch.randn(self.d_model), requires_grad=True)
250 | self.v = nn.Parameter(torch.randn(self.d_model), requires_grad=True)
251 | else:
252 | self.u = None
253 | self.v = None
254 |
255 | if args.shared_attention_layer == 1:
256 | self.attention_layer = SALayerXL(
257 | args, self.d_model, self.heads, self.d_ff, args.dropout, self.attention_dropout,
258 | self.u, self.v)
259 | else:
260 | self.attention_layers = nn.ModuleList(
261 | [SALayerXL(
262 | args, self.d_model, self.heads, self.d_ff, args.dropout, self.attention_dropout,
263 | self.u, self.v)
264 | for i in range(self.num_layers)])
265 | self.layer_norm = nn.LayerNorm(self.d_model, eps=1e-6)
266 |
267 | self.query_w = torch.nn.Sequential(
268 | *[Block(self.d_model) for _ in range(self.post_processing_layers)]
269 | )
270 | self.key_w = torch.nn.Sequential(
271 | *[Block(self.d_model) for _ in range(self.post_processing_layers)]
272 | )
273 |
274 | self.query_diag_w = torch.nn.Sequential(
275 | *[Block(self.d_model) for _ in range(self.post_processing_layers)]
276 | )
277 | self.key_diag_w = torch.nn.Sequential(
278 | *[Block(self.d_model) for _ in range(self.post_processing_layers)]
279 | )
280 | self.value_diag_w = torch.nn.Sequential(
281 | *[Block(self.d_model) for _ in range(self.post_processing_layers)]
282 | )
283 | self.final_diag_w = torch.nn.Linear(self.d_model, 1)
284 |
285 | self.rel_emb_w = torch.nn.Sequential(
286 | *[*[Block(self.d_model) for _ in range(self.post_processing_layers)],
287 | torch.nn.Linear(self.d_model, 1)]
288 | )
289 | self.softmax = nn.Softmax(dim=-1)
290 |
291 | rbf_layers = []
292 | # rbf_layers.append(torch.nn.Linear(self.rbf.dim+1, self.rbf.dim))
293 | for _ in range(self.post_processing_layers):
294 | rbf_layers.append(Block(self.rbf.dim))
295 | self.rbf_linear = torch.nn.Sequential(*rbf_layers)
296 | self.rbf_final_linear = torch.nn.Linear(self.rbf.dim, 1)
297 |
298 | def id2emb(self, src_token_id):
299 | return self.atom_embedding(src_token_id)
300 |
301 | def forward(self, src, lengths, bond_matrix, timestep):
302 | """adapt from onmt TransformerEncoder
303 | src_token_id: (b, t, h)
304 | lengths: (b,)
305 |
306 | NEW on Jan'23: return: (b, t, h)
307 | """
308 | if timestep.dim() == 0:
309 | timestep = timestep.repeat(lengths.shape[0])
310 |
311 | b, n, _ = bond_matrix.shape
312 | timestep = self.time_embed(timestep_embedding(timestep, self.time_dim))
313 | timestep = timestep.unsqueeze(1).unsqueeze(1) # unsqueeze to match bond n x n
314 | timestep = timestep.repeat(1, n, n, 1) # unsqueeze to match bond n x n
315 |
316 | mask = ~sequence_mask(lengths).unsqueeze(1)
317 |
318 | matrix_masks = ~(~mask * ~mask.transpose(1, 2)).bool()
319 | rbf_bond_matrix = self.rbf(bond_matrix, matrix_masks)
320 | rbf_bond_matrix = self.rbf_linear(rbf_bond_matrix) # b, n, n, 1 -> b, n, n, rbf-dim
321 |
322 | rel_emb = torch.cat((rbf_bond_matrix, timestep), dim=-1) # b, n, n, d
323 |
324 | # src = self.atom_embedding(src_token_id)
325 | # h_place = (src_token_id == 1).float().unsqueeze(-1).repeat(1, 1, src.shape[-1])
326 |
327 | b, n, d = src.shape
328 |
329 | # a_i - raw atom embeddings
330 | a_i = src * math.sqrt(self.d_model)
331 | a_i = self.dropout(a_i)
332 |
333 | if self.args.shared_attention_layer == 1:
334 | layer = self.attention_layer
335 | for i in range(self.num_layers):
336 | a_i = layer(a_i, mask, rel_emb)
337 | else:
338 | for layer in self.attention_layers:
339 | a_i = layer(a_i, mask, rel_emb)
340 | a_i = self.layer_norm(a_i) # b,n,d
341 |
342 | # a_i - atom embeddings after multiheaded attention on atom embeddings + rbf expansion
343 |
344 | # diagonal prediction
345 | query_diag = self.query_diag_w(a_i) # b,n,d @ d,d -> b,n,d
346 | key_diag = self.key_diag_w(a_i) # b,n,d @ d,d -> b,n,d
347 | value_diag = self.value_diag_w(a_i) # b,n,d @ d,d -> b,n,d
348 |
349 | diag_scores = torch.matmul(query_diag, key_diag.transpose(1, 2)) # b,n,d @ b,d,n -> b,n,n
350 | diag_scores = diag_scores.masked_fill(matrix_masks, 1e-9)
351 | diag_scores = self.softmax(diag_scores) / math.sqrt(self.d_model)
352 | context = torch.matmul(diag_scores, value_diag) # b,n,n @ b,n,d -> b,n,d
353 | diag = self.final_diag_w(context).view(b, n) # b,n,d @ d,1 -> b,n,1 -> b,n
354 |
355 | # non diagonal prediction
356 | query = self.query_w(a_i) # b,n,d @ d,d -> b,n,d
357 | key = self.key_w(a_i)
358 |
359 | scores = torch.matmul(query, key.transpose(1, 2)) # b,n,d @ b,d,n -> b,n,n
360 | a_ij = scores / math.sqrt(self.d_model)
361 |
362 | rbfw_ij = self.rel_emb_w(rel_emb).view(b, n, n) # b,n,n,d @ d,1 -> b,n,n,1 -> b,n,n
363 | out = a_ij + rbfw_ij
364 |
365 | for i in range(b):
366 | indices = torch.arange(n)
367 | out[i, indices, indices] = 0
368 | out[i].diagonal().add_(diag[i])
369 |
370 | out = zero_center_output(out, matrix_masks)
371 | out = (out + out.transpose(1, 2))
372 |
373 | return out
374 |
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from rdkit import Chem
4 | import numpy as np
5 | import sys
6 | from rdkit import RDLogger
7 | RDLogger.DisableLog('rdApp.*')
8 | from multiprocessing import Pool, cpu_count
9 |
10 | np.set_printoptions(threshold=sys.maxsize, linewidth=500)
11 | torch.set_printoptions(profile="full")
12 |
13 | ELEM_LIST = ['PAD', 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Na', 'Mg', 'Al', 'Si', \
14 | 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', \
15 | 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Sr', 'Y', 'Zr', 'Mo', 'Tc', 'Ru', \
16 | 'Rh', 'Pd', 'Ag', 'In', 'Sn', 'Sb', 'Te', 'I', 'Cs', 'Ba', 'La', 'Ce', 'Eu', \
17 | 'Yb', 'Ta', 'W', 'Os', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'V', 'Sm']
18 |
19 | MATRIX_PAD = -30
20 | bt_to_electron = {Chem.rdchem.BondType.SINGLE: 2,
21 | Chem.rdchem.BondType.DOUBLE: 4,
22 | Chem.rdchem.BondType.TRIPLE: 6,
23 | Chem.rdchem.BondType.AROMATIC: 3}
24 |
25 | tbl = Chem.GetPeriodicTable()
26 |
27 | def bond_features(bond):
28 | bt = bond.GetBondType()
29 |
30 | return bt_to_electron[bt]
31 |
32 | def count_lone_pairs(a):
33 | v=tbl.GetNOuterElecs(a.GetAtomicNum())
34 | c=a.GetFormalCharge()
35 | b=sum([bond.GetBondTypeAsDouble() for bond in a.GetBonds()])
36 | h=a.GetTotalNumHs()
37 | return v-c-b-h
38 |
39 | ps = Chem.SmilesParserParams()
40 | ps.removeHs = False
41 | ps.sanitize = True
42 | def get_BE_matrix(r):
43 | rmol = Chem.MolFromSmiles(r, ps)
44 | Chem.Kekulize(rmol)
45 | max_natoms = len(rmol.GetAtoms())
46 | f = np.zeros((max_natoms,max_natoms))
47 |
48 | for atom in rmol.GetAtoms():
49 | lone_pair = count_lone_pairs(atom)
50 | f[atom.GetIntProp('molAtomMapNumber') - 1, atom.GetIntProp('molAtomMapNumber') - 1] = lone_pair
51 |
52 | for bond in rmol.GetBonds():
53 | a1 = bond.GetBeginAtom().GetIntProp('molAtomMapNumber') - 1
54 | a2 = bond.GetEndAtom().GetIntProp('molAtomMapNumber') - 1
55 | f[(a1,a2)] = f[(a2,a1)] = bond_features(bond)/2 # so that bond electron diff matrix sums up to 0
56 |
57 | return f
58 |
59 | electron_to_bo = {val:key for key, val in bt_to_electron.items()}
60 |
61 | def get_formal_charge(a, electron):
62 | v=tbl.GetNOuterElecs(a.GetAtomicNum())
63 | b=sum([bond.GetBondTypeAsDouble() for bond in a.GetBonds()])
64 | h=a.GetTotalNumHs()
65 | f =v - electron - b - h
66 | return f
67 |
68 | def mol_prop_compute(matrix):
69 | """
70 | vectorized way of computing atom dict and bond dict from matrix
71 | """
72 | n = matrix.shape[0]
73 |
74 | # 1) Compute symmetric bond sums once:
75 | Mplus = matrix + matrix.T
76 |
77 | # 2) Extract all off-diagonal i>')
132 |
133 | error = ""
134 | try:
135 | _ = get_BE_matrix(src_smi)
136 | _ = get_BE_matrix(tgt_smi)
137 | src_vocab_id_list, src_len = smi2vocabid(src_smi)
138 | tgt_vocab_id_list, tgt_len = smi2vocabid(tgt_smi)
139 | assert (src_vocab_id_list == tgt_vocab_id_list).all()
140 | except Exception as e:
141 | error = e
142 | src_smi, tgt_smi = '', ''
143 | src_vocab_id_list, src_len = [], 0
144 | tgt_vocab_id_list, tgt_len = [], 0
145 |
146 | # Return a tuple of results for this smiles pair
147 | return {
148 | 'src_smi': src_smi,
149 | 'tgt_smi': tgt_smi,
150 | 'src_vocab_id_list': src_vocab_id_list,
151 | 'tgt_vocab_id_list': tgt_vocab_id_list,
152 | 'src_len': src_len,
153 | 'tgt_len': tgt_len
154 | # "error": error
155 | }
156 |
157 | class ReactionBatch:
158 | def __init__(self,
159 | src_data_indices: torch.Tensor,
160 | src_token_ids: torch.Tensor,
161 | src_lens: torch.Tensor,
162 | src_matrices: torch.Tensor,
163 | tgt_matrices: torch.Tensor,
164 | matrix_masks: torch.Tensor,
165 | src_smiles_list: list,
166 | tgt_smiles_list: list,
167 | ):
168 | self.src_data_indices = src_data_indices
169 | self.src_token_ids = src_token_ids
170 | self.src_lens = src_lens
171 | self.src_matrices = src_matrices
172 | self.tgt_matrices = tgt_matrices
173 | self.matrix_masks = matrix_masks
174 | self.src_smiles_list = src_smiles_list
175 | self.tgt_smiles_list = tgt_smiles_list
176 |
177 | def to(self, device):
178 | self.src_data_indices = self.src_data_indices.to(device)
179 | self.src_token_ids = self.src_token_ids.to(device)
180 | self.src_lens = self.src_lens.to(device)
181 | self.src_matrices = self.src_matrices.to(device)
182 | self.tgt_matrices = self.tgt_matrices.to(device)
183 | self.matrix_masks = self.matrix_masks.to(device)
184 |
185 | def pin_memory(self):
186 | self.src_data_indices = self.src_data_indices.pin_memory()
187 | self.src_token_ids = self.src_token_ids.pin_memory()
188 | self.src_lens = self.src_lens.pin_memory()
189 | self.src_matrices = self.src_matrices.pin_memory()
190 | self.tgt_matrices = self.tgt_matrices.pin_memory()
191 | self.matrix_masks = self.matrix_masks.pin_memory()
192 |
193 | return self
194 |
195 | class ReactionDataset(Dataset):
196 | def __init__(self, args, smiles_list, parallel=True, reactant_only=False):
197 | self.args = args
198 | self.device = args.device
199 | self.reactant_only = reactant_only
200 | self.smiles_list = smiles_list
201 | self.src_smis = []
202 | self.tgt_smis = []
203 |
204 | self.src_token_ids = []
205 | self.tgt_token_ids = []
206 |
207 | self.src_lens = []
208 | self.tgt_lens = []
209 |
210 | if reactant_only:
211 | self.parse_reactant_only()
212 | else:
213 | if parallel:
214 | self.parse_data_parallel()
215 | else:
216 | self.parse_data()
217 |
218 | self.src_lens = np.asarray(self.src_lens)
219 |
220 | self.data_size = len(self.src_smis)
221 | self.data_indices = np.arange(self.data_size)
222 |
223 | def parse_reactant_only(self):
224 | for src_smi in self.smiles_list:
225 | src_smi = src_smi.strip()
226 | try:
227 | _ = get_BE_matrix(src_smi)
228 | src_vocab_id_list, src_len = smi2vocabid(src_smi)
229 | except Exception as e:
230 | print(e)
231 | continue
232 |
233 | self.src_smis.append(src_smi)
234 | self.src_token_ids.append(src_vocab_id_list)
235 | self.src_lens.append(src_len)
236 | self.tgt_lens.append(src_len)
237 |
238 | assert len(self.src_smis) > 0, "Empty Data"
239 |
240 | def parse_data(self):
241 | for smiles in self.smiles_list:
242 | src_smi, tgt_smi = smiles.strip().split('|')[0].split('>>')
243 |
244 | try:
245 | src_matrix = get_BE_matrix(src_smi)
246 | tgt_matrix = get_BE_matrix(tgt_smi)
247 | src_vocab_id_list, src_len = smi2vocabid(src_smi)
248 | tgt_vocab_id_list, tgt_len = smi2vocabid(tgt_smi)
249 | assert (src_vocab_id_list == tgt_vocab_id_list).all()
250 | assert src_len == tgt_len, "src len and tgt len should be the same"
251 | except Exception as e:
252 | print(e)
253 | continue
254 |
255 | self.src_smis.append(src_smi)
256 | self.tgt_smis.append(tgt_smi)
257 | self.src_token_ids.append(src_vocab_id_list)
258 | self.tgt_token_ids.append(tgt_vocab_id_list)
259 | self.src_lens.append(src_len)
260 | self.tgt_lens.append(tgt_len)
261 |
262 | assert len(self.src_smis) == len(self.tgt_smis) == len(self.src_lens) == len(self.tgt_lens) \
263 | == len(self.tgt_lens) == len(self.src_token_ids) == len(self.tgt_token_ids)
264 |
265 | def parse_data_parallel(self):
266 |
267 | p = Pool(cpu_count())
268 | results = p.imap(process_smiles, ((smiles) for smiles in self.smiles_list))
269 | p.close()
270 | p.join()
271 |
272 | # Prepare the final data structures
273 | count = 0
274 | total = 0
275 | for result in results:
276 | total += 1
277 | if result['src_vocab_id_list'] is [] or result['src_len'] == 0:
278 | # print(f"{result['src_smi']}>>{result['tgt_smi']}")
279 | # print(result['error'])
280 | count += 1
281 | continue
282 | self.src_smis.append(result['src_smi'])
283 | self.tgt_smis.append(result['tgt_smi'])
284 | self.src_token_ids.append(result['src_vocab_id_list'])
285 | self.tgt_token_ids.append(result['tgt_vocab_id_list'])
286 | self.src_lens.append(result['src_len'])
287 | self.tgt_lens.append(result['tgt_len'])
288 |
289 | print(f"{count*100/total}% data is unparseable")
290 |
291 | def sort(self):
292 | self.data_indices = np.argsort(self.src_lens)
293 |
294 | def shuffle_in_bucket(self, bucket_size: int):
295 | for i in range(0, self.data_size, bucket_size):
296 | np.random.shuffle(self.data_indices[i:i + bucket_size])
297 |
298 | def batch(self, batch_type: str, batch_size: int, verbose=False):
299 |
300 | self.batch_sizes = []
301 | if batch_type.startswith("tokens"):
302 | sample_size = 0
303 | max_batch_src_len = 0
304 | max_batch_tgt_len = 0
305 |
306 | for data_idx in self.data_indices:
307 | src_len = self.src_lens[data_idx]
308 | tgt_len = self.tgt_lens[data_idx]
309 |
310 | max_batch_src_len = max(src_len, max_batch_src_len)
311 | max_batch_tgt_len = max(tgt_len, max_batch_tgt_len)
312 |
313 | if batch_type == "tokens" and \
314 | max_batch_src_len * (sample_size + 1) <= batch_size:
315 | sample_size += 1
316 | elif batch_type == "tokens_sum" and \
317 | (max_batch_src_len + max_batch_tgt_len) * (sample_size + 1) <= batch_size:
318 | sample_size += 1
319 | else:
320 | self.batch_sizes.append(sample_size)
321 |
322 | sample_size = 1
323 | max_batch_src_len = src_len
324 | max_batch_tgt_len = tgt_len
325 |
326 | # lastly
327 | self.batch_sizes.append(sample_size)
328 | self.batch_sizes = np.array(self.batch_sizes)
329 | assert np.sum(self.batch_sizes) == self.data_size, \
330 | f"Size mismatch! Data size: {self.data_size}, sum batch sizes: {np.sum(self.batch_sizes)}"
331 |
332 | self.batch_ends = np.cumsum(self.batch_sizes)
333 | self.batch_starts = np.concatenate([[0], self.batch_ends[:-1]])
334 |
335 | else:
336 | raise ValueError(f"batch_type {batch_type} not supported!")
337 |
338 |
339 | def __len__(self):
340 | return len(self.batch_sizes)
341 |
342 | def __getitem__(self, idx : int):
343 | batch_index = idx
344 |
345 | batch_start = self.batch_starts[batch_index]
346 | batch_end = self.batch_ends[batch_index]
347 |
348 | data_indices = self.data_indices[batch_start:batch_end]
349 |
350 | # print(self.src_lens[data_indices], data_indices)
351 | max_len = max(self.src_lens[data_indices])
352 |
353 | src_token_id_batch = []
354 | src_len_batch = []
355 | src_matrix_batch = []
356 | tgt_matrix_batch = []
357 | src_smiles_batch = []
358 | tgt_smiles_batch = []
359 | for data_index in data_indices:
360 | # src_token_id, _ = smi2vocabid(self.src_smis[data_index])
361 | src_token_id = self.src_token_ids[data_index]
362 | src_len = self.src_lens[data_index]
363 | src_token_id = np.pad(src_token_id, (0, max_len - src_len),
364 | mode='constant', constant_values=0) # constant value 0 based on 'PAD' in ELEM_LIST
365 | src_token_id = torch.as_tensor(src_token_id, dtype=torch.long)
366 |
367 | src_token_id_batch.append(src_token_id)
368 |
369 | src_matrix = get_BE_matrix(self.src_smis[data_index])
370 | src_matrix = np.pad(src_matrix, ((0, max_len - src_len), (0, max_len - src_len)),
371 | mode='constant', constant_values=MATRIX_PAD)
372 | src_len_batch.append(src_len)
373 | src_matrix_batch.append(src_matrix)
374 | src_smiles_batch.append(self.src_smis[data_index])
375 |
376 | if not self.reactant_only:
377 | tgt_matrix = get_BE_matrix(self.tgt_smis[data_index])
378 | tgt_matrix = np.pad(tgt_matrix, ((0, max_len - src_len), (0, max_len - src_len)),
379 | mode='constant', constant_values=MATRIX_PAD)
380 | tgt_matrix_batch.append(tgt_matrix)
381 | tgt_smiles_batch.append(self.tgt_smis[data_index])
382 |
383 | src_data_indices = torch.as_tensor(data_indices, dtype=torch.long)
384 | src_len_batch = torch.as_tensor(src_len_batch, dtype=torch.long)
385 | src_token_id_batch = torch.stack(src_token_id_batch)
386 | src_matrix_batch = torch.as_tensor(np.stack(src_matrix_batch), dtype=torch.float)
387 | if not self.reactant_only:
388 | tgt_matrix_batch = torch.as_tensor(np.stack(tgt_matrix_batch), dtype=torch.float)
389 | else: tgt_matrix_batch = src_matrix_batch
390 |
391 | node_mask = (src_matrix_batch[:, :, 0] != MATRIX_PAD)
392 | matrix_masks = (node_mask.unsqueeze(1) * node_mask.unsqueeze(2)).long()
393 |
394 | reaction_batch = ReactionBatch(
395 | src_data_indices=src_data_indices,
396 | src_token_ids=src_token_id_batch,
397 | src_lens=src_len_batch,
398 | src_matrices=src_matrix_batch,
399 | tgt_matrices=tgt_matrix_batch,
400 | matrix_masks=matrix_masks,
401 | src_smiles_list=src_smiles_batch,
402 | tgt_smiles_list=tgt_smiles_batch
403 | )
404 |
405 | return reaction_batch
406 |
--------------------------------------------------------------------------------