├── .gitignore ├── GCN ├── batch.py ├── dataloader.py ├── feature_extract.py ├── finetune.py ├── finetune_mutag_ptc.py ├── loader.py ├── model.py ├── model_gin │ ├── contextpred.pth │ ├── edgepred.pth │ ├── infomax.pth │ ├── masking.pth │ ├── supervised.pth │ ├── supervised_contextpred.pth │ ├── supervised_edgepred.pth │ ├── supervised_infomax.pth │ └── supervised_masking.pth ├── parse_result.py ├── pretrain_contextpred.py ├── pretrain_deepgraphinfomax.py ├── pretrain_edgepred.py ├── pretrain_masking.py ├── pretrain_supervised.py ├── splitters.py └── util.py ├── LICENSE ├── README.md ├── agent.py ├── assets └── pipeline.png ├── colab.ipynb ├── datasets ├── acrylates.txt ├── chain_extenders.txt ├── isocyanates.txt └── polymers_117.txt ├── environment.yml ├── fuseprop ├── __init__.py ├── chemutils.py ├── dataset.py ├── decoder.py ├── encoder.py ├── gnn.py ├── inc_graph.py ├── mol_graph.py ├── nnutils.py ├── rnn.py └── vocab.py ├── grammar_generation.py ├── main.py ├── private ├── __init__.py ├── grammar.py ├── hypergraph.py ├── metrics.py ├── molecule_graph.py ├── subgraph_set.py ├── symbol.py └── utils.py ├── retro_star ├── alg │ ├── __init__.py │ ├── mol_node.py │ ├── mol_tree.py │ ├── molstar.py │ ├── reaction_node.py │ └── syn_route.py ├── api.py ├── common │ ├── __init__.py │ ├── parse_args.py │ ├── prepare_utils.py │ └── smiles_to_fp.py ├── data_loader │ ├── __init__.py │ └── value_data_loader.py ├── environment.yml ├── model │ ├── __init__.py │ └── value_mlp.py ├── packages │ ├── mlp_retrosyn │ │ ├── mlp_retrosyn.egg-info │ │ │ ├── PKG-INFO │ │ │ ├── SOURCES.txt │ │ │ ├── dependency_links.txt │ │ │ └── top_level.txt │ │ ├── mlp_retrosyn │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ │ └── toy │ │ │ │ │ └── sample.csv │ │ │ ├── extract_template.py │ │ │ ├── mlp_inference.py │ │ │ ├── mlp_policies.py │ │ │ ├── mlp_train.py │ │ │ └── scripts │ │ │ │ ├── run_extract_templates.sh │ │ │ │ ├── run_mlp_inference.sh │ │ │ │ └── run_mlp_train.sh │ │ └── setup.py │ └── rdchiral │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── README.md │ │ ├── rdchiral │ │ ├── __init__.py │ │ ├── backup │ │ │ ├── __init__.py │ │ │ ├── bonds.py │ │ │ ├── chiral.py │ │ │ ├── clean.py │ │ │ ├── initialization.py │ │ │ ├── main.py │ │ │ ├── template_extractor.py │ │ │ └── utils.py │ │ ├── bonds.py │ │ ├── chiral.py │ │ ├── clean.py │ │ ├── initialization.py │ │ ├── main.py │ │ ├── old │ │ │ ├── __init__.py │ │ │ ├── chiral.py │ │ │ ├── clean.py │ │ │ ├── initialization.py │ │ │ ├── main.py │ │ │ ├── template_extractor.py │ │ │ └── utils.py │ │ ├── template_extractor.py │ │ ├── test │ │ │ ├── __init__.py │ │ │ └── test_smiles_from_50k_uspto.txt │ │ └── utils.py │ │ ├── setup.py │ │ ├── templates │ │ ├── Examine templates.ipynb │ │ ├── README.md │ │ ├── clean_and_extract_uspto.py │ │ ├── example_template_extractions_bad.json │ │ └── example_template_extractions_good.json │ │ └── test │ │ ├── Test rdchiral notebook.ipynb │ │ ├── test_rdchiral.py │ │ └── test_rdchiral_cases.json ├── retro_plan.py ├── train.py ├── trainer │ ├── __init__.py │ └── trainer.py └── utils │ ├── __init__.py │ └── logger.py ├── retro_star_listener.py ├── retro_star_listener.sh ├── simple_main.py └── visualization.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | __pycache__/ 3 | ./*/__pycache__/ 4 | ./*/*/__pycache__/ 5 | 6 | .ipynb_checkpoints/ 7 | .git 8 | .DS_store 9 | ./*/.DS_store 10 | .DS_Store 11 | ./*/.DS_Store 12 | ./*/*/.DS_Store 13 | ./*/*/.DS_Store 14 | retro_star/dataset/ 15 | retro_star/one_step_model/ 16 | retro_star/saved_models/ 17 | retro_star/retro_data.zip 18 | 19 | # Environments 20 | .env 21 | .venv 22 | -------------------------------------------------------------------------------- /GCN/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torch.utils.data.dataloader import default_collate 3 | 4 | from batch import BatchSubstructContext, BatchMasking, BatchAE 5 | 6 | class DataLoaderSubstructContext(torch.utils.data.DataLoader): 7 | r"""Data loader which merges data objects from a 8 | :class:`torch_geometric.data.dataset` to a mini-batch. 9 | Args: 10 | dataset (Dataset): The dataset from which to load the data. 11 | batch_size (int, optional): How may samples per batch to load. 12 | (default: :obj:`1`) 13 | shuffle (bool, optional): If set to :obj:`True`, the data will be 14 | reshuffled at every epoch (default: :obj:`True`) 15 | """ 16 | 17 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 18 | super(DataLoaderSubstructContext, self).__init__( 19 | dataset, 20 | batch_size, 21 | shuffle, 22 | collate_fn=lambda data_list: BatchSubstructContext.from_data_list(data_list), 23 | **kwargs) 24 | 25 | class DataLoaderMasking(torch.utils.data.DataLoader): 26 | r"""Data loader which merges data objects from a 27 | :class:`torch_geometric.data.dataset` to a mini-batch. 28 | Args: 29 | dataset (Dataset): The dataset from which to load the data. 30 | batch_size (int, optional): How may samples per batch to load. 31 | (default: :obj:`1`) 32 | shuffle (bool, optional): If set to :obj:`True`, the data will be 33 | reshuffled at every epoch (default: :obj:`True`) 34 | """ 35 | 36 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 37 | super(DataLoaderMasking, self).__init__( 38 | dataset, 39 | batch_size, 40 | shuffle, 41 | collate_fn=lambda data_list: BatchMasking.from_data_list(data_list), 42 | **kwargs) 43 | 44 | 45 | class DataLoaderAE(torch.utils.data.DataLoader): 46 | r"""Data loader which merges data objects from a 47 | :class:`torch_geometric.data.dataset` to a mini-batch. 48 | Args: 49 | dataset (Dataset): The dataset from which to load the data. 50 | batch_size (int, optional): How may samples per batch to load. 51 | (default: :obj:`1`) 52 | shuffle (bool, optional): If set to :obj:`True`, the data will be 53 | reshuffled at every epoch (default: :obj:`True`) 54 | """ 55 | 56 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 57 | super(DataLoaderAE, self).__init__( 58 | dataset, 59 | batch_size, 60 | shuffle, 61 | collate_fn=lambda data_list: BatchAE.from_data_list(data_list), 62 | **kwargs) 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /GCN/feature_extract.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from rdkit import Chem 7 | from tqdm import tqdm 8 | import numpy as np 9 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool 10 | from GCN.model import GNN, GNN_feature 11 | from GCN.loader import mol_to_graph_data_obj_simple 12 | 13 | 14 | class feature_extractor(): 15 | def __init__(self, pretrained_model_path): 16 | self.pretrained_model_path = pretrained_model_path 17 | 18 | def preprocessing(self, graph_mol): 19 | # rdkit_mol = AllChem.MolFromSmiles(graph_sml) 20 | # kekulize the molecule to distinguish the aromatic bonds 21 | data = mol_to_graph_data_obj_simple(Chem.MolFromSmiles(Chem.MolToSmiles(graph_mol))) 22 | # print("data.x, data.edge_index, data.edge_attr:", data.x, data.edge_index, data.edge_attr) 23 | # print(Chem.MolToSmiles(graph_mol)) 24 | # import pdb; pdb.set_trace() 25 | return data 26 | 27 | def extract(self, graph_mol): 28 | model = GNN_feature(num_layer=5, emb_dim=300, num_tasks=1, JK='last', drop_ratio=0, graph_pooling='mean', gnn_type='gin') 29 | model.from_pretrained(self.pretrained_model_path) 30 | # model.cuda(device=0) 31 | model.eval() 32 | graph_data = self.preprocessing(graph_mol) 33 | # graph_data = graph_data.cuda(device=0) 34 | with torch.no_grad(): 35 | node_features = model(graph_data.x, graph_data.edge_index, graph_data.edge_attr) 36 | del model 37 | return node_features 38 | 39 | -------------------------------------------------------------------------------- /GCN/model_gin/contextpred.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmh14/data_efficient_grammar/ef34648f987278496e1216cbeb7f82c9429da4b0/GCN/model_gin/contextpred.pth -------------------------------------------------------------------------------- /GCN/model_gin/edgepred.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmh14/data_efficient_grammar/ef34648f987278496e1216cbeb7f82c9429da4b0/GCN/model_gin/edgepred.pth -------------------------------------------------------------------------------- /GCN/model_gin/infomax.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmh14/data_efficient_grammar/ef34648f987278496e1216cbeb7f82c9429da4b0/GCN/model_gin/infomax.pth -------------------------------------------------------------------------------- /GCN/model_gin/masking.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmh14/data_efficient_grammar/ef34648f987278496e1216cbeb7f82c9429da4b0/GCN/model_gin/masking.pth -------------------------------------------------------------------------------- /GCN/model_gin/supervised.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmh14/data_efficient_grammar/ef34648f987278496e1216cbeb7f82c9429da4b0/GCN/model_gin/supervised.pth -------------------------------------------------------------------------------- /GCN/model_gin/supervised_contextpred.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmh14/data_efficient_grammar/ef34648f987278496e1216cbeb7f82c9429da4b0/GCN/model_gin/supervised_contextpred.pth -------------------------------------------------------------------------------- /GCN/model_gin/supervised_edgepred.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmh14/data_efficient_grammar/ef34648f987278496e1216cbeb7f82c9429da4b0/GCN/model_gin/supervised_edgepred.pth -------------------------------------------------------------------------------- /GCN/model_gin/supervised_infomax.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmh14/data_efficient_grammar/ef34648f987278496e1216cbeb7f82c9429da4b0/GCN/model_gin/supervised_infomax.pth -------------------------------------------------------------------------------- /GCN/model_gin/supervised_masking.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmh14/data_efficient_grammar/ef34648f987278496e1216cbeb7f82c9429da4b0/GCN/model_gin/supervised_masking.pth -------------------------------------------------------------------------------- /GCN/parse_result.py: -------------------------------------------------------------------------------- 1 | ### Parsing the result! 2 | import tensorflow as tf 3 | import os 4 | import numpy as np 5 | import pickle 6 | 7 | def get_test_acc(event_file): 8 | val_auc_list = np.zeros(100) 9 | test_auc_list = np.zeros(100) 10 | for e in list(tf.train.summary_iterator(event_file)): 11 | if len(e.summary.value) == 0: 12 | continue 13 | if e.summary.value[0].tag == "data/val_auc": 14 | val_auc_list[e.step-1] = e.summary.value[0].simple_value 15 | if e.summary.value[0].tag == "data/test_auc": 16 | test_auc_list[e.step-1] = e.summary.value[0].simple_value 17 | 18 | best_epoch = np.argmax(val_auc_list) 19 | 20 | return test_auc_list[best_epoch] 21 | 22 | if __name__ == "__main__": 23 | 24 | dataset_list = ["muv", "bace", "bbbp", "clintox", "hiv", "sider", "tox21", "toxcast"] 25 | #10 random seed 26 | seed_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 27 | config_list = [] 28 | 29 | config_list.append("gin_nopretrain") 30 | config_list.append("gin_infomax") 31 | config_list.append("gin_edgepred") 32 | config_list.append("gin_masking") 33 | config_list.append("gin_contextpred") 34 | config_list.append("gin_supervised") 35 | config_list.append("gin_supervised_infomax") 36 | config_list.append("gin_supervised_edgepred") 37 | config_list.append("gin_supervised_masking") 38 | config_list.append("gin_supervised_contextpred") 39 | config_list.append("gcn_nopretrain") 40 | config_list.append("gcn_supervised_contextpred") 41 | config_list.append("graphsage_nopretrain") 42 | config_list.append("graphsage_supervised_contextpred") 43 | config_list.append("gat_nopretrain") 44 | config_list.append("gat_supervised_contextpred") 45 | 46 | result_mat = np.zeros((len(seed_list), len(config_list), len(dataset_list))) 47 | 48 | for i, seed in enumerate(seed_list): 49 | for j, config in enumerate(config_list): 50 | for k, dataset in enumerate(dataset_list): 51 | dir_name = "runs/finetune_cls_runseed" + str(seed) + "/" + dataset + "/" + config 52 | print(dir_name) 53 | file_in_dir = os.listdir(dir_name) 54 | event_file_list = [] 55 | for f in file_in_dir: 56 | if "events" in f: 57 | event_file_list.append(f) 58 | 59 | event_file = event_file_list[0] 60 | 61 | result_mat[i, j, k] = get_test_acc(dir_name + "/" + event_file) 62 | 63 | with open("result_summary", "wb") as f: 64 | pickle.dump({"result_mat": result_mat, "seed_list": seed_list, "config_list": config_list, "dataset_list": dataset_list}, f) 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /GCN/pretrain_deepgraphinfomax.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from loader import MoleculeDataset 4 | from torch_geometric.data import DataLoader 5 | from torch_geometric.nn.inits import uniform 6 | from torch_geometric.nn import global_mean_pool 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | 13 | from tqdm import tqdm 14 | import numpy as np 15 | 16 | from model import GNN 17 | from sklearn.metrics import roc_auc_score 18 | 19 | from splitters import scaffold_split, random_split, random_scaffold_split 20 | import pandas as pd 21 | 22 | from tensorboardX import SummaryWriter 23 | 24 | 25 | def cycle_index(num, shift): 26 | arr = torch.arange(num) + shift 27 | arr[-shift:] = torch.arange(shift) 28 | return arr 29 | 30 | class Discriminator(nn.Module): 31 | def __init__(self, hidden_dim): 32 | super(Discriminator, self).__init__() 33 | self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim)) 34 | self.reset_parameters() 35 | 36 | def reset_parameters(self): 37 | size = self.weight.size(0) 38 | uniform(size, self.weight) 39 | 40 | def forward(self, x, summary): 41 | h = torch.matmul(summary, self.weight) 42 | return torch.sum(x*h, dim = 1) 43 | 44 | class Infomax(nn.Module): 45 | def __init__(self, gnn, discriminator): 46 | super(Infomax, self).__init__() 47 | self.gnn = gnn 48 | self.discriminator = discriminator 49 | self.loss = nn.BCEWithLogitsLoss() 50 | self.pool = global_mean_pool 51 | 52 | 53 | def train(args, model, device, loader, optimizer): 54 | model.train() 55 | 56 | train_acc_accum = 0 57 | train_loss_accum = 0 58 | 59 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 60 | batch = batch.to(device) 61 | node_emb = model.gnn(batch.x, batch.edge_index, batch.edge_attr) 62 | summary_emb = torch.sigmoid(model.pool(node_emb, batch.batch)) 63 | 64 | positive_expanded_summary_emb = summary_emb[batch.batch] 65 | 66 | shifted_summary_emb = summary_emb[cycle_index(len(summary_emb), 1)] 67 | negative_expanded_summary_emb = shifted_summary_emb[batch.batch] 68 | 69 | positive_score = model.discriminator(node_emb, positive_expanded_summary_emb) 70 | negative_score = model.discriminator(node_emb, negative_expanded_summary_emb) 71 | 72 | optimizer.zero_grad() 73 | loss = model.loss(positive_score, torch.ones_like(positive_score)) + model.loss(negative_score, torch.zeros_like(negative_score)) 74 | loss.backward() 75 | 76 | optimizer.step() 77 | 78 | train_loss_accum += float(loss.detach().cpu().item()) 79 | acc = (torch.sum(positive_score > 0) + torch.sum(negative_score < 0)).to(torch.float32)/float(2*len(positive_score)) 80 | train_acc_accum += float(acc.detach().cpu().item()) 81 | 82 | return train_acc_accum/step, train_loss_accum/step 83 | 84 | 85 | def main(): 86 | # Training settings 87 | parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks') 88 | parser.add_argument('--device', type=int, default=0, 89 | help='which gpu to use if any (default: 0)') 90 | parser.add_argument('--batch_size', type=int, default=256, 91 | help='input batch size for training (default: 256)') 92 | parser.add_argument('--epochs', type=int, default=100, 93 | help='number of epochs to train (default: 100)') 94 | parser.add_argument('--lr', type=float, default=0.001, 95 | help='learning rate (default: 0.001)') 96 | parser.add_argument('--decay', type=float, default=0, 97 | help='weight decay (default: 0)') 98 | parser.add_argument('--num_layer', type=int, default=5, 99 | help='number of GNN message passing layers (default: 5).') 100 | parser.add_argument('--emb_dim', type=int, default=300, 101 | help='embedding dimensions (default: 300)') 102 | parser.add_argument('--dropout_ratio', type=float, default=0, 103 | help='dropout ratio (default: 0)') 104 | parser.add_argument('--JK', type=str, default="last", 105 | help='how the node features across layers are combined. last, sum, max or concat') 106 | parser.add_argument('--dataset', type=str, default = 'zinc_standard_agent', help='root directory of dataset. For now, only classification.') 107 | parser.add_argument('--output_model_file', type = str, default = '', help='filename to output the pre-trained model') 108 | parser.add_argument('--gnn_type', type=str, default="gin") 109 | parser.add_argument('--seed', type=int, default=0, help = "Seed for splitting dataset.") 110 | parser.add_argument('--num_workers', type=int, default = 8, help='number of workers for dataset loading') 111 | args = parser.parse_args() 112 | 113 | 114 | torch.manual_seed(0) 115 | np.random.seed(0) 116 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 117 | if torch.cuda.is_available(): 118 | torch.cuda.manual_seed_all(0) 119 | 120 | 121 | #set up dataset 122 | dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset) 123 | 124 | print(dataset) 125 | 126 | loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 127 | 128 | #set up model 129 | gnn = GNN(args.num_layer, args.emb_dim, JK = args.JK, drop_ratio = args.dropout_ratio, gnn_type = args.gnn_type) 130 | 131 | discriminator = Discriminator(args.emb_dim) 132 | 133 | model = Infomax(gnn, discriminator) 134 | 135 | model.to(device) 136 | 137 | #set up optimizer 138 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) 139 | print(optimizer) 140 | 141 | for epoch in range(1, args.epochs+1): 142 | print("====epoch " + str(epoch)) 143 | 144 | train_acc, train_loss = train(args, model, device, loader, optimizer) 145 | 146 | print(train_acc) 147 | print(train_loss) 148 | 149 | 150 | if not args.output_model_file == "": 151 | torch.save(gnn.state_dict(), args.output_model_file + ".pth") 152 | 153 | if __name__ == "__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /GCN/pretrain_edgepred.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from loader import MoleculeDataset 4 | from dataloader import DataLoaderAE 5 | from util import NegativeEdge 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | from model import GNN, GNN_graphpred 16 | from sklearn.metrics import roc_auc_score 17 | 18 | from splitters import scaffold_split, random_split, random_scaffold_split 19 | import pandas as pd 20 | 21 | from tensorboardX import SummaryWriter 22 | 23 | criterion = nn.BCEWithLogitsLoss() 24 | 25 | def train(args, model, device, loader, optimizer): 26 | model.train() 27 | 28 | train_acc_accum = 0 29 | train_loss_accum = 0 30 | 31 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 32 | batch = batch.to(device) 33 | node_emb = model(batch.x, batch.edge_index, batch.edge_attr) 34 | 35 | positive_score = torch.sum(node_emb[batch.edge_index[0, ::2]] * node_emb[batch.edge_index[1, ::2]], dim = 1) 36 | negative_score = torch.sum(node_emb[batch.negative_edge_index[0]] * node_emb[batch.negative_edge_index[1]], dim = 1) 37 | 38 | optimizer.zero_grad() 39 | loss = criterion(positive_score, torch.ones_like(positive_score)) + criterion(negative_score, torch.zeros_like(negative_score)) 40 | loss.backward() 41 | optimizer.step() 42 | 43 | train_loss_accum += float(loss.detach().cpu().item()) 44 | acc = (torch.sum(positive_score > 0) + torch.sum(negative_score < 0)).to(torch.float32)/float(2*len(positive_score)) 45 | train_acc_accum += float(acc.detach().cpu().item()) 46 | 47 | return train_acc_accum/step, train_loss_accum/step 48 | 49 | 50 | def main(): 51 | # Training settings 52 | parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks') 53 | parser.add_argument('--device', type=int, default=0, 54 | help='which gpu to use if any (default: 0)') 55 | parser.add_argument('--batch_size', type=int, default=256, 56 | help='input batch size for training (default: 256)') 57 | parser.add_argument('--epochs', type=int, default=100, 58 | help='number of epochs to train (default: 100)') 59 | parser.add_argument('--lr', type=float, default=0.001, 60 | help='learning rate (default: 0.001)') 61 | parser.add_argument('--decay', type=float, default=0, 62 | help='weight decay (default: 0)') 63 | parser.add_argument('--num_layer', type=int, default=5, 64 | help='number of GNN message passing layers (default: 5).') 65 | parser.add_argument('--emb_dim', type=int, default=300, 66 | help='embedding dimensions (default: 300)') 67 | parser.add_argument('--dropout_ratio', type=float, default=0, 68 | help='dropout ratio (default: 0)') 69 | parser.add_argument('--JK', type=str, default="last", 70 | help='how the node features across layers are combined. last, sum, max or concat') 71 | parser.add_argument('--dataset', type=str, default = 'zinc_standard_agent', help='root directory of dataset. For now, only classification.') 72 | parser.add_argument('--output_model_file', type = str, default = '', help='filename to output the pre-trained model') 73 | parser.add_argument('--gnn_type', type=str, default="gin") 74 | parser.add_argument('--num_workers', type=int, default = 8, help='number of workers for dataset loading') 75 | args = parser.parse_args() 76 | 77 | 78 | torch.manual_seed(0) 79 | np.random.seed(0) 80 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 81 | if torch.cuda.is_available(): 82 | torch.cuda.manual_seed_all(0) 83 | 84 | #set up dataset 85 | dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset, transform = NegativeEdge()) 86 | 87 | print(dataset[0]) 88 | 89 | loader = DataLoaderAE(dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 90 | 91 | #set up model 92 | model = GNN(args.num_layer, args.emb_dim, JK = args.JK, drop_ratio = args.dropout_ratio, gnn_type = args.gnn_type) 93 | 94 | model.to(device) 95 | 96 | #set up optimizer 97 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) 98 | print(optimizer) 99 | 100 | for epoch in range(1, args.epochs+1): 101 | print("====epoch " + str(epoch)) 102 | 103 | train_acc, train_loss = train(args, model, device, loader, optimizer) 104 | 105 | print(train_acc) 106 | print(train_loss) 107 | 108 | if not args.output_model_file == "": 109 | torch.save(model.state_dict(), args.output_model_file + ".pth") 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /GCN/pretrain_supervised.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from loader import MoleculeDataset 4 | from torch_geometric.data import DataLoader 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | 14 | from model import GNN, GNN_graphpred 15 | from sklearn.metrics import roc_auc_score 16 | 17 | from splitters import scaffold_split, random_split, random_scaffold_split 18 | import pandas as pd 19 | 20 | from tensorboardX import SummaryWriter 21 | 22 | criterion = nn.BCEWithLogitsLoss(reduction = "none") 23 | 24 | def train(args, model, device, loader, optimizer): 25 | model.train() 26 | 27 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 28 | batch = batch.to(device) 29 | pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) 30 | y = batch.y.view(pred.shape).to(torch.float64) 31 | 32 | #Whether y is non-null or not. 33 | is_valid = y**2 > 0 34 | #Loss matrix 35 | loss_mat = criterion(pred.double(), (y+1)/2) 36 | #loss matrix after removing null target 37 | loss_mat = torch.where(is_valid, loss_mat, torch.zeros(loss_mat.shape).to(loss_mat.device).to(loss_mat.dtype)) 38 | 39 | optimizer.zero_grad() 40 | loss = torch.sum(loss_mat)/torch.sum(is_valid) 41 | loss.backward() 42 | 43 | optimizer.step() 44 | 45 | 46 | def eval(args, model, device, loader, normalized_weight): 47 | model.eval() 48 | y_true = [] 49 | y_scores = [] 50 | 51 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 52 | batch = batch.to(device) 53 | 54 | with torch.no_grad(): 55 | pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) 56 | 57 | y_true.append(batch.y.view(pred.shape).cpu()) 58 | y_scores.append(pred.cpu()) 59 | 60 | y_true = torch.cat(y_true, dim = 0).numpy() 61 | y_scores = torch.cat(y_scores, dim = 0).numpy() 62 | 63 | roc_list = [] 64 | weight = [] 65 | for i in range(y_true.shape[1]): 66 | #AUC is only defined when there is at least one positive data. 67 | if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == -1) > 0: 68 | is_valid = y_true[:,i]**2 > 0 69 | roc_list.append(roc_auc_score((y_true[is_valid,i] + 1)/2, y_scores[is_valid,i])) 70 | weight.append(normalized_weight[i]) 71 | 72 | if len(roc_list) < y_true.shape[1]: 73 | print("Some target is missing!") 74 | print("Missing ratio: %f" %(1 - float(len(roc_list))/y_true.shape[1])) 75 | 76 | weight = np.array(weight) 77 | roc_list = np.array(roc_list) 78 | 79 | return weight.dot(roc_list) 80 | 81 | 82 | def main(): 83 | # Training settings 84 | parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks') 85 | parser.add_argument('--device', type=int, default=0, 86 | help='which gpu to use if any (default: 0)') 87 | parser.add_argument('--batch_size', type=int, default=32, 88 | help='input batch size for training (default: 32)') 89 | parser.add_argument('--epochs', type=int, default=100, 90 | help='number of epochs to train (default: 100)') 91 | parser.add_argument('--lr', type=float, default=0.001, 92 | help='learning rate (default: 0.001)') 93 | parser.add_argument('--decay', type=float, default=0, 94 | help='weight decay (default: 0)') 95 | parser.add_argument('--num_layer', type=int, default=5, 96 | help='number of GNN message passing layers (default: 5).') 97 | parser.add_argument('--emb_dim', type=int, default=300, 98 | help='embedding dimensions (default: 300)') 99 | parser.add_argument('--dropout_ratio', type=float, default=0.2, 100 | help='dropout ratio (default: 0.2)') 101 | parser.add_argument('--graph_pooling', type=str, default="mean", 102 | help='graph level pooling (sum, mean, max, set2set, attention)') 103 | parser.add_argument('--JK', type=str, default="last", 104 | help='how the node features across layers are combined. last, sum, max or concat') 105 | parser.add_argument('--dataset', type=str, default = 'chembl_filtered', help='root directory of dataset. For now, only classification.') 106 | parser.add_argument('--gnn_type', type=str, default="gin") 107 | parser.add_argument('--input_model_file', type=str, default = '', help='filename to read the model (if there is any)') 108 | parser.add_argument('--output_model_file', type = str, default = '', help='filename to output the pre-trained model') 109 | parser.add_argument('--num_workers', type=int, default = 8, help='number of workers for dataset loading') 110 | args = parser.parse_args() 111 | 112 | 113 | torch.manual_seed(0) 114 | np.random.seed(0) 115 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 116 | if torch.cuda.is_available(): 117 | torch.cuda.manual_seed_all(0) 118 | 119 | #Bunch of classification tasks 120 | if args.dataset == "chembl_filtered": 121 | num_tasks = 1310 122 | else: 123 | raise ValueError("Invalid dataset name.") 124 | 125 | #set up dataset 126 | dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset) 127 | 128 | loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 129 | 130 | #set up model 131 | model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type) 132 | if not args.input_model_file == "": 133 | model.from_pretrained(args.input_model_file + ".pth") 134 | 135 | model.to(device) 136 | 137 | #set up optimizer 138 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) 139 | print(optimizer) 140 | 141 | 142 | for epoch in range(1, args.epochs+1): 143 | print("====epoch " + str(epoch)) 144 | 145 | train(args, model, device, loader, optimizer) 146 | 147 | if not args.output_model_file == "": 148 | torch.save(model.gnn.state_dict(), args.output_model_file + ".pth") 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Minghao Guo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data-Efficient Graph Grammar Learning for Molecular Generation 2 | This repository contains the implementation code for paper [Data-Efficient Graph Grammar Learning for Molecular Generation 3 | ](https://openreview.net/forum?id=l4IHywGq6a) (__ICLR 2022 oral__). 4 | 5 | In this work, we propose a data-efficient generative model (__DEG__) that can be learned from datasets with orders of 6 | magnitude smaller sizes than common benchmarks. At the heart of this method is a learnable graph grammar that generates molecules from a sequence of production rules. Our learned graph grammar yields state-of-the-art results on generating high-quality molecules for 7 | three monomer datasets that contain only ∼20 samples each. 8 | 9 | ![overview](assets/pipeline.png) 10 | 11 | ## Installation 12 | 13 | ### Prerequisites 14 | - __Retro*:__ The training of our DEG relies on [Retro*](https://github.com/binghong-ml/retro_star) to calculate the metric. Follow the instruction [here](#conda) to install. 15 | 16 | - __Pretrained GNN:__ We use [this codebase](https://github.com/snap-stanford/pretrain-gnns) for the pretrained GNN used in our paper. The necessary code & pretrained models are built in the current repo. 17 | 18 | 19 | ### Conda 20 | You can use ``conda`` to install the dependencies for DEG from the provided ``environment.yml`` file, which can give you the exact python environment we run the code for the paper: 21 | ```bash 22 | git clone git@github.com:gmh14/data_efficient_grammar.git 23 | cd data_efficient_grammar 24 | conda env create -f environment.yml 25 | conda activate DEG 26 | pip install -e retro_star/packages/mlp_retrosyn 27 | pip install -e retro_star/packages/rdchiral 28 | ``` 29 | >Note: it may take a decent amount of time to build necessary wheels using conda. 30 | 31 | ### Install ``Retro*``: 32 | - Download and unzip the files from this [link](https://www.dropbox.com/s/ar9cupb18hv96gj/retro_data.zip?dl=0), 33 | and put all the folders (```dataset/```, ```one_step_model/``` and ```saved_models/```) under the ```retro_star``` directory. 34 | 35 | - Install dependencies: 36 | ```bash 37 | conda deactivate 38 | conda env create -f retro_star/environment.yml 39 | conda activate retro_star_env 40 | pip install -e retro_star/packages/mlp_retrosyn 41 | pip install -e retro_star/packages/rdchiral 42 | pip install setproctitle 43 | ``` 44 | 45 | 46 | ## Train 47 | 48 | For Acrylates, Chain Extenders, and Isocyanates, 49 | ```bash 50 | conda activate DEG 51 | python main.py --training_data=./datasets/**dataset_path** 52 | ``` 53 | where ``**dataset_path**`` can be ``acrylates.txt``, ``chain_extenders.txt``, or ``isocyanates.txt``. 54 | 55 | For Polymer dataset, 56 | ```bash 57 | conda activate DEG 58 | python main.py --training_data=./datasets/polymers_117.txt --motif 59 | ``` 60 | 61 | Since ``Retro*`` is a major bottleneck of the training speed, we separate it from the main process, run multiple ``Retro*`` processes, and use file communication to evaluate the generated grammar during training. This is a compromise on the inefficiency of the built-in python multiprocessing package. We need to run the following command in another terminal window, 62 | ```bash 63 | conda activate retro_star_env 64 | bash retro_star_listener.sh **num_processes** 65 | ``` 66 | >Note: opening multiple ``Retro*`` is EXTREMELY memory consuming (~5G each). We suggest to start from using only one process by ``bash retro_star_listener.sh 1`` and monitor the memory usage, then accordingly increase the number to maximize the efficiency. We use ``35`` in the paper. 67 | 68 | After finishing the training, to kill all the generated processes related to ``Retro*``, run 69 | ```bash 70 | killall retro_star_listener 71 | ``` 72 | 73 | 74 | ## Use DEG 75 | Download and unzip the log & checkpoint files from this [link](https://drive.google.com/file/d/12g28WNAgRGzaLtuG6ESg25W-uzlNrpLQ/view?usp=sharing). See ``visualization.ipynb`` for more details. 76 | 77 | 78 | ## Acknowledgements 79 | The implementation of DEG is partly based on [Molecular Optimization Using Molecular Hypergraph Grammar](https://github.com/ibm-research-tokyo/graph_grammar) and [Hierarchical Generation of Molecular Graphs using Structural Motifs 80 | ](https://github.com/wengong-jin/hgraph2graph). 81 | 82 | 83 | ## Citation 84 | If you find the idea or code useful for your research, please cite [our paper](https://openreview.net/forum?id=l4IHywGq6a): 85 | ```bib 86 | @inproceedings{guo2021data, 87 | title={Data-Efficient Graph Grammar Learning for Molecular Generation}, 88 | author={Guo, Minghao and Thost, Veronika and Li, Beichen and Das, Payel and Chen, Jie and Matusik, Wojciech}, 89 | booktitle={International Conference on Learning Representations}, 90 | year={2021} 91 | } 92 | ``` 93 | 94 | 95 | ## Contact 96 | Please contact guomh2014@gmail.com if you have any questions. Enjoy! -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch.distributions import Categorical 4 | import numpy as np 5 | 6 | 7 | class Agent(nn.Module): 8 | def __init__(self, feat_dim, hidden_size): 9 | super(Agent, self).__init__() 10 | self.affine1 = nn.Linear(feat_dim + 2, hidden_size) 11 | self.dropout = nn.Dropout(p=0.5) 12 | self.affine2 = nn.Linear(hidden_size, 2) 13 | self.saved_log_probs = {} 14 | 15 | def forward(self, x): 16 | x = self.affine1(x) 17 | x = F.relu(x) 18 | scores = self.affine2(x) 19 | return F.softmax(scores, dim=1) 20 | 21 | 22 | def sample(agent, subgraph_feature, iter_num, sample_number): 23 | # subgraph_feature: N * (2+feat_dim), N is the number of subgraphs inside all inputs 24 | prob = agent(subgraph_feature) 25 | m = Categorical(prob) 26 | a = m.sample() 27 | take_action = (np.sum(a.numpy()) != 0) 28 | if take_action: 29 | if sample_number not in agent.saved_log_probs.keys(): 30 | agent.saved_log_probs[sample_number] = {} 31 | if iter_num not in agent.saved_log_probs[sample_number].keys(): 32 | agent.saved_log_probs[sample_number][iter_num] = [m.log_prob(a)] 33 | else: 34 | agent.saved_log_probs[sample_number][iter_num].append(m.log_prob(a)) 35 | return a.numpy(), take_action 36 | 37 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmh14/data_efficient_grammar/ef34648f987278496e1216cbeb7f82c9429da4b0/assets/pipeline.png -------------------------------------------------------------------------------- /colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Instructions\n", 9 | "\n", 10 | "**README** \n", 11 | "\n", 12 | "* Download the code from this [link](https://github.com/gmh14/data_efficient_grammar).\n", 13 | "* Download and unzip the log & checkpoint files from this [link](https://drive.google.com/file/d/12g28WNAgRGzaLtuG6ESg25W-uzlNrpLQ/view). \n" 14 | ] 15 | }, 16 | { 17 | "attachments": {}, 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 1. Setup" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "!cd data_efficient_grammar\n", 31 | "!conda create -n DEG_test python=3.6\n", 32 | "!conda activate DEG_test \n", 33 | "!conda install scipy pandas numpy scikit-learn\n", 34 | "!conda install pytorch torchvision torchaudio cpuonly -c pytorch\n", 35 | "!conda install -c rdkit rdkit\n", 36 | "\n", 37 | "!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cpu.html\n", 38 | "!pip install torch-geometric\n", 39 | "!pip install torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-1.10.0+cpu.html\n", 40 | "\n", 41 | "!pip install setproctitle\n", 42 | "!pip install graphviz" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "!conda install scipy pandas numpy scikit-learn\n", 52 | "!conda install pytorch torchvision torchaudio cpuonly -c pytorch\n", 53 | "!conda install -c rdkit rdkit\n", 54 | "!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cpu.html\n", 55 | "!pip install torch-geometric\n", 56 | "!pip install torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-1.10.0+cpu.html\n", 57 | "!pip install setproctitle\n", 58 | "!pip install graphviz\n", 59 | "!pip install pickle5" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "!pip install -e retro_star/packages/mlp_retrosyn\n", 69 | "!pip install -e retro_star/packages/rdchiral" 70 | ] 71 | }, 72 | { 73 | "attachments": {}, 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "## 2. Play with the trained model" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "from private.hypergraph import Hypergraph, hg_to_mol\n", 87 | "from grammar_generation import random_produce\n", 88 | "\n", 89 | "\n", 90 | "from rdkit import Chem\n", 91 | "from rdkit.Chem import Draw\n", 92 | "import numpy as np\n", 93 | "from copy import deepcopy\n", 94 | "import pickle5 as pickle\n", 95 | "import torch\n", 96 | "from os import listdir" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "expr_name_dict = dict()\n", 106 | "expr_name_dict['polymer_117motif'] = 'grammar-log/log_117motifs'\n", 107 | "expr_name_dict['iso'] = 'grammar-log/log_iso'\n", 108 | "expr_name_dict['acrylates'] = 'grammar-log/log_acy'\n", 109 | "expr_name_dict['chain_extender'] = 'grammar-log/log_ce'\n", 110 | "\n", 111 | "expr_names = list(expr_name_dict.keys())\n", 112 | "generated_mols = dict()\n", 113 | "for expr_name in expr_names:\n", 114 | " print('dealing with {}'.format(expr_name))\n", 115 | " ckpt_list = listdir(expr_name_dict[expr_name])\n", 116 | " max_R = 0\n", 117 | " max_R_ckpt = None\n", 118 | " for ckpt in ckpt_list:\n", 119 | " if 'grammar' in ckpt:\n", 120 | " curr_R = float(ckpt.split('_')[4][:-4])\n", 121 | " if curr_R > max_R:\n", 122 | " max_R = curr_R\n", 123 | " max_R_ckpt = ckpt\n", 124 | " print('loading {}'.format(max_R_ckpt))\n", 125 | " with open('{}/{}'.format(expr_name_dict[expr_name], max_R_ckpt), 'rb') as fr:\n", 126 | " grammar = pickle.load(fr)\n", 127 | " for i in range(8):\n", 128 | " mol, _ = random_produce(grammar)\n", 129 | " if expr_name not in generated_mols.keys():\n", 130 | " generated_mols[expr_name] = [mol]\n", 131 | " else:\n", 132 | " generated_mols[expr_name].append(mol)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "exp = 'polymer_117motif' # 'iso', 'acrylates', 'chain_extender'\n", 142 | "Chem.Draw.MolsToGridImage(generated_mols[exp], molsPerRow=4, subImgSize=(200,200))" 143 | ] 144 | }, 145 | { 146 | "attachments": {}, 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "## 3. Train your own model (w/o optimization)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "!python main.py --training_data=./datasets/**dataset_path**" 160 | ] 161 | }, 162 | { 163 | "attachments": {}, 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "Check your model in \"log-num_generated_samples100-_timestamp_\"" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [] 174 | } 175 | ], 176 | "metadata": { 177 | "kernelspec": { 178 | "display_name": "Python 3", 179 | "language": "python", 180 | "name": "python3" 181 | }, 182 | "language_info": { 183 | "name": "python", 184 | "version": "3.11.1 (main, Dec 23 2022, 09:40:27) [Clang 14.0.0 (clang-1400.0.29.202)]" 185 | }, 186 | "orig_nbformat": 4, 187 | "vscode": { 188 | "interpreter": { 189 | "hash": "1a1af0ee75eeea9e2e1ee996c87e7a2b11a0bebd85af04bb136d915cefc0abce" 190 | } 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 2 195 | } 196 | -------------------------------------------------------------------------------- /datasets/acrylates.txt: -------------------------------------------------------------------------------- 1 | C=CC(=O)OCC1=CC=CC=C1 2 | C=CC(=O)OC1=CC=CC=C1 3 | CC(=C)C(=O)OC1=CC=CC=C1 4 | C=CC(=O)OCCC1=CC=CC=C1 5 | CCCCCCCCOC(=O)C(=C)C 6 | CCC(C)OC(=O)C=C 7 | CC(=C)C(=O)OCC1=CC=CC=C1 8 | C=CC(=O)OC1=C(C(=C(C(=C1F)F)F)F)F 9 | CC(C)COC(=O)C(=C)C 10 | CCCCCCCCCCCCOC(=O)C(=C)C 11 | CCC(C)OC(=O)C(=C)C 12 | CCCOC(=O)C(=C)C 13 | CC1CC(CC(C1)(C)C)OC(=O)C(=C)C 14 | CC(C)CCCCCCCOC(=O)C=C 15 | CCCOC(=O)C=C 16 | COCCOC(=O)C=C 17 | CC(=C)C(=O)OCCOC1=CC=CC=C1 18 | CCCCCCOC(=O)C=C 19 | CCCCOCCOC(=O)C(=C)C 20 | CC(=C)C(=O)OC 21 | COC(=O)C=C 22 | CCCCOC(=O)C=C 23 | CCOCCOC(=O)C(=C)C 24 | CC(=C)C(=O)OC1CC2CCC1(C2(C)C)C 25 | CCCCC(CC)COC(=O)C(=C)C 26 | CC(C)(COCCCOC(=O)C=C)COCCCOC(=O)C=C 27 | C=CC(=O)OCCCCCCOC(=O)C=C 28 | C=CC(=O)OCC(CO)(COC(=O)C=C)COC(=O)C=C 29 | CCC(COCCCOC(=O)C=C)(COCCCOC(=O)C=C)COCCCOC(=O)C=C 30 | CCC(COCC(CC)(COC(=O)C=C)COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C 31 | C=CC(=O)OCC(CO)(COCC(COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C)COC(=O)C=C 32 | C=CC(=O)OCC(COCC(COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C -------------------------------------------------------------------------------- /datasets/chain_extenders.txt: -------------------------------------------------------------------------------- 1 | OCCO 2 | OC(C)CCO 3 | OCCCCO 4 | OCCNC(=O)NCCCCCCNC(=O)NCCO 5 | OCCN1C(=O)NC(C1(=O))CCCCNC(=O)NCCO 6 | Oc1ccc(cc1)CCC(=O)OCCOC(=O)CCc1ccc(cc1)O 7 | OC(=O)C(N)CCCCN 8 | OC(=O)C(N)CCCN 9 | N1CCNCC1 10 | Nc1ccc(cc1)SSc2ccc(cc2)N 11 | Nc1ccc(cc1)Cc2ccc(cc2)N -------------------------------------------------------------------------------- /datasets/isocyanates.txt: -------------------------------------------------------------------------------- 1 | O=C=NC1=CC=CC(CC2=CC=C(C=C2N=C=O)CC3=CC=C(C=C3)N=C=O)=C1 2 | O=C=NC1=CC(CC2=C(C=C(C=C2)CC3=CC=C(C=C3N=C=O)CC4=CC=C(C=C4)N=C=O)N=C=O)=CC=C1 3 | O=C=NC1=CC=C(C=C1)CC2=CC=C(C=C2N=C=O)CC3=C(C=C(C=C3)CC4=CC=C(C=C4N=C=O)CC5=CC=C(C=C5)N=C=O)N=C=O 4 | O=C=NCCCCCCN=C=O 5 | O=C=NCCCCCCCCCCCCN=C=O 6 | O=C=NCCCCCCCCCCCCCCCCCCN=C=O 7 | O=C=NCCCCCCCCCCCCCCCCCCCCCCCCN=C=O 8 | CC1(CC(CC(CN=C=O)(C1)C)N=C=O)C 9 | CC1=C(C=C(C=C1)CN=C=O)N=C=O 10 | O=C=NC1CCC(CC2CCC(CC2)N=C=O)CC1 11 | CCOC(C(N=C=O)CCCCN=C=O)=O -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: DEG 2 | channels: 3 | - rdkit 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - bzip2=1.0.8=h7b6447c_0 10 | - ca-certificates=2021.10.26=h06a4308_2 11 | - cairo=1.16.0=hf32fb01_1 12 | - certifi=2021.5.30=py36h06a4308_0 13 | - dataclasses=0.8=pyh4f3eec9_6 14 | - fontconfig=2.13.1=h6c09931_0 15 | - freetype=2.11.0=h70c0345_0 16 | - glib=2.69.1=h5202010_0 17 | - gmp=6.2.1=h2531618_2 18 | - gnutls=3.6.15=he1e5248_0 19 | - icu=58.2=he6710b0_3 20 | - intel-openmp=2020.2=254 21 | - joblib=1.0.1=pyhd3eb1b0_0 22 | - jpeg=9d=h7f8727e_0 23 | - lame=3.100=h7b6447c_0 24 | - lcms2=2.12=h3be6417_0 25 | - ld_impl_linux-64=2.35.1=h7274673_9 26 | - libboost=1.65.1=habcd387_4 27 | - libffi=3.3=he6710b0_2 28 | - libgcc-ng=9.1.0=hdf63c60_0 29 | - libgfortran-ng=7.3.0=hdf63c60_0 30 | - libiconv=1.15=h63c8f33_5 31 | - libidn2=2.3.2=h7f8727e_0 32 | - libpng=1.6.37=hbc83047_0 33 | - libstdcxx-ng=9.1.0=hdf63c60_0 34 | - libtasn1=4.16.0=h27cfd23_0 35 | - libtiff=4.2.0=h85742a9_0 36 | - libunistring=0.9.10=h27cfd23_0 37 | - libuuid=1.0.3=h7f8727e_2 38 | - libuv=1.40.0=h7b6447c_0 39 | - libwebp-base=1.2.0=h27cfd23_0 40 | - libxcb=1.14=h7b6447c_0 41 | - libxml2=2.9.10=hb55368b_3 42 | - lz4-c=1.9.3=h295c915_1 43 | - mkl=2020.2=256 44 | - mkl-service=2.3.0=py36he8ac12f_0 45 | - mkl_fft=1.2.0=py36h23d657b_0 46 | - mkl_random=1.1.1=py36h0573a6f_0 47 | - ncurses=6.3=h7f8727e_2 48 | - nettle=3.7.3=hbbd107a_1 49 | - numpy=1.16.2=py36h7e9f1db_0 50 | - numpy-base=1.16.2=py36hde5b4d6_0 51 | - olefile=0.46=pyhd3eb1b0_0 52 | - openh264=2.1.0=hd408876_0 53 | - openjpeg=2.4.0=h3ad879b_0 54 | - openssl=1.1.1l=h7f8727e_0 55 | - pandas=0.23.4=py36h04863e7_0 56 | - pcre=8.45=h295c915_0 57 | - pillow=8.3.1=py36h2c7a002_0 58 | - pip=21.2.2=py36h06a4308_0 59 | - pixman=0.40.0=h7f8727e_1 60 | - py-boost=1.65.1=py36hf484d3e_4 61 | - python=3.6.13=h12debd9_1 62 | - python-dateutil=2.8.2=pyhd3eb1b0_0 63 | - pytz=2021.3=pyhd3eb1b0_0 64 | - readline=8.1=h27cfd23_0 65 | - scikit-learn=0.23.2=py36h0573a6f_0 66 | - scipy=1.2.1=py36h7c811a0_0 67 | - setuptools=58.0.4=py36h06a4308_0 68 | - six=1.16.0=pyhd3eb1b0_0 69 | - sqlite=3.36.0=hc218d9a_0 70 | - threadpoolctl=2.2.0=pyh0d69192_0 71 | - tk=8.6.11=h1ccaba5_0 72 | - typing_extensions=3.10.0.2=pyh06a4308_0 73 | - wheel=0.37.0=pyhd3eb1b0_1 74 | - xz=5.2.5=h7b6447c_0 75 | - zlib=1.2.11=h7b6447c_3 76 | - zstd=1.4.9=haebb681_0 77 | - cpuonly=2.0=0 78 | - ffmpeg=4.3=hf484d3e_0 79 | - pytorch=1.10.0=py3.6_cpu_0 80 | - pytorch-mutex=1.0=cpu 81 | - torchaudio=0.10.0=py36_cpu 82 | - torchvision=0.11.1=py36_cpu 83 | - rdkit=2018.03.4.0=py36h71b666b_1 84 | - pip: 85 | - alabaster==0.7.12 86 | - argon2-cffi==21.1.0 87 | - async-generator==1.10 88 | - attrs==21.2.0 89 | - babel==2.9.1 90 | - backcall==0.2.0 91 | - bleach==4.1.0 92 | - cffi==1.15.0 93 | - charset-normalizer==2.0.9 94 | - chemprop==1.3.1 95 | - click==8.0.3 96 | - cloudpickle==2.0.0 97 | - cycler==0.11.0 98 | - decorator==4.4.2 99 | - defusedxml==0.7.1 100 | - docutils==0.17.1 101 | - entrypoints==0.3 102 | - flask==2.0.2 103 | - future==0.18.2 104 | - googledrivedownloader==0.4 105 | - graphviz==0.19 106 | - hyperopt==0.2.7 107 | - idna==3.3 108 | - imagesize==1.3.0 109 | - importlib-metadata==4.8.2 110 | - ipykernel==5.5.6 111 | - ipython==7.16.2 112 | - ipython-genutils==0.2.0 113 | - ipywidgets==7.6.5 114 | - isodate==0.6.0 115 | - itsdangerous==2.0.1 116 | - jedi==0.17.2 117 | - jinja2==3.0.3 118 | - jsonschema==3.2.0 119 | - jupyter==1.0.0 120 | - jupyter-client==7.1.0 121 | - jupyter-console==6.4.0 122 | - jupyter-core==4.9.1 123 | - jupyterlab-pygments==0.1.2 124 | - jupyterlab-widgets==1.0.2 125 | - kiwisolver==1.3.1 126 | - markupsafe==2.0.1 127 | - matplotlib==3.3.4 128 | - mistune==0.8.4 129 | - mypy-extensions==0.4.3 130 | - nbclient==0.5.9 131 | - nbconvert==6.0.7 132 | - nbformat==5.1.3 133 | - nest-asyncio==1.5.4 134 | - networkx==2.5.1 135 | - notebook==6.4.6 136 | - packaging==21.3 137 | - pandas-flavor==0.2.0 138 | - pandocfilters==1.5.0 139 | - parso==0.7.1 140 | - pbr==5.8.0 141 | - pexpect==4.8.0 142 | - pickle5==0.0.12 143 | - pickleshare==0.7.5 144 | - prometheus-client==0.12.0 145 | - prompt-toolkit==3.0.23 146 | - protobuf==3.19.1 147 | - ptyprocess==0.7.0 148 | - py4j==0.10.9.3 149 | - pycparser==2.21 150 | - pygments==2.10.0 151 | - pyparsing==3.0.6 152 | - pyrsistent==0.18.0 153 | - pysmiles==1.0.1 154 | - pyyaml==6.0 155 | - pyzmq==22.3.0 156 | - qtconsole==5.2.1 157 | - qtpy==1.11.3 158 | - rdflib==5.0.0 159 | - requests==2.26.0 160 | - send2trash==1.8.0 161 | - setproctitle==1.2.2 162 | - snowballstemmer==2.2.0 163 | - sphinx==4.3.1 164 | - sphinxcontrib-applehelp==1.0.2 165 | - sphinxcontrib-devhelp==1.0.2 166 | - sphinxcontrib-htmlhelp==2.0.0 167 | - sphinxcontrib-jsmath==1.0.1 168 | - sphinxcontrib-qthelp==1.0.3 169 | - sphinxcontrib-serializinghtml==1.1.5 170 | - tensorboardx==2.4.1 171 | - terminado==0.12.1 172 | - testpath==0.5.0 173 | - torch==1.10.0 174 | - torch-geometric==2.0.2 175 | - torch-scatter==2.0.9 176 | - torch-sparse==0.6.12 177 | - tornado==6.1 178 | - tqdm==4.62.3 179 | - traitlets==4.3.3 180 | - typed-argument-parser==1.7.1 181 | - typing-inspect==0.7.1 182 | - urllib3==1.26.7 183 | - wcwidth==0.2.5 184 | - webencodings==0.5.1 185 | - werkzeug==2.0.2 186 | - widgetsnbextension==3.5.2 187 | - xarray==0.16.2 188 | - yacs==0.1.8 189 | - zipp==3.6.0 190 | prefix: $CONDA_PREFIX 191 | 192 | -------------------------------------------------------------------------------- /fuseprop/__init__.py: -------------------------------------------------------------------------------- 1 | from fuseprop.mol_graph import MolGraph 2 | from fuseprop.vocab import common_atom_vocab 3 | from fuseprop.gnn import AtomVGNN 4 | from fuseprop.dataset import * 5 | from fuseprop.chemutils import find_clusters, random_subgraph, extract_subgraph, enum_subgraph, dual_random_subgraph, unique_rationales, merge_rationales, get_mol, get_smiles, find_fragments 6 | -------------------------------------------------------------------------------- /fuseprop/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, random, gc 3 | import pickle 4 | 5 | from rdkit import Chem 6 | from torch.utils.data import Dataset 7 | from fuseprop.chemutils import random_subgraph, extract_subgraph, enum_root 8 | from fuseprop.mol_graph import MolGraph 9 | 10 | class MoleculeDataset(Dataset): 11 | 12 | def __init__(self, data, avocab, batch_size): 13 | self.batches = [data[i : i + batch_size] for i in range(0, len(data), batch_size)] 14 | self.avocab = avocab 15 | 16 | def __len__(self): 17 | return len(self.batches) 18 | 19 | def __getitem__(self, idx): 20 | init_smiles, final_smiles = zip(*self.batches[idx]) 21 | init_batch = [Chem.MolFromSmiles(x) for x in init_smiles] 22 | mol_batch = [Chem.MolFromSmiles(x) for x in final_smiles] 23 | init_atoms = [mol.GetSubstructMatch(x) for mol,x in zip(mol_batch, init_batch)] 24 | mol_batch = [MolGraph(x, atoms) for x, atoms in zip(final_smiles, init_atoms)] 25 | mol_batch = [x for x in mol_batch if len(x.root_atoms) > 0] 26 | if len(mol_batch) < len(self.batches[idx]): 27 | num = len(self.batches[idx]) - len(mol_batch) 28 | print("MoleculeDataset: %d graph removed" % (num,)) 29 | return MolGraph.tensorize(mol_batch, self.avocab) if len(mol_batch) > 0 else None 30 | 31 | 32 | class ReconstructDataset(Dataset): 33 | 34 | def __init__(self, data, avocab, batch_size): 35 | self.batches = [data[i : i + batch_size] for i in range(0, len(data), batch_size)] 36 | self.avocab = avocab 37 | 38 | def __len__(self): 39 | return len(self.batches) 40 | 41 | def __getitem__(self, idx): 42 | subgraphs = [] 43 | init_smiles = [] 44 | for smiles in self.batches[idx]: 45 | mol = Chem.MolFromSmiles(smiles) 46 | selected_atoms = random_subgraph(mol, ratio=0.5) 47 | sub_smiles, root_atoms = extract_subgraph(smiles, selected_atoms) 48 | subgraph = MolGraph(smiles, selected_atoms, root_atoms, shuffle_roots=False) 49 | subgraphs.append(subgraph) 50 | init_smiles.append(sub_smiles) 51 | return MolGraph.tensorize(subgraphs), self.batches[idx], init_smiles 52 | 53 | 54 | class SubgraphDataset(Dataset): 55 | 56 | def __init__(self, data, avocab, batch_size, num_decode): 57 | data = [x for smiles in data for x in enum_root(smiles, num_decode)] 58 | self.batches = [data[i : i + batch_size] for i in range(0, len(data), batch_size)] 59 | self.avocab = avocab 60 | 61 | def __len__(self): 62 | return len(self.batches) 63 | 64 | def __getitem__(self, idx): 65 | return self.batches[idx] 66 | 67 | 68 | class DataFolder(object): 69 | 70 | def __init__(self, data_folder, batch_size, shuffle=True): 71 | self.data_folder = data_folder 72 | self.data_files = [fn for fn in os.listdir(data_folder)] 73 | self.batch_size = batch_size 74 | self.shuffle = shuffle 75 | 76 | def __iter__(self): 77 | for fn in self.data_files: 78 | fn = os.path.join(self.data_folder, fn) 79 | with open(fn, 'rb') as f: 80 | batches = pickle.load(f) 81 | 82 | if self.shuffle: random.shuffle(batches) #shuffle data before batch 83 | for batch in batches: 84 | yield batch 85 | 86 | del batches 87 | gc.collect() 88 | 89 | -------------------------------------------------------------------------------- /fuseprop/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import rdkit.Chem as Chem 4 | import torch.nn.functional as F 5 | from fuseprop.nnutils import * 6 | from fuseprop.mol_graph import MolGraph 7 | from fuseprop.rnn import GRU, LSTM 8 | 9 | class MPNEncoder(nn.Module): 10 | 11 | def __init__(self, rnn_type, input_size, node_fdim, hidden_size, depth): 12 | super(MPNEncoder, self).__init__() 13 | self.hidden_size = hidden_size 14 | self.input_size = input_size 15 | self.depth = depth 16 | self.W_o = nn.Sequential( 17 | nn.Linear(node_fdim + hidden_size, hidden_size), 18 | nn.ReLU() 19 | ) 20 | 21 | if rnn_type == 'GRU': 22 | self.rnn = GRU(input_size, hidden_size, depth) 23 | elif rnn_type == 'LSTM': 24 | self.rnn = LSTM(input_size, hidden_size, depth) 25 | else: 26 | raise ValueError('unsupported rnn cell type ' + rnn_type) 27 | 28 | def forward(self, fnode, fmess, agraph, bgraph, mask): 29 | h = self.rnn(fmess, bgraph) 30 | h = self.rnn.get_hidden_state(h) 31 | nei_message = index_select_ND(h, 0, agraph) 32 | nei_message = nei_message.sum(dim=1) 33 | node_hiddens = torch.cat([fnode, nei_message], dim=1) 34 | node_hiddens = self.W_o(node_hiddens) 35 | 36 | if mask is None: 37 | mask = torch.ones(node_hiddens.size(0), 1, device=fnode.device) 38 | mask[0, 0] = 0 #first node is padding 39 | 40 | return node_hiddens * mask, h 41 | 42 | 43 | class GraphEncoder(nn.Module): 44 | def __init__(self, avocab, rnn_type, embed_size, hidden_size, depth): 45 | super(GraphEncoder, self).__init__() 46 | self.avocab = avocab 47 | self.hidden_size = hidden_size 48 | self.atom_size = atom_size = avocab.size() + MolGraph.MAX_POS 49 | self.bond_size = bond_size = len(MolGraph.BOND_LIST) 50 | 51 | self.E_a = torch.eye( avocab.size() ).cuda() 52 | self.E_b = torch.eye( len(MolGraph.BOND_LIST) ).cuda() 53 | self.E_pos = torch.eye( MolGraph.MAX_POS ).cuda() 54 | 55 | self.encoder = MPNEncoder(rnn_type, atom_size + bond_size, atom_size, hidden_size, depth) 56 | 57 | def embed_graph(self, graph_tensors): 58 | fnode, fmess, agraph, bgraph, _ = graph_tensors 59 | fnode1 = self.E_a.index_select(index=fnode[:, 0], dim=0) 60 | fnode2 = self.E_pos.index_select(index=fnode[:, 1], dim=0) 61 | hnode = torch.cat([fnode1, fnode2], dim=-1) 62 | 63 | fmess1 = hnode.index_select(index=fmess[:, 0], dim=0) 64 | fmess2 = self.E_b.index_select(index=fmess[:, 2], dim=0) 65 | hmess = torch.cat([fmess1, fmess2], dim=-1) 66 | return hnode, hmess, agraph, bgraph 67 | 68 | def forward(self, graph_tensors): 69 | tensors = self.embed_graph(graph_tensors) 70 | hatom,_ = self.encoder(*tensors, mask=None) 71 | return hatom 72 | 73 | -------------------------------------------------------------------------------- /fuseprop/gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import rdkit.Chem as Chem 4 | import torch.nn.functional as F 5 | from fuseprop.mol_graph import MolGraph 6 | from fuseprop.encoder import GraphEncoder 7 | from fuseprop.decoder import GraphDecoder 8 | from fuseprop.nnutils import * 9 | 10 | def make_cuda(graph_tensors): 11 | make_tensor = lambda x: x if type(x) is torch.Tensor else torch.tensor(x) 12 | graph_tensors = [make_tensor(x).cuda().long() for x in graph_tensors[:-1]] + [graph_tensors[-1]] 13 | return graph_tensors 14 | 15 | 16 | class AtomVGNN(nn.Module): 17 | 18 | def __init__(self, args): 19 | super(AtomVGNN, self).__init__() 20 | self.latent_size = args.latent_size 21 | self.encoder = GraphEncoder(args.atom_vocab, args.rnn_type, args.embed_size, args.hidden_size, args.depth) 22 | self.decoder = GraphDecoder(args.atom_vocab, args.rnn_type, args.embed_size, args.hidden_size, args.latent_size, args.diter) 23 | 24 | self.G_mean = nn.Linear(args.hidden_size, args.latent_size) 25 | self.G_var = nn.Linear(args.hidden_size, args.latent_size) 26 | 27 | def encode(self, graph_tensors): 28 | graph_vecs = self.encoder(graph_tensors) 29 | graph_vecs = [graph_vecs[st : st + le].sum(dim=0) for st,le in graph_tensors[-1]] 30 | return torch.stack(graph_vecs, dim=0) 31 | 32 | def decode(self, init_smiles): 33 | batch_size = len(init_smiles) 34 | z_graph_vecs = torch.randn(batch_size, self.latent_size).cuda() 35 | return self.decoder.decode(z_graph_vecs, init_smiles) 36 | 37 | def rsample(self, z_vecs, W_mean, W_var, mean_only=False): 38 | batch_size = z_vecs.size(0) 39 | z_mean = W_mean(z_vecs) 40 | 41 | z_log_var = -torch.abs( W_var(z_vecs) ) 42 | kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size 43 | if mean_only: 44 | return z_mean, kl_loss 45 | else: 46 | epsilon = torch.randn_like(z_mean).cuda() 47 | z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon 48 | return z_vecs, kl_loss 49 | 50 | def forward(self, graphs, tensors, init_atoms, orders, beta): 51 | tensors = make_cuda(tensors) 52 | graph_vecs = self.encode(tensors) 53 | z_graph_vecs, kl_div = self.rsample(graph_vecs, self.G_mean, self.G_var) 54 | loss, wacc, tacc, sacc = self.decoder(z_graph_vecs, graphs, tensors, init_atoms, orders) 55 | return loss + beta * kl_div, kl_div.item(), wacc, tacc, sacc 56 | 57 | def test_reconstruct(self, graphs, tensors, init_atoms, orders, init_smiles): 58 | tensors = make_cuda(tensors) 59 | graph_vecs = self.encode(tensors) 60 | z_graph_vecs, kl_div = self.rsample(graph_vecs, self.G_mean, self.G_var, mean_only=True) 61 | loss, wacc, tacc, sacc = self.decoder(z_graph_vecs, graphs, tensors, init_atoms, orders) 62 | return self.decoder.decode(z_graph_vecs, init_smiles) 63 | 64 | def likelihood(self, graphs, tensors, init_atoms, orders): 65 | tensors = make_cuda(tensors) 66 | graph_vecs = self.encode(tensors) 67 | z_graph_vecs, kl_div = self.rsample(graph_vecs, self.G_mean, self.G_var, mean_only=True) 68 | loss, wacc, tacc, sacc = self.decoder(z_graph_vecs, graphs, tensors, init_atoms, orders) 69 | return -loss - kl_div # Important: loss is negative log likelihood 70 | 71 | -------------------------------------------------------------------------------- /fuseprop/inc_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import rdkit.Chem as Chem 3 | import networkx as nx 4 | 5 | from fuseprop.mol_graph import MolGraph 6 | from fuseprop.chemutils import * 7 | from collections import defaultdict 8 | 9 | class IncBase(object): 10 | 11 | def __init__(self, batch_size, node_fdim, edge_fdim, max_nodes, max_edges, max_nb): 12 | self.max_nb = max_nb 13 | self.graph = nx.DiGraph() 14 | self.graph.add_node(0) #make sure node is 1 index 15 | self.edge_dict = {None : 0} #make sure edge is 1 index 16 | 17 | self.fnode = torch.zeros(max_nodes * batch_size, node_fdim).cuda() 18 | self.fmess = self.fnode.new_zeros(max_edges * batch_size, edge_fdim) 19 | self.agraph = self.fnode.new_zeros(max_nodes * batch_size, max_nb).long() 20 | self.bgraph = self.fnode.new_zeros(max_edges * batch_size, max_nb).long() 21 | 22 | def add_node(self, feature): 23 | idx = len(self.graph) 24 | if idx >= len(self.fnode) - 1: 25 | self.fnode = torch.cat([self.fnode, self.fnode * 0], dim=0) 26 | self.agraph = torch.cat([self.agraph, self.agraph * 0], dim=0) 27 | 28 | self.graph.add_node(idx) 29 | self.fnode[idx, :len(feature)] = feature 30 | return idx 31 | 32 | def can_expand(self, idx): 33 | return self.graph.in_degree(idx) < self.max_nb 34 | 35 | def add_edge(self, i, j, feature=None): 36 | if (i,j) in self.edge_dict: 37 | return self.edge_dict[(i,j)] 38 | 39 | self.graph.add_edge(i, j) 40 | self.edge_dict[(i,j)] = idx = len(self.edge_dict) 41 | 42 | if idx >= len(self.fmess) - 1: 43 | self.fmess = torch.cat([self.fmess, self.fmess * 0], dim=0) 44 | self.bgraph = torch.cat([self.bgraph, self.bgraph * 0], dim=0) 45 | 46 | self.agraph[j, self.graph.in_degree(j) - 1] = idx 47 | if feature is not None: 48 | self.fmess[idx, :len(feature)] = feature 49 | 50 | in_edges = [self.edge_dict[(k,i)] for k in self.graph.predecessors(i) if k != j] 51 | self.bgraph[idx, :len(in_edges)] = self.fnode.new_tensor(in_edges) 52 | 53 | for k in self.graph.successors(j): 54 | if k == i: continue 55 | nei_idx = self.edge_dict[(j,k)] 56 | self.bgraph[nei_idx, self.graph.in_degree(j) - 2] = idx 57 | 58 | return idx 59 | 60 | 61 | class IncGraph(IncBase): 62 | 63 | def __init__(self, avocab, batch_size, node_fdim, edge_fdim, max_nodes=20, max_edges=50, max_nb=6): 64 | super(IncGraph, self).__init__(batch_size, node_fdim, edge_fdim, max_nodes, max_edges, max_nb) 65 | self.avocab = avocab 66 | self.mol = Chem.RWMol() 67 | self.mol.AddAtom( Chem.Atom('C') ) #make sure node is 1 index, consistent to self.graph 68 | self.batch = defaultdict(list) 69 | self.interior_atoms = defaultdict(list) 70 | 71 | def get_mol(self): 72 | mol_list = [None] * len(self.batch) 73 | for batch_idx, batch_atoms in self.batch.items(): 74 | mol = get_sub_mol(self.mol, batch_atoms) 75 | mol = sanitize(mol, kekulize=False) 76 | if mol is None: 77 | mol_list[batch_idx] = None 78 | else: 79 | for atom in mol.GetAtoms(): 80 | atom.SetAtomMapNum(0) 81 | mol_list[batch_idx] = Chem.MolToSmiles(mol) 82 | return mol_list 83 | 84 | def get_tensors(self): 85 | return self.fnode, self.fmess, self.agraph, self.bgraph 86 | 87 | def add_mol(self, bid, smiles): 88 | mol = get_mol(smiles) #Important: must kekulize! 89 | root_atoms = [] 90 | amap = {} 91 | for atom in mol.GetAtoms(): 92 | symbol, charge = atom.GetSymbol(), atom.GetFormalCharge() 93 | nth_atom = atom.GetAtomMapNum() 94 | idx = self.add_atom(bid, (symbol, charge), nth_atom) 95 | amap[atom.GetIdx()] = idx 96 | if nth_atom > 0: 97 | root_atoms.append( (idx, nth_atom) ) 98 | else: 99 | self.interior_atoms[bid].append(idx) 100 | 101 | for bond in mol.GetBonds(): 102 | a1 = amap[bond.GetBeginAtom().GetIdx()] 103 | a2 = amap[bond.GetEndAtom().GetIdx()] 104 | bt = bond.GetBondType() 105 | self.add_bond(a1, a2, MolGraph.BOND_LIST.index(bt)) 106 | 107 | root_atoms = sorted(root_atoms, key=lambda x:x[1]) 108 | root_atoms = next(zip(*root_atoms)) 109 | return root_atoms 110 | 111 | def add_atom(self, bid, atom_type, nth_atom=None): 112 | if nth_atom is None: 113 | nth_atom = len(self.batch[bid]) - len(self.interior_atoms[bid]) + 1 114 | new_atom = Chem.Atom(atom_type[0]) 115 | new_atom.SetFormalCharge(atom_type[1]) 116 | atom_feature = self.get_atom_feature(new_atom, nth_atom) 117 | aid = self.mol.AddAtom( new_atom ) 118 | assert aid == self.add_node( atom_feature ) 119 | self.batch[bid].append(aid) 120 | return aid 121 | 122 | def add_bond(self, a1, a2, bond_pred): 123 | assert 1 <= bond_pred <= 3 124 | if a1 == a2: return 125 | if self.can_expand(a1) == False or self.can_expand(a2) == False: 126 | return 127 | if self.mol.GetBondBetweenAtoms(a1, a2) is not None: 128 | return 129 | 130 | atom1, atom2 = self.mol.GetAtomWithIdx(a1), self.mol.GetAtomWithIdx(a2) 131 | if valence_check(atom1, bond_pred) and valence_check(atom2, bond_pred): 132 | bond_type = MolGraph.BOND_LIST[bond_pred] 133 | self.mol.AddBond(a1, a2, bond_type) 134 | self.add_edge( a1, a2, self.get_mess_feature(self.fnode[a1], bond_pred) ) 135 | self.add_edge( a2, a1, self.get_mess_feature(self.fnode[a2], bond_pred) ) 136 | 137 | # TOO SLOW! 138 | #if sanitize(self.mol.GetMol(), kekulize=False) is None: 139 | # self.mol.RemoveBond(a1, a2) 140 | # return 141 | 142 | 143 | def get_atom_feature(self, atom, nth_atom): 144 | nth_atom = min(MolGraph.MAX_POS - 1, nth_atom) 145 | f_atom = torch.zeros(self.avocab.size()) 146 | f_pos = torch.zeros( MolGraph.MAX_POS ) 147 | symbol, charge = atom.GetSymbol(), atom.GetFormalCharge() 148 | f_atom[ self.avocab[(symbol,charge)] ] = 1 149 | f_pos[ nth_atom ] = 1 150 | return torch.cat( [f_atom, f_pos], dim=-1 ).cuda() 151 | 152 | def get_mess_feature(self, atom_fea, bond_type): 153 | bond_fea = torch.zeros(len(MolGraph.BOND_LIST)).cuda() 154 | bond_fea[ bond_type ] = 1 155 | return torch.cat( [atom_fea, bond_fea], dim=-1 ) 156 | 157 | -------------------------------------------------------------------------------- /fuseprop/mol_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import rdkit 4 | import rdkit.Chem as Chem 5 | import networkx as nx 6 | from fuseprop.chemutils import * 7 | from fuseprop.nnutils import * 8 | from fuseprop.vocab import common_atom_vocab 9 | from collections import deque 10 | 11 | add = lambda x,y : x + y if type(x) is int else (x[0] + y, x[1] + y) 12 | 13 | class MolGraph(object): 14 | 15 | BOND_LIST = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] 16 | MAX_POS = 40 17 | 18 | def __init__(self, smiles, init_atoms, root_atoms=None, shuffle_roots=True): 19 | self.smiles = smiles 20 | self.mol = get_mol(smiles) 21 | self.mol_graph = self.build_mol_graph() 22 | self.init_atoms = set(init_atoms) 23 | self.root_atoms = self.get_root_atoms() if root_atoms is None else root_atoms 24 | if len(self.root_atoms) > 0: 25 | if shuffle_roots: random.shuffle(self.root_atoms) 26 | self.order = self.get_bfs_order() 27 | 28 | def debug(self): 29 | for atom in self.mol.GetAtoms(): 30 | if atom.GetIdx() in self.init_atoms: 31 | atom.SetAtomMapNum(atom.GetIdx()) 32 | print( Chem.MolToSmiles(self.mol) ) 33 | print('root', self.root_atoms) 34 | print('init', self.init_atoms) 35 | for x in self.order: 36 | print(x) 37 | 38 | def get_root_atoms(self): 39 | roots = [] 40 | for idx in self.init_atoms: 41 | atom = self.mol.GetAtomWithIdx(idx) 42 | bad_neis = [y for y in atom.GetNeighbors() if y.GetIdx() not in self.init_atoms] 43 | if len(bad_neis) > 0: 44 | roots.append(idx) 45 | return roots 46 | 47 | def build_mol_graph(self): 48 | mol = self.mol 49 | graph = nx.DiGraph(Chem.rdmolops.GetAdjacencyMatrix(mol)) 50 | for atom in mol.GetAtoms(): 51 | graph.nodes[atom.GetIdx()]['label'] = (atom.GetSymbol(), atom.GetFormalCharge()) 52 | 53 | for bond in mol.GetBonds(): 54 | a1 = bond.GetBeginAtom().GetIdx() 55 | a2 = bond.GetEndAtom().GetIdx() 56 | btype = MolGraph.BOND_LIST.index( bond.GetBondType() ) 57 | graph[a1][a2]['label'] = btype 58 | graph[a2][a1]['label'] = btype 59 | 60 | return graph 61 | 62 | def get_bfs_order(self): 63 | order = [] 64 | visited = set(self.init_atoms) 65 | queue = deque( [self.mol.GetAtomWithIdx(k) for k in self.root_atoms] ) 66 | 67 | for a in self.init_atoms: 68 | self.mol_graph.nodes[a]['pos'] = 0 69 | 70 | for i,root in enumerate(self.root_atoms): 71 | self.mol_graph.nodes[root]['pos'] = i + 1 # interior atoms are 0, boundary atoms are 1,2,... 72 | 73 | pos_id = len(self.root_atoms) 74 | while len(queue) > 0: 75 | x = queue.popleft() 76 | x_idx = x.GetIdx() 77 | for y in x.GetNeighbors(): 78 | y_idx = y.GetIdx() 79 | if y_idx in visited: continue 80 | 81 | frontier = [x_idx] + [a.GetIdx() for a in list(queue)] 82 | bonds = [0] * len(frontier) 83 | y_neis = set([z.GetIdx() for z in y.GetNeighbors()]) 84 | 85 | for i,z_idx in enumerate(frontier): 86 | if z_idx in y_neis: 87 | bonds[i] = self.mol_graph[y_idx][z_idx]['label'] 88 | 89 | order.append( (x_idx, y_idx, frontier, bonds) ) 90 | pos_id += 1 91 | self.mol_graph.nodes[y_idx]['pos'] = min(MolGraph.MAX_POS - 1, pos_id) 92 | visited.add( y_idx ) 93 | queue.append(y) 94 | 95 | order.append( (x_idx, None, None, None) ) 96 | 97 | return order 98 | 99 | @staticmethod 100 | def tensorize(mol_batch, avocab=common_atom_vocab): 101 | graph_tensors, graph_batchG = MolGraph.tensorize_graph([x.mol_graph for x in mol_batch], avocab) 102 | graph_scope = graph_tensors[-1] 103 | 104 | add = lambda a,b : None if a is None else a + b 105 | add_list = lambda alist,b : None if alist is None else [a + b for a in alist] 106 | 107 | all_orders = [] 108 | all_init_atoms = [] 109 | for i,hmol in enumerate(mol_batch): 110 | offset = graph_scope[i][0] 111 | order = [(x + offset, add(y, offset), add_list(z, offset), t) for x,y,z,t in hmol.order] 112 | init_atoms = [x + offset for x in hmol.init_atoms] 113 | all_orders.append(order) 114 | all_init_atoms.append(init_atoms) 115 | 116 | return graph_batchG, graph_tensors, all_init_atoms, all_orders 117 | 118 | @staticmethod 119 | def tensorize_graph(graph_batch, vocab): 120 | fnode,fmess = [None],[(0,0,0)] 121 | agraph,bgraph = [[]], [[]] 122 | scope = [] 123 | edge_dict = {} 124 | all_G = [] 125 | 126 | for bid,G in enumerate(graph_batch): 127 | offset = len(fnode) 128 | scope.append( (offset, len(G)) ) 129 | G = nx.convert_node_labels_to_integers(G, first_label=offset) 130 | all_G.append(G) 131 | fnode.extend( [None for v in G.nodes] ) 132 | 133 | for v, attr in G.nodes(data='label'): 134 | G.nodes[v]['batch_id'] = bid 135 | fnode[v] = (vocab[attr], G.nodes[v]['pos']) 136 | agraph.append([]) 137 | 138 | for u, v, attr in G.edges(data='label'): 139 | fmess.append( (u, v, attr) ) 140 | edge_dict[(u, v)] = eid = len(edge_dict) + 1 141 | G[u][v]['mess_idx'] = eid 142 | agraph[v].append(eid) 143 | bgraph.append([]) 144 | 145 | for u, v in G.edges: 146 | eid = edge_dict[(u, v)] 147 | for w in G.predecessors(u): 148 | if w == v: continue 149 | bgraph[eid].append( edge_dict[(w, u)] ) 150 | 151 | fnode[0] = fnode[1] 152 | fnode = torch.LongTensor(fnode) 153 | fmess = torch.LongTensor(fmess) 154 | agraph = create_pad_tensor(agraph) 155 | bgraph = create_pad_tensor(bgraph) 156 | return (fnode, fmess, agraph, bgraph, scope), nx.union_all(all_G) 157 | 158 | -------------------------------------------------------------------------------- /fuseprop/nnutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def index_select_ND(source, dim, index): 6 | index_size = index.size() 7 | suffix_dim = source.size()[1:] 8 | final_size = index_size + suffix_dim 9 | target = source.index_select(dim, index.view(-1)) 10 | return target.view(final_size) 11 | 12 | def avg_pool(all_vecs, scope, dim): 13 | size = create_var(torch.Tensor([le for _,le in scope])) 14 | return all_vecs.sum(dim=dim) / size.unsqueeze(-1) 15 | 16 | def get_accuracy_bin(scores, labels): 17 | preds = torch.ge(scores, 0).long() 18 | acc = torch.eq(preds, labels).float() 19 | return torch.sum(acc) / labels.nelement() 20 | 21 | def get_accuracy(scores, labels): 22 | _,preds = torch.max(scores, dim=-1) 23 | acc = torch.eq(preds, labels).float() 24 | return torch.sum(acc) / labels.nelement() 25 | 26 | def get_accuracy_sym(scores, labels): 27 | max_scores,max_idx = torch.max(scores, dim=-1) 28 | lab_scores = scores[torch.arange(len(scores)), labels] 29 | acc = torch.eq(lab_scores, max_scores).float() 30 | return torch.sum(acc) / labels.nelement() 31 | 32 | def stack_pad_tensor(tensor_list): 33 | max_len = max([t.size(0) for t in tensor_list]) 34 | for i,tensor in enumerate(tensor_list): 35 | pad_len = max_len - tensor.size(0) 36 | tensor_list[i] = F.pad( tensor, (0,0,0,pad_len) ) 37 | return torch.stack(tensor_list, dim=0) 38 | 39 | def create_pad_tensor(alist): 40 | max_len = max([len(a) for a in alist]) + 1 41 | for a in alist: 42 | pad_len = max_len - len(a) 43 | a.extend([0] * pad_len) 44 | return torch.IntTensor(alist) 45 | 46 | def zip_tensors(tup_list): 47 | arr0, arr1, arr2 = list(zip(*tup_list)) 48 | if type(arr2[0]) is int: 49 | arr0 = torch.stack(arr0, dim=0) 50 | arr1 = torch.LongTensor(arr1).cuda() 51 | arr2 = torch.LongTensor(arr2).cuda() 52 | else: 53 | arr0 = torch.cat(arr0, dim=0) 54 | arr1 = [x for a in arr1 for x in a] 55 | arr1 = torch.LongTensor(arr1).cuda() 56 | arr2 = torch.cat(arr2, dim=0) 57 | return arr0, arr1, arr2 58 | 59 | def index_scatter(sub_data, all_data, index): 60 | d0, d1 = all_data.size() 61 | buf = torch.zeros_like(all_data).scatter_(0, index.repeat(d1, 1).t(), sub_data) 62 | mask = torch.ones(d0, device=all_data.device).scatter_(0, index, 0) 63 | return all_data * mask.unsqueeze(-1) + buf 64 | 65 | -------------------------------------------------------------------------------- /fuseprop/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fuseprop.nnutils import * 4 | 5 | class GRU(nn.Module): 6 | 7 | def __init__(self, input_size, hidden_size, depth): 8 | super(GRU, self).__init__() 9 | self.hidden_size = hidden_size 10 | self.input_size = input_size 11 | self.depth = depth 12 | 13 | self.W_z = nn.Linear(input_size + hidden_size, hidden_size) 14 | self.W_r = nn.Linear(input_size, hidden_size, bias=False) 15 | self.U_r = nn.Linear(hidden_size, hidden_size) 16 | self.W_h = nn.Linear(input_size + hidden_size, hidden_size) 17 | 18 | def get_init_state(self, fmess, init_state=None): 19 | h = torch.zeros(len(fmess), self.hidden_size, device=fmess.device) 20 | return h if init_state is None else torch.cat( (h, init_state), dim=0) 21 | 22 | def get_hidden_state(self, h): 23 | return h 24 | 25 | def GRU(self, x, h_nei): 26 | sum_h = h_nei.sum(dim=1) 27 | z_input = torch.cat([x,sum_h], dim=1) 28 | z = torch.sigmoid(self.W_z(z_input)) 29 | 30 | r_1 = self.W_r(x).view(-1, 1, self.hidden_size) 31 | r_2 = self.U_r(h_nei) 32 | r = torch.sigmoid(r_1 + r_2) 33 | 34 | gated_h = r * h_nei 35 | sum_gated_h = gated_h.sum(dim=1) 36 | h_input = torch.cat([x,sum_gated_h], dim=1) 37 | pre_h = torch.tanh(self.W_h(h_input)) 38 | new_h = (1.0 - z) * sum_h + z * pre_h 39 | return new_h 40 | 41 | def forward(self, fmess, bgraph): 42 | h = torch.zeros(fmess.size(0), self.hidden_size, device=fmess.device) 43 | mask = torch.ones(h.size(0), 1, device=h.device) 44 | mask[0, 0] = 0 #first message is padding 45 | 46 | for i in range(self.depth): 47 | h_nei = index_select_ND(h, 0, bgraph) 48 | h = self.GRU(fmess, h_nei) 49 | h = h * mask 50 | return h 51 | 52 | def sparse_forward(self, h, fmess, submess, bgraph): 53 | mask = h.new_ones(h.size(0)).scatter_(0, submess, 0) 54 | h = h * mask.unsqueeze(1) 55 | for i in range(self.depth): 56 | h_nei = index_select_ND(h, 0, bgraph) 57 | sub_h = self.GRU(fmess, h_nei) 58 | h = index_scatter(sub_h, h, submess) 59 | return h 60 | 61 | class LSTM(nn.Module): 62 | 63 | def __init__(self, input_size, hidden_size, depth): 64 | super(LSTM, self).__init__() 65 | self.hidden_size = hidden_size 66 | self.input_size = input_size 67 | self.depth = depth 68 | 69 | self.W_i = nn.Sequential( nn.Linear(input_size + hidden_size, hidden_size), nn.Sigmoid() ) 70 | self.W_o = nn.Sequential( nn.Linear(input_size + hidden_size, hidden_size), nn.Sigmoid() ) 71 | self.W_f = nn.Sequential( nn.Linear(input_size + hidden_size, hidden_size), nn.Sigmoid() ) 72 | self.W = nn.Sequential( nn.Linear(input_size + hidden_size, hidden_size), nn.Tanh() ) 73 | 74 | def get_init_state(self, fmess, init_state=None): 75 | h = torch.zeros(len(fmess), self.hidden_size, device=fmess.device) 76 | c = torch.zeros(len(fmess), self.hidden_size, device=fmess.device) 77 | if init_state is not None: 78 | h = torch.cat( (h, init_state), dim=0) 79 | c = torch.cat( (c, torch.zeros_like(init_state)), dim=0) 80 | return h,c 81 | 82 | def get_hidden_state(self, h): 83 | return h[0] 84 | 85 | def LSTM(self, x, h_nei, c_nei): 86 | h_sum_nei = h_nei.sum(dim=1) 87 | x_expand = x.unsqueeze(1).expand(-1, h_nei.size(1), -1) 88 | i = self.W_i( torch.cat([x, h_sum_nei], dim=-1) ) 89 | o = self.W_o( torch.cat([x, h_sum_nei], dim=-1) ) 90 | f = self.W_f( torch.cat([x_expand, h_nei], dim=-1) ) 91 | u = self.W( torch.cat([x, h_sum_nei], dim=-1) ) 92 | c = i * u + (f * c_nei).sum(dim=1) 93 | h = o * torch.tanh(c) 94 | return h, c 95 | 96 | def forward(self, fmess, bgraph): 97 | h = torch.zeros(fmess.size(0), self.hidden_size, device=fmess.device) 98 | c = torch.zeros(fmess.size(0), self.hidden_size, device=fmess.device) 99 | mask = torch.ones(h.size(0), 1, device=h.device) 100 | mask[0, 0] = 0 #first message is padding 101 | 102 | for i in range(self.depth): 103 | h_nei = index_select_ND(h, 0, bgraph) 104 | c_nei = index_select_ND(c, 0, bgraph) 105 | h,c = self.LSTM(fmess, h_nei, c_nei) 106 | h = h * mask 107 | c = c * mask 108 | return h,c 109 | 110 | def sparse_forward(self, h, fmess, submess, bgraph): 111 | h,c = h 112 | mask = h.new_ones(h.size(0)).scatter_(0, submess, 0) 113 | h = h * mask.unsqueeze(1) 114 | c = c * mask.unsqueeze(1) 115 | for i in range(self.depth): 116 | h_nei = index_select_ND(h, 0, bgraph) 117 | c_nei = index_select_ND(c, 0, bgraph) 118 | sub_h, sub_c = self.LSTM(fmess, h_nei, c_nei) 119 | h = index_scatter(sub_h, h, submess) 120 | c = index_scatter(sub_c, c, submess) 121 | return h,c 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /fuseprop/vocab.py: -------------------------------------------------------------------------------- 1 | import rdkit 2 | import rdkit.Chem as Chem 3 | import copy 4 | import torch 5 | 6 | class Vocab(object): 7 | 8 | def __init__(self, smiles_list): 9 | self.vocab = [x for x in smiles_list] #copy 10 | self.vmap = {x:i for i,x in enumerate(self.vocab)} 11 | 12 | def __getitem__(self, smiles): 13 | return self.vmap[smiles] 14 | 15 | def get_smiles(self, idx): 16 | return self.vocab[idx] 17 | 18 | def size(self): 19 | return len(self.vocab) 20 | 21 | class PairVocab(object): 22 | 23 | def __init__(self, smiles_pairs, cuda=True): 24 | cls = next(zip(*smiles_pairs)) 25 | self.hvocab = list( set(cls) ) 26 | self.hmap = {x:i for i,x in enumerate(self.hvocab)} 27 | 28 | self.vocab = [tuple(x) for x in smiles_pairs] #copy 29 | self.inter_size = [count_inters(x[1]) for x in self.vocab] 30 | self.vmap = {x:i for i,x in enumerate(self.vocab)} 31 | 32 | self.mask = torch.zeros(len(self.hvocab), len(self.vocab)) 33 | for h,s in smiles_pairs: 34 | hid = self.hmap[h] 35 | idx = self.vmap[(h,s)] 36 | self.mask[hid, idx] = 1000.0 37 | 38 | if cuda: self.mask = self.mask.cuda() 39 | self.mask = self.mask - 1000.0 40 | 41 | def __getitem__(self, x): 42 | assert type(x) is tuple 43 | return self.hmap[x[0]], self.vmap[x] 44 | 45 | def get_smiles(self, idx): 46 | return self.hvocab[idx] 47 | 48 | def get_ismiles(self, idx): 49 | return self.vocab[idx][1] 50 | 51 | def size(self): 52 | return len(self.hvocab), len(self.vocab) 53 | 54 | def get_mask(self, cls_idx): 55 | return self.mask.index_select(index=cls_idx, dim=0) 56 | 57 | def get_inter_size(self, icls_idx): 58 | return self.inter_size[icls_idx] 59 | 60 | COMMON_ATOMS = [('B', 0), ('B', -1), ('Br', 0), ('Br', -1), ('Br', 2), ('C', 0), ('C', 1), ('C', -1), ('Cl', 0), ('Cl', 1), ('Cl', -1), ('Cl', 2), ('Cl', 3), ('F', 0), ('F', 1), ('F', -1), ('I', -1), ('I', 0), ('I', 1), ('I', 2), ('I', 3), ('N', 0), ('N', 1), ('N', -1), ('O', 0), ('O', 1), ('O', -1), ('P', 0), ('P', 1), ('P', -1), ('S', 0), ('S', 1), ('S', -1), ('Se', 0), ('Se', 1), ('Se', -1), ('Si', 0), ('Si', -1)] 61 | common_atom_vocab = Vocab(COMMON_ATOMS) 62 | 63 | MAX_VALENCE = {'B': 3, 'Br':1, 'C':4, 'Cl':1, 'F':1, 'I':1, 'N':5, 'O':2, 'P':5, 'S':6, 'Se':4, 'Si':4} 64 | 65 | def count_inters(s): 66 | mol = Chem.MolFromSmiles(s) 67 | inters = [a for a in mol.GetAtoms() if a.GetAtomMapNum() > 0] 68 | return max(1, len(inters)) 69 | 70 | 71 | -------------------------------------------------------------------------------- /private/__init__.py: -------------------------------------------------------------------------------- 1 | from private.molecule_graph import MolGraph, InputGraph, MolKey, SubGraph 2 | from private.grammar import ProductionRuleCorpus, generate_rule, ProductionRule 3 | from private.subgraph_set import SubGraphSet 4 | from private.metrics import InternalDiversity 5 | from private.hypergraph import Hypergraph, hg_to_mol 6 | from private.utils import create_exp_dir, create_logger 7 | -------------------------------------------------------------------------------- /private/metrics.py: -------------------------------------------------------------------------------- 1 | from rdkit import DataStructs, Chem 2 | from rdkit.Chem import AllChem 3 | import numpy as np 4 | import torch.multiprocessing as mp 5 | from retro_star.api import RSPlanner 6 | 7 | 8 | class InternalDiversity(): 9 | def distance(self, mol1, mol2, dtype="Tanimoto"): 10 | assert dtype in ["Tanimoto"] 11 | if dtype == "Tanimoto": 12 | sim = DataStructs.FingerprintSimilarity(Chem.RDKFingerprint(mol1), Chem.RDKFingerprint(mol2)) 13 | return 1 - sim 14 | else: 15 | raise NotImplementedError 16 | 17 | def get_diversity(self, mol_list, dtype="Tanimoto"): 18 | similarity = 0 19 | mol_list = [AllChem.GetMorganFingerprintAsBitVect(x, 3, 2048) for x in mol_list] 20 | for i in range(len(mol_list)): 21 | sims = DataStructs.BulkTanimotoSimilarity(mol_list[i], mol_list[:i]) 22 | similarity += sum(sims) 23 | n = len(mol_list) 24 | n_pairs = n * (n - 1) / 2 25 | diversity = 1 - similarity / n_pairs 26 | return diversity 27 | 28 | 29 | if __name__ == "__main__": 30 | pass 31 | 32 | -------------------------------------------------------------------------------- /private/subgraph_set.py: -------------------------------------------------------------------------------- 1 | from .molecule_graph import MolKey 2 | 3 | class SubGraphSet(): 4 | def __init__(self, init_subgraphs, subgraphs_idx, inputs): 5 | self.subgraphs = init_subgraphs 6 | self.subgraphs_idx = subgraphs_idx 7 | self.inputs = inputs 8 | self.map_to_input = self.get_map_to_input() 9 | 10 | def get_map_to_input(self): 11 | ''' 12 | Input: 13 | init_subgraphs: a list with length equal to # input graphs, elements are Chem.Mol 14 | subgraphs_idx: a three-level list 15 | the first level: the same length with # input graphs, each element corresponds to a input graph 16 | the second level: the same length with # sub_graphs of the input graph, each element corresponds to a sub_graph 17 | the third level: the same length with # atoms of the sub_graph, each element corresponds to an atom index of input graph 18 | Output: 19 | map_to_input: a dict, [key_of_subgraphs][0][key_of_inputs][the_matched_subgraph][atom_idx_of_input] 20 | [key_of_subgraphs][1] 21 | ''' 22 | map_to_input = dict() 23 | for i, input_i in enumerate(self.inputs): 24 | key_input = MolKey(input_i) 25 | subgraphs_idx_i = self.subgraphs_idx[i] 26 | for j, subgraph_idx in enumerate(subgraphs_idx_i): 27 | key_subgraph = MolKey(self.subgraphs[i][j]) 28 | subg = input_i.subgraphs[input_i.subgraphs_idx.index(subgraph_idx)] 29 | if key_subgraph not in map_to_input.keys(): 30 | map_to_input[key_subgraph] = [dict(), 1] 31 | else: 32 | map_to_input[key_subgraph][1] += 1 33 | if key_input not in map_to_input[key_subgraph][0].keys(): 34 | map_to_input[key_subgraph][0][key_input] = list() 35 | map_to_input[key_subgraph][0][key_input].append((subgraph_idx, subg)) 36 | return map_to_input 37 | 38 | def update(self, input_graphs): 39 | new_subgraphs = [] 40 | new_subgraphs_idx = [] 41 | new_inputs = [] 42 | # Could merge with get_map_to_input() 43 | for i, input_i in enumerate(input_graphs): 44 | new_subgraphs.append(input_i.subgraphs) 45 | new_subgraphs_idx.append(input_i.subgraphs_idx) 46 | new_inputs.append(input_i) 47 | self.subgraphs = new_subgraphs 48 | self.subgraphs_idx = new_subgraphs_idx 49 | self.inputs = new_inputs 50 | self.map_to_input = self.get_map_to_input() 51 | 52 | -------------------------------------------------------------------------------- /private/symbol.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | 4 | class TSymbol(object): 5 | 6 | ''' terminal symbol 7 | 8 | Attributes 9 | ---------- 10 | degree : int 11 | the number of nodes in a hyperedge 12 | is_aromatic : bool 13 | whether or not the hyperedge is in an aromatic ring 14 | symbol : str 15 | atomic symbol 16 | num_explicit_Hs : int 17 | the number of hydrogens associated to this hyperedge 18 | formal_charge : int 19 | charge 20 | chirality : int 21 | chirality 22 | ''' 23 | 24 | def __init__(self, degree, is_aromatic, 25 | symbol, num_explicit_Hs, formal_charge, chirality): 26 | self.degree = degree 27 | self.is_aromatic = is_aromatic 28 | self.symbol = symbol 29 | self.num_explicit_Hs = num_explicit_Hs 30 | self.formal_charge = formal_charge 31 | self.chirality = chirality 32 | 33 | @property 34 | def terminal(self): 35 | return True 36 | 37 | def __eq__(self, other): 38 | if not isinstance(other, TSymbol): 39 | return False 40 | if self.degree != other.degree: 41 | return False 42 | if self.is_aromatic != other.is_aromatic: 43 | return False 44 | if self.symbol != other.symbol: 45 | return False 46 | if self.num_explicit_Hs != other.num_explicit_Hs: 47 | return False 48 | if self.formal_charge != other.formal_charge: 49 | return False 50 | if self.chirality != other.chirality: 51 | return False 52 | return True 53 | 54 | def __hash__(self): 55 | return self.__str__().__hash__() 56 | 57 | def __str__(self): 58 | return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\ 59 | f'symbol={self.symbol}, '\ 60 | f'num_explicit_Hs={self.num_explicit_Hs}, '\ 61 | f'formal_charge={self.formal_charge}, chirality={self.chirality}' 62 | 63 | 64 | class NTSymbol(object): 65 | 66 | ''' non-terminal symbol 67 | 68 | Attributes 69 | ---------- 70 | degree : int 71 | degree of the hyperedge 72 | is_aromatic : bool 73 | if True, at least one of the associated bonds must be aromatic. 74 | node_aromatic_list : list of bool 75 | indicate whether each of the nodes is aromatic or not. 76 | bond_type_list : list of int 77 | bond type of each node" 78 | ''' 79 | 80 | def __init__(self, degree: int, is_aromatic: bool, 81 | bond_symbol_list: list, 82 | for_ring=False): 83 | self.degree = degree 84 | self.is_aromatic = is_aromatic 85 | self.for_ring = for_ring 86 | self.bond_symbol_list = self.sort_list(bond_symbol_list) 87 | 88 | def sort_list(self, bond_symbol_list): 89 | bond_symbol_type_list = [bond.bond_type for bond in bond_symbol_list] 90 | sorted_idx = np.argsort(bond_symbol_type_list) 91 | new_bond_symbol_list = [bond_symbol_list[i] for i in sorted_idx] 92 | return new_bond_symbol_list 93 | 94 | @property 95 | def terminal(self) -> bool: 96 | return False 97 | 98 | @property 99 | def symbol(self): 100 | return f'R' 101 | 102 | def __eq__(self, other) -> bool: 103 | if not isinstance(other, NTSymbol): 104 | return False 105 | 106 | if self.degree != other.degree: 107 | return False 108 | if self.is_aromatic != other.is_aromatic: 109 | return False 110 | if self.for_ring != other.for_ring: 111 | return False 112 | if len(self.bond_symbol_list) != len(other.bond_symbol_list): 113 | return False 114 | for each_idx in range(len(self.bond_symbol_list)): 115 | if self.bond_symbol_list[each_idx] != other.bond_symbol_list[each_idx]: 116 | return False 117 | return True 118 | 119 | def __hash__(self): 120 | return self.__str__().__hash__() 121 | 122 | def __str__(self) -> str: 123 | return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\ 124 | f'bond_symbol_list={[str(each_symbol) for each_symbol in self.bond_symbol_list]}'\ 125 | f'for_ring={self.for_ring}' 126 | 127 | 128 | class BondSymbol(object): 129 | 130 | 131 | ''' Bond symbol 132 | 133 | Attributes 134 | ---------- 135 | is_aromatic : bool 136 | if True, at least one of the associated bonds must be aromatic. 137 | bond_type : int 138 | bond type of each node" 139 | ''' 140 | 141 | def __init__(self, is_aromatic: bool, 142 | bond_type: int, 143 | stereo: int): 144 | self.is_aromatic = is_aromatic 145 | self.bond_type = bond_type 146 | self.stereo = stereo 147 | 148 | def __eq__(self, other) -> bool: 149 | if not isinstance(other, BondSymbol): 150 | return False 151 | 152 | if self.is_aromatic != other.is_aromatic: 153 | return False 154 | if self.bond_type != other.bond_type: 155 | return False 156 | if self.stereo != other.stereo: 157 | return False 158 | return True 159 | 160 | def __hash__(self): 161 | return self.__str__().__hash__() 162 | 163 | def __str__(self) -> str: 164 | return f'is_aromatic={self.is_aromatic}, '\ 165 | f'bond_type={self.bond_type}, '\ 166 | f'stereo={self.stereo}, ' 167 | -------------------------------------------------------------------------------- /private/utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import List 3 | import numpy as np 4 | import os 5 | import logging 6 | import shutil 7 | 8 | 9 | def _node_match(node1, node2): 10 | # if the nodes are hyperedges, `atom_attr` determines the match 11 | if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge': 12 | return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol'] 13 | elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node': 14 | # bond_symbol 15 | return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol'] 16 | else: 17 | return False 18 | 19 | def _easy_node_match(node1, node2): 20 | # if the nodes are hyperedges, `atom_attr` determines the match 21 | if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge': 22 | return node1["attr_dict"].get('symbol', None) == node2["attr_dict"].get('symbol', None) 23 | elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node': 24 | # bond_symbol 25 | return node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)\ 26 | and node1['attr_dict']['symbol'] == node2['attr_dict']['symbol'] 27 | else: 28 | return False 29 | 30 | 31 | def _node_match_prod_rule(node1, node2, ignore_order=False): 32 | # if the nodes are hyperedges, `atom_attr` determines the match 33 | if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge': 34 | return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol'] 35 | elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node': 36 | # ext_id, order4hrg, bond_symbol 37 | if ignore_order: 38 | return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol'] 39 | else: 40 | return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']\ 41 | and node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1) 42 | else: 43 | return False 44 | 45 | 46 | def _edge_match(edge1, edge2, ignore_order=False): 47 | #return True 48 | if ignore_order: 49 | return True 50 | else: 51 | return edge1["order"] == edge2["order"] 52 | 53 | def masked_softmax(logit, mask): 54 | ''' compute a probability distribution from logit 55 | 56 | Parameters 57 | ---------- 58 | logit : array-like, length D 59 | each element indicates how each dimension is likely to be chosen 60 | (the larger, the more likely) 61 | mask : array-like, length D 62 | each element is either 0 or 1. 63 | if 0, the dimension is ignored 64 | when computing the probability distribution. 65 | 66 | Returns 67 | ------- 68 | prob_dist : array, length D 69 | probability distribution computed from logit. 70 | if `mask[d] = 0`, `prob_dist[d] = 0`. 71 | ''' 72 | if logit.shape != mask.shape: 73 | raise ValueError('logit and mask must have the same shape') 74 | c = np.max(logit) 75 | exp_logit = np.exp(logit - c) * mask 76 | sum_exp_logit = exp_logit @ mask 77 | return exp_logit / sum_exp_logit 78 | 79 | 80 | def create_logger(name, log_file, level=logging.INFO): 81 | l = logging.getLogger(name) 82 | formatter = logging.Formatter( 83 | '[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s') 84 | fh = logging.FileHandler(log_file) 85 | fh.setFormatter(formatter) 86 | sh = logging.StreamHandler() 87 | sh.setFormatter(formatter) 88 | l.setLevel(level) 89 | l.addHandler(fh) 90 | l.addHandler(sh) 91 | return l 92 | 93 | def create_exp_dir(path, scripts_to_save=None): 94 | if not os.path.exists(path): 95 | os.makedirs(path) 96 | print('Experiment dir : {}'.format(path)) 97 | if scripts_to_save is not None: 98 | os.mkdir(os.path.join(path, 'scripts')) 99 | for script in scripts_to_save: 100 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 101 | shutil.copyfile(script, dst_file) 102 | -------------------------------------------------------------------------------- /retro_star/alg/__init__.py: -------------------------------------------------------------------------------- 1 | from retro_star.alg.molstar import molstar -------------------------------------------------------------------------------- /retro_star/alg/mol_node.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | 5 | class MolNode: 6 | def __init__(self, mol, init_value, parent=None, is_known=False, 7 | zero_known_value=True): 8 | self.mol = mol 9 | self.pred_value = init_value 10 | self.value = init_value 11 | self.succ_value = np.inf # total cost for existing solution 12 | self.parent = parent 13 | 14 | self.id = -1 15 | if self.parent is None: 16 | self.depth = 0 17 | else: 18 | self.depth = self.parent.depth 19 | 20 | self.is_known = is_known 21 | self.children = [] 22 | self.succ = is_known 23 | self.open = True # before expansion: True, after expansion: False 24 | if is_known: 25 | self.open = False 26 | if zero_known_value: 27 | self.value = 0 28 | self.succ_value = self.value 29 | 30 | if parent is not None: 31 | parent.children.append(self) 32 | 33 | def v_self(self): 34 | """ 35 | :return: V_self(self | subtree) 36 | """ 37 | return self.value 38 | 39 | def v_target(self): 40 | """ 41 | :return: V_target(self | whole tree) 42 | """ 43 | if self.parent is None: 44 | return self.value 45 | else: 46 | return self.parent.v_target() 47 | 48 | def init_values(self, no_child=False): 49 | assert self.open and (no_child or self.children) 50 | 51 | new_value = np.inf 52 | self.succ = False 53 | for reaction in self.children: 54 | new_value = np.min((new_value, reaction.v_self())) 55 | self.succ |= reaction.succ 56 | 57 | v_delta = new_value - self.value 58 | self.value = new_value 59 | 60 | if self.succ: 61 | for reaction in self.children: 62 | self.succ_value = np.min((self.succ_value, 63 | reaction.succ_value)) 64 | 65 | self.open = False 66 | 67 | return v_delta 68 | 69 | def backup(self, succ): 70 | assert not self.is_known 71 | 72 | new_value = np.inf 73 | for reaction in self.children: 74 | new_value = np.min((new_value, reaction.v_self())) 75 | new_succ = self.succ | succ 76 | updated = (self.value != new_value) or (self.succ != new_succ) 77 | 78 | new_succ_value = np.inf 79 | if new_succ: 80 | for reaction in self.children: 81 | new_succ_value = np.min((new_succ_value, reaction.succ_value)) 82 | updated = updated or (self.succ_value != new_succ_value) 83 | 84 | v_delta = new_value - self.value 85 | self.value = new_value 86 | self.succ = new_succ 87 | self.succ_value = new_succ_value 88 | 89 | if updated and self.parent: 90 | return self.parent.backup(v_delta, from_mol=self.mol) 91 | 92 | def serialize(self): 93 | text = '%d | %s' % (self.id, self.mol) 94 | # text = '%d | %s | pred %.2f | value %.2f | target %.2f' % \ 95 | # (self.id, self.mol, self.pred_value, self.v_self(), 96 | # self.v_target()) 97 | return text 98 | 99 | def get_ancestors(self): 100 | if self.parent is None: 101 | return {self.mol} 102 | 103 | ancestors = self.parent.parent.get_ancestors() 104 | ancestors.add(self.mol) 105 | return ancestors -------------------------------------------------------------------------------- /retro_star/alg/mol_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from queue import Queue 3 | import logging 4 | import networkx as nx 5 | from graphviz import Digraph 6 | from retro_star.alg.mol_node import MolNode 7 | from retro_star.alg.reaction_node import ReactionNode 8 | from retro_star.alg.syn_route import SynRoute 9 | 10 | 11 | class MolTree: 12 | def __init__(self, target_mol, known_mols, value_fn, zero_known_value=True): 13 | self.target_mol = target_mol 14 | self.known_mols = known_mols 15 | self.value_fn = value_fn 16 | self.zero_known_value = zero_known_value 17 | self.mol_nodes = [] 18 | self.reaction_nodes = [] 19 | 20 | self.root = self._add_mol_node(target_mol, None) 21 | self.succ = target_mol in known_mols 22 | self.search_status = 0 23 | 24 | if self.succ: 25 | logging.info('Synthesis route found: target in starting molecules') 26 | 27 | def _add_mol_node(self, mol, parent): 28 | is_known = mol in self.known_mols 29 | 30 | init_value = self.value_fn(mol) 31 | 32 | mol_node = MolNode( 33 | mol=mol, 34 | init_value=init_value, 35 | parent=parent, 36 | is_known=is_known, 37 | zero_known_value=self.zero_known_value 38 | ) 39 | self.mol_nodes.append(mol_node) 40 | mol_node.id = len(self.mol_nodes) 41 | 42 | return mol_node 43 | 44 | def _add_reaction_and_mol_nodes(self, cost, mols, parent, template, ancestors): 45 | assert cost >= 0 46 | 47 | for mol in mols: 48 | if mol in ancestors: 49 | return 50 | 51 | reaction_node = ReactionNode(parent, cost, template) 52 | for mol in mols: 53 | self._add_mol_node(mol, reaction_node) 54 | reaction_node.init_values() 55 | self.reaction_nodes.append(reaction_node) 56 | reaction_node.id = len(self.reaction_nodes) 57 | 58 | return reaction_node 59 | 60 | def expand(self, mol_node, reactant_lists, costs, templates): 61 | assert not mol_node.is_known and not mol_node.children 62 | 63 | if costs is None: # No expansion results 64 | assert mol_node.init_values(no_child=True) == np.inf 65 | if mol_node.parent: 66 | mol_node.parent.backup(np.inf, from_mol=mol_node.mol) 67 | return self.succ 68 | 69 | assert mol_node.open 70 | ancestors = mol_node.get_ancestors() 71 | for i in range(len(costs)): 72 | self._add_reaction_and_mol_nodes(costs[i], reactant_lists[i], 73 | mol_node, templates[i], ancestors) 74 | 75 | if len(mol_node.children) == 0: # No valid expansion results 76 | assert mol_node.init_values(no_child=True) == np.inf 77 | if mol_node.parent: 78 | mol_node.parent.backup(np.inf, from_mol=mol_node.mol) 79 | return self.succ 80 | 81 | v_delta = mol_node.init_values() 82 | if mol_node.parent: 83 | mol_node.parent.backup(v_delta, from_mol=mol_node.mol) 84 | 85 | if not self.succ and self.root.succ: 86 | logging.info('Synthesis route found!') 87 | self.succ = True 88 | 89 | return self.succ 90 | 91 | def get_best_route(self): 92 | if not self.succ: 93 | return None 94 | 95 | syn_route = SynRoute( 96 | target_mol=self.root.mol, 97 | succ_value=self.root.succ_value, 98 | search_status=self.search_status 99 | ) 100 | 101 | mol_queue = Queue() 102 | mol_queue.put(self.root) 103 | while not mol_queue.empty(): 104 | mol = mol_queue.get() 105 | if mol.is_known: 106 | syn_route.set_value(mol.mol, mol.succ_value) 107 | continue 108 | 109 | best_reaction = None 110 | for reaction in mol.children: 111 | if reaction.succ: 112 | if best_reaction is None or \ 113 | reaction.succ_value < best_reaction.succ_value: 114 | best_reaction = reaction 115 | assert best_reaction.succ_value == mol.succ_value 116 | 117 | reactants = [] 118 | for reactant in best_reaction.children: 119 | mol_queue.put(reactant) 120 | reactants.append(reactant.mol) 121 | 122 | syn_route.add_reaction( 123 | mol=mol.mol, 124 | value=mol.succ_value, 125 | template=best_reaction.template, 126 | reactants=reactants, 127 | cost=best_reaction.cost 128 | ) 129 | 130 | return syn_route 131 | 132 | def viz_search_tree(self, viz_file): 133 | G = Digraph('G', filename=viz_file) 134 | G.attr(rankdir='LR') 135 | G.attr('node', shape='box') 136 | G.format = 'pdf' 137 | 138 | node_queue = Queue() 139 | node_queue.put((self.root, None)) 140 | while not node_queue.empty(): 141 | node, parent = node_queue.get() 142 | 143 | if node.open: 144 | color = 'lightgrey' 145 | else: 146 | color = 'aquamarine' 147 | 148 | if hasattr(node, 'mol'): 149 | shape = 'box' 150 | else: 151 | shape = 'rarrow' 152 | 153 | if node.succ: 154 | color = 'lightblue' 155 | if hasattr(node, 'mol') and node.is_known: 156 | color = 'lightyellow' 157 | 158 | G.node(node.serialize(), shape=shape, color=color, style='filled') 159 | 160 | label = '' 161 | if hasattr(parent, 'mol'): 162 | label = '%.3f' % node.cost 163 | if parent is not None: 164 | G.edge(parent.serialize(), node.serialize(), label=label) 165 | 166 | if node.children is not None: 167 | for c in node.children: 168 | node_queue.put((c, node)) 169 | 170 | G.render() 171 | -------------------------------------------------------------------------------- /retro_star/alg/molstar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import logging 4 | from retro_star.alg.mol_tree import MolTree 5 | 6 | 7 | def molstar(target_mol, target_mol_id, starting_mols, expand_fn, value_fn, 8 | iterations, viz=False, viz_dir=None): 9 | mol_tree = MolTree( 10 | target_mol=target_mol, 11 | known_mols=starting_mols, 12 | value_fn=value_fn 13 | ) 14 | 15 | i = -1 16 | 17 | if not mol_tree.succ: 18 | for i in range(iterations): 19 | scores = [] 20 | for m in mol_tree.mol_nodes: 21 | if m.open: 22 | scores.append(m.v_target()) 23 | else: 24 | scores.append(np.inf) 25 | scores = np.array(scores) 26 | 27 | if np.min(scores) == np.inf: 28 | logging.info('No open nodes!') 29 | break 30 | 31 | metric = scores 32 | 33 | mol_tree.search_status = np.min(metric) 34 | m_next = mol_tree.mol_nodes[np.argmin(metric)] 35 | assert m_next.open 36 | 37 | result = expand_fn(m_next.mol) 38 | 39 | if result is not None and (len(result['scores']) > 0): 40 | reactants = result['reactants'] 41 | scores = result['scores'] 42 | costs = 0.0 - np.log(np.clip(np.array(scores), 1e-3, 1.0)) 43 | # costs = 1.0 - np.array(scores) 44 | if 'templates' in result.keys(): 45 | templates = result['templates'] 46 | else: 47 | templates = result['template'] 48 | 49 | reactant_lists = [] 50 | for j in range(len(scores)): 51 | reactant_list = list(set(reactants[j].split('.'))) 52 | reactant_lists.append(reactant_list) 53 | 54 | assert m_next.open 55 | succ = mol_tree.expand(m_next, reactant_lists, costs, templates) 56 | 57 | if succ: 58 | break 59 | 60 | # found optimal route 61 | if mol_tree.root.succ_value <= mol_tree.search_status: 62 | break 63 | 64 | else: 65 | mol_tree.expand(m_next, None, None, None) 66 | logging.info('Expansion fails on %s!' % m_next.mol) 67 | 68 | logging.info('Final search status | success value | iter: %s | %s | %d' 69 | % (str(mol_tree.search_status), str(mol_tree.root.succ_value), i+1)) 70 | 71 | best_route = None 72 | if mol_tree.succ: 73 | best_route = mol_tree.get_best_route() 74 | assert best_route is not None 75 | 76 | if viz: 77 | if not os.path.exists(viz_dir): 78 | os.makedirs(viz_dir) 79 | 80 | if mol_tree.succ: 81 | if best_route.optimal: 82 | f = '%s/mol_%d_route_optimal' % (viz_dir, target_mol_id) 83 | else: 84 | f = '%s/mol_%d_route' % (viz_dir, target_mol_id) 85 | best_route.viz_route(f) 86 | 87 | f = '%s/mol_%d_search_tree' % (viz_dir, target_mol_id) 88 | mol_tree.viz_search_tree(f) 89 | 90 | return mol_tree.succ, (best_route, i+1) 91 | -------------------------------------------------------------------------------- /retro_star/alg/reaction_node.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | 5 | class ReactionNode: 6 | def __init__(self, parent, cost, template): 7 | self.parent = parent 8 | 9 | self.depth = self.parent.depth + 1 10 | self.id = -1 11 | 12 | self.cost = cost 13 | self.template = template 14 | self.children = [] 15 | self.value = None # [V(m | subtree_m) for m in children].sum() + cost 16 | self.succ_value = np.inf # total cost for existing solution 17 | self.target_value = None # V_target(self | whole tree) 18 | self.succ = None # successfully found a valid synthesis route 19 | self.open = True # before expansion: True, after expansion: False 20 | parent.children.append(self) 21 | 22 | def v_self(self): 23 | """ 24 | :return: V_self(self | subtree) 25 | """ 26 | return self.value 27 | 28 | def v_target(self): 29 | """ 30 | :return: V_target(self | whole tree) 31 | """ 32 | return self.target_value 33 | 34 | def init_values(self): 35 | assert self.open 36 | 37 | self.value = self.cost 38 | self.succ = True 39 | for mol in self.children: 40 | self.value += mol.value 41 | self.succ &= mol.succ 42 | 43 | if self.succ: 44 | self.succ_value = self.cost 45 | for mol in self.children: 46 | self.succ_value += mol.succ_value 47 | 48 | self.target_value = self.parent.v_target() - self.parent.v_self() + \ 49 | self.value 50 | self.open = False 51 | 52 | def backup(self, v_delta, from_mol=None): 53 | self.value += v_delta 54 | self.target_value += v_delta 55 | 56 | self.succ = True 57 | for mol in self.children: 58 | self.succ &= mol.succ 59 | 60 | if self.succ: 61 | self.succ_value = self.cost 62 | for mol in self.children: 63 | self.succ_value += mol.succ_value 64 | 65 | if v_delta != 0: 66 | assert from_mol 67 | self.propagate(v_delta, exclude=from_mol) 68 | 69 | return self.parent.backup(self.succ) 70 | 71 | def propagate(self, v_delta, exclude=None): 72 | if exclude is None: 73 | self.target_value += v_delta 74 | 75 | for child in self.children: 76 | if exclude is None or child.mol != exclude: 77 | for grandchild in child.children: 78 | grandchild.propagate(v_delta) 79 | 80 | def serialize(self): 81 | return '%d' % (self.id) 82 | # return '%d | value %.2f | target %.2f' % \ 83 | # (self.id, self.v_self(), self.v_target()) -------------------------------------------------------------------------------- /retro_star/alg/syn_route.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from queue import Queue 3 | from graphviz import Digraph 4 | 5 | class SynRoute: 6 | def __init__(self, target_mol, succ_value, search_status): 7 | self.target_mol = target_mol 8 | self.mols = [target_mol] 9 | self.values = [None] 10 | self.templates = [None] 11 | self.parents = [-1] 12 | self.children = [None] 13 | self.optimal = False 14 | self.costs = {} 15 | 16 | self.succ_value = succ_value 17 | self.total_cost = 0 18 | self.length = 0 19 | self.search_status = search_status 20 | if self.succ_value <= self.search_status: 21 | self.optimal = True 22 | 23 | def _add_mol(self, mol, parent_id): 24 | self.mols.append(mol) 25 | self.values.append(None) 26 | self.templates.append(None) 27 | self.parents.append(parent_id) 28 | self.children.append(None) 29 | 30 | self.children[parent_id].append(len(self.mols)-1) 31 | 32 | def set_value(self, mol, value): 33 | assert mol in self.mols 34 | 35 | mol_id = self.mols.index(mol) 36 | self.values[mol_id] = value 37 | 38 | def add_reaction(self, mol, value, template, reactants, cost): 39 | assert mol in self.mols 40 | 41 | self.total_cost += cost 42 | self.length += 1 43 | 44 | parent_id = self.mols.index(mol) 45 | self.values[parent_id] = value 46 | self.templates[parent_id] = template 47 | self.children[parent_id] = [] 48 | self.costs[parent_id] = cost 49 | 50 | for reactant in reactants: 51 | self._add_mol(reactant, parent_id) 52 | 53 | def viz_route(self, viz_file): 54 | G = Digraph('G', filename=viz_file) 55 | G.attr('node', shape='box') 56 | G.format = 'pdf' 57 | 58 | names = [] 59 | for i in range(len(self.mols)): 60 | name = self.mols[i] 61 | # if self.templates[i] is not None: 62 | # name += ' | %s' % self.templates[i] 63 | names.append(name) 64 | 65 | node_queue = Queue() 66 | node_queue.put((0,-1)) # target mol idx, and parent idx 67 | while not node_queue.empty(): 68 | idx, parent_idx = node_queue.get() 69 | 70 | if parent_idx >= 0: 71 | G.edge(names[parent_idx], names[idx], label='cost') 72 | 73 | if self.children[idx] is not None: 74 | for c in self.children[idx]: 75 | node_queue.put((c, idx)) 76 | 77 | G.render() 78 | 79 | def serialize_reaction(self, idx): 80 | s = self.mols[idx] 81 | if self.children[idx] is None: 82 | return s 83 | s += '>%.4f>' % np.exp(-self.costs[idx]) 84 | s += self.mols[self.children[idx][0]] 85 | for i in range(1, len(self.children[idx])): 86 | s += '.' 87 | s += self.mols[self.children[idx][i]] 88 | 89 | return s 90 | 91 | def serialize(self): 92 | s = self.serialize_reaction(0) 93 | for i in range(1, len(self.mols)): 94 | if self.children[i] is not None: 95 | s += '|' 96 | s += self.serialize_reaction(i) 97 | 98 | return s -------------------------------------------------------------------------------- /retro_star/api.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import time 4 | from retro_star.common import prepare_starting_molecules, prepare_mlp, prepare_molstar_planner_fn, prepare_molstar_planner, smiles_to_fp 5 | from retro_star.model import ValueMLP 6 | from retro_star.utils import setup_logger 7 | 8 | import os 9 | dirpath = os.path.dirname(os.path.abspath(__file__)) 10 | 11 | class value_fnc(): 12 | def __init__(self, model, fp_dim, device): 13 | self.model = model 14 | self.fp_dim = fp_dim 15 | self.device = device 16 | 17 | def __call__(self, mol): 18 | fp = smiles_to_fp(mol, fp_dim=self.fp_dim).reshape(1, -1) 19 | fp = torch.FloatTensor(fp).to(self.device) 20 | v = self.model(fp).item() 21 | return v 22 | 23 | 24 | class RSPlanner: 25 | def __init__(self, 26 | gpu=-1, 27 | expansion_topk=50, 28 | iterations=500, 29 | use_value_fn=False, 30 | starting_molecules=dirpath+'/dataset/origin_dict.csv', 31 | mlp_templates=dirpath+'/one_step_model/template_rules_1.dat', 32 | mlp_model_dump=dirpath+'/one_step_model/saved_rollout_state_1_2048.ckpt', 33 | save_folder=dirpath+'/saved_models', 34 | value_model='best_epoch_final_4.pt', 35 | fp_dim=2048, 36 | viz=False, 37 | viz_dir='viz'): 38 | 39 | setup_logger() 40 | device = torch.device('cuda:%d' % gpu if gpu >= 0 else 'cpu') 41 | starting_mols = prepare_starting_molecules(starting_molecules) 42 | 43 | one_step = prepare_mlp(mlp_templates, mlp_model_dump) 44 | 45 | if use_value_fn: 46 | model = ValueMLP( 47 | n_layers=1, 48 | fp_dim=fp_dim, 49 | latent_dim=128, 50 | dropout_rate=0.1, 51 | device=device 52 | ).to(device) 53 | model_f = '%s/%s' % (save_folder, value_model) 54 | logging.info('Loading value nn from %s' % model_f) 55 | model.load_state_dict(torch.load(model_f, map_location=device)) 56 | model.eval() 57 | value_fn = value_fnc(model, fp_dim, device) 58 | 59 | else: 60 | value_fn = lambda x: 0. 61 | 62 | self.plan_handle = prepare_molstar_planner_fn( 63 | one_step=one_step, 64 | value_fn=value_fn, 65 | starting_mols=starting_mols, 66 | expansion_topk=expansion_topk, 67 | iterations=iterations, 68 | viz=viz, 69 | viz_dir=viz_dir 70 | ) 71 | 72 | def plan(self, target_mol): 73 | t0 = time.time() 74 | succ, msg = self.plan_handle(target_mol) 75 | 76 | if succ: 77 | result = { 78 | 'succ': succ, 79 | 'time': time.time() - t0, 80 | 'iter': msg[1], 81 | 'routes': msg[0].serialize(), 82 | 'route_cost': msg[0].total_cost, 83 | 'route_len': msg[0].length 84 | } 85 | return result 86 | 87 | else: 88 | logging.info('Synthesis path for %s not found. Please try increasing ' 89 | 'the number of iterations.' % target_mol) 90 | return None 91 | 92 | 93 | if __name__ == '__main__': 94 | planner = RSPlanner( 95 | gpu=0, 96 | use_value_fn=True, 97 | iterations=100, 98 | expansion_topk=50 99 | ) 100 | 101 | result = planner.plan('CCCC[C@@H](C(=O)N1CCC[C@H]1C(=O)O)[C@@H](F)C(=O)OC') 102 | print(result) 103 | 104 | result = planner.plan('CCOC(=O)c1nc(N2CC[C@H](NC(=O)c3nc(C(F)(F)F)c(CC)[nH]3)[C@H](OC)C2)sc1C') 105 | print(result) 106 | 107 | result = planner.plan('CC(C)c1ccc(-n2nc(O)c3c(=O)c4ccc(Cl)cc4[nH]c3c2=O)cc1') 108 | print(result) 109 | 110 | -------------------------------------------------------------------------------- /retro_star/common/__init__.py: -------------------------------------------------------------------------------- 1 | # from retro_star.common.parse_args import args 2 | from retro_star.common.prepare_utils import * 3 | from retro_star.common.smiles_to_fp import smiles_to_fp, batch_smiles_to_fp 4 | -------------------------------------------------------------------------------- /retro_star/common/parse_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import sys 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | 9 | # ===================== gpu id ===================== # 10 | parser.add_argument('--gpu', type=int, default=-1) 11 | 12 | # =================== random seed ================== # 13 | parser.add_argument('--seed', type=int, default=1234) 14 | 15 | # ==================== dataset ===================== # 16 | parser.add_argument('--test_routes', 17 | default='dataset/routes_possible_test_hard.pkl') 18 | parser.add_argument('--starting_molecules', default='dataset/origin_dict.csv') 19 | 20 | # ================== value dataset ================= # 21 | parser.add_argument('--value_root', default='dataset') 22 | parser.add_argument('--value_train', default='train_mol_fp_value_step') 23 | parser.add_argument('--value_val', default='val_mol_fp_value_step') 24 | 25 | # ================== one-step model ================ # 26 | parser.add_argument('--mlp_model_dump', 27 | default='one_step_model/saved_rollout_state_1_2048.ckpt') 28 | parser.add_argument('--mlp_templates', 29 | default='one_step_model/template_rules_1.dat') 30 | 31 | # ===================== all algs =================== # 32 | parser.add_argument('--iterations', type=int, default=500) 33 | parser.add_argument('--expansion_topk', type=int, default=50) 34 | parser.add_argument('--viz', action='store_true') 35 | parser.add_argument('--viz_dir', default='viz') 36 | 37 | # ===================== model ====================== # 38 | parser.add_argument('--fp_dim', type=int, default=2048) 39 | parser.add_argument('--n_layers', type=int, default=1) 40 | parser.add_argument('--latent_dim', type=int, default=128) 41 | 42 | # ==================== training ==================== # 43 | parser.add_argument('--n_epochs', type=int, default=1) 44 | parser.add_argument('--batch_size', type=int, default=128) 45 | parser.add_argument('--lr', type=float, default=1e-3) 46 | parser.add_argument('--save_epoch_int', type=int, default=1) 47 | parser.add_argument('--save_folder', default='saved_models') 48 | 49 | # ==================== evaluation =================== # 50 | parser.add_argument('--use_value_fn', action='store_true') 51 | parser.add_argument('--value_model', default='best_epoch_final_4.pt') 52 | parser.add_argument('--result_folder', default='results') 53 | 54 | args = parser.parse_args() 55 | 56 | # setup device 57 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 58 | -------------------------------------------------------------------------------- /retro_star/common/prepare_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pandas as pd 3 | import logging 4 | from mlp_retrosyn.mlp_inference import MLPModel 5 | from retro_star.alg import molstar 6 | 7 | def prepare_starting_molecules(filename): 8 | logging.info('Loading starting molecules from %s' % filename) 9 | 10 | if filename[-3:] == 'csv': 11 | starting_mols = set(list(pd.read_csv(filename)['mol'])) 12 | else: 13 | assert filename[-3:] == 'pkl' 14 | with open(filename, 'rb') as f: 15 | starting_mols = pickle.load(f) 16 | 17 | logging.info('%d starting molecules loaded' % len(starting_mols)) 18 | return starting_mols 19 | 20 | def prepare_mlp(templates, model_dump): 21 | logging.info('Templates: %s' % templates) 22 | logging.info('Loading trained mlp model from %s' % model_dump) 23 | one_step = MLPModel(model_dump, templates, device=-1) 24 | return one_step 25 | 26 | def prepare_molstar_planner(one_step, value_fn, starting_mols, expansion_topk, 27 | iterations, viz=False, viz_dir=None): 28 | expansion_handle = lambda x: one_step.run(x, topk=expansion_topk) 29 | 30 | plan_handle = lambda x, y=0: molstar( 31 | target_mol=x, 32 | target_mol_id=y, 33 | starting_mols=starting_mols, 34 | expand_fn=expansion_handle, 35 | value_fn=value_fn, 36 | iterations=iterations, 37 | viz=viz, 38 | viz_dir=viz_dir 39 | ) 40 | return plan_handle 41 | 42 | class prepare_molstar_planner_fn(): 43 | def __init__(self, one_step, value_fn, starting_mols, expansion_topk, iterations, viz=False, viz_dir=None): 44 | self.one_step = one_step 45 | self.value_fn = value_fn 46 | self.starting_mols = starting_mols 47 | self.expansion_topk = expansion_topk 48 | self.iterations = iterations 49 | self.viz = viz 50 | self.viz_dir = viz_dir 51 | 52 | def expansion_handle(self, x): 53 | return self.one_step.run(x, topk=self.expansion_topk) 54 | 55 | def __call__(self, x): 56 | plan_handle = molstar( 57 | target_mol=x, 58 | target_mol_id=0, 59 | starting_mols=self.starting_mols, 60 | expand_fn=self.expansion_handle, 61 | value_fn=self.value_fn, 62 | iterations=self.iterations, 63 | viz=self.viz, 64 | viz_dir=self.viz_dir 65 | ) 66 | return plan_handle 67 | -------------------------------------------------------------------------------- /retro_star/common/smiles_to_fp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | from rdkit.Chem import AllChem 4 | 5 | def smiles_to_fp(s, fp_dim=2048, pack=False): 6 | mol = Chem.MolFromSmiles(s) 7 | fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=fp_dim) 8 | onbits = list(fp.GetOnBits()) 9 | arr = np.zeros(fp.GetNumBits(), dtype=np.bool) 10 | arr[onbits] = 1 11 | 12 | if pack: 13 | arr = np.packbits(arr) 14 | 15 | return arr 16 | 17 | def batch_smiles_to_fp(s_list, fp_dim): 18 | fps = [] 19 | for s in s_list: 20 | fps.append(smiles_to_fp(s, fp_dim)) 21 | fps = np.array(fps) 22 | 23 | assert fps.shape[0] == len(s_list) and fps.shape[1] == fp_dim 24 | 25 | return fps -------------------------------------------------------------------------------- /retro_star/data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from retro_star.data_loader.value_data_loader import ValueDataLoader -------------------------------------------------------------------------------- /retro_star/data_loader/value_data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import pickle 5 | import logging 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | 9 | def unpack_fps(packed_fps): 10 | # packed_fps = np.array(packed_fps) 11 | shape = (*(packed_fps.shape[:-1]), -1) 12 | fps = np.unpackbits(packed_fps.reshape((-1, packed_fps.shape[-1])), 13 | axis=-1) 14 | fps = torch.FloatTensor(fps).view(shape) 15 | 16 | return fps 17 | 18 | class ValueDataset(Dataset): 19 | def __init__(self, fp_value_f): 20 | assert os.path.exists('%s.pt' % fp_value_f) 21 | logging.info('Loading value dataset from %s.pt'% fp_value_f) 22 | data_dict = torch.load('%s.pt' % fp_value_f) 23 | self.fps = unpack_fps(data_dict['fps']) 24 | self.values = data_dict['values'] 25 | 26 | filter = self.values[:,0] > 0 27 | self.fps = self.fps[filter] 28 | self.values = self.values[filter] 29 | 30 | self.reaction_costs = data_dict['reaction_costs'] 31 | self.target_values = data_dict['target_values'] 32 | # self.reactant_fps = unpack_fps(data_dict['reactant_fps']) 33 | self.reactant_packed_fps = data_dict['reactant_fps'] 34 | self.reactant_masks = data_dict['reactant_masks'] 35 | self.reactant_fps = None 36 | self.reshuffle() 37 | 38 | assert self.fps.shape[0] == self.values.shape[0] 39 | logging.info('%d (fp, value) pairs loaded' % self.fps.shape[0]) 40 | logging.info('%d nagative samples loaded' % self.reactant_fps.shape[0]) 41 | print(self.fps.shape, self.values.shape, 42 | self.reactant_fps.shape, self.reactant_masks.shape) 43 | 44 | logging.info( 45 | 'mean: %f, std:%f, min: %f, max: %f, zeros: %f' % 46 | (self.values.mean(), self.values.std(), self.values.min(), 47 | self.values.max(), (self.values==0).sum()*1. / self.fps.shape[0]) 48 | ) 49 | 50 | def reshuffle(self): 51 | shuffle_idx = np.random.permutation(self.reaction_costs.shape[0]) 52 | self.reaction_costs = self.reaction_costs[shuffle_idx] 53 | self.target_values = self.target_values[shuffle_idx] 54 | self.reactant_packed_fps = self.reactant_packed_fps[shuffle_idx] 55 | self.reactant_masks = self.reactant_masks[shuffle_idx] 56 | 57 | self.reactant_fps = unpack_fps( 58 | self.reactant_packed_fps[:self.fps.shape[0],:,:]) 59 | 60 | def __len__(self): 61 | return self.fps.shape[0] 62 | 63 | def __getitem__(self, index): 64 | return self.fps[index], self.values[index], \ 65 | self.reaction_costs[index], self.target_values[index], \ 66 | self.reactant_fps[index], self.reactant_masks[index] 67 | 68 | 69 | class ValueDataLoader(DataLoader): 70 | def __init__(self, fp_value_f, batch_size, shuffle=True): 71 | self.dataset = ValueDataset(fp_value_f) 72 | 73 | super(ValueDataLoader, self).__init__( 74 | dataset=self.dataset, 75 | batch_size=batch_size, 76 | shuffle=shuffle 77 | ) 78 | 79 | def reshuffle(self): 80 | self.dataset.reshuffle() 81 | 82 | -------------------------------------------------------------------------------- /retro_star/environment.yml: -------------------------------------------------------------------------------- 1 | name: retro_star_env 2 | channels: 3 | - rdkit 4 | - conda-forge 5 | - pytorch 6 | - anaconda 7 | - defaults 8 | dependencies: 9 | - python=3.7 10 | - torchvision 11 | - cpuonly 12 | - pytorch 13 | - pandas 14 | - rdkit 15 | - tqdm 16 | - networkx 17 | - graphviz 18 | - python-graphviz 19 | prefix: $CONDA_PREFIX 20 | 21 | -------------------------------------------------------------------------------- /retro_star/model/__init__.py: -------------------------------------------------------------------------------- 1 | from retro_star.model.value_mlp import ValueMLP 2 | -------------------------------------------------------------------------------- /retro_star/model/value_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import logging 5 | 6 | 7 | class ValueMLP(nn.Module): 8 | def __init__(self, n_layers, fp_dim, latent_dim, dropout_rate, device): 9 | super(ValueMLP, self).__init__() 10 | self.n_layers = n_layers 11 | self.fp_dim = fp_dim 12 | self.latent_dim = latent_dim 13 | self.dropout_rate = dropout_rate 14 | self.device = device 15 | 16 | logging.info('Initializing value model: latent_dim=%d' % self.latent_dim) 17 | 18 | layers = [] 19 | layers.append(nn.Linear(fp_dim, latent_dim)) 20 | # layers.append(nn.BatchNorm1d(latent_dim, 21 | # track_running_stats=False)) 22 | layers.append(nn.ReLU()) 23 | layers.append(nn.Dropout(self.dropout_rate)) 24 | for _ in range(self.n_layers - 1): 25 | layers.append(nn.Linear(latent_dim, latent_dim)) 26 | # layers.append(nn.BatchNorm1d(latent_dim, 27 | # track_running_stats=False)) 28 | layers.append(nn.ReLU()) 29 | layers.append(nn.Dropout(self.dropout_rate)) 30 | layers.append(nn.Linear(latent_dim, 1)) 31 | 32 | self.layers = nn.Sequential(*layers) 33 | 34 | def forward(self, fps): 35 | x = fps 36 | x = self.layers(x) 37 | x = torch.log(1 + torch.exp(x)) 38 | 39 | return x 40 | -------------------------------------------------------------------------------- /retro_star/packages/mlp_retrosyn/mlp_retrosyn.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: mlp-retrosyn 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | License: UNKNOWN 7 | Platform: UNKNOWN 8 | 9 | UNKNOWN 10 | 11 | -------------------------------------------------------------------------------- /retro_star/packages/mlp_retrosyn/mlp_retrosyn.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | mlp_retrosyn.egg-info/PKG-INFO 3 | mlp_retrosyn.egg-info/SOURCES.txt 4 | mlp_retrosyn.egg-info/dependency_links.txt 5 | mlp_retrosyn.egg-info/top_level.txt -------------------------------------------------------------------------------- /retro_star/packages/mlp_retrosyn/mlp_retrosyn.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /retro_star/packages/mlp_retrosyn/mlp_retrosyn.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | mlp_retrosyn 2 | -------------------------------------------------------------------------------- /retro_star/packages/mlp_retrosyn/mlp_retrosyn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmh14/data_efficient_grammar/ef34648f987278496e1216cbeb7f82c9429da4b0/retro_star/packages/mlp_retrosyn/mlp_retrosyn/__init__.py -------------------------------------------------------------------------------- /retro_star/packages/mlp_retrosyn/mlp_retrosyn/extract_template.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified version of: 3 | 4 | """ 5 | 6 | import os 7 | from tqdm import tqdm 8 | from pprint import pprint 9 | from collections import defaultdict 10 | import pandas as pd 11 | 12 | if __name__ == '__main__': 13 | import argparse 14 | from pprint import pprint 15 | parser = argparse.ArgumentParser(description="Policies for retrosynthesis Planner") 16 | parser.add_argument('--data_folder', default='../data/uspto_all/', 17 | type=str, help='Specify the path of all template rules.') 18 | parser.add_argument('--file_name',default='proc_all_cano_smiles_w_tmpl.csv', 19 | type=str, 20 | help='Specify the filen name') 21 | args = parser.parse_args() 22 | data_folder = args.data_folder 23 | file_name = args.file_name 24 | 25 | 26 | templates = defaultdict(tuple) 27 | transforms = [] 28 | datafile = os.path.join(data_folder,file_name) 29 | df = pd.read_csv(datafile) 30 | rxn_smiles = list(df['rxn_smiles']) 31 | retro_templates = list(df['retro_templates']) 32 | for i in tqdm(range(len(df))): 33 | rxn = rxn_smiles[i] 34 | rule = retro_templates[i] 35 | product = rxn.strip().split('>')[-1] 36 | transforms.append((rule,product)) 37 | print(len(transforms)) 38 | with open(os.path.join(data_folder,'templates.dat'), 'w') as f: 39 | f.write('\n'.join(['\t'.join(rxn_prod) for rxn_prod in transforms])) 40 | 41 | # Generate rules for MCTS 42 | templates = defaultdict(int) 43 | for rule, _ in tqdm(transforms): 44 | templates[rule] += 1 45 | print("The number of templates is {}".format(len(templates))) 46 | # # 47 | template_rules = [rule for rule, cnt in templates.items() if cnt >= 1] 48 | print("all template rules with count >= 1: ", len(template_rules)) 49 | with open(os.path.join(data_folder,'template_rules_1.dat'), 'w') as f: 50 | f.write('\n'.join(template_rules)) -------------------------------------------------------------------------------- /retro_star/packages/mlp_retrosyn/mlp_retrosyn/mlp_inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from rdkit import Chem 6 | import rdchiral 7 | from rdchiral.main import rdchiralRunText, rdchiralRun 8 | from rdchiral.initialization import rdchiralReaction, rdchiralReactants 9 | from .mlp_policies import load_parallel_model , preprocess 10 | from collections import defaultdict, OrderedDict 11 | 12 | def merge(reactant_d): 13 | ret = [] 14 | for reactant, l in reactant_d.items(): 15 | ss, ts = zip(*l) 16 | ret.append((reactant, sum(ss), list(ts)[0])) 17 | reactants, scores, templates = zip(*sorted(ret,key=lambda item : item[1], reverse=True)) 18 | return list(reactants), list(scores), list(templates) 19 | 20 | 21 | 22 | class MLPModel(object): 23 | def __init__(self,state_path, template_path, device=-1, fp_dim=2048): 24 | super(MLPModel, self).__init__() 25 | self.fp_dim = fp_dim 26 | self.net, self.idx2rules = load_parallel_model(state_path,template_path, fp_dim) 27 | self.net.eval() 28 | self.device = device 29 | if device >= 0: 30 | self.net.to(device) 31 | 32 | def run(self, x, topk=10): 33 | arr = preprocess(x, self.fp_dim) 34 | arr = np.reshape(arr,[-1, arr.shape[0]]) 35 | arr = torch.tensor(arr, dtype=torch.float32) 36 | if self.device >= 0: 37 | arr = arr.to(self.device) 38 | preds = self.net(arr) 39 | preds = F.softmax(preds,dim=1) 40 | if self.device >= 0: 41 | preds = preds.cpu() 42 | probs, idx = torch.topk(preds,k=topk) 43 | # probs = F.softmax(probs,dim=1) 44 | rule_k = [self.idx2rules[id] for id in idx[0].numpy().tolist()] 45 | reactants = [] 46 | scores = [] 47 | templates = [] 48 | for i , rule in enumerate(rule_k): 49 | out1 = [] 50 | try: 51 | out1 = rdchiralRunText(rule, x) 52 | # out1 = rdchiralRunText(rule, Chem.MolToSmiles(Chem.MolFromSmarts(x))) 53 | if len(out1) == 0: continue 54 | # if len(out1) > 1: print("more than two reactants."),print(out1) 55 | out1 = sorted(out1) 56 | for reactant in out1: 57 | reactants.append(reactant) 58 | scores.append(probs[0][i].item()/len(out1)) 59 | templates.append(rule) 60 | # out1 = rdchiralRunText(x, rule) 61 | except ValueError: 62 | pass 63 | if len(reactants) == 0: return None 64 | reactants_d = defaultdict(list) 65 | for r, s, t in zip(reactants, scores, templates): 66 | if '.' in r: 67 | str_list = sorted(r.strip().split('.')) 68 | reactants_d['.'.join(str_list)].append((s, t)) 69 | else: 70 | reactants_d[r].append((s, t)) 71 | 72 | reactants, scores, templates = merge(reactants_d) 73 | total = sum(scores) 74 | scores = [s / total for s in scores] 75 | return {'reactants':reactants, 76 | 'scores' : scores, 77 | 'template' : templates} 78 | 79 | 80 | 81 | if __name__ == '__main__': 82 | import argparse 83 | from pprint import pprint 84 | parser = argparse.ArgumentParser(description="Policies for retrosynthesis Planner") 85 | parser.add_argument('--template_rule_path', default='../data/uspto_all/template_rules_1.dat', 86 | type=str, help='Specify the path of all template rules.') 87 | parser.add_argument('--model_path', default='../model/saved_rollout_state_1_2048.ckpt', 88 | type=str, help='specify where the trained model is') 89 | args = parser.parse_args() 90 | state_path = args.model_path 91 | template_path = args.template_rule_path 92 | model = MLPModel(state_path,template_path,device=-1) 93 | x = '[F-:1]' 94 | # x = '[CH2:10]([S:14]([O:3][CH2:2][CH2:1][Cl:4])(=[O:16])=[O:15])[CH:11]([CH3:13])[CH3:12]' 95 | # x = '[S:3](=[O:4])(=[O:5])([O:6][CH2:7][CH:8]([CH2:9][CH2:10][CH2:11][CH3:12])[CH2:13][CH3:14])[OH:15]' 96 | # x = 'OCC(=O)OCCCO' 97 | # x = 'CC(=O)NC1=CC=C(O)C=C1' 98 | x = 'S=C(Cl)(Cl)' 99 | # x = "NCCNC(=O)c1ccc(/C=N/Nc2ncnc3c2cnn3-c2ccccc2)cc1" 100 | # x = 'CCOC(=O)c1cnc2c(F)cc(Br)cc2c1O' 101 | # x = 'COc1cc2ncnc(Oc3cc(NC(=O)Nc4cc(C(C)(C)C(F)(F)F)on4)ccc3F)c2cc1OC' 102 | # x = 'COC(=O)c1ccc(CN2C(=O)C3(COc4cc5c(cc43)OCCO5)c3ccccc32)o1' 103 | x = 'O=C1Nc2ccccc2C12COc1cc3c(cc12)OCCO3' 104 | # x = 'CO[C@H](CC(=O)O)C(=O)O' 105 | # x = 'O=C(O)c1cc(OCC(F)(F)F)c(C2CC2)cn1' 106 | y = model.run(x,10) 107 | pprint(y) 108 | -------------------------------------------------------------------------------- /retro_star/packages/mlp_retrosyn/mlp_retrosyn/mlp_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from tqdm import tqdm 4 | from .mlp_policies import train_mlp 5 | from pprint import pprint 6 | if __name__ == '__main__': 7 | import argparse 8 | parser = argparse.ArgumentParser(description="train function for retrosynthesis Planner policies") 9 | parser.add_argument('--template_path',default= 'data/cooked_data/templates.dat', 10 | type=str, help='Specify the path of the template.data') 11 | parser.add_argument('--template_rule_path', default='data/cooked_data/template_rules_1.dat', 12 | type=str, help='Specify the path of all template rules.') 13 | parser.add_argument('--model_dump_folder',default='./model', 14 | type=str, help='specify where to save the trained models') 15 | parser.add_argument('--fp_dim',default=2048, type=int, 16 | help="specify the fingerprint feature dimension") 17 | parser.add_argument('--batch_size', default=1024, type=int, 18 | help="specify the batch size") 19 | parser.add_argument('--dropout_rate', default=0.4, type=float, 20 | help="specify the dropout rate") 21 | parser.add_argument('--learning_rate', default=0.001, type=float, 22 | help="specify the learning rate") 23 | args = parser.parse_args() 24 | template_path = args.template_path 25 | template_rule_path = args.template_rule_path 26 | model_dump_folder = args.model_dump_folder 27 | fp_dim = args.fp_dim 28 | batch_size = args.batch_size 29 | dropout_rate = args.dropout_rate 30 | lr = args.learning_rate 31 | print('Loading data...') 32 | prod_to_rules = defaultdict(set) 33 | ### read the template data. 34 | with open(template_path, 'r') as f: 35 | for l in tqdm(f, desc="reading the mapping from prod to rules"): 36 | rule, prod = l.strip().split('\t') 37 | prod_to_rules[prod].add(rule) 38 | if not os.path.exists(model_dump_folder): 39 | os.mkdir(model_dump_folder) 40 | pprint(args) 41 | train_mlp(prod_to_rules, 42 | template_rule_path, 43 | fp_dim=fp_dim, 44 | batch_size=batch_size, 45 | lr=lr, 46 | dropout_rate=dropout_rate, 47 | saved_model=os.path.join(model_dump_folder, 'saved_rollout_state_1')) 48 | -------------------------------------------------------------------------------- /retro_star/packages/mlp_retrosyn/mlp_retrosyn/scripts/run_extract_templates.sh: -------------------------------------------------------------------------------- 1 | python ../extract_template.py --data_folder '../data/toy/' \ 2 | --file_name 'sample.csv' -------------------------------------------------------------------------------- /retro_star/packages/mlp_retrosyn/mlp_retrosyn/scripts/run_mlp_inference.sh: -------------------------------------------------------------------------------- 1 | python ../mlp_inference.py --template_rule_path '../data/toy/template_rules_1.dat' \ 2 | --model_path '../model/toy/saved_rollout_state_1_2048.ckpt' -------------------------------------------------------------------------------- /retro_star/packages/mlp_retrosyn/mlp_retrosyn/scripts/run_mlp_train.sh: -------------------------------------------------------------------------------- 1 | python ../mlp_train.py --template_path '../data/toy/templates.dat' \ 2 | --template_rule_path '../data/toy/template_rules_1.dat' \ 3 | --model_dump_folder '../model/toy/' \ 4 | --fp_dim 2048 \ 5 | --batch_size 1024 \ 6 | --dropout_rate 0.4 \ 7 | --learning_rate 0.001 8 | -------------------------------------------------------------------------------- /retro_star/packages/mlp_retrosyn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | import os 4 | BASEPATH = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | setup(name='mlp_retrosyn', 7 | py_modules=['mlp_retrosyn'], 8 | install_requires=[], 9 | ) 10 | -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Connor Coley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/README.md: -------------------------------------------------------------------------------- 1 | # rdchiral 2 | Wrapper for RDKit's RunReactants to improve stereochemistry handling 3 | 4 | See ```rdchiral/main.py``` for a brief description of expected behavior and a few basic examples of how to use the wrapper. 5 | 6 | See ```rdchiral/test/test_rdchiral.py``` for a small set of test cases described [here](https://chemrxiv.org/articles/RDChiral_An_RDKit_Wrapper_for_Handling_Stereochemistry_in_Retrosynthetic_Template_Extraction_and_Application/7949024) 7 | 8 | 9 | ## install 10 | cd to the root of this project, then do 11 | 12 | `pip install -e .` 13 | -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/__init__.py: -------------------------------------------------------------------------------- 1 | from rdkit import RDLogger 2 | lg = RDLogger.logger() 3 | lg.setLevel(RDLogger.CRITICAL) -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/backup/__init__.py: -------------------------------------------------------------------------------- 1 | from rdkit import RDLogger 2 | lg = RDLogger.logger() 3 | lg.setLevel(RDLogger.CRITICAL) -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/backup/chiral.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from rdkit.Chem.rdchem import ChiralType, BondType, BondDir 3 | 4 | from rdchiral.utils import vprint, parity4, PLEVEL 5 | 6 | def template_atom_could_have_been_tetra(a, strip_if_spec=False, cache=True): 7 | ''' 8 | Could this atom have been a tetrahedral center? 9 | If yes, template atom is considered achiral and will not match a chiral rct 10 | If no, the tempalte atom is auxilliary and we should not use it to remove 11 | a matched reaction. For example, a fully-generalized terminal [C:1] 12 | ''' 13 | 14 | if a.HasProp('tetra_possible'): 15 | return a.GetBoolProp('tetra_possible') 16 | if a.GetDegree() < 3 or (a.GetDegree() == 3 and 'H' not in a.GetSmarts()): 17 | if cache: 18 | a.SetBoolProp('tetra_possible', False) 19 | if strip_if_spec: # Clear chiral tag in case improperly set 20 | a.SetChiralTag(ChiralType.CHI_UNSPECIFIED) 21 | return False 22 | if cache: 23 | a.SetBoolProp('tetra_possible', True) 24 | return True 25 | 26 | 27 | 28 | def copy_chirality(a_src, a_new): 29 | 30 | # Not possible to be a tetrahedral center anymore? 31 | if a_new.GetDegree() < 3: 32 | return 33 | if a_new.GetDegree() == 3 and \ 34 | any(b.GetBondType() != BondType.SINGLE for b in a_new.GetBonds()): 35 | return 36 | 37 | if PLEVEL >= 3: print('For mapnum {}, copying src {} chirality tag to new'.format( 38 | a_src.GetAtomMapNum(), a_src.GetChiralTag())) 39 | a_new.SetChiralTag(a_src.GetChiralTag()) 40 | 41 | if atom_chirality_matches(a_src, a_new) == -1: 42 | if PLEVEL >= 3: print('For mapnum {}, inverting chirality'.format(a_new.GetAtomMapNum())) 43 | a_new.InvertChirality() 44 | 45 | def atom_chirality_matches(a_tmp, a_mol): 46 | ''' 47 | Checks for consistency in chirality between a template atom and a molecule atom. 48 | 49 | Also checks to see if chirality needs to be inverted in copy_chirality 50 | 51 | Returns +1 if it is a match and there is no need for inversion (or ambiguous) 52 | Returns -1 if it is a match but they are the opposite 53 | Returns 0 if an explicit NOT match 54 | Returns 2 if ambiguous or achiral-achiral 55 | ''' 56 | if a_mol.GetChiralTag() == ChiralType.CHI_UNSPECIFIED: 57 | if a_tmp.GetChiralTag() == ChiralType.CHI_UNSPECIFIED: 58 | if PLEVEL >= 3: print('atom {} is achiral & achiral -> match'.format(a_mol.GetAtomMapNum())) 59 | return 2 # achiral template, achiral molecule -> match 60 | # What if the template was chiral, but the reactant isn't just due to symmetry? 61 | if not a_mol.HasProp('_ChiralityPossible'): 62 | # It's okay to make a match, as long as the product is achiral (even 63 | # though the product template will try to impose chirality) 64 | if PLEVEL >= 3: print('atom {} is specified in template, but cant possibly be chiral in mol'.format(a_mol.GetAtomMapNum())) 65 | return 2 66 | 67 | # Discussion: figure out if we want this behavior - should a chiral template 68 | # be applied to an achiral molecule? For the retro case, if we have 69 | # a retro reaction that requires a specific stereochem, return False; 70 | # however, there will be many cases where the reaction would probably work 71 | if PLEVEL >= 3: print('atom {} is achiral in mol, but specified in template'.format(a_mol.GetAtomMapNum())) 72 | return 0 73 | if a_tmp.GetChiralTag() == ChiralType.CHI_UNSPECIFIED: 74 | if PLEVEL >= 3: print('Reactant {} atom chiral, rtemplate achiral...'.format(a_tmp.GetAtomMapNum())) 75 | if template_atom_could_have_been_tetra(a_tmp): 76 | if PLEVEL >= 3: print('...and that atom could have had its chirality specified! no_match') 77 | return 0 78 | if PLEVEL >= 3: print('...but the rtemplate atom could not have had chirality specified, match anyway') 79 | return 2 80 | 81 | mapnums_tmp = [a.GetAtomMapNum() for a in a_tmp.GetNeighbors()] 82 | mapnums_mol = [a.GetAtomMapNum() for a in a_mol.GetNeighbors()] 83 | 84 | # When there are fewer than 3 heavy neighbors, chirality is ambiguous... 85 | if len(mapnums_tmp) < 3 or len(mapnums_mol) < 3: 86 | return 2 87 | 88 | # Degree of 3 -> remaining atom is a hydrogen, add to list 89 | if len(mapnums_tmp) < 4: 90 | mapnums_tmp.append(-1) # H 91 | if len(mapnums_mol) < 4: 92 | mapnums_mol.append(-1) # H 93 | 94 | try: 95 | if PLEVEL >= 10: print(str(mapnums_tmp)) 96 | if PLEVEL >= 10: print(str(mapnums_mol)) 97 | if PLEVEL >= 10: print(str(a_tmp.GetChiralTag())) 98 | if PLEVEL >= 10: print(str(a_mol.GetChiralTag())) 99 | only_in_src = [i for i in mapnums_tmp if i not in mapnums_mol][::-1] # reverse for popping 100 | only_in_mol = [i for i in mapnums_mol if i not in mapnums_tmp] 101 | if len(only_in_src) <= 1 and len(only_in_mol) <= 1: 102 | tmp_parity = parity4(mapnums_tmp) 103 | mol_parity = parity4([i if i in mapnums_tmp else only_in_src.pop() for i in mapnums_mol]) 104 | if PLEVEL >= 10: print(str(tmp_parity)) 105 | if PLEVEL >= 10: print(str(mol_parity)) 106 | parity_matches = tmp_parity == mol_parity 107 | tag_matches = a_tmp.GetChiralTag() == a_mol.GetChiralTag() 108 | chirality_matches = parity_matches == tag_matches 109 | if PLEVEL >= 2: print('mapnum {} chiral match? {}'.format(a_tmp.GetAtomMapNum(), chirality_matches)) 110 | return 1 if chirality_matches else -1 111 | else: 112 | if PLEVEL >= 2: print('mapnum {} chiral match? Based on mapnum lists, ambiguous -> True'.format(a_tmp.GetAtomMapNum())) 113 | return 2 # ambiguous case, just return for now 114 | 115 | except IndexError as e: 116 | print(a_tmp.GetPropsAsDict()) 117 | print(a_mol.GetPropsAsDict()) 118 | print(a_tmp.GetChiralTag()) 119 | print(a_mol.GetChiralTag()) 120 | print(str(e)) 121 | print(str(mapnums_tmp)) 122 | print(str(mapnums_mol)) 123 | raise KeyError('Pop from empty set - this should not happen!') -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/backup/clean.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import rdkit.Chem as Chem 3 | import re 4 | from itertools import chain 5 | 6 | from rdchiral.utils import vprint, PLEVEL 7 | 8 | 9 | def canonicalize_outcome_smiles(smiles, ensure=True): 10 | # Uniquify via SMILES string - a little sloppy 11 | # Need a full SMILES->MOL->SMILES cycle to get a true canonical string 12 | # also, split by '.' and sort when outcome contains multiple molecules 13 | if ensure: 14 | outcome = Chem.MolFromSmiles(smiles) 15 | if outcome is None: 16 | if PLEVEL >= 1: print('~~ could not parse self?') 17 | if PLEVEL >= 1: print('Attempted SMILES: {}', smiles) 18 | return None 19 | 20 | smiles = Chem.MolToSmiles(outcome, True) 21 | 22 | return '.'.join(sorted(smiles.split('.'))) 23 | 24 | def combine_enantiomers_into_racemic(final_outcomes): 25 | ''' 26 | If two products are identical except for an inverted CW/CCW or an 27 | opposite cis/trans, then just strip that from the product. Return 28 | the achiral one instead. 29 | 30 | This is not very sophisticated, since the chirality could affect the bond 31 | order and thus the canonical SMILES. But, whatever. It also does not look 32 | to invert multiple stereocenters at once 33 | ''' 34 | 35 | for smiles in list(final_outcomes)[:]: 36 | 37 | # Look for @@ tetrahedral center 38 | for match in re.finditer(r'@@', smiles): 39 | smiles_inv = '%s@%s' % (smiles[:match.start()], smiles[match.end():]) 40 | if smiles_inv in final_outcomes: 41 | if smiles in final_outcomes: 42 | final_outcomes.remove(smiles) 43 | final_outcomes.remove(smiles_inv) 44 | # Re-parse smiles so that hydrogens can become implicit 45 | smiles = smiles[:match.start()] + smiles[match.end():] 46 | outcome = Chem.MolFromSmiles(smiles) 47 | if outcome is None: 48 | raise ValueError('Horrible mistake when fixing duplicate!') 49 | smiles = '.'.join(sorted(Chem.MolToSmiles(outcome, True).split('.'))) 50 | final_outcomes.add(smiles) 51 | 52 | # Look for // or \\ trans bond 53 | # where [^=\.] is any non-double bond or period or slash 54 | for match in chain(re.finditer(r'(\/)([^=\.\\\/]+=[^=\.\\\/]+)(\/)', smiles), 55 | re.finditer(r'(\\)([^=\.\\\/]+=[^=\.\\\/]+)(\\)', smiles)): 56 | # See if cis version is present in list of outcomes 57 | opposite = {'\\': '/', '/': '\\'} 58 | smiles_cis1 = '%s%s%s%s%s' % (smiles[:match.start()], 59 | match.group(1), match.group(2), opposite[match.group(3)], 60 | smiles[match.end():]) 61 | smiles_cis2 = '%s%s%s%s%s' % (smiles[:match.start()], 62 | opposite[match.group(1)], match.group(2), match.group(3), 63 | smiles[match.end():]) 64 | # Also look for equivalent trans 65 | smiles_trans2 = '%s%s%s%s%s' % (smiles[:match.start()], 66 | opposite[match.group(1)], match.group(2), 67 | opposite[match.group(3)], smiles[match.end():]) 68 | # Kind of weird remove conditionals... 69 | remove = False 70 | if smiles_cis1 in final_outcomes: 71 | final_outcomes.remove(smiles_cis1) 72 | remove = True 73 | if smiles_cis2 in final_outcomes: 74 | final_outcomes.remove(smiles_cis2) 75 | remove = True 76 | if smiles_trans2 in final_outcomes and smiles in final_outcomes: 77 | final_outcomes.remove(smiles_trans2) 78 | if remove: 79 | final_outcomes.remove(smiles) 80 | smiles = smiles[:match.start()] + match.group(2) + smiles[match.end():] 81 | outcome = Chem.MolFromSmiles(smiles) 82 | if outcome is None: 83 | raise ValueError('Horrible mistake when fixing duplicate!') 84 | smiles = '.'.join(sorted(Chem.MolToSmiles(outcome, True).split('.'))) 85 | final_outcomes.add(smiles) 86 | return final_outcomes 87 | -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/backup/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | PLEVEL = 0 4 | def vprint(level, txt, *args): 5 | if PLEVEL >= level: 6 | print(txt.format(*args)) 7 | 8 | def parity4(data): 9 | ''' 10 | Thanks to http://www.dalkescientific.com/writings/diary/archive/2016/08/15/fragment_parity_calculation.html 11 | ''' 12 | if data[0] < data[1]: 13 | if data[2] < data[3]: 14 | if data[0] < data[2]: 15 | if data[1] < data[2]: 16 | return 0 # (0, 1, 2, 3) 17 | else: 18 | if data[1] < data[3]: 19 | return 1 # (0, 2, 1, 3) 20 | else: 21 | return 0 # (0, 3, 1, 2) 22 | else: 23 | if data[0] < data[3]: 24 | if data[1] < data[3]: 25 | return 0 # (1, 2, 0, 3) 26 | else: 27 | return 1 # (1, 3, 0, 2) 28 | else: 29 | return 0 # (2, 3, 0, 1) 30 | else: 31 | if data[0] < data[3]: 32 | if data[1] < data[2]: 33 | if data[1] < data[3]: 34 | return 1 # (0, 1, 3, 2) 35 | else: 36 | return 0 # (0, 2, 3, 1) 37 | else: 38 | return 1 # (0, 3, 2, 1) 39 | else: 40 | if data[0] < data[2]: 41 | if data[1] < data[2]: 42 | return 1 # (1, 2, 3, 0) 43 | else: 44 | return 0 # (1, 3, 2, 0) 45 | else: 46 | return 1 # (2, 3, 1, 0) 47 | else: 48 | if data[2] < data[3]: 49 | if data[0] < data[3]: 50 | if data[0] < data[2]: 51 | return 1 # (1, 0, 2, 3) 52 | else: 53 | if data[1] < data[2]: 54 | return 0 # (2, 0, 1, 3) 55 | else: 56 | return 1 # (2, 1, 0, 3) 57 | else: 58 | if data[1] < data[2]: 59 | return 1 # (3, 0, 1, 2) 60 | else: 61 | if data[1] < data[3]: 62 | return 0 # (3, 1, 0, 2) 63 | else: 64 | return 1 # (3, 2, 0, 1) 65 | else: 66 | if data[0] < data[2]: 67 | if data[0] < data[3]: 68 | return 0 # (1, 0, 3, 2) 69 | else: 70 | if data[1] < data[3]: 71 | return 1 # (2, 0, 3, 1) 72 | else: 73 | return 0 # (2, 1, 3, 0) 74 | else: 75 | if data[1] < data[2]: 76 | if data[1] < data[3]: 77 | return 0 # (3, 0, 2, 1) 78 | else: 79 | return 1 # (3, 1, 2, 0) 80 | else: 81 | return 0 # (3, 2, 1, 0) 82 | 83 | def bond_to_label(bond): 84 | '''This function takes an RDKit bond and creates a label describing 85 | the most important attributes''' 86 | 87 | a1_label = str(bond.GetBeginAtom().GetAtomicNum()) 88 | a2_label = str(bond.GetEndAtom().GetAtomicNum()) 89 | if bond.GetBeginAtom().GetAtomMapNum(): 90 | a1_label += str(bond.GetBeginAtom().GetAtomMapNum()) 91 | if bond.GetEndAtom().GetAtomMapNum(): 92 | a2_label += str(bond.GetEndAtom().GetAtomMapNum()) 93 | atoms = sorted([a1_label, a2_label]) 94 | 95 | return '{}{}{}'.format(atoms[0], bond.GetSmarts(), atoms[1]) 96 | 97 | 98 | def atoms_are_different(atom1, atom2): 99 | '''Compares two RDKit atoms based on basic properties''' 100 | 101 | if atom1.GetSmarts() != atom2.GetSmarts(): return True # should be very general 102 | if atom1.GetAtomicNum() != atom2.GetAtomicNum(): return True # must be true for atom mapping 103 | if atom1.GetTotalNumHs() != atom2.GetTotalNumHs(): return True 104 | if atom1.GetFormalCharge() != atom2.GetFormalCharge(): return True 105 | if atom1.GetDegree() != atom2.GetDegree(): return True 106 | if atom1.GetNumRadicalElectrons() != atom2.GetNumRadicalElectrons(): return True 107 | if atom1.GetIsAromatic() != atom2.GetIsAromatic(): return True 108 | 109 | # Check bonds and nearest neighbor identity 110 | bonds1 = sorted([bond_to_label(bond) for bond in atom1.GetBonds()]) 111 | bonds2 = sorted([bond_to_label(bond) for bond in atom2.GetBonds()]) 112 | if bonds1 != bonds2: return True 113 | 114 | return False -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/chiral.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from rdkit.Chem.rdchem import ChiralType, BondType, BondDir 3 | 4 | from rdchiral.utils import vprint, parity4, PLEVEL 5 | 6 | def template_atom_could_have_been_tetra(a, strip_if_spec=False, cache=True): 7 | ''' 8 | Could this atom have been a tetrahedral center? 9 | If yes, template atom is considered achiral and will not match a chiral rct 10 | If no, the tempalte atom is auxilliary and we should not use it to remove 11 | a matched reaction. For example, a fully-generalized terminal [C:1] 12 | ''' 13 | 14 | if a.HasProp('tetra_possible'): 15 | return a.GetBoolProp('tetra_possible') 16 | if a.GetDegree() < 3 or (a.GetDegree() == 3 and 'H' not in a.GetSmarts()): 17 | if cache: 18 | a.SetBoolProp('tetra_possible', False) 19 | if strip_if_spec: # Clear chiral tag in case improperly set 20 | a.SetChiralTag(ChiralType.CHI_UNSPECIFIED) 21 | return False 22 | if cache: 23 | a.SetBoolProp('tetra_possible', True) 24 | return True 25 | 26 | 27 | 28 | def copy_chirality(a_src, a_new): 29 | 30 | # Not possible to be a tetrahedral center anymore? 31 | if a_new.GetDegree() < 3: 32 | return 33 | if a_new.GetDegree() == 3 and \ 34 | any(b.GetBondType() != BondType.SINGLE for b in a_new.GetBonds()): 35 | return 36 | 37 | if PLEVEL >= 3: print('For mapnum {}, copying src {} chirality tag to new'.format( 38 | a_src.GetAtomMapNum(), a_src.GetChiralTag())) 39 | a_new.SetChiralTag(a_src.GetChiralTag()) 40 | 41 | if atom_chirality_matches(a_src, a_new) == -1: 42 | if PLEVEL >= 3: print('For mapnum {}, inverting chirality'.format(a_new.GetAtomMapNum())) 43 | a_new.InvertChirality() 44 | 45 | def atom_chirality_matches(a_tmp, a_mol): 46 | ''' 47 | Checks for consistency in chirality between a template atom and a molecule atom. 48 | 49 | Also checks to see if chirality needs to be inverted in copy_chirality 50 | 51 | Returns +1 if it is a match and there is no need for inversion (or ambiguous) 52 | Returns -1 if it is a match but they are the opposite 53 | Returns 0 if an explicit NOT match 54 | Returns 2 if ambiguous or achiral-achiral 55 | ''' 56 | if a_mol.GetChiralTag() == ChiralType.CHI_UNSPECIFIED: 57 | if a_tmp.GetChiralTag() == ChiralType.CHI_UNSPECIFIED: 58 | if PLEVEL >= 3: print('atom {} is achiral & achiral -> match'.format(a_mol.GetAtomMapNum())) 59 | return 2 # achiral template, achiral molecule -> match 60 | # What if the template was chiral, but the reactant isn't just due to symmetry? 61 | if not a_mol.HasProp('_ChiralityPossible'): 62 | # It's okay to make a match, as long as the product is achiral (even 63 | # though the product template will try to impose chirality) 64 | if PLEVEL >= 3: print('atom {} is specified in template, but cant possibly be chiral in mol'.format(a_mol.GetAtomMapNum())) 65 | return 2 66 | 67 | # Discussion: figure out if we want this behavior - should a chiral template 68 | # be applied to an achiral molecule? For the retro case, if we have 69 | # a retro reaction that requires a specific stereochem, return False; 70 | # however, there will be many cases where the reaction would probably work 71 | if PLEVEL >= 3: print('atom {} is achiral in mol, but specified in template'.format(a_mol.GetAtomMapNum())) 72 | return 0 73 | if a_tmp.GetChiralTag() == ChiralType.CHI_UNSPECIFIED: 74 | if PLEVEL >= 3: print('Reactant {} atom chiral, rtemplate achiral...'.format(a_tmp.GetAtomMapNum())) 75 | if template_atom_could_have_been_tetra(a_tmp): 76 | if PLEVEL >= 3: print('...and that atom could have had its chirality specified! no_match') 77 | return 0 78 | if PLEVEL >= 3: print('...but the rtemplate atom could not have had chirality specified, match anyway') 79 | return 2 80 | 81 | mapnums_tmp = [a.GetAtomMapNum() for a in a_tmp.GetNeighbors()] 82 | mapnums_mol = [a.GetAtomMapNum() for a in a_mol.GetNeighbors()] 83 | 84 | # When there are fewer than 3 heavy neighbors, chirality is ambiguous... 85 | if len(mapnums_tmp) < 3 or len(mapnums_mol) < 3: 86 | return 2 87 | 88 | # Degree of 3 -> remaining atom is a hydrogen, add to list 89 | if len(mapnums_tmp) < 4: 90 | mapnums_tmp.append(-1) # H 91 | if len(mapnums_mol) < 4: 92 | mapnums_mol.append(-1) # H 93 | 94 | try: 95 | if PLEVEL >= 10: print(str(mapnums_tmp)) 96 | if PLEVEL >= 10: print(str(mapnums_mol)) 97 | if PLEVEL >= 10: print(str(a_tmp.GetChiralTag())) 98 | if PLEVEL >= 10: print(str(a_mol.GetChiralTag())) 99 | only_in_src = [i for i in mapnums_tmp if i not in mapnums_mol][::-1] # reverse for popping 100 | only_in_mol = [i for i in mapnums_mol if i not in mapnums_tmp] 101 | if len(only_in_src) <= 1 and len(only_in_mol) <= 1: 102 | tmp_parity = parity4(mapnums_tmp) 103 | mol_parity = parity4([i if i in mapnums_tmp else only_in_src.pop() for i in mapnums_mol]) 104 | if PLEVEL >= 10: print(str(tmp_parity)) 105 | if PLEVEL >= 10: print(str(mol_parity)) 106 | parity_matches = tmp_parity == mol_parity 107 | tag_matches = a_tmp.GetChiralTag() == a_mol.GetChiralTag() 108 | chirality_matches = parity_matches == tag_matches 109 | if PLEVEL >= 2: print('mapnum {} chiral match? {}'.format(a_tmp.GetAtomMapNum(), chirality_matches)) 110 | return 1 if chirality_matches else -1 111 | else: 112 | if PLEVEL >= 2: print('mapnum {} chiral match? Based on mapnum lists, ambiguous -> True'.format(a_tmp.GetAtomMapNum())) 113 | return 2 # ambiguous case, just return for now 114 | 115 | except IndexError as e: 116 | print(a_tmp.GetPropsAsDict()) 117 | print(a_mol.GetPropsAsDict()) 118 | print(a_tmp.GetChiralTag()) 119 | print(a_mol.GetChiralTag()) 120 | print(str(e)) 121 | print(str(mapnums_tmp)) 122 | print(str(mapnums_mol)) 123 | raise KeyError('Pop from empty set - this should not happen!') -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/clean.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import rdkit.Chem as Chem 3 | import re 4 | from itertools import chain 5 | 6 | from rdchiral.utils import vprint, PLEVEL 7 | 8 | 9 | def canonicalize_outcome_smiles(smiles, ensure=True): 10 | # Uniquify via SMILES string - a little sloppy 11 | # Need a full SMILES->MOL->SMILES cycle to get a true canonical string 12 | # also, split by '.' and sort when outcome contains multiple molecules 13 | if ensure: 14 | outcome = Chem.MolFromSmiles(smiles) 15 | if outcome is None: 16 | if PLEVEL >= 1: print('~~ could not parse self?') 17 | if PLEVEL >= 1: print('Attempted SMILES: {}', smiles) 18 | return None 19 | 20 | smiles = Chem.MolToSmiles(outcome, True) 21 | 22 | return '.'.join(sorted(smiles.split('.'))) 23 | 24 | def combine_enantiomers_into_racemic(final_outcomes): 25 | ''' 26 | If two products are identical except for an inverted CW/CCW or an 27 | opposite cis/trans, then just strip that from the product. Return 28 | the achiral one instead. 29 | 30 | This is not very sophisticated, since the chirality could affect the bond 31 | order and thus the canonical SMILES. But, whatever. It also does not look 32 | to invert multiple stereocenters at once 33 | ''' 34 | 35 | for smiles in list(final_outcomes)[:]: 36 | 37 | # Look for @@ tetrahedral center 38 | for match in re.finditer(r'@@', smiles): 39 | smiles_inv = '%s@%s' % (smiles[:match.start()], smiles[match.end():]) 40 | if smiles_inv in final_outcomes: 41 | if smiles in final_outcomes: 42 | final_outcomes.remove(smiles) 43 | final_outcomes.remove(smiles_inv) 44 | # Re-parse smiles so that hydrogens can become implicit 45 | smiles = smiles[:match.start()] + smiles[match.end():] 46 | outcome = Chem.MolFromSmiles(smiles) 47 | if outcome is None: 48 | raise ValueError('Horrible mistake when fixing duplicate!') 49 | smiles = '.'.join(sorted(Chem.MolToSmiles(outcome, True).split('.'))) 50 | final_outcomes.add(smiles) 51 | 52 | # Look for // or \\ trans bond 53 | # where [^=\.] is any non-double bond or period or slash 54 | for match in chain(re.finditer(r'(\/)([^=\.\\\/]+=[^=\.\\\/]+)(\/)', smiles), 55 | re.finditer(r'(\\)([^=\.\\\/]+=[^=\.\\\/]+)(\\)', smiles)): 56 | # See if cis version is present in list of outcomes 57 | opposite = {'\\': '/', '/': '\\'} 58 | smiles_cis1 = '%s%s%s%s%s' % (smiles[:match.start()], 59 | match.group(1), match.group(2), opposite[match.group(3)], 60 | smiles[match.end():]) 61 | smiles_cis2 = '%s%s%s%s%s' % (smiles[:match.start()], 62 | opposite[match.group(1)], match.group(2), match.group(3), 63 | smiles[match.end():]) 64 | # Also look for equivalent trans 65 | smiles_trans2 = '%s%s%s%s%s' % (smiles[:match.start()], 66 | opposite[match.group(1)], match.group(2), 67 | opposite[match.group(3)], smiles[match.end():]) 68 | # Kind of weird remove conditionals... 69 | remove = False 70 | if smiles_cis1 in final_outcomes: 71 | final_outcomes.remove(smiles_cis1) 72 | remove = True 73 | if smiles_cis2 in final_outcomes: 74 | final_outcomes.remove(smiles_cis2) 75 | remove = True 76 | if smiles_trans2 in final_outcomes and smiles in final_outcomes: 77 | final_outcomes.remove(smiles_trans2) 78 | if remove: 79 | final_outcomes.remove(smiles) 80 | smiles = smiles[:match.start()] + match.group(2) + smiles[match.end():] 81 | outcome = Chem.MolFromSmiles(smiles) 82 | if outcome is None: 83 | raise ValueError('Horrible mistake when fixing duplicate!') 84 | smiles = '.'.join(sorted(Chem.MolToSmiles(outcome, True).split('.'))) 85 | final_outcomes.add(smiles) 86 | return final_outcomes 87 | -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/old/__init__.py: -------------------------------------------------------------------------------- 1 | from rdkit import RDLogger 2 | lg = RDLogger.logger() 3 | lg.setLevel(RDLogger.CRITICAL) -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/old/chiral.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from rdkit.Chem.rdchem import ChiralType, BondType, BondDir 3 | 4 | from rdchiral.old.utils import vprint, parity4 5 | 6 | def template_atom_could_have_been_tetra(a, strip_if_spec=False): 7 | ''' 8 | Could this atom have been a tetrahedral center? 9 | If yes, template atom is considered achiral and will not match a chiral rct 10 | If no, the tempalte atom is auxilliary and we should not use it to remove 11 | a matched reaction. For example, a fully-generalized terminal [C:1] 12 | ''' 13 | 14 | if a.HasProp('tetra_possible'): 15 | return a.GetBoolProp('tetra_possible') 16 | if a.GetDegree() < 3 or (a.GetDegree() == 3 and 'H' not in a.GetSmarts()): 17 | a.SetBoolProp('tetra_possible', False) 18 | if strip_if_spec: # Clear chiral tag in case improperly set 19 | a.SetChiralTag(ChiralType.CHI_UNSPECIFIED) 20 | return False 21 | a.SetBoolProp('tetra_possible', True) 22 | return True 23 | 24 | 25 | 26 | def copy_chirality(a_src, a_new): 27 | 28 | # Not possible to be a tetrahedral center anymore? 29 | if a_new.GetDegree() < 3: 30 | return 31 | if a_new.GetDegree() == 3 and \ 32 | any(b.GetBondType() != BondType.SINGLE for b in a_new.GetBonds()): 33 | return 34 | 35 | vprint(3, 'For isotope {}, copying src {} chirality tag to new', 36 | a_src.GetIsotope(), a_src.GetChiralTag()) 37 | a_new.SetChiralTag(a_src.GetChiralTag()) 38 | 39 | if not atom_chirality_matches(a_src, a_new): 40 | vprint(3, 'For isotope {}, inverting chirality', a_new.GetIsotope()) 41 | a_new.InvertChirality() 42 | 43 | def atom_chirality_matches(a_tmp, a_mol): 44 | ''' 45 | Checks for consistency in chirality between a template atom and a molecule atom. 46 | 47 | Also checks to see if chirality needs to be inverted in copy_chirality 48 | ''' 49 | if a_mol.GetChiralTag() == ChiralType.CHI_UNSPECIFIED: 50 | if a_tmp.GetChiralTag() == ChiralType.CHI_UNSPECIFIED: 51 | vprint(3, 'atom {} is achiral & achiral -> match', a_mol.GetIsotope()) 52 | return True # achiral template, achiral molecule -> match 53 | # What if the template was chiral, but the reactant isn't just due to symmetry? 54 | if not a_mol.HasProp('_ChiralityPossible'): 55 | # It's okay to make a match, as long as the product is achiral (even 56 | # though the product template will try to impose chirality) 57 | vprint(3, 'atom {} is specified in template, but cant possibly be chiral in mol', a_mol.GetIsotope()) 58 | return True 59 | 60 | # TODO: figure out if we want this behavior - should a chiral template 61 | # be applied to an achiral molecule? For the retro case, if we have 62 | # a retro reaction that requires a specific stereochem, return False; 63 | # however, there will be many cases where the reaction would probably work 64 | vprint(3, 'atom {} is achiral in mol, but specified in template', a_mol.GetIsotope()) 65 | return False 66 | if a_tmp.GetChiralTag() == ChiralType.CHI_UNSPECIFIED: 67 | vprint(3, 'Reactant {} atom chiral, rtemplate achiral...', a_tmp.GetIsotope()) 68 | if template_atom_could_have_been_tetra(a_tmp): 69 | vprint(3, '...and that atom could have had its chirality specified! no_match') 70 | return False 71 | vprint(3, '...but the rtemplate atom could not have had chirality specified, match anyway') 72 | return True 73 | 74 | isotopes_tmp = [a.GetIsotope() for a in a_tmp.GetNeighbors()] 75 | isotopes_mol = [a.GetIsotope() for a in a_mol.GetNeighbors()] 76 | 77 | # When there are fewer than 3 heavy neighbors, chirality is ambiguous... 78 | if len(isotopes_tmp) < 3 or len(isotopes_mol) < 3: 79 | return True 80 | 81 | # Degree of 3 -> remaining atom is a hydrogen, add to list 82 | if len(isotopes_tmp) < 4: 83 | isotopes_tmp.append(-1) # H 84 | if len(isotopes_mol) < 4: 85 | isotopes_mol.append(-1) # H 86 | 87 | try: 88 | vprint(10, str(isotopes_tmp)) 89 | vprint(10, str(isotopes_mol)) 90 | vprint(10, str(a_tmp.GetChiralTag())) 91 | vprint(10, str(a_mol.GetChiralTag())) 92 | only_in_src = [i for i in isotopes_tmp if i not in isotopes_mol][::-1] # reverse for popping 93 | only_in_mol = [i for i in isotopes_mol if i not in isotopes_tmp] 94 | if len(only_in_src) <= 1 and len(only_in_mol) <= 1: 95 | tmp_parity = parity4(isotopes_tmp) 96 | mol_parity = parity4([i if i in isotopes_tmp else only_in_src.pop() for i in isotopes_mol]) 97 | vprint(10, str(tmp_parity)) 98 | vprint(10, str(mol_parity)) 99 | parity_matches = tmp_parity == mol_parity 100 | tag_matches = a_tmp.GetChiralTag() == a_mol.GetChiralTag() 101 | chirality_matches = parity_matches == tag_matches 102 | vprint(2, 'Isotope {} chiral match? {}', a_tmp.GetIsotope(), chirality_matches) 103 | return chirality_matches 104 | else: 105 | vprint(2, 'Isotope {} chiral match? Based on isotope lists, ambiguous -> True', a_tmp.GetIsotope()) 106 | return True # ambiguous case, just return for now 107 | # TODO: fix this? 108 | 109 | except IndexError as e: 110 | print(a_tmp.GetPropsAsDict()) 111 | print(a_mol.GetPropsAsDict()) 112 | print(a_tmp.GetChiralTag()) 113 | print(a_mol.GetChiralTag()) 114 | print(str(e)) 115 | print(str(isotopes_tmp)) 116 | print(str(isotopes_mol)) 117 | raise KeyError('Pop from empty set - this should not happen!') 118 | -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/old/clean.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import rdkit.Chem as Chem 3 | import re 4 | from itertools import chain 5 | 6 | from rdchiral.old.utils import vprint 7 | 8 | 9 | def canonicalize_outcome_smiles(outcome): 10 | # Uniquify via SMILES string - a little sloppy 11 | # Need a full SMILES->MOL->SMILES cycle to get a true canonical string 12 | # also, split by '.' and sort when outcome contains multiple molecules 13 | smiles = Chem.MolToSmiles(outcome, True) 14 | outcome = Chem.MolFromSmiles(smiles) 15 | if outcome is None: 16 | vprint(1, '~~ could not parse self?') 17 | vprint(1, 'Attempted SMILES: {}', smiles) 18 | return None 19 | return '.'.join(sorted(Chem.MolToSmiles(outcome, True).split('.'))) 20 | 21 | def combine_enantiomers_into_racemic(final_outcomes): 22 | ''' 23 | If two products are identical except for an inverted CW/CCW or an 24 | opposite cis/trans, then just strip that from the product. Return 25 | the achiral one instead. 26 | 27 | This is not very sophisticated, since the chirality could affect the bond 28 | order and thus the canonical SMILES. But, whatever. It also does not look 29 | to invert multiple stereocenters at once 30 | ''' 31 | 32 | for smiles in list(final_outcomes)[:]: 33 | # Look for @@ tetrahedral center 34 | for match in re.finditer(r'@@', smiles): 35 | smiles_inv = '%s@%s' % (smiles[:match.start()], smiles[match.end():]) 36 | if smiles_inv in final_outcomes: 37 | final_outcomes.remove(smiles) 38 | final_outcomes.remove(smiles_inv) 39 | # Re-parse smiles so that hydrogens can become implicit 40 | smiles = smiles[:match.start()] + smiles[match.end():] 41 | outcome = Chem.MolFromSmiles(smiles) 42 | if outcome is None: 43 | raise ValueError('Horrible mistake when fixing duplicate!') 44 | smiles = '.'.join(sorted(Chem.MolToSmiles(outcome, True).split('.'))) 45 | final_outcomes.add(smiles) 46 | 47 | # Look for // or \\ trans bond 48 | # where [^=\.] is any non-double bond or period or slash 49 | for match in chain(re.finditer(r'(\/)([^=\.\\\/]+=[^=\.\\\/]+)(\/)', smiles), 50 | re.finditer(r'(\\)([^=\.\\\/]+=[^=\.\\\/]+)(\\)', smiles)): 51 | # See if cis version is present in list of outcomes 52 | opposite = {'\\': '/', '/': '\\'} 53 | smiles_cis1 = '%s%s%s%s%s' % (smiles[:match.start()], 54 | match.group(1), match.group(2), opposite[match.group(3)], 55 | smiles[match.end():]) 56 | smiles_cis2 = '%s%s%s%s%s' % (smiles[:match.start()], 57 | opposite[match.group(1)], match.group(2), match.group(3), 58 | smiles[match.end():]) 59 | # Also look for equivalent trans 60 | smiles_trans2 = '%s%s%s%s%s' % (smiles[:match.start()], 61 | opposite[match.group(1)], match.group(2), 62 | opposite[match.group(3)], smiles[match.end():]) 63 | # Kind of weird remove conditionals... 64 | remove = False 65 | if smiles_cis1 in final_outcomes: 66 | final_outcomes.remove(smiles_cis1) 67 | remove = True 68 | if smiles_cis2 in final_outcomes: 69 | final_outcomes.remove(smiles_cis2) 70 | remove = True 71 | if smiles_trans2 in final_outcomes and smiles in final_outcomes: 72 | final_outcomes.remove(smiles_trans2) 73 | if remove: 74 | final_outcomes.remove(smiles) 75 | smiles = smiles[:match.start()] + match.group(2) + smiles[match.end():] 76 | outcome = Chem.MolFromSmiles(smiles) 77 | if outcome is None: 78 | raise ValueError('Horrible mistake when fixing duplicate!') 79 | smiles = '.'.join(sorted(Chem.MolToSmiles(outcome, True).split('.'))) 80 | final_outcomes.add(smiles) 81 | return final_outcomes 82 | -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/old/initialization.py: -------------------------------------------------------------------------------- 1 | import rdkit.Chem as Chem 2 | import rdkit.Chem.AllChem as AllChem 3 | from rdkit.Chem.rdchem import ChiralType, BondType, BondDir 4 | 5 | from rdchiral.old.chiral import template_atom_could_have_been_tetra 6 | from rdchiral.old.utils import vprint 7 | 8 | class rdchiralReaction(): 9 | ''' 10 | Class to store everything that should be pre-computed for a reaction. This 11 | makes library application much faster, since we can pre-do a lot of work 12 | instead of doing it for every mol-template pair 13 | ''' 14 | def __init__(self, reaction_smarts): 15 | # Keep smarts, useful for reporting 16 | self.reaction_smarts = reaction_smarts 17 | 18 | # Initialize - assigns stereochemistry and fills in missing rct map numbers 19 | self.rxn = initialize_rxn_from_smarts(reaction_smarts) 20 | 21 | # Combine template fragments so we can play around with isotopes 22 | self.template_r, self.template_p = get_template_frags_from_rxn(self.rxn) 23 | 24 | # Define molAtomMapNumber->atom dictionary for template rct and prd 25 | self.atoms_rt_map = {a.GetIntProp('molAtomMapNumber'): a \ 26 | for a in self.template_r.GetAtoms() if a.HasProp('molAtomMapNumber')} 27 | self.atoms_pt_map = {a.GetIntProp('molAtomMapNumber'): a \ 28 | for a in self.template_p.GetAtoms() if a.HasProp('molAtomMapNumber')} 29 | 30 | # Call template_atom_could_have_been_tetra to pre-assign value to atom 31 | [template_atom_could_have_been_tetra(a) for a in self.template_r.GetAtoms()] 32 | [template_atom_could_have_been_tetra(a) for a in self.template_p.GetAtoms()] 33 | 34 | class rdchiralReactants(): 35 | ''' 36 | Class to store everything that should be pre-computed for a reactant mol 37 | so that library application is faster 38 | ''' 39 | def __init__(self, reactant_smiles): 40 | # Keep original smiles, useful for reporting 41 | self.reactant_smiles = reactant_smiles 42 | 43 | # Initialize into RDKit mol 44 | self.reactants = initialize_reactants_from_smiles(reactant_smiles) 45 | 46 | # Set isotope->atom dictionary 47 | # all reactant atoms must be mapped after initialization, so this is safe 48 | self.atoms_r = {a.GetIsotope(): a for a in self.reactants.GetAtoms()} 49 | 50 | # Create copy of molecule without chiral information, used with 51 | # RDKit's naive runReactants 52 | self.reactants_achiral = initialize_reactants_from_smiles(reactant_smiles) 53 | [a.SetChiralTag(ChiralType.CHI_UNSPECIFIED) for a in self.reactants_achiral.GetAtoms()] 54 | # TODO: strip bond chirality? 55 | 56 | # Pre-list reactant bonds (for stitching broken products) 57 | self.bonds_by_isotope = [ 58 | (b.GetBeginAtom().GetIsotope(), b.GetEndAtom().GetIsotope(), b) \ 59 | for b in self.reactants.GetBonds() 60 | ] 61 | 62 | def initialize_rxn_from_smarts(reaction_smarts): 63 | # Initialize reaction 64 | rxn = AllChem.ReactionFromSmarts(reaction_smarts) 65 | rxn.Initialize() 66 | if rxn.Validate()[1] != 0: 67 | raise ValueError('validation failed') 68 | vprint(2, 'Validated rxn without errors') 69 | 70 | unmapped = 700 71 | for rct in rxn.GetReactants(): 72 | rct.UpdatePropertyCache() 73 | Chem.AssignStereochemistry(rct) 74 | # Fill in atom map numbers 75 | for a in rct.GetAtoms(): 76 | if not a.HasProp('molAtomMapNumber'): 77 | a.SetIntProp('molAtomMapNumber', unmapped) 78 | unmapped += 1 79 | vprint(2, 'Added {} map nums to unmapped reactants', unmapped-700) 80 | if unmapped > 800: 81 | raise ValueError('Why do you have so many unmapped atoms in the template reactants?') 82 | 83 | return rxn 84 | 85 | def initialize_reactants_from_smiles(reactant_smiles): 86 | # Initialize reactants 87 | reactants = Chem.MolFromSmiles(reactant_smiles) 88 | Chem.AssignStereochemistry(reactants, flagPossibleStereoCenters=True) 89 | reactants.UpdatePropertyCache() 90 | # To have the product atoms match reactant atoms, we 91 | # need to populate the Isotope field, since this field 92 | # gets copied over during the reaction. 93 | [a.SetIsotope(i+1) for (i, a) in enumerate(reactants.GetAtoms())] 94 | vprint(2, 'Initialized reactants, assigned isotopes, stereochem, flagpossiblestereocenters') 95 | return reactants 96 | 97 | def get_template_frags_from_rxn(rxn): 98 | # Copy reaction template so we can play around with isotopes 99 | for i, rct in enumerate(rxn.GetReactants()): 100 | if i == 0: 101 | template_r = rct 102 | else: 103 | template_r = AllChem.CombineMols(template_r, rct) 104 | for i, prd in enumerate(rxn.GetProducts()): 105 | if i == 0: 106 | template_p = prd 107 | else: 108 | template_p = AllChem.CombineMols(template_p, prd) 109 | return template_r, template_p 110 | -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/old/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | PLEVEL = 0 4 | def vprint(level, txt, *args): 5 | if PLEVEL >= level: 6 | print(txt.format(*args)) 7 | 8 | def parity4(data): 9 | ''' 10 | Thanks to http://www.dalkescientific.com/writings/diary/archive/2016/08/15/fragment_parity_calculation.html 11 | ''' 12 | if data[0] < data[1]: 13 | if data[2] < data[3]: 14 | if data[0] < data[2]: 15 | if data[1] < data[2]: 16 | return 0 # (0, 1, 2, 3) 17 | else: 18 | if data[1] < data[3]: 19 | return 1 # (0, 2, 1, 3) 20 | else: 21 | return 0 # (0, 3, 1, 2) 22 | else: 23 | if data[0] < data[3]: 24 | if data[1] < data[3]: 25 | return 0 # (1, 2, 0, 3) 26 | else: 27 | return 1 # (1, 3, 0, 2) 28 | else: 29 | return 0 # (2, 3, 0, 1) 30 | else: 31 | if data[0] < data[3]: 32 | if data[1] < data[2]: 33 | if data[1] < data[3]: 34 | return 1 # (0, 1, 3, 2) 35 | else: 36 | return 0 # (0, 2, 3, 1) 37 | else: 38 | return 1 # (0, 3, 2, 1) 39 | else: 40 | if data[0] < data[2]: 41 | if data[1] < data[2]: 42 | return 1 # (1, 2, 3, 0) 43 | else: 44 | return 0 # (1, 3, 2, 0) 45 | else: 46 | return 1 # (2, 3, 1, 0) 47 | else: 48 | if data[2] < data[3]: 49 | if data[0] < data[3]: 50 | if data[0] < data[2]: 51 | return 1 # (1, 0, 2, 3) 52 | else: 53 | if data[1] < data[2]: 54 | return 0 # (2, 0, 1, 3) 55 | else: 56 | return 1 # (2, 1, 0, 3) 57 | else: 58 | if data[1] < data[2]: 59 | return 1 # (3, 0, 1, 2) 60 | else: 61 | if data[1] < data[3]: 62 | return 0 # (3, 1, 0, 2) 63 | else: 64 | return 1 # (3, 2, 0, 1) 65 | else: 66 | if data[0] < data[2]: 67 | if data[0] < data[3]: 68 | return 0 # (1, 0, 3, 2) 69 | else: 70 | if data[1] < data[3]: 71 | return 1 # (2, 0, 3, 1) 72 | else: 73 | return 0 # (2, 1, 3, 0) 74 | else: 75 | if data[1] < data[2]: 76 | if data[1] < data[3]: 77 | return 0 # (3, 0, 2, 1) 78 | else: 79 | return 1 # (3, 1, 2, 0) 80 | else: 81 | return 0 # (3, 2, 1, 0) 82 | -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmh14/data_efficient_grammar/ef34648f987278496e1216cbeb7f82c9429da4b0/retro_star/packages/rdchiral/rdchiral/test/__init__.py -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/rdchiral/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | PLEVEL = 0 4 | def vprint(level, txt, *args): 5 | if PLEVEL >= level: 6 | print(txt.format(*args)) 7 | 8 | def parity4(data): 9 | ''' 10 | Thanks to http://www.dalkescientific.com/writings/diary/archive/2016/08/15/fragment_parity_calculation.html 11 | ''' 12 | if data[0] < data[1]: 13 | if data[2] < data[3]: 14 | if data[0] < data[2]: 15 | if data[1] < data[2]: 16 | return 0 # (0, 1, 2, 3) 17 | else: 18 | if data[1] < data[3]: 19 | return 1 # (0, 2, 1, 3) 20 | else: 21 | return 0 # (0, 3, 1, 2) 22 | else: 23 | if data[0] < data[3]: 24 | if data[1] < data[3]: 25 | return 0 # (1, 2, 0, 3) 26 | else: 27 | return 1 # (1, 3, 0, 2) 28 | else: 29 | return 0 # (2, 3, 0, 1) 30 | else: 31 | if data[0] < data[3]: 32 | if data[1] < data[2]: 33 | if data[1] < data[3]: 34 | return 1 # (0, 1, 3, 2) 35 | else: 36 | return 0 # (0, 2, 3, 1) 37 | else: 38 | return 1 # (0, 3, 2, 1) 39 | else: 40 | if data[0] < data[2]: 41 | if data[1] < data[2]: 42 | return 1 # (1, 2, 3, 0) 43 | else: 44 | return 0 # (1, 3, 2, 0) 45 | else: 46 | return 1 # (2, 3, 1, 0) 47 | else: 48 | if data[2] < data[3]: 49 | if data[0] < data[3]: 50 | if data[0] < data[2]: 51 | return 1 # (1, 0, 2, 3) 52 | else: 53 | if data[1] < data[2]: 54 | return 0 # (2, 0, 1, 3) 55 | else: 56 | return 1 # (2, 1, 0, 3) 57 | else: 58 | if data[1] < data[2]: 59 | return 1 # (3, 0, 1, 2) 60 | else: 61 | if data[1] < data[3]: 62 | return 0 # (3, 1, 0, 2) 63 | else: 64 | return 1 # (3, 2, 0, 1) 65 | else: 66 | if data[0] < data[2]: 67 | if data[0] < data[3]: 68 | return 0 # (1, 0, 3, 2) 69 | else: 70 | if data[1] < data[3]: 71 | return 1 # (2, 0, 3, 1) 72 | else: 73 | return 0 # (2, 1, 3, 0) 74 | else: 75 | if data[1] < data[2]: 76 | if data[1] < data[3]: 77 | return 0 # (3, 0, 2, 1) 78 | else: 79 | return 1 # (3, 1, 2, 0) 80 | else: 81 | return 0 # (3, 2, 1, 0) 82 | 83 | def bond_to_label(bond): 84 | '''This function takes an RDKit bond and creates a label describing 85 | the most important attributes''' 86 | 87 | a1_label = str(bond.GetBeginAtom().GetAtomicNum()) 88 | a2_label = str(bond.GetEndAtom().GetAtomicNum()) 89 | if bond.GetBeginAtom().GetAtomMapNum(): 90 | a1_label += str(bond.GetBeginAtom().GetAtomMapNum()) 91 | if bond.GetEndAtom().GetAtomMapNum(): 92 | a2_label += str(bond.GetEndAtom().GetAtomMapNum()) 93 | atoms = sorted([a1_label, a2_label]) 94 | 95 | return '{}{}{}'.format(atoms[0], bond.GetSmarts(), atoms[1]) 96 | 97 | 98 | def atoms_are_different(atom1, atom2): 99 | '''Compares two RDKit atoms based on basic properties''' 100 | 101 | if atom1.GetSmarts() != atom2.GetSmarts(): return True # should be very general 102 | if atom1.GetAtomicNum() != atom2.GetAtomicNum(): return True # must be true for atom mapping 103 | if atom1.GetTotalNumHs() != atom2.GetTotalNumHs(): return True 104 | if atom1.GetFormalCharge() != atom2.GetFormalCharge(): return True 105 | if atom1.GetDegree() != atom2.GetDegree(): return True 106 | if atom1.GetNumRadicalElectrons() != atom2.GetNumRadicalElectrons(): return True 107 | if atom1.GetIsAromatic() != atom2.GetIsAromatic(): return True 108 | 109 | # Check bonds and nearest neighbor identity 110 | bonds1 = sorted([bond_to_label(bond) for bond in atom1.GetBonds()]) 111 | bonds2 = sorted([bond_to_label(bond) for bond in atom2.GetBonds()]) 112 | if bonds1 != bonds2: return True 113 | 114 | return False -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | import os 4 | BASEPATH = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | setup(name='rdchiral', 7 | py_modules=['rdchiral'], 8 | install_requires=[ 9 | ], 10 | ) 11 | -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/templates/README.md: -------------------------------------------------------------------------------- 1 | # Template Extractor 2 | -------------------- 3 | 4 | Code to clean and extract templates from USPTO reaction database. 5 | 6 | ### Getting Started 7 | 8 | ##### Pre-requisites 9 | 10 | * 7z archive extraction tool 11 | * rdkit 12 | * numpy 13 | * pandas 14 | * joblib 15 | 16 | 17 | ##### Step 1 18 | 19 | Download USPTO reaction database from https://figshare.com/articles/Chemical_reactions_from_US_patents_1976-Sep2016_/5104873, extract 7z archive, and place `1976_Sep2016_USPTOgrants_smiles.rsmi` into `data/` folder 20 | 21 | ```bash 22 | $ mkdir data/ && cd data/ 23 | $ wget -O 1976_Sep2016_USPTOgrants_smiles.7z https://ndownloader.figshare.com/files/8664379 24 | $ 7z e 1976_Sep2016_USPTOgrants_smiles.7z 25 | $ cd ../ 26 | ``` 27 | 28 | ##### Step 2 29 | 30 | Run `clean_and_extract_uspto.py` script. This will try to use all the CPU cores on your machine. On 32 cores it takes roughly 1 hour to run. 31 | 32 | ```bash 33 | $ python clean_and_extract_uspto.py 34 | ``` 35 | 36 | This will generate `data/uspto.reactions.json.gz` and `data/uspto.templates.json.gz`. These two files can also be downloaded directly from [here](https://chemrxiv.org/articles/RDChiral_An_RDKit_Wrapper_for_Handling_Stereochemistry_in_Retrosynthetic_Template_Extraction_and_Application/7949024) if you do not wish to re-run the extraction code. 37 | 38 | TODO: 39 | * better documentation 40 | * argparse 41 | * template grouping 42 | * merge templates with reactions to generate training data 43 | -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/templates/clean_and_extract_uspto.py: -------------------------------------------------------------------------------- 1 | from rdkit import RDLogger 2 | lg = RDLogger.logger() 3 | lg.setLevel(RDLogger.ERROR) 4 | 5 | import json 6 | import gzip 7 | import hashlib 8 | import pandas as pd 9 | from rdkit import Chem 10 | from joblib import Parallel, delayed 11 | from time import time 12 | 13 | 14 | import template_extractor 15 | 16 | def can_parse(rsmi): 17 | react, spec, prod = rsmi.split('>') 18 | if Chem.MolFromSmiles(react) and Chem.MolFromSmiles(prod): 19 | return True 20 | else: 21 | return False 22 | 23 | t0 = time() 24 | 25 | uspto = pd.read_csv('data/1976_Sep2016_USPTOgrants_smiles.rsmi', sep='\t') 26 | 27 | uspto['ReactionSmiles'] = uspto['ReactionSmiles'].str.split(' ', expand=True)[0] 28 | split_smiles = uspto['ReactionSmiles'].str.split('>', expand=True) 29 | uspto['reactants'] = split_smiles[0] 30 | uspto['spectators'] = split_smiles[1] 31 | uspto['products'] = split_smiles[2] 32 | 33 | parsable = Parallel(n_jobs=-1, verbose=1)(delayed(can_parse)(rsmi) for rsmi in uspto['ReactionSmiles'].values) 34 | # parsable = uspto['ReactionSmiles'].map(can_parse) 35 | 36 | uspto = uspto[parsable] 37 | print('{} parsable reactions'.format(len(uspto))) 38 | 39 | hexhash = (uspto['ReactionSmiles']+uspto['PatentNumber']).apply(lambda x: hashlib.sha256(x.encode('utf-8')).hexdigest()) 40 | 41 | uspto['source'] = 'uspto' 42 | uspto['source_id'] = hexhash 43 | 44 | uspto = uspto.reset_index().rename(columns={'index': '_id'}) 45 | 46 | reactions = uspto[['_id', 'reactants', 'products', 'spectators', 'source', 'source_id']] 47 | 48 | reactions.to_json('data/uspto.reactions.json.gz', orient='records', compression='gzip') 49 | 50 | with gzip.open('data/uspto.reactions.json.gz') as f: 51 | reactions = json.load(f) 52 | 53 | def extract(reaction): 54 | try: 55 | return template_extractor.extract_from_reaction(reaction) 56 | except KeyboardInterrupt: 57 | print('Interrupted') 58 | raise KeyboardInterrupt 59 | except Exception as e: 60 | print(e) 61 | return {'reaction_id': reaction['_id']} 62 | 63 | templates = Parallel(n_jobs=-1, verbose=4)(delayed(extract)(reaction) for reaction in reactions) 64 | 65 | with gzip.open('data/uspto.templates.json.gz', 'w') as f: 66 | json.dump(templates, f) 67 | 68 | print('elapsed seconds: {}'.format(int(time()-t0))) -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/templates/example_template_extractions_bad.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "reaction_id": 69, 4 | "reactants": "Cl-[CH;D3;+0:1](-[C;D1;H3:2])-[c;H0;D3;+0:3]1:[n;H0;D2;+0:4]:c:s:c:1.O=[S;H0;D3;+0:5](-[CH3;D1;+0:6])-[CH3;D1;+0:7].[C-;H0;D1:8]#[N;H0;D1;+0:9]", 5 | "intra_only": false, 6 | "products": "[C;D1;H3:2]-[CH;D3;+0:1](-[C;H0;D2;+0:3]#[N;H0;D1;+0:4])-[c;H0;D3;+0:8]1:[cH;D2;+0:7]:[s;H0;D2;+0:5]:[cH;D2;+0:6]:[n;H0;D2;+0:9]:1", 7 | "reaction_smarts": "[C;D1;H3:2]-[CH;D3;+0:1](-[C;H0;D2;+0:3]#[N;H0;D1;+0:4])-[c;H0;D3;+0:8]1:[cH;D2;+0:7]:[s;H0;D2;+0:5]:[cH;D2;+0:6]:[n;H0;D2;+0:9]:1>>Cl-[CH;D3;+0:1](-[C;D1;H3:2])-[c;H0;D3;+0:3]1:[n;H0;D2;+0:4]:c:s:c:1.O=[S;H0;D3;+0:5](-[CH3;D1;+0:6])-[CH3;D1;+0:7].[C-;H0;D1:8]#[N;H0;D1;+0:9]", 8 | "dimer_only": false, 9 | "necessary_reagent": "" 10 | }, 11 | { 12 | "reaction_id": 97, 13 | "reactants": "C-[CH2;D2;+0:1]-[C:2]", 14 | "intra_only": true, 15 | "products": "[C:2]-[CH3;D1;+0:1]", 16 | "reaction_smarts": "[C:2]-[CH3;D1;+0:1]>>C-[CH2;D2;+0:1]-[C:2]", 17 | "dimer_only": false, 18 | "necessary_reagent": "" 19 | } 20 | ] -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/templates/example_template_extractions_good.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "reaction_id": 25, 4 | "reactants": "O=C1-C-C-C(=O)-N-1-[Br;H0;D1;+0:1].[c:2]:[cH;D2;+0:3]:[c:4]", 5 | "intra_only": false, 6 | "products": "[Br;H0;D1;+0:1]-[c;H0;D3;+0:3](:[c:2]):[c:4]", 7 | "reaction_smarts": "[Br;H0;D1;+0:1]-[c;H0;D3;+0:3](:[c:2]):[c:4]>>O=C1-C-C-C(=O)-N-1-[Br;H0;D1;+0:1].[c:2]:[cH;D2;+0:3]:[c:4]", 8 | "dimer_only": false, 9 | "necessary_reagent": "" 10 | }, 11 | { 12 | "reaction_id": 35, 13 | "reactants": "C-C-[O;H0;D2;+0:1]-[C:2]=[O;D1;H0:3]", 14 | "intra_only": true, 15 | "products": "[O;D1;H0:3]=[C:2]-[OH;D1;+0:1]", 16 | "reaction_smarts": "[O;D1;H0:3]=[C:2]-[OH;D1;+0:1]>>C-C-[O;H0;D2;+0:1]-[C:2]=[O;D1;H0:3]", 17 | "dimer_only": false, 18 | "necessary_reagent": "" 19 | }, 20 | { 21 | "reaction_id": 49, 22 | "reactants": "O-S(=O)(=O)-[OH;D1;+0:1].[C:2]-[C;H0;D2;+0:3]#[N;H0;D1;+0:4]", 23 | "intra_only": false, 24 | "products": "[C:2]-[C;H0;D3;+0:3](-[NH2;D1;+0:4])=[O;H0;D1;+0:1]", 25 | "reaction_smarts": "[C:2]-[C;H0;D3;+0:3](-[NH2;D1;+0:4])=[O;H0;D1;+0:1]>>O-S(=O)(=O)-[OH;D1;+0:1].[C:2]-[C;H0;D2;+0:3]#[N;H0;D1;+0:4]", 26 | "dimer_only": false, 27 | "necessary_reagent": "" 28 | }, 29 | { 30 | "reaction_id": 51, 31 | "reactants": "Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3].[NH2;D1;+0:4]-[c:5]", 32 | "intra_only": false, 33 | "products": "[C:2]-[C;H0;D3;+0:1](=[O;D1;H0:3])-[NH;D2;+0:4]-[c:5]", 34 | "reaction_smarts": "[C:2]-[C;H0;D3;+0:1](=[O;D1;H0:3])-[NH;D2;+0:4]-[c:5]>>Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3].[NH2;D1;+0:4]-[c:5]", 35 | "dimer_only": false, 36 | "necessary_reagent": "" 37 | }, 38 | { 39 | "reaction_id": 67, 40 | "reactants": "O=[CH2;D1;+0:1].[C:2]-[NH;D2;+0:3]-[C:4].[NH2;D1;+0:5]-[C:6]=[S;D1;H0:7]", 41 | "intra_only": false, 42 | "products": "[C:2]-[N;H0;D3;+0:3](-[C:4])-[CH2;D2;+0:1]-[NH;D2;+0:5]-[C:6]=[S;D1;H0:7]", 43 | "reaction_smarts": "[C:2]-[N;H0;D3;+0:3](-[C:4])-[CH2;D2;+0:1]-[NH;D2;+0:5]-[C:6]=[S;D1;H0:7]>>O=[CH2;D1;+0:1].[C:2]-[NH;D2;+0:3]-[C:4].[NH2;D1;+0:5]-[C:6]=[S;D1;H0:7]", 44 | "dimer_only": false, 45 | "necessary_reagent": "" 46 | }, 47 | { 48 | "reaction_id": 115, 49 | "reactants": "Br-[CH2;D2;+0:1]-[CH2;D2;+0:2]-Br.[OH;D1;+0:3]-[c:4]:[c:5]-[OH;D1;+0:6]", 50 | "intra_only": false, 51 | "products": "[CH2;D2;+0:1]1-[CH2;D2;+0:2]-[O;H0;D2;+0:6]-[c:5]:[c:4]-[O;H0;D2;+0:3]-1", 52 | "reaction_smarts": "[CH2;D2;+0:1]1-[CH2;D2;+0:2]-[O;H0;D2;+0:6]-[c:5]:[c:4]-[O;H0;D2;+0:3]-1>>Br-[CH2;D2;+0:1]-[CH2;D2;+0:2]-Br.[OH;D1;+0:3]-[c:4]:[c:5]-[OH;D1;+0:6]", 53 | "dimer_only": false, 54 | "necessary_reagent": "" 55 | }, 56 | { 57 | "reaction_id": 119, 58 | "reactants": "Cl-P(-Cl)(=O)-[Cl;H0;D1;+0:1].O=[c;H0;D3;+0:2]1:[c:3]:[c:4]:[#7;a:5]:[c:6]:[nH;D2;+0:7]:1", 59 | "intra_only": false, 60 | "products": "[Cl;H0;D1;+0:1]-[c;H0;D3;+0:2]1:[c:3]:[c:4]:[#7;a:5]:[c:6]:[n;H0;D2;+0:7]:1", 61 | "reaction_smarts": "[Cl;H0;D1;+0:1]-[c;H0;D3;+0:2]1:[c:3]:[c:4]:[#7;a:5]:[c:6]:[n;H0;D2;+0:7]:1>>Cl-P(-Cl)(=O)-[Cl;H0;D1;+0:1].O=[c;H0;D3;+0:2]1:[c:3]:[c:4]:[#7;a:5]:[c:6]:[nH;D2;+0:7]:1", 62 | "dimer_only": false, 63 | "necessary_reagent": "" 64 | }, 65 | { 66 | "reaction_id": 126, 67 | "reactants": "C-c1:c:c:c(-S(=O)(=O)-O-[CH2;D2;+0:1]-[C:2]):c:c:1.[#7;a:3]:[c:4]-[OH;D1;+0:5]", 68 | "intra_only": false, 69 | "products": "[#7;a:3]:[c:4]-[O;H0;D2;+0:5]-[CH2;D2;+0:1]-[C:2]", 70 | "reaction_smarts": "[#7;a:3]:[c:4]-[O;H0;D2;+0:5]-[CH2;D2;+0:1]-[C:2]>>C-c1:c:c:c(-S(=O)(=O)-O-[CH2;D2;+0:1]-[C:2]):c:c:1.[#7;a:3]:[c:4]-[OH;D1;+0:5]", 71 | "dimer_only": false, 72 | "necessary_reagent": "" 73 | }, 74 | { 75 | "reaction_id": 134, 76 | "reactants": "[C:1]-[NH2;D1;+0:2].[C:3]1-[CH2;D2;+0:4]-[O;H0;D2;+0:5]-1", 77 | "intra_only": false, 78 | "products": "[C:1]-[NH;D2;+0:2]-[CH2;D2;+0:4]-[C:3]-[OH;D1;+0:5]", 79 | "reaction_smarts": "[C:1]-[NH;D2;+0:2]-[CH2;D2;+0:4]-[C:3]-[OH;D1;+0:5]>>[C:1]-[NH2;D1;+0:2].[C:3]1-[CH2;D2;+0:4]-[O;H0;D2;+0:5]-1", 80 | "dimer_only": false, 81 | "necessary_reagent": "" 82 | }, 83 | { 84 | "reaction_id": 168, 85 | "reactants": "C-C(=O)-[O;H0;D2;+0:1]-[c:2]", 86 | "intra_only": true, 87 | "products": "[OH;D1;+0:1]-[c:2]", 88 | "reaction_smarts": "[OH;D1;+0:1]-[c:2]>>C-C(=O)-[O;H0;D2;+0:1]-[c:2]", 89 | "dimer_only": false, 90 | "necessary_reagent": "" 91 | }, 92 | { 93 | "reaction_id": 195, 94 | "reactants": "O=[CH;D2;+0:1]-[c:2].[#8:3]-[C:4](=[O;D1;H0:5])-[CH2;D2;+0:6]-[C:7]#[N;D1;H0:8]", 95 | "intra_only": false, 96 | "products": "[#8:3]-[C:4](=[O;D1;H0:5])-[C;H0;D3;+0:6](-[C:7]#[N;D1;H0:8])=[CH;D2;+0:1]-[c:2]", 97 | "reaction_smarts": "[#8:3]-[C:4](=[O;D1;H0:5])-[C;H0;D3;+0:6](-[C:7]#[N;D1;H0:8])=[CH;D2;+0:1]-[c:2]>>O=[CH;D2;+0:1]-[c:2].[#8:3]-[C:4](=[O;D1;H0:5])-[CH2;D2;+0:6]-[C:7]#[N;D1;H0:8]", 98 | "dimer_only": false, 99 | "necessary_reagent": "" 100 | }, 101 | { 102 | "reaction_id": 242, 103 | "reactants": "[#7;a:1]:[c;H0;D3;+0:2](-[Li]):[c:3].[O;H0;D1;+0:4]=[CH;D2;+0:5]-[c:6]", 104 | "intra_only": false, 105 | "products": "[#7;a:1]:[c;H0;D3;+0:2](:[c:3])-[CH;D3;+0:5](-[OH;D1;+0:4])-[c:6]", 106 | "reaction_smarts": "[#7;a:1]:[c;H0;D3;+0:2](:[c:3])-[CH;D3;+0:5](-[OH;D1;+0:4])-[c:6]>>[#7;a:1]:[c;H0;D3;+0:2](-[Li]):[c:3].[O;H0;D1;+0:4]=[CH;D2;+0:5]-[c:6]", 107 | "dimer_only": false, 108 | "necessary_reagent": "" 109 | }, 110 | { 111 | "reaction_id": 250, 112 | "reactants": "O=C1-C-C-C(=O)-N-1-[Br;H0;D1;+0:1].[CH3;D1;+0:2]-[c:3]", 113 | "intra_only": false, 114 | "products": "[Br;H0;D1;+0:1]-[CH2;D2;+0:2]-[c:3]", 115 | "reaction_smarts": "[Br;H0;D1;+0:1]-[CH2;D2;+0:2]-[c:3]>>O=C1-C-C-C(=O)-N-1-[Br;H0;D1;+0:1].[CH3;D1;+0:2]-[c:3]", 116 | "dimer_only": false, 117 | "necessary_reagent": "" 118 | }, 119 | { 120 | "reaction_id": 1076, 121 | "reactants": "C-C(=O)-[O;H0;D2;+0:1]-[c:2]:[c:3]-[C;H0;D3;+0:4](-Cl)=[O;D1;H0:5].[N;D1;H0:6]#[C:7]-[CH2;D2;+0:8]-[C;H0;D2;+0:9]#[N;H0;D1;+0:10]", 122 | "intra_only": false, 123 | "products": "[N;D1;H0:6]#[C:7]-[c;H0;D3;+0:8]1:[c;H0;D3;+0:9](-[NH2;D1;+0:10]):[o;H0;D2;+0:1]:[c:2]:[c:3]:[c;H0;D3;+0:4]:1=[O;D1;H0:5]", 124 | "reaction_smarts": "[N;D1;H0:6]#[C:7]-[c;H0;D3;+0:8]1:[c;H0;D3;+0:9](-[NH2;D1;+0:10]):[o;H0;D2;+0:1]:[c:2]:[c:3]:[c;H0;D3;+0:4]:1=[O;D1;H0:5]>>C-C(=O)-[O;H0;D2;+0:1]-[c:2]:[c:3]-[C;H0;D3;+0:4](-Cl)=[O;D1;H0:5].[N;D1;H0:6]#[C:7]-[CH2;D2;+0:8]-[C;H0;D2;+0:9]#[N;H0;D1;+0:10]", 125 | "dimer_only": false, 126 | "necessary_reagent": "" 127 | } 128 | ] -------------------------------------------------------------------------------- /retro_star/packages/rdchiral/test/test_rdchiral.py: -------------------------------------------------------------------------------- 1 | import os, sys, json 2 | sys.path = [os.path.dirname(os.path.dirname((__file__)))] + sys.path 3 | 4 | from rdchiral.main import rdchiralReaction, rdchiralReactants, rdchiralRunText, rdchiralRun 5 | 6 | with open(os.path.join(os.path.dirname(__file__), 'test_rdchiral_cases.json'), 'r') as fid: 7 | test_cases = json.load(fid) 8 | 9 | all_passed = True 10 | for i, test_case in enumerate(test_cases): 11 | 12 | print('\n# Test {:2d}/{}'.format(i+1, len(test_cases))) 13 | 14 | # Directly use SMILES/SMARTS 15 | reaction_smarts = test_case['smarts'] 16 | reactant_smiles = test_case['smiles'] 17 | if rdchiralRunText(reaction_smarts, reactant_smiles) == test_case['expected']: 18 | print(' from text: passed') 19 | else: 20 | print(' from text: failed') 21 | all_passed = False 22 | 23 | # Pre-initialize & repeat 24 | rxn = rdchiralReaction(reaction_smarts) 25 | reactants = rdchiralReactants(reactant_smiles) 26 | if all(rdchiralRun(rxn, reactants) == test_case['expected'] for j in range(3)): 27 | print(' from init: passed') 28 | else: 29 | print(' from init: failed') 30 | all_passed = False 31 | 32 | all_passed = 'All passed!' if all_passed else 'Failed!' 33 | print('\n# Final result: {}'.format(all_passed)) -------------------------------------------------------------------------------- /retro_star/retro_plan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import logging 5 | import time 6 | import pickle 7 | import os 8 | from retro_star.common import args, prepare_starting_molecules, prepare_mlp, \ 9 | prepare_molstar_planner, smiles_to_fp 10 | from retro_star.model import ValueMLP 11 | from retro_star.utils import setup_logger 12 | 13 | 14 | def retro_plan(): 15 | device = torch.device('cuda' if args.gpu >= 0 else 'cpu') 16 | 17 | starting_mols = prepare_starting_molecules(args.starting_molecules) 18 | 19 | routes = pickle.load(open(args.test_routes, 'rb')) 20 | logging.info('%d routes extracted from %s loaded' % (len(routes), 21 | args.test_routes)) 22 | 23 | one_step = prepare_mlp(args.mlp_templates, args.mlp_model_dump) 24 | 25 | # create result folder 26 | if not os.path.exists(args.result_folder): 27 | os.mkdir(args.result_folder) 28 | 29 | if args.use_value_fn: 30 | model = ValueMLP( 31 | n_layers=args.n_layers, 32 | fp_dim=args.fp_dim, 33 | latent_dim=args.latent_dim, 34 | dropout_rate=0.1, 35 | device=device 36 | ).to(device) 37 | model_f = '%s/%s' % (args.save_folder, args.value_model) 38 | logging.info('Loading value nn from %s' % model_f) 39 | model.load_state_dict(torch.load(model_f, map_location=device)) 40 | model.eval() 41 | 42 | def value_fn(mol): 43 | fp = smiles_to_fp(mol, fp_dim=args.fp_dim).reshape(1,-1) 44 | fp = torch.FloatTensor(fp).to(device) 45 | v = model(fp).item() 46 | return v 47 | else: 48 | value_fn = lambda x: 0. 49 | 50 | plan_handle = prepare_molstar_planner( 51 | one_step=one_step, 52 | value_fn=value_fn, 53 | starting_mols=starting_mols, 54 | expansion_topk=args.expansion_topk, 55 | iterations=args.iterations, 56 | viz=args.viz, 57 | viz_dir=args.viz_dir 58 | ) 59 | 60 | result = { 61 | 'succ': [], 62 | 'cumulated_time': [], 63 | 'iter': [], 64 | 'routes': [], 65 | 'route_costs': [], 66 | 'route_lens': [] 67 | } 68 | num_targets = len(routes) 69 | t0 = time.time() 70 | for (i, route) in enumerate(routes): 71 | 72 | target_mol = route[0].split('>')[0] 73 | succ, msg = plan_handle(target_mol, i) 74 | 75 | result['succ'].append(succ) 76 | result['cumulated_time'].append(time.time() - t0) 77 | result['iter'].append(msg[1]) 78 | result['routes'].append(msg[0]) 79 | if succ: 80 | result['route_costs'].append(msg[0].total_cost) 81 | result['route_lens'].append(msg[0].length) 82 | else: 83 | result['route_costs'].append(None) 84 | result['route_lens'].append(None) 85 | 86 | tot_num = i + 1 87 | tot_succ = np.array(result['succ']).sum() 88 | avg_time = (time.time() - t0) * 1.0 / tot_num 89 | avg_iter = np.array(result['iter'], dtype=float).mean() 90 | logging.info('Succ: %d/%d/%d | avg time: %.2f s | avg iter: %.2f' % 91 | (tot_succ, tot_num, num_targets, avg_time, avg_iter)) 92 | 93 | f = open(args.result_folder + '/plan.pkl', 'wb') 94 | pickle.dump(result, f) 95 | f.close() 96 | 97 | if __name__ == '__main__': 98 | np.random.seed(args.seed) 99 | torch.manual_seed(args.seed) 100 | random.seed(args.seed) 101 | setup_logger('plan.log') 102 | 103 | retro_plan() 104 | -------------------------------------------------------------------------------- /retro_star/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import random 5 | import pickle 6 | import torch.nn.functional as F 7 | import logging 8 | from retro_star.common import args 9 | from retro_star.model import ValueMLP 10 | from retro_star.data_loader import ValueDataLoader 11 | from retro_star.trainer import Trainer 12 | from retro_star.utils import setup_logger 13 | 14 | def train(): 15 | device = torch.device('cuda' if args.gpu >= 0 else 'cpu') 16 | 17 | model = ValueMLP( 18 | n_layers=args.n_layers, 19 | fp_dim=args.fp_dim, 20 | latent_dim=args.latent_dim, 21 | dropout_rate=0.1, 22 | device=device 23 | ) 24 | 25 | assert os.path.exists('%s/%s.pt' % (args.value_root, args.value_train)) 26 | 27 | train_data_loader = ValueDataLoader( 28 | fp_value_f='%s/%s' % (args.value_root, args.value_train), 29 | batch_size=args.batch_size 30 | ) 31 | 32 | val_data_loader = ValueDataLoader( 33 | fp_value_f='%s/%s' % (args.value_root, args.value_val), 34 | batch_size=args.batch_size 35 | ) 36 | 37 | trainer = Trainer( 38 | model=model, 39 | train_data_loader=train_data_loader, 40 | val_data_loader=val_data_loader, 41 | n_epochs=args.n_epochs, 42 | lr=args.lr, 43 | save_epoch_int=args.save_epoch_int, 44 | model_folder=args.save_folder, 45 | device=device 46 | ) 47 | 48 | trainer.train() 49 | 50 | 51 | if __name__ == '__main__': 52 | np.random.seed(args.seed) 53 | torch.manual_seed(args.seed) 54 | random.seed(args.seed) 55 | setup_logger('train.log') 56 | 57 | train() 58 | -------------------------------------------------------------------------------- /retro_star/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from retro_star.trainer.trainer import Trainer -------------------------------------------------------------------------------- /retro_star/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | from tqdm import tqdm 7 | import logging 8 | 9 | 10 | class Trainer: 11 | def __init__(self, model, train_data_loader, val_data_loader, n_epochs, lr, 12 | save_epoch_int, model_folder, device): 13 | self.train_data_loader = train_data_loader 14 | self.val_data_loader = val_data_loader 15 | self.n_epochs = n_epochs 16 | self.lr = lr 17 | self.save_epoch_int = save_epoch_int 18 | self.model_folder = model_folder 19 | self.device = device 20 | self.model = model.to(self.device) 21 | 22 | if not os.path.exists(model_folder): 23 | os.makedirs(model_folder) 24 | 25 | self.optim = optim.Adam( 26 | filter(lambda p: p.requires_grad, self.model.parameters()), 27 | lr=lr 28 | ) 29 | 30 | def _pass(self, data, train=True): 31 | self.optim.zero_grad() 32 | 33 | for i in range(len(data)): 34 | data[i] = data[i].to(self.device) 35 | 36 | fps, values, r_costs, t_values, r_fps, r_masks = data 37 | v_pred = self.model(fps) 38 | loss = F.mse_loss(v_pred, values) 39 | 40 | batch_size, n_reactants, fp_dim = r_fps.shape 41 | r_values = self.model(r_fps.view(-1, fp_dim)).view((batch_size, 42 | n_reactants)) 43 | r_values = r_values * r_masks 44 | r_values = torch.sum(r_values, dim=1, keepdim=True) 45 | 46 | """ 47 | r_values: sum of reactant values in a negative reaction sample 48 | r_costs: reaction cost 49 | t_values: true product value 50 | 7. (const): margin, -log(1e-3) 51 | """ 52 | 53 | r_gap = - r_values - r_costs + t_values + 7. 54 | r_gap = torch.clamp(r_gap, min=0) 55 | loss += (r_gap**2).mean() 56 | 57 | if train: 58 | loss.backward() 59 | self.optim.step() 60 | 61 | return loss.item() 62 | 63 | def _train_epoch(self): 64 | self.model.train() 65 | 66 | losses = [] 67 | pbar = tqdm(self.train_data_loader) 68 | for data in pbar: 69 | loss = self._pass(data) 70 | losses.append(loss) 71 | pbar.set_description('[loss: %f]' % (loss)) 72 | 73 | return np.array(losses).mean() 74 | 75 | def _val_epoch(self): 76 | self.model.eval() 77 | 78 | losses = [] 79 | pbar = tqdm(self.val_data_loader) 80 | for data in pbar: 81 | loss = self._pass(data, train=False) 82 | losses.append(loss) 83 | pbar.set_description('[loss: %f]' % (loss)) 84 | 85 | return np.array(losses).mean() 86 | 87 | def train(self): 88 | best_val_loss = np.inf 89 | for epoch in range(self.n_epochs): 90 | self.train_data_loader.reshuffle() 91 | 92 | train_loss = self._train_epoch() 93 | val_loss = self._val_epoch() 94 | logging.info( 95 | '[Epoch %d/%d] [training loss: %f] [validation loss: %f]' % 96 | (epoch, self.n_epochs, train_loss, val_loss) 97 | ) 98 | 99 | # if val_loss < best_val_loss or epoch==self.n_epochs-1: 100 | # best_val_loss = val_loss 101 | # save_file = self.model_folder + '/best_epoch_%d.pt' % epoch 102 | # torch.save(self.model.state_dict(), save_file) 103 | 104 | if (epoch + 1) % self.save_epoch_int == 0: 105 | save_file = self.model_folder + '/epoch_%d.pt' % epoch 106 | torch.save(self.model.state_dict(), save_file) 107 | -------------------------------------------------------------------------------- /retro_star/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import setup_logger 2 | -------------------------------------------------------------------------------- /retro_star/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logger(fname=None, silent=False): 5 | if fname is None: 6 | logging.basicConfig( 7 | level=logging.INFO if not silent else logging.CRITICAL, 8 | format='%(name)-12s: %(levelname)-8s %(message)s', 9 | datefmt='%m-%d %H:%M', 10 | filemode='w' 11 | ) 12 | else: 13 | logging.basicConfig( 14 | level=logging.INFO if not silent else logging.CRITICAL, 15 | format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', 16 | datefmt='%m-%d %H:%M', 17 | filename=fname, 18 | filemode='w' 19 | ) 20 | console = logging.StreamHandler() 21 | console.setLevel(logging.INFO) 22 | formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') 23 | console.setFormatter(formatter) 24 | logging.getLogger('').addHandler(console) 25 | -------------------------------------------------------------------------------- /retro_star_listener.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as mp 3 | import numpy as np 4 | import fcntl 5 | import argparse 6 | import setproctitle 7 | from retro_star.api import RSPlanner 8 | from rdkit import Chem 9 | 10 | 11 | def lock(f): 12 | try: 13 | fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) 14 | except IOError: 15 | return False 16 | return True 17 | 18 | 19 | class Synthesisability(): 20 | def __init__(self): 21 | self.planner = RSPlanner( 22 | gpu=-1, 23 | starting_molecules='./retro_star/dataset/origin_dict.csv', 24 | use_value_fn=True, 25 | iterations=50, 26 | expansion_topk=50) 27 | 28 | def get_syn_rate(self, mol_list): 29 | assert type(mol_list) == list 30 | syn_flag = [] 31 | for i, mol_sml in enumerate(mol_list): 32 | result = self.planner.plan(Chem.MolToSmiles(mol_sml)) 33 | if result: 34 | syn_flag.append(result['succ']) 35 | else: 36 | syn_flag.append(False) 37 | return np.mean(syn_flag) 38 | 39 | 40 | def main(proc_id, filename, output_filename): 41 | syn = Synthesisability() 42 | while(True): 43 | selected_mol = None 44 | with open(filename, 'r') as f: 45 | editable = lock(f) 46 | if editable: 47 | lines = f.readlines() 48 | num_samples = len(lines) 49 | new_lines = [] 50 | for idx, line in enumerate(lines): 51 | splitted_line = line.strip().split() 52 | if len(splitted_line) == 1 and (selected_mol is None): 53 | selected_mol = (idx, splitted_line[0]) 54 | new_line = "{} {}\n".format(splitted_line[0], "working") 55 | else: 56 | new_line = "{}\n".format(" ".join(splitted_line)) 57 | new_lines.append(new_line) 58 | with open(filename, 'w') as fw: 59 | for _new_line in new_lines: 60 | fw.write(_new_line) 61 | fcntl.flock(f, fcntl.LOCK_UN) 62 | if selected_mol is None: 63 | continue 64 | 65 | print("====Working for sample {}/{}====".format(selected_mol[0], num_samples)) 66 | try: 67 | result = syn.planner.plan(selected_mol[1]) 68 | except: 69 | result = None 70 | 71 | while(True): 72 | with open(output_filename, 'a') as f: 73 | editable = lock(f) 74 | if editable: 75 | f.write("{} {} {}\n".format(selected_mol[0], selected_mol[1], "False" if result is None else "True")) 76 | fcntl.flock(f, fcntl.LOCK_UN) 77 | break 78 | 79 | 80 | if __name__ == "__main__": 81 | parser = argparse.ArgumentParser(description='retro* listener') 82 | parser.add_argument('--proc_id', type=int, default=1, help="process id") 83 | parser.add_argument('--filename', type=str, default="generated_samples.txt", help="file name to lister") 84 | parser.add_argument('--output_filename', type=str, default="output_syn.txt", help="file name to output") 85 | args = parser.parse_args() 86 | setproctitle.setproctitle("retro_star_listener") 87 | main(args.proc_id, args.filename, args.output_filename) 88 | -------------------------------------------------------------------------------- /retro_star_listener.sh: -------------------------------------------------------------------------------- 1 | for ((i =1; i <= $1; i++)); 2 | do 3 | python -u retro_star_listener.py --proc_id=$i & 4 | done 5 | --------------------------------------------------------------------------------