├── .gitignore ├── Readme.md ├── configs ├── vp_qm9_cdgs.py └── vp_zinc_cdgs.py ├── data └── raw │ ├── qm9.csv │ ├── qm9_property.csv │ ├── valid_idx_qm9.json │ ├── valid_idx_zinc250k.json │ ├── zinc250k_property.csv │ ├── zinc_800_graphaf.csv │ └── zinc_800_jt.csv ├── datasets.py ├── dpm_solvers.py ├── evaluation ├── __init__.py ├── evaluator.py └── mol_metrics.py ├── exp └── vpsde_qm9_cdgs │ └── checkpoints │ └── checkpoint_200.pth ├── fpscores.pkl.gz ├── losses.py ├── main.py ├── models ├── __init__.py ├── cdgs.py ├── ema.py ├── hmpb.py ├── layers.py ├── transformer_layers.py └── utils.py ├── mol_config.csv ├── requirements.txt ├── run_lib.py ├── sampling.py ├── sascorer.py ├── sde_lib.py ├── utils.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.xml 3 | *.iml 4 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # CDGS 2 | 3 | [Conditional Diffusion Based on Discrete Graph Structures for Molecular Graph Generation](https://arxiv.org/abs/2301.00427) - AAAI 2023 4 | 5 | The extension version: Learning Joint 2D & 3D Diffusion Models for 6 | Complete Molecule Generation [[Paper]](https://arxiv.org/abs/2305.12347) [[Code]](https://github.com/GRAPH-0/JODO). 7 | 8 | ## Dependencies 9 | 10 | * pytorch 1.11 11 | * PyG 2.1 12 | 13 | For NSPDK evaluation: 14 | 15 | ```pip install git+https://github.com/fabriziocosta/EDeN.git --user``` 16 | 17 | Others see requirements.txt . 18 | 19 | 20 | ## Training 21 | 22 | ### QM9 23 | 24 | ```shell 25 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vp_qm9_cdgs.py --mode train --workdir exp/vpsde_qm9_cdgs 26 | ``` 27 | 28 | * Set GPU id via `CUDA_VISIBLE_DEVICES`. 29 | * `workdir` is the directory path to save checkpoints, which can be changed to `YOUR_PATH`. We provide the pretrained checkpoint in `exp/vpsde_qm9_cdgs`. 30 | * More hyperparameters in the config file `configs/vp_qm9_cdgs.py` 31 | 32 | ### ZINC250k 33 | 34 | ```shell 35 | # 256 hidden dimension 36 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vp_zinc_cdgs.py --mode train --workdir exp/vpsde_zinc_cdgs_256 --config.training.n_iters 2500000 37 | 38 | # 128 hidden dimension 39 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vp_zinc_cdgs.py --mode train --workdir exp/vpsde_zinc_cdgs_128 --config.training.batch_size 128 --config.training.eval_batch_size 128 --config.training.n_iters 2500000 40 | ``` 41 | 42 | The pretrained checkpoints are provided in [Google Drive 256ch](https://drive.google.com/drive/folders/1bA_6ldtwF6gTMToZGG7w1dqmOpIuWZJd?usp=sharing) 43 | and [Google Drive 128ch](https://drive.google.com/drive/folders/1WRKkqJyJMue_evkqULRSftyHvRPqir6m?usp=share_link). 44 | 45 | ## Sampling 46 | 47 | ### QM9 48 | 49 | 1. EM sampling with 1000 steps 50 | 51 | ```shell 52 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vp_qm9_cdgs.py --mode eval --workdir exp/vpsde_qm9_cdgs --config.eval.begin_ckpt 200 --config.eval.end_ckpt 200 53 | ``` 54 | 55 | * Add `--config.eval.nspdk` if apply NSPDK evaluation. 56 | * Change iteration steps through `--config.model.num_scales YOUR_STEPS`. 57 | * Change sampling batch size `--config.eval.batch_size` to control GPU memory usage. 58 | 59 | 2. DPM-Solver examples 60 | 61 | ```shell 62 | # Order 3; 50 step 63 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vp_qm9_cdgs.py --mode eval --workdir exp/vpsde_qm9_cdgs --config.eval.begin_ckpt 200 --config.eval.end_ckpt 200 --config.sampling.method dpm3 --config.sampling.ode_step 50 64 | 65 | # Order 2; 20 step 66 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vp_qm9_cdgs.py --mode eval --workdir exp/vpsde_qm9_cdgs --config.eval.begin_ckpt 200 --config.eval.end_ckpt 200 --config.sampling.method dpm2 --config.sampling.ode_step 20 67 | 68 | # Order 1; 10 step 69 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vp_qm9_cdgs.py --mode eval --workdir exp/vpsde_qm9_cdgs --config.eval.begin_ckpt 200 --config.eval.end_ckpt 200 --config.sampling.method dpm1 --config.sampling.ode_step 10 70 | ``` 71 | 72 | ### ZINC250k 73 | 74 | 1. EM sampling examples 75 | 76 | ```shell 77 | # 1000 steps 78 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vp_zinc_cdgs.py --mode eval --workdir exp/vpsde_zinc_cdgs_256 --config.eval.begin_ckpt 250 --config.eval.end_ckpt 250 --config.eval.batch_size 800 79 | 80 | # 200 steps 81 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vp_zinc_cdgs.py --mode eval --workdir exp/vpsde_zinc_cdgs_256 --config.eval.begin_ckpt 250 --config.eval.end_ckpt 250 --config.eval.batch_size 800 --config.model.num_scales 200 82 | ``` 83 | 84 | 2. DPM-Solver examples 85 | 86 | ```shell 87 | # Order 3; 50 step 88 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vp_zinc_cdgs.py --mode eval --workdir exp/vpsde_zinc_cdgs_256 --config.eval.begin_ckpt 250 --config.eval.end_ckpt 250 --config.eval.batch_size 800 --config.sampling.method dpm3 --config.sampling.ode_step 50 89 | ``` 90 | 91 | ### Results 92 | We provide molecules generated by CDGS: [Google Drive](https://drive.google.com/drive/folders/1eafc2bMETEyUVXvD9fW5Fc72v70rvPzH?usp=share_link). 93 | 94 | ## Citation 95 | 96 | ```bibtex 97 | @article{huang2023conditional, 98 | title={Conditional Diffusion Based on Discrete Graph Structures for Molecular Graph Generation}, 99 | author={Huang, Han and Sun, Leilei and Du, Bowen and Lv, Weifeng}, 100 | journal={arXiv preprint arXiv:2301.00427}, 101 | year={2023} 102 | } 103 | ``` 104 | 105 | -------------------------------------------------------------------------------- /configs/vp_qm9_cdgs.py: -------------------------------------------------------------------------------- 1 | """Training GNN on QM9 with continuous VPSDE.""" 2 | 3 | import ml_collections 4 | import torch 5 | 6 | 7 | def get_config(): 8 | config = ml_collections.ConfigDict() 9 | 10 | config.model_type = 'mol_sde' 11 | 12 | # training 13 | config.training = training = ml_collections.ConfigDict() 14 | training.sde = 'vpsde' 15 | training.continuous = True 16 | training.reduce_mean = False 17 | 18 | training.batch_size = 128 19 | training.eval_batch_size = 512 20 | training.n_iters = 1000000 21 | training.snapshot_freq = 5000 # SET Larger values to save less checkpoints 22 | training.log_freq = 200 23 | training.eval_freq = 5000 24 | ## store additional checkpoints for preemption 25 | training.snapshot_freq_for_preemption = 2000 26 | ## produce samples at each snapshot. 27 | training.snapshot_sampling = True 28 | training.likelihood_weighting = False 29 | 30 | # sampling 31 | config.sampling = sampling = ml_collections.ConfigDict() 32 | sampling.method = 'pc' 33 | sampling.predictor = 'euler_maruyama' 34 | sampling.corrector = 'none' 35 | sampling.rtol = 1e-5 36 | sampling.atol = 1e-5 37 | sampling.ode_method = 'rk4' 38 | sampling.ode_step = 0.01 39 | 40 | sampling.n_steps_each = 1 41 | sampling.noise_removal = True 42 | sampling.probability_flow = False 43 | sampling.atom_snr = 0.16 44 | sampling.bond_snr = 0.16 45 | sampling.vis_row = 4 46 | sampling.vis_col = 4 47 | 48 | # evaluation 49 | config.eval = evaluate = ml_collections.ConfigDict() 50 | evaluate.begin_ckpt = 15 51 | evaluate.end_ckpt = 40 52 | evaluate.batch_size = 10000 # 1024 53 | evaluate.enable_sampling = True 54 | evaluate.num_samples = 10000 55 | evaluate.mmd_distance = 'RBF' 56 | evaluate.max_subgraph = False 57 | evaluate.save_graph = False 58 | evaluate.nn_eval = False 59 | evaluate.nspdk = False 60 | 61 | # data 62 | config.data = data = ml_collections.ConfigDict() 63 | data.centered = True 64 | data.dequantization = False 65 | 66 | data.root = 'data' 67 | data.name = 'QM9' 68 | data.split_ratio = 0.8 69 | data.max_node = 9 70 | data.atom_channels = 4 71 | data.bond_channels = 2 72 | data.atom_list = [6, 7, 8, 9] 73 | data.norm = (0.5, 1.0) 74 | 75 | # model 76 | config.model = model = ml_collections.ConfigDict() 77 | model.name = 'CDGS' 78 | model.ema_rate = 0.9999 79 | model.normalization = 'GroupNorm' 80 | model.nonlinearity = 'swish' 81 | model.nf = 64 82 | model.num_gnn_layers = 6 83 | model.conditional = True 84 | model.embedding_type = 'positional' 85 | model.rw_depth = 8 86 | model.graph_layer = 'GINE' 87 | model.edge_th = -1. 88 | model.heads = 8 89 | model.dropout = 0.1 90 | 91 | model.num_scales = 1000 # SDE total steps (N) 92 | model.sigma_min = 0.01 93 | model.sigma_max = 50 94 | model.node_beta_min = 0.1 95 | model.node_beta_max = 20. 96 | model.edge_beta_min = 0.1 97 | model.edge_beta_max = 20. 98 | 99 | # optimization 100 | config.optim = optim = ml_collections.ConfigDict() 101 | optim.weight_decay = 0 102 | optim.optimizer = 'Adam' 103 | optim.lr = 1e-4 104 | optim.beta1 = 0.9 105 | optim.eps = 1e-8 106 | optim.warmup = 1000 107 | optim.grad_clip = 1. # SET Larger values to converge faster, e.g., 10. 108 | 109 | config.seed = 42 110 | config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 111 | 112 | return config 113 | -------------------------------------------------------------------------------- /configs/vp_zinc_cdgs.py: -------------------------------------------------------------------------------- 1 | """Training GNN on ZINC250k with continuous VPSDE.""" 2 | 3 | import ml_collections 4 | import torch 5 | 6 | 7 | def get_config(): 8 | config = ml_collections.ConfigDict() 9 | 10 | config.model_type = 'mol_sde' 11 | 12 | # training 13 | config.training = training = ml_collections.ConfigDict() 14 | training.sde = 'vpsde' 15 | training.continuous = True 16 | training.reduce_mean = False 17 | 18 | training.batch_size = 64 19 | training.eval_batch_size = 64 20 | training.n_iters = 2000000 21 | training.snapshot_freq = 5000 # SET Larger values to save less checkpoints 22 | training.log_freq = 200 23 | training.eval_freq = 5000 24 | ## store additional checkpoints for preemption 25 | training.snapshot_freq_for_preemption = 2000 26 | ## produce samples at each snapshot. 27 | training.snapshot_sampling = True 28 | training.likelihood_weighting = False 29 | 30 | # sampling 31 | config.sampling = sampling = ml_collections.ConfigDict() 32 | sampling.method = 'pc' 33 | sampling.predictor = 'euler_maruyama' 34 | sampling.corrector = 'none' 35 | sampling.rtol = 1e-5 36 | sampling.atol = 1e-5 37 | sampling.ode_method = 'rk4' 38 | sampling.ode_step = 0.01 39 | 40 | sampling.n_steps_each = 1 41 | sampling.noise_removal = True 42 | sampling.probability_flow = False 43 | sampling.atom_snr = 0.16 44 | sampling.bond_snr = 0.16 45 | sampling.vis_row = 4 46 | sampling.vis_col = 4 47 | 48 | # evaluation 49 | config.eval = evaluate = ml_collections.ConfigDict() 50 | evaluate.begin_ckpt = 15 51 | evaluate.end_ckpt = 40 52 | evaluate.batch_size = 2000 # 1024 53 | evaluate.enable_sampling = True 54 | evaluate.num_samples = 10000 55 | evaluate.mmd_distance = 'RBF' 56 | evaluate.max_subgraph = False 57 | evaluate.save_graph = False 58 | evaluate.nn_eval = False 59 | evaluate.nspdk = False 60 | 61 | # data 62 | config.data = data = ml_collections.ConfigDict() 63 | data.centered = True 64 | data.dequantization = False 65 | 66 | data.root = 'data' 67 | data.name = 'ZINC250K' 68 | data.split_ratio = 0.8 69 | data.max_node = 38 70 | data.atom_channels = 9 71 | data.bond_channels = 2 72 | data.atom_list = [6, 7, 8, 9, 15, 16, 17, 35, 53] 73 | data.norm = (0.5, 1.0) 74 | 75 | # model 76 | config.model = model = ml_collections.ConfigDict() 77 | model.name = 'CDGS' 78 | model.ema_rate = 0.9999 79 | model.normalization = 'GroupNorm' 80 | model.nonlinearity = 'swish' 81 | model.nf = 256 82 | model.num_gnn_layers = 10 83 | model.conditional = True 84 | model.embedding_type = 'positional' 85 | model.rw_depth = 20 86 | model.graph_layer = 'GINE' 87 | model.edge_th = -1. 88 | model.heads = 8 89 | model.dropout = 0.1 90 | 91 | model.num_scales = 1000 # SDE total steps (N) 92 | model.sigma_min = 0.01 93 | model.sigma_max = 50 94 | model.node_beta_min = 0.1 95 | model.node_beta_max = 20. 96 | model.edge_beta_min = 0.1 97 | model.edge_beta_max = 20. 98 | 99 | # optimization 100 | config.optim = optim = ml_collections.ConfigDict() 101 | optim.weight_decay = 0 102 | optim.optimizer = 'Adam' 103 | optim.lr = 1e-4 104 | optim.beta1 = 0.9 105 | optim.eps = 1e-8 106 | optim.warmup = 1000 107 | optim.grad_clip = 1. # SET Larger values to converge faster, e.g., 10. 108 | 109 | config.seed = 42 110 | config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 111 | 112 | return config 113 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import torch 3 | import json 4 | import os 5 | import numpy as np 6 | import os.path as osp 7 | import pandas as pd 8 | import pickle as pk 9 | from itertools import repeat 10 | from rdkit import Chem 11 | import torch_geometric.transforms as T 12 | from torch_geometric.data import Data, InMemoryDataset, download_url 13 | from torch_geometric.utils import from_networkx, degree, to_networkx 14 | 15 | 16 | bond_type_to_int = {Chem.BondType.SINGLE: 0, Chem.BondType.DOUBLE: 1, Chem.BondType.TRIPLE: 2} 17 | 18 | 19 | def get_data_scaler(config): 20 | """Data normalizer. Assume data are always in [0, 1].""" 21 | 22 | centered = config.data.centered 23 | if hasattr(config.data, "shift"): 24 | shift = config.data.shift 25 | else: 26 | shift = 0. 27 | 28 | if hasattr(config.data, 'norm'): 29 | atom_norm, bond_norm = config.data.norm 30 | assert shift == 0. 31 | 32 | def scale_fn(x, atom=False): 33 | if centered: 34 | x = x * 2. - 1. 35 | else: 36 | x = x 37 | if atom: 38 | x = x * atom_norm 39 | else: 40 | x = x * bond_norm 41 | return x 42 | return scale_fn 43 | else: 44 | if centered: 45 | # Rescale to [-1, 1] 46 | return lambda x: x * 2. - 1. + shift 47 | else: 48 | assert shift == 0. 49 | return lambda x: x 50 | 51 | 52 | def get_data_inverse_scaler(config): 53 | """Inverse data normalizer.""" 54 | 55 | centered = config.data.centered 56 | if hasattr(config.data, "shift"): 57 | shift = config.data.shift 58 | else: 59 | shift = 0. 60 | 61 | if hasattr(config.data, 'norm'): 62 | atom_norm, bond_norm = config.data.norm 63 | 64 | assert shift == 0. 65 | 66 | def inverse_scale_fn(x, atom=False): 67 | if atom: 68 | x = x / atom_norm 69 | else: 70 | x = x / bond_norm 71 | if centered: 72 | x = (x + 1.) / 2. 73 | else: 74 | x = x 75 | return x 76 | 77 | return inverse_scale_fn 78 | else: 79 | if centered: 80 | # Rescale [-1, 1] to [0, 1] 81 | return lambda x: (x + 1. - shift) / 2. 82 | else: 83 | assert shift == 0. 84 | return lambda x: x 85 | 86 | 87 | def networkx_graphs(dataset): 88 | return [to_networkx(dataset[i], to_undirected=True, remove_self_loops=True) for i in range(len(dataset))] 89 | 90 | 91 | class StructureDataset(InMemoryDataset): 92 | def __init__(self, 93 | root, 94 | dataset_name, 95 | transform=None, 96 | pre_transform=None, 97 | pre_filter=None): 98 | 99 | self.dataset_name = dataset_name 100 | 101 | super(StructureDataset, self).__init__(root, transform, pre_transform, pre_filter) 102 | 103 | if not os.path.exists(self.raw_paths[0]): 104 | raise ValueError("Without raw files.") 105 | if os.path.exists(self.processed_paths[0]): 106 | self.data, self.slices = torch.load(self.processed_paths[0]) 107 | else: 108 | self.process() 109 | 110 | @property 111 | def raw_file_names(self): 112 | return [self.dataset_name + '.pkl'] 113 | 114 | @property 115 | def processed_file_names(self): 116 | return [self.dataset_name + '.pt'] 117 | 118 | @property 119 | def num_node_features(self): 120 | if self.data.x is None: 121 | return 0 122 | return self.data.x.size(1) 123 | 124 | def __repr__(self) -> str: 125 | arg_repr = str(len(self)) if len(self) > 1 else '' 126 | return f'{self.dataset_name}({arg_repr})' 127 | 128 | def process(self): 129 | # Read data into 'Data' list 130 | input_path = self.raw_paths[0] 131 | with open(input_path, 'rb') as f: 132 | graphs_nx = pk.load(f) 133 | data_list = [from_networkx(G) for G in graphs_nx] 134 | 135 | if self.pre_filter is not None: 136 | data_list = [data for data in data_list if self.pre_filter(data)] 137 | 138 | if self.pre_transform is not None: 139 | data_list = [self.pre_transform(data) for data in data_list] 140 | 141 | self.data, self.slices = self.collate(data_list) 142 | torch.save((self.data, self.slices), self.processed_paths[0]) 143 | 144 | @torch.no_grad() 145 | def max_degree(self): 146 | data_list = [self.get(i) for i in range(len(self))] 147 | 148 | def graph_max_degree(g_data): 149 | return max(degree(g_data.edge_index[1], num_nodes=g_data.num_nodes)) 150 | 151 | degree_list = [graph_max_degree(data) for data in data_list] 152 | return int(max(degree_list).item()) 153 | 154 | def n_node_pmf(self): 155 | node_list = [self.get(i).num_nodes for i in range(len(self))] 156 | n_node_pmf = np.bincount(node_list) 157 | n_node_pmf = n_node_pmf / n_node_pmf.sum() 158 | return n_node_pmf 159 | 160 | 161 | class MolDataset(InMemoryDataset): 162 | # from DIG: Dive into Graphs 163 | """ 164 | A Pytorch Geometric data interface for datasets used in molecule generation. 165 | 166 | .. note:: 167 | Some datasets may not come with any node labels, like :obj:`moses`. 168 | Since they don't have any properties in the original data file. The process of the 169 | dataset can only save the current input property and will load the same property 170 | label when the processed dataset is used. You can change the augment :obj:`processed_filename` 171 | to re-process the dataset with intended property. 172 | 173 | Args: 174 | root (string, optional): Root directory where the dataset should be saved. 175 | name (string, optional): The name of the dataset. Available dataset names are as follows: 176 | :obj:`zinc250k`, :obj:`zinc_800_graphaf`, :obj:`zinc_800_jt`, 177 | :obj:`zinc250k_property`, :obj:`qm9_property`, :obj:`qm9`, :obj:`moses`. 178 | bond_ch (int): The channels for bond matrices. {1, 2, 4} 179 | prop_name (string, optional): The molecular property desired and used as the optimization target. 180 | (eg. "obj:`penalized_logp`) 181 | conf_dict (dictionary, optional): dictionary that stores all the configuration for the corresponding dataset 182 | transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` 183 | object and returns a transformed version. The data object will be transformed before every access. 184 | pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` 185 | object and returns a transformed version.The data object will be transformed before being saved to disk. 186 | pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and 187 | returns a boolean value, indicating whether the data object should be included in the final dataset. 188 | 189 | """ 190 | 191 | def __init__(self, root, name, bond_ch, prop_name='penalized_logp', 192 | conf_dict=None, transform=None, pre_transform=None, 193 | pre_filter=None, processed_filename='data.pt'): 194 | 195 | self.processed_filename = processed_filename 196 | self.root = root 197 | self.name = name 198 | self.prop_name = prop_name 199 | self.bond_ch = bond_ch 200 | 201 | if conf_dict is None: 202 | config_file = pd.read_csv(os.path.join(os.path.dirname(__file__), 'mol_config.csv'), index_col=0) 203 | if self.name not in config_file: 204 | error_mssg = 'Invalid dataset name {}.\n'.format(self.name) 205 | error_mssg += 'Available datasets are as follows:\n' 206 | error_mssg += '\n'.join(config_file.keys()) 207 | raise ValueError(error_mssg) 208 | config = config_file[self.name] 209 | else: 210 | config = conf_dict 211 | 212 | self.url = config['url'] 213 | self.available_prop = str(prop_name) in ast.literal_eval(config['prop_list']) 214 | self.smile_col = config['smile'] 215 | self.num_max_node = int(config['num_max_node']) 216 | self.atom_list = ast.literal_eval(config['atom_list']) 217 | 218 | super(MolDataset, self).__init__(root, transform, pre_transform, pre_filter) 219 | if not osp.exists(self.raw_paths[0]): 220 | self.download() 221 | if osp.exists(self.processed_paths[0]): 222 | self.data, self.slices, self.all_smiles = torch.load(self.processed_paths[0]) 223 | else: 224 | self.process() 225 | 226 | @property 227 | def raw_dir(self): 228 | name = 'raw' 229 | return osp.join(self.root, name) 230 | 231 | @property 232 | def processed_dir(self): 233 | name = 'processed' 234 | return osp.join(self.root, self.name, name) 235 | 236 | @property 237 | def raw_file_names(self): 238 | name = self.name + '.csv' 239 | return name 240 | 241 | @property 242 | def processed_file_names(self): 243 | return self.processed_filename 244 | 245 | def download(self): 246 | print('making raw files:', self.raw_dir) 247 | if not osp.exists(self.raw_dir): 248 | os.makedirs(self.raw_dir) 249 | url = self.url 250 | path = download_url(url, self.raw_dir) 251 | 252 | def process(self): 253 | """Process the dataset from raw data file to the :obj:`self.processed_dir` folder.""" 254 | 255 | print('Processing...') 256 | self.data, self.slices = self.pre_process() 257 | 258 | if self.pre_filter is not None: 259 | data_list = [self.get(idx) for idx in range(len(self))] 260 | data_list = [data for data in data_list if self.pre_filter(data)] 261 | self.data, self.slices = self.collate(data_list) 262 | 263 | if self.pre_transform is not None: 264 | data_list = [self.get(idx) for idx in range(len(self))] 265 | data_list = [self.pre_transform(data) for data in data_list] 266 | self.data, self.slices = self.collate(data_list) 267 | 268 | print('making processed files:', self.processed_dir) 269 | if not osp.exists(self.processed_dir): 270 | os.makedirs(self.processed_dir) 271 | 272 | torch.save((self.data, self.slices, self.all_smiles), self.processed_paths[0]) 273 | print('Done!') 274 | 275 | def __repr__(self): 276 | return '{}({})'.format(self.name, len(self)) 277 | 278 | def get(self, idx): 279 | """Get the data object at index idx. """ 280 | 281 | data = self.data.__class__() 282 | 283 | if hasattr(self.data, '__num_nodes__'): 284 | data.num_nodes = self.data.__num_nodes__[idx] 285 | 286 | for key in self.data.keys: 287 | item, slices = self.data[key], self.slices[key] 288 | if torch.is_tensor(item): 289 | s = list(repeat(slice(None), item.dim())) 290 | s[self.data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) 291 | else: 292 | s = slice(slices[idx], slices[idx + 1]) 293 | data[key] = item[s] 294 | 295 | data['smile'] = self.all_smiles[idx] 296 | 297 | if self.bond_ch == 1: 298 | with torch.no_grad(): 299 | adj = data.adj 300 | ch = adj.shape[0] 301 | adj = torch.argmax(adj, dim=0) 302 | adj[adj == 3] = -1 303 | adj = (adj + 1).float() 304 | data['adj'] = adj.unsqueeze(0) / (ch - 1) 305 | elif self.bond_ch == 2: 306 | with torch.no_grad(): 307 | adj = data.adj 308 | ch = adj.shape[0] 309 | adj = torch.argmax(adj, dim=0) 310 | adj[adj == 3] = -1 311 | adj_1 = ((adj + 1) != 0).float() 312 | adj = (adj + 1).float() 313 | adj = torch.stack([adj / (ch - 1), adj_1]) 314 | data['adj'] = adj 315 | 316 | return data 317 | 318 | def pre_process(self): 319 | input_path = self.raw_paths[0] 320 | input_df = pd.read_csv(input_path, sep=',', dtype='str') 321 | smile_list = list(input_df[self.smile_col]) 322 | if self.available_prop: 323 | prop_list = list(input_df[self.prop_name]) 324 | 325 | self.all_smiles = smile_list 326 | data_list = [] 327 | 328 | for i in range(len(smile_list)): 329 | 330 | smile = smile_list[i] 331 | mol = Chem.MolFromSmiles(smile) 332 | Chem.Kekulize(mol) 333 | num_atom = mol.GetNumAtoms() 334 | if num_atom > self.num_max_node: 335 | continue 336 | else: 337 | # atoms 338 | atom_array = np.zeros((self.num_max_node, len(self.atom_list)), dtype=np.float32) 339 | atom_mask = np.zeros(self.num_max_node, dtype=np.float32) 340 | atom_mask[:num_atom] = 1. 341 | 342 | atom_idx = 0 343 | for atom in mol.GetAtoms(): 344 | atom_feature = atom.GetAtomicNum() 345 | atom_array[atom_idx, self.atom_list.index(atom_feature)] = 1 346 | atom_idx += 1 347 | 348 | x = torch.tensor(atom_array) 349 | 350 | # bonds 351 | adj_array = np.zeros([4, self.num_max_node, self.num_max_node], dtype=np.float32) 352 | for bond in mol.GetBonds(): 353 | bond_type = bond.GetBondType() 354 | 355 | ch = bond_type_to_int[bond_type] 356 | i = bond.GetBeginAtomIdx() 357 | j = bond.GetEndAtomIdx() 358 | adj_array[ch, i, j] = 1. 359 | adj_array[ch, j, i] = 1. 360 | 361 | adj_array[-1, :, :] = 1 - np.sum(adj_array, axis=0) 362 | # adj_array += np.eye(self.num_max_node) 363 | 364 | data = Data(x=x) 365 | data.adj = torch.tensor(adj_array) 366 | data.num_atom = num_atom 367 | data.atom_mask = torch.tensor(atom_mask) 368 | if self.available_prop: 369 | data.y = torch.tensor([float(prop_list[i])]) 370 | data_list.append(data) 371 | 372 | data, slices = self.collate(data_list) 373 | return data, slices 374 | 375 | def get_split_idx(self): 376 | """ 377 | Gets the train-valid set split indices of the dataset. 378 | Return: 379 | A dictionary for training-validation split with key `train_idx` and `valid_idx`. 380 | """ 381 | 382 | if self.name.find('zinc250k') != -1: 383 | path = os.path.join(self.root, 'raw/valid_idx_zinc250k.json') 384 | with open(path) as f: 385 | valid_idx = json.load(f) 386 | 387 | elif self.name.find('qm9') != -1: 388 | path = os.path.join(self.root, 'raw/valid_idx_qm9.json') 389 | with open(path) as f: 390 | valid_idx = json.load(f)['valid_idxs'] 391 | valid_idx = list(map(int, valid_idx)) 392 | 393 | else: 394 | print('No available split file for this dataset, please check.') 395 | return None 396 | 397 | train_idx = list(set(np.arange(self.__len__())).difference(set(valid_idx))) 398 | 399 | return {'train_idx': torch.tensor(train_idx, dtype=torch.long), 400 | 'valid_idx': torch.tensor(valid_idx, dtype=torch.long)} 401 | 402 | def n_node_pmf(self): 403 | # if 'qm9' in self.name: 404 | # n_node_pmf = [0. for _ in range(10)] 405 | # n_node_pmf[-1] = 1. 406 | # return np.array(n_node_pmf) 407 | node_list = [self.get(i).num_atom.item() for i in range(len(self))] 408 | n_node_pmf = np.bincount(node_list) 409 | n_node_pmf = n_node_pmf / n_node_pmf.sum() 410 | return n_node_pmf 411 | 412 | 413 | class QM9(MolDataset): 414 | def __init__(self, root='./', bond_ch=4, prop_name='penalized_logp', conf_dict=None, transform=None, 415 | pre_transform=None, pre_filter=None, processed_filename='data.pt'): 416 | name = 'qm9_property' 417 | super(QM9, self).__init__(root, name, bond_ch, prop_name, conf_dict, transform, pre_transform, pre_filter, 418 | processed_filename) 419 | 420 | 421 | class ZINC250k(MolDataset): 422 | """ 423 | The attributes of the output data: 424 | x: the node features. 425 | y: the property labels for the graph. 426 | adj: the edge features in the form of dense adjacent matrices. 427 | batch: the assignment vector which maps each node to its respective graph identifier and can help reconstruct 428 | single graphs. 429 | num_atom: number of atoms for each graph. 430 | smile: original SMILE sequences for the graphs. 431 | """ 432 | 433 | def __init__(self, root='./', bond_ch=4, prop_name='penalized_logp', conf_dict=None, transform=None, 434 | pre_transform=None, pre_filter=None, processed_filename='data.pt'): 435 | name = 'zinc250k_property' 436 | super(ZINC250k, self).__init__(root, name, bond_ch, prop_name, conf_dict, transform, pre_transform, pre_filter, 437 | processed_filename) 438 | 439 | 440 | class MOSES(MolDataset): 441 | def __init__(self, root='./', bond_ch=4, prop_name=None, conf_dict=None, transform=None, pre_transform=None, pre_filter=None, 442 | processed_filename='data.pt'): 443 | name = 'moses' 444 | super(MOSES, self).__init__(root, name, bond_ch, prop_name, conf_dict, transform, pre_transform, pre_filter, 445 | processed_filename) 446 | 447 | 448 | class ZINC800(MolDataset): 449 | """ 450 | ZINC800 contains 800 selected molecules with lowest penalized logP scores. While method `jt` selects from the test 451 | set and `graphaf` selects from the train set. 452 | """ 453 | 454 | def __init__(self, root='./', method='jt', bond_ch=4, prop_name='penalized_logp', conf_dict=None, transform=None, 455 | pre_transform=None, pre_filter=None, processed_filename='data.pt'): 456 | name = 'zinc_800' 457 | name = name + '_' + method 458 | 459 | super(ZINC800, self).__init__(root, name, bond_ch, prop_name, conf_dict, transform, pre_transform, pre_filter, 460 | processed_filename) 461 | 462 | 463 | def get_opt_dataset(config): 464 | """Create data loaders for similarity constrained molecule optimization. 465 | 466 | Args: 467 | config: A ml_collection.ConfigDict parsed from config files. 468 | 469 | Returns: 470 | dataset 471 | """ 472 | transform = T.Compose([ 473 | T.ToDevice(config.device) 474 | ]) 475 | 476 | assert 'zinc_800' in config.data.name 477 | 478 | if 'jt' in config.data.name: 479 | dataset = ZINC800(config.data.root, 'jt', bond_ch=config.data.bond_channels, transform=transform) 480 | elif 'graphaf' in config.data.name: 481 | dataset = ZINC800(config.data.root, 'graphaf', bond_ch=config.data.bond_channels, transform=transform) 482 | else: 483 | error_mssg = 'Invalid method type {}.\n'.format(config.data.name) 484 | error_mssg += 'Available datasets are as follows:\n' 485 | error_mssg += '\n'.join(['jt', 'graphaf']) 486 | raise ValueError(error_mssg) 487 | 488 | return dataset 489 | 490 | 491 | def get_dataset(config): 492 | """Create data loaders for training and evaluation. 493 | 494 | Args: 495 | config: A ml_collection.ConfigDict parsed from config files. 496 | 497 | Returns: 498 | train_ds, eval_ds, test_ds, n_node_pmf 499 | """ 500 | # define data transforms 501 | transform = T.Compose([ 502 | # T.ToDense(config.data.max_node), 503 | T.ToDevice(config.device) 504 | ]) 505 | 506 | # Build up data iterators 507 | if config.model_type == 'mol_sde' or config.model_type == 'sep_mol_sde': 508 | if config.data.name == 'QM9': 509 | dataset = QM9(config.data.root, bond_ch=config.data.bond_channels, transform=transform) 510 | elif config.data.name == 'ZINC250K': 511 | if hasattr(config.data, 'property'): 512 | property = config.data.property 513 | else: 514 | property = 'penalized_logp' 515 | if property == 'qed': 516 | dataset = ZINC250k(config.data.root, prop_name=property, bond_ch=config.data.bond_channels, 517 | transform=transform, processed_filename='qed_data.pt') 518 | else: 519 | dataset = ZINC250k(config.data.root, prop_name=property, 520 | bond_ch=config.data.bond_channels, transform=transform) 521 | elif config.data.name == 'MOSES': 522 | dataset = MOSES(config.data.root, bond_ch=config.data.bond_channels, transform=transform) 523 | else: 524 | raise ValueError('Undefined dataset name.') 525 | 526 | all_smiles = dataset.all_smiles 527 | splits = dataset.get_split_idx() 528 | train_idx = splits['train_idx'] 529 | test_idx = splits['valid_idx'] 530 | 531 | train_dataset = dataset[train_idx] 532 | train_dataset.sub_smiles = [all_smiles[idx] for idx in train_idx] 533 | test_dataset = dataset[test_idx] 534 | test_dataset.sub_smiles = [all_smiles[idx] for idx in test_idx] 535 | 536 | eval_idx = train_idx[torch.randperm(len(train_idx))[:len(test_idx)]] 537 | eval_dataset = dataset[eval_idx] 538 | eval_dataset.sub_smiles = [all_smiles[idx] for idx in eval_idx] 539 | else: 540 | dataset = StructureDataset(config.data.root, config.data.name, transform=transform) 541 | num_train = int(len(dataset) * config.data.split_ratio) 542 | num_test = len(dataset) - num_train 543 | train_dataset = dataset[:num_train] 544 | eval_dataset = dataset[:num_test] 545 | test_dataset = dataset[num_train:] 546 | 547 | n_node_pmf = train_dataset.n_node_pmf() 548 | 549 | return train_dataset, eval_dataset, test_dataset, n_node_pmf 550 | -------------------------------------------------------------------------------- /dpm_solvers.py: -------------------------------------------------------------------------------- 1 | # DPM solvers: stiff semi-linear ODE 2 | # Note: hyperparams of Atom_SDE and Bond_SDE should keep the same for DPM-Solver-1, DPM-Solver-2 and DPM-Solver-3 !!! 3 | 4 | import torch 5 | import numpy as np 6 | import functools 7 | 8 | from models.utils import get_multi_theta_fn, get_multi_score_fn, get_theta_fn 9 | 10 | 11 | def sample_nodes(n_nodes_pmf, atom_shape, device): 12 | n_nodes = torch.multinomial(n_nodes_pmf, atom_shape[0], replacement=True) 13 | atom_mask = torch.zeros((atom_shape[0], atom_shape[1]), device=device) 14 | for i in range(atom_shape[0]): 15 | atom_mask[i][:n_nodes[i]] = 1. 16 | bond_mask = (atom_mask[:, None, :] * atom_mask[:, :, None]).unsqueeze(1) 17 | bond_mask = torch.tril(bond_mask, -1) 18 | bond_mask = bond_mask + bond_mask.transpose(-1, -2) 19 | return n_nodes, atom_mask, bond_mask 20 | 21 | 22 | def expand_dim(x, n_dim): 23 | if n_dim == 3: 24 | x = x[:, None, None] 25 | elif n_dim == 4: 26 | x = x[:, None, None, None] 27 | return x 28 | 29 | 30 | def dpm1_update(x_last, t_last, t_i, sde, theta): 31 | # dpm_solver 1 order update function 32 | expand_fn = functools.partial(expand_dim, n_dim=len(x_last.shape)) 33 | 34 | lambda_i, alpha_i, std_i = sde.log_snr(t_i) 35 | lambda_last, alpha_last, _ = sde.log_snr(t_last) 36 | h_i = lambda_i - lambda_last 37 | 38 | x_i = expand_fn(alpha_i / alpha_last) * x_last - expand_fn(std_i * torch.expm1(h_i)) * theta 39 | return x_i 40 | 41 | 42 | def dpm_mol_solver_1(atom_sde, bond_sde, theta_fn, x_atom_last, x_bond_last, 43 | t_last, t_i, atom_mask, bond_mask): 44 | # run solver func once 45 | 46 | vec_t_last = torch.ones(x_atom_last.shape[0], device=x_atom_last.device) * t_last 47 | vec_t_i = torch.ones(x_atom_last.shape[0], device=x_atom_last.device) * t_i 48 | atom_fn = functools.partial(expand_dim, n_dim=len(x_atom_last.shape)) 49 | bond_fn = functools.partial(expand_dim, n_dim=len(x_bond_last.shape)) 50 | 51 | lambda_i, alpha_i, std_i = atom_sde.log_snr(vec_t_i) 52 | lambda_last, alpha_last, _ = atom_sde.log_snr(vec_t_last) 53 | h_i = lambda_i - lambda_last 54 | 55 | atom_theta, bond_theta = theta_fn((x_atom_last, x_bond_last), vec_t_last, atom_mask=atom_mask, bond_mask=bond_mask) 56 | tmp_linear = alpha_i / alpha_last 57 | tmp_nonlinear = std_i * torch.expm1(h_i) 58 | x_atom_i = atom_fn(tmp_linear) * x_atom_last - atom_fn(tmp_nonlinear) * atom_theta 59 | x_bond_i = bond_fn(tmp_linear) * x_bond_last - bond_fn(tmp_nonlinear) * bond_theta 60 | 61 | return x_atom_i, x_bond_i 62 | 63 | 64 | def dpm_mol_solver_2(atom_sde, bond_sde, theta_fn, x_atom_last, x_bond_last, 65 | t_last, t_i, atom_mask, bond_mask, r1=0.5): 66 | vec_t_last = torch.ones(x_atom_last.shape[0], device=x_atom_last.device) * t_last 67 | vec_t_i = torch.ones(x_atom_last.shape[0], device=x_atom_last.device) * t_i 68 | atom_fn = functools.partial(expand_dim, n_dim=len(x_atom_last.shape)) 69 | bond_fn = functools.partial(expand_dim, n_dim=len(x_bond_last.shape)) 70 | 71 | lambda_i, alpha_i, std_i = atom_sde.log_snr(vec_t_i) 72 | lambda_last, alpha_last, _ = atom_sde.log_snr(vec_t_last) 73 | h_i = lambda_i - lambda_last 74 | 75 | s_i = atom_sde.lambda2t(lambda_last + r1 * h_i) 76 | _, alpha_si, std_si = atom_sde.log_snr(s_i) 77 | atom_theta_0, bond_theta_0 = theta_fn((x_atom_last, x_bond_last), vec_t_last, 78 | atom_mask=atom_mask, bond_mask=bond_mask) 79 | 80 | tmp_lin = alpha_si / alpha_last 81 | tmp_nonlin = std_si * torch.expm1(r1 * h_i) 82 | u_atom_i = atom_fn(tmp_lin) * x_atom_last - atom_fn(tmp_nonlin) * atom_theta_0 83 | u_bond_i = bond_fn(tmp_lin) * x_bond_last - bond_fn(tmp_nonlin) * bond_theta_0 84 | 85 | atom_theta_si, bond_theta_si = theta_fn((u_atom_i, u_bond_i), s_i, atom_mask=atom_mask, bond_mask=bond_mask) 86 | 87 | tmp_lin = alpha_i / alpha_last 88 | tmp_nonlin1 = std_i * torch.expm1(h_i) 89 | tmp_nonlin2 = (std_i / (2. * r1)) * torch.expm1(h_i) 90 | x_atom_i = atom_fn(tmp_lin) * x_atom_last - atom_fn(tmp_nonlin1) * atom_theta_0 - \ 91 | atom_fn(tmp_nonlin2) * (atom_theta_si - atom_theta_0) 92 | x_bond_i = bond_fn(tmp_lin) * x_bond_last - bond_fn(tmp_nonlin1) * bond_theta_0 - \ 93 | bond_fn(tmp_nonlin2) * (bond_theta_si - bond_theta_0) 94 | 95 | return x_atom_i, x_bond_i 96 | 97 | 98 | def dpm_mol_solver_3(atom_sde, bond_sde, theta_fn, x_atom_last, x_bond_last, 99 | t_last, t_i, atom_mask, bond_mask, r1=1./3., r2=2./3.): 100 | vec_t_last = torch.ones(x_atom_last.shape[0], device=x_atom_last.device) * t_last 101 | vec_t_i = torch.ones(x_atom_last.shape[0], device=x_atom_last.device) * t_i 102 | atom_fn = functools.partial(expand_dim, n_dim=len(x_atom_last.shape)) 103 | bond_fn = functools.partial(expand_dim, n_dim=len(x_bond_last.shape)) 104 | 105 | lambda_i, alpha_i, std_i = atom_sde.log_snr(vec_t_i) 106 | lambda_last, alpha_last, _ = atom_sde.log_snr(vec_t_last) 107 | h_i = lambda_i - lambda_last 108 | 109 | s1 = atom_sde.lambda2t(lambda_last + r1 * h_i) 110 | s2 = atom_sde.lambda2t(lambda_last + r2 * h_i) 111 | 112 | _, alpha_s1, std_s1 = atom_sde.log_snr(s1) 113 | _, alpha_s2, std_s2 = atom_sde.log_snr(s2) 114 | 115 | atom_theta_0, bond_theta_0 = theta_fn((x_atom_last, x_bond_last), vec_t_last, 116 | atom_mask=atom_mask, bond_mask=bond_mask) 117 | 118 | tmp_lin = alpha_s1 / alpha_last 119 | tmp_nonlin = std_s1 * torch.expm1(r1 * h_i) 120 | u_atom_1 = atom_fn(tmp_lin) * x_atom_last - atom_fn(tmp_nonlin) * atom_theta_0 121 | u_bond_1 = bond_fn(tmp_lin) * x_bond_last - bond_fn(tmp_nonlin) * bond_theta_0 122 | 123 | atom_theta_s1, bond_theta_s1 = theta_fn((u_atom_1, u_bond_1), s1, atom_mask=atom_mask, bond_mask=bond_mask) 124 | D_atom_1 = atom_theta_s1 - atom_theta_0 125 | D_bond_1 = bond_theta_s1 - bond_theta_0 126 | 127 | tmp_lin = alpha_s2 / alpha_last 128 | tmp_nonlin1 = std_s2 * torch.expm1(r2 * h_i) 129 | tmp_nonlin2 = (std_s2 * r2 / r1) * (torch.expm1(r2 * h_i) / (r2 * h_i) - 1) 130 | u_atom_2 = atom_fn(tmp_lin) * x_atom_last - atom_fn(tmp_nonlin1) * atom_theta_0 - atom_fn(tmp_nonlin2) * D_atom_1 131 | u_bond_2 = bond_fn(tmp_lin) * x_bond_last - bond_fn(tmp_nonlin1) * bond_theta_0 - bond_fn(tmp_nonlin2) * D_bond_1 132 | 133 | atom_theta_s2, bond_theta_s2 = theta_fn((u_atom_2, u_bond_2), s2, atom_mask=atom_mask, bond_mask=bond_mask) 134 | D_atom_2 = atom_theta_s2 - atom_theta_0 135 | D_bond_2 = bond_theta_s2 - bond_theta_0 136 | 137 | tmp_lin = alpha_i / alpha_last 138 | tmp_nonlin1 = std_i * torch.expm1(h_i) 139 | tmp_nonlin2 = (std_i / r2) * (torch.expm1(h_i) / h_i - 1) 140 | x_atom_i = atom_fn(tmp_lin) * x_atom_last - atom_fn(tmp_nonlin1) * atom_theta_0 - atom_fn(tmp_nonlin2) * D_atom_2 141 | x_bond_i = bond_fn(tmp_lin) * x_bond_last - bond_fn(tmp_nonlin1) * bond_theta_0 - bond_fn(tmp_nonlin2) * D_bond_2 142 | 143 | return x_atom_i, x_bond_i 144 | 145 | 146 | def dpm_solver_3(sde, theta_fn, x_last, t_last, t_i, mask, r1=1./3., r2=2./3.): 147 | vec_t_last = torch.ones(x_last.shape[0], device=x_last.device) * t_last 148 | vec_t_i = torch.ones(x_last.shape[0], device=x_last.device) * t_i 149 | expand_fn = functools.partial(expand_dim, n_dim=len(x_last.shape)) 150 | 151 | lambda_i, alpha_i, std_i = sde.log_snr(vec_t_i) 152 | lambda_last, alpha_last, _ = sde.log_snr(vec_t_last) 153 | h_i = lambda_i - lambda_last 154 | 155 | s1 = sde.lambda2t(lambda_last + r1 * h_i) 156 | s2 = sde.lambda2t(lambda_last + r2 * h_i) 157 | 158 | _, alpha_s1, std_s1 = sde.log_snr(s1) 159 | _, alpha_s2, std_s2 = sde.log_snr(s2) 160 | 161 | theta_0 = theta_fn(x_last, vec_t_last, mask=mask) 162 | 163 | tmp_lin = alpha_s1 / alpha_last 164 | tmp_nonlin = std_s1 * torch.expm1(r1 * h_i) 165 | u_1 = expand_fn(tmp_lin) * x_last - expand_fn(tmp_nonlin) * theta_0 166 | 167 | theta_s1 = theta_fn(u_1, s1, mask=mask) 168 | D_1 = theta_s1 - theta_0 169 | 170 | tmp_lin = alpha_s2 / alpha_last 171 | tmp_nonlin1 = std_s2 * torch.expm1(r2 * h_i) 172 | tmp_nonlin2 = (std_s2 * r2 / r1) * (torch.expm1(r2 * h_i) / (r2 * h_i) - 1) 173 | u_2 = expand_fn(tmp_lin) * x_last - expand_fn(tmp_nonlin1) * theta_0 - expand_fn(tmp_nonlin2) * D_1 174 | 175 | theta_s2 = theta_fn(u_2, s2, mask=mask) 176 | D_2 = theta_s2 - theta_0 177 | 178 | tmp_lin = alpha_i / alpha_last 179 | tmp_nonlin1 = std_i * torch.expm1(h_i) 180 | tmp_nonlin2 = (std_i / r2) * (torch.expm1(h_i) / h_i - 1) 181 | x_i = expand_fn(tmp_lin) * x_last - expand_fn(tmp_nonlin1) * theta_0 - expand_fn(tmp_nonlin2) * D_2 182 | 183 | return x_i 184 | 185 | 186 | def get_mol_sampler_dpm1(atom_sde, bond_sde, atom_shape, bond_shape, inverse_scaler, 187 | time_step, eps=1e-3, denoise=False, device='cuda'): 188 | # arrange time schedule 189 | start_lambda = atom_sde.log_snr_np(atom_sde.T) 190 | stop_lambda = atom_sde.log_snr_np(eps) 191 | lambda_sched = np.linspace(start=start_lambda, stop=stop_lambda, num=int(time_step + 1)) 192 | time_steps = [atom_sde.lambda2t_np(lambda_ori) for lambda_ori in lambda_sched] 193 | 194 | # time_steps = np.linspace(start=atom_sde.T, stop=eps, num=int(time_step + 1)) 195 | 196 | def sampler(model, n_nodes_pmf, z=None): 197 | with torch.no_grad(): 198 | # set up dpm theta func 199 | theta_fn = get_multi_theta_fn(atom_sde, bond_sde, model, train=False, continuous=True) 200 | 201 | # initial sample 202 | assert z is None 203 | # If not represent, sample the latent code from the prior distribution of the SDE. 204 | x_atom = atom_sde.prior_sampling(atom_shape).to(device) 205 | x_bond = bond_sde.prior_sampling(bond_shape).to(device) 206 | 207 | # Sample the number of nodes, if z is None 208 | n_nodes, atom_mask, bond_mask = sample_nodes(n_nodes_pmf, atom_shape, device) 209 | x_atom = x_atom * atom_mask.unsqueeze(-1) 210 | x_bond = x_bond * bond_mask 211 | 212 | # run solver func according to time schedule 213 | t_last = time_steps[0] 214 | for t_i in time_steps[1:]: 215 | x_atom, x_bond = dpm_mol_solver_1(atom_sde, bond_sde, theta_fn, x_atom, x_bond, t_last, t_i, 216 | atom_mask, bond_mask) 217 | t_last = t_i 218 | 219 | if denoise: 220 | pass 221 | 222 | x_atom = inverse_scaler(x_atom, atom=True) * atom_mask.unsqueeze(-1) 223 | x_bond = inverse_scaler(x_bond, atom=False) * bond_mask 224 | return x_atom, x_bond, len(time_steps) - 1, n_nodes 225 | 226 | return sampler 227 | 228 | 229 | def get_mol_sampler_dpm2(atom_sde, bond_sde, atom_shape, bond_shape, inverse_scaler, 230 | time_step, eps=1e-3, denoise=False, device='cuda'): 231 | # arrange time schedule 232 | num_step = int(time_step // 2) 233 | 234 | start_lambda = atom_sde.log_snr_np(atom_sde.T) 235 | stop_lambda = atom_sde.log_snr_np(eps) 236 | lambda_sched = np.linspace(start=start_lambda, stop=stop_lambda, num=num_step+1) 237 | time_steps = [atom_sde.lambda2t_np(lambda_ori) for lambda_ori in lambda_sched] 238 | 239 | # time_steps = np.linspace(start=atom_sde.T, stop=eps, num=num_step + 1) 240 | 241 | def sampler(model, n_nodes_pmf, z=None): 242 | with torch.no_grad(): 243 | # set up dpm theta func 244 | theta_fn = get_multi_theta_fn(atom_sde, bond_sde, model, train=False, continuous=True) 245 | 246 | # initial sample 247 | assert z is None 248 | # If not represent, sample the latent code from the prior distribution of the SDE. 249 | x_atom = atom_sde.prior_sampling(atom_shape).to(device) 250 | x_bond = bond_sde.prior_sampling(bond_shape).to(device) 251 | 252 | # Sample the number of nodes, if z is None 253 | n_nodes, atom_mask, bond_mask = sample_nodes(n_nodes_pmf, atom_shape, device) 254 | x_atom = x_atom * atom_mask.unsqueeze(-1) 255 | x_bond = x_bond * bond_mask 256 | 257 | # run solver func according to time schedule 258 | t_last = time_steps[0] 259 | for t_i in time_steps[1:]: 260 | x_atom, x_bond = dpm_mol_solver_2(atom_sde, bond_sde, theta_fn, x_atom, x_bond, t_last, t_i, 261 | atom_mask, bond_mask) 262 | t_last = t_i 263 | 264 | if denoise: 265 | pass 266 | 267 | x_atom = inverse_scaler(x_atom, atom=True) * atom_mask.unsqueeze(-1) 268 | x_bond = inverse_scaler(x_bond, atom=False) * bond_mask 269 | return x_atom, x_bond, num_step * 2, n_nodes 270 | 271 | return sampler 272 | 273 | 274 | def get_mol_sampler_dpm3(atom_sde, bond_sde, atom_shape, bond_shape, inverse_scaler, 275 | time_step, eps=1e-3, denoise=False, device='cuda'): 276 | # arrange time schedule 277 | num_step = int(time_step // 3) 278 | 279 | def sampler(model, n_nodes_pmf=None, time_point=None, z=None, atom_mask=None, bond_mask=None, theta_fn=None): 280 | if time_point is None: 281 | start_lambda = atom_sde.log_snr_np(atom_sde.T) 282 | stop_lambda = atom_sde.log_snr_np(eps) 283 | lambda_sched = np.linspace(start=start_lambda, stop=stop_lambda, num=num_step + 1) 284 | time_steps = [atom_sde.lambda2t_np(lambda_ori) for lambda_ori in lambda_sched] 285 | else: 286 | start_time, stop_time = time_point 287 | start_lambda = atom_sde.log_snr_np(start_time) 288 | stop_lambda = atom_sde.log_snr_np(stop_time) 289 | lambda_sched = np.linspace(start=start_lambda, stop=stop_lambda, num=num_step + 1) 290 | time_steps = [atom_sde.lambda2t_np(lambda_ori) for lambda_ori in lambda_sched] 291 | 292 | with torch.no_grad(): 293 | # set up dpm theta func 294 | if theta_fn is None: 295 | theta_fn = get_multi_theta_fn(atom_sde, bond_sde, model, train=False, continuous=True) 296 | else: 297 | theta_fn = theta_fn 298 | 299 | # initial sample 300 | if z is None: 301 | # If not represent, sample the latent code from the prior distribution of the SDE. 302 | x_atom = atom_sde.prior_sampling(atom_shape).to(device) 303 | x_bond = bond_sde.prior_sampling(bond_shape).to(device) 304 | 305 | # Sample the number of nodes, if z is None 306 | n_nodes, atom_mask, bond_mask = sample_nodes(n_nodes_pmf, atom_shape, device) 307 | x_atom = x_atom * atom_mask.unsqueeze(-1) 308 | x_bond = x_bond * bond_mask 309 | else: 310 | # just use the concurrent prior z and node_mask, bond_mask 311 | x_atom, x_bond = z 312 | n_nodes = atom_mask.sum(-1).long() 313 | 314 | # run solver func according to time schedule 315 | t_last = time_steps[0] 316 | for t_i in time_steps[1:]: 317 | x_atom, x_bond = dpm_mol_solver_3(atom_sde, bond_sde, theta_fn, x_atom, x_bond, t_last, t_i, 318 | atom_mask, bond_mask) 319 | t_last = t_i 320 | 321 | if denoise: 322 | pass 323 | 324 | x_atom = inverse_scaler(x_atom, atom=True) * atom_mask.unsqueeze(-1) 325 | x_bond = inverse_scaler(x_bond, atom=False) * bond_mask 326 | return x_atom, x_bond, num_step * 3, n_nodes 327 | 328 | return sampler 329 | 330 | 331 | def get_mol_encoder_dpm3(atom_sde, bond_sde, time_step, eps=1e-3, device='cuda'): 332 | # arrange time schedule 333 | num_step = int(time_step // 3) 334 | 335 | def sampler(model, batch, time_point=None): 336 | if time_point is None: 337 | start_lambda = atom_sde.log_snr_np(atom_sde.T) 338 | stop_lambda = atom_sde.log_snr_np(eps) 339 | lambda_sched = np.linspace(start=start_lambda, stop=stop_lambda, num=num_step + 1) 340 | time_steps = [atom_sde.lambda2t_np(lambda_ori) for lambda_ori in lambda_sched] 341 | time_steps.reverse() 342 | else: 343 | start_time, stop_time = time_point 344 | start_lambda = atom_sde.log_snr_np(start_time) 345 | stop_lambda = atom_sde.log_snr_np(stop_time) 346 | lambda_sched = np.linspace(start=start_lambda, stop=stop_lambda, num=num_step + 1) 347 | time_steps = [atom_sde.lambda2t_np(lambda_ori) for lambda_ori in lambda_sched] 348 | 349 | with torch.no_grad(): 350 | # set up dpm theta func 351 | theta_fn = get_multi_theta_fn(atom_sde, bond_sde, model, train=False, continuous=True) 352 | 353 | # run forward deterministic diffusion process 354 | x_atom, atom_mask, x_bond, bond_mask = batch 355 | 356 | # run solver func according to time schedule 357 | t_last = time_steps[0] 358 | for t_i in time_steps[1:]: 359 | x_atom, x_bond = dpm_mol_solver_3(atom_sde, bond_sde, theta_fn, x_atom, x_bond, t_last, t_i, 360 | atom_mask, bond_mask) 361 | # pdb.set_trace() 362 | t_last = t_i 363 | 364 | return x_atom, x_bond, num_step * 3 365 | 366 | return sampler 367 | 368 | 369 | def get_mol_sampler_dpm_mix(atom_sde, bond_sde, atom_shape, bond_shape, inverse_scaler, 370 | time_step, eps=1e-3, denoise=False, device='cuda'): 371 | # arrange time schedule 372 | num_step = int(time_step // 3) 373 | 374 | start_lambda = atom_sde.log_snr_np(atom_sde.T) 375 | stop_lambda = atom_sde.log_snr_np(eps) 376 | lambda_sched = np.linspace(start=start_lambda, stop=stop_lambda, num=num_step+1) 377 | time_steps = [atom_sde.lambda2t_np(lambda_ori) for lambda_ori in lambda_sched] 378 | 379 | R = int(time_step) % 3 380 | # time_steps = np.linspace(start=atom_sde.T, stop=eps, num=num_step + 1) 381 | 382 | def sampler(model, n_nodes_pmf, z=None): 383 | with torch.no_grad(): 384 | # set up dpm theta func 385 | theta_fn = get_multi_theta_fn(atom_sde, bond_sde, model, train=False, continuous=True) 386 | 387 | # initial sample 388 | assert z is None 389 | # If not represent, sample the latent code from the prior distribution of the SDE. 390 | x_atom = atom_sde.prior_sampling(atom_shape).to(device) 391 | x_bond = bond_sde.prior_sampling(bond_shape).to(device) 392 | 393 | # Sample the number of nodes, if z is None 394 | n_nodes, atom_mask, bond_mask = sample_nodes(n_nodes_pmf, atom_shape, device) 395 | x_atom = x_atom * atom_mask.unsqueeze(-1) 396 | x_bond = x_bond * bond_mask 397 | 398 | # run solver func according to time schedule 399 | t_last = time_steps[0] 400 | 401 | if R == 0: 402 | for t_i in time_steps[1:-2]: 403 | x_atom, x_bond = dpm_mol_solver_3(atom_sde, bond_sde, theta_fn, x_atom, x_bond, t_last, t_i, 404 | atom_mask, bond_mask) 405 | t_last = t_i 406 | t_i = time_steps[-2] 407 | x_atom, x_bond = dpm_mol_solver_2(atom_sde, bond_sde, theta_fn, x_atom, x_bond, t_last, t_i, 408 | atom_mask, bond_mask) 409 | t_last = t_i 410 | t_i = time_steps[-1] 411 | x_atom, x_bond = dpm_mol_solver_1(atom_sde, bond_sde, theta_fn, x_atom, x_bond, t_last, t_i, 412 | atom_mask, bond_mask) 413 | else: 414 | for t_i in time_steps[1:-1]: 415 | x_atom, x_bond = dpm_mol_solver_3(atom_sde, bond_sde, theta_fn, x_atom, x_bond, t_last, t_i, 416 | atom_mask, bond_mask) 417 | t_last = t_i 418 | t_i = time_steps[-1] 419 | if R == 1: 420 | x_atom, x_bond = dpm_mol_solver_1(atom_sde, bond_sde, theta_fn, x_atom, x_bond, t_last, t_i, 421 | atom_mask, bond_mask) 422 | elif R == 2: 423 | x_atom, x_bond = dpm_mol_solver_2(atom_sde, bond_sde, theta_fn, x_atom, x_bond, t_last, t_i, 424 | atom_mask, bond_mask) 425 | else: 426 | raise ValueError('Step Error in mix DPM-solver.') 427 | 428 | if denoise: 429 | pass 430 | 431 | x_atom = inverse_scaler(x_atom, atom=True) * atom_mask.unsqueeze(-1) 432 | x_bond = inverse_scaler(x_bond, atom=False) * bond_mask 433 | return x_atom, x_bond, time_step, n_nodes 434 | 435 | return sampler 436 | 437 | 438 | def get_sampler_dpm3(sde, shape, inverse_scaler, time_step, eps=1e-3, denoise=False, device='cuda'): 439 | # arrange time schedule 440 | num_step = int(time_step // 3) 441 | 442 | def sampler(model, n_nodes_pmf=None, time_point=None, z=None, mask=None, theta_fn=None): 443 | if time_point is None: 444 | start_lambda = sde.log_snr_np(sde.T) 445 | stop_lambda = sde.log_snr_np(eps) 446 | lambda_sched = np.linspace(start=start_lambda, stop=stop_lambda, num=num_step + 1) 447 | time_steps = [sde.lambda2t_np(lambda_ori) for lambda_ori in lambda_sched] 448 | else: 449 | start_time, stop_time = time_point 450 | start_lambda = sde.log_snr_np(start_time) 451 | stop_lambda = sde.log_snr_np(stop_time) 452 | lambda_sched = np.linspace(start=start_lambda, stop=stop_lambda, num=num_step + 1) 453 | time_steps = [sde.lambda2t_np(lambda_ori) for lambda_ori in lambda_sched] 454 | 455 | with torch.no_grad(): 456 | # set up dpm theta func 457 | if theta_fn is None: 458 | theta_fn = get_theta_fn(sde, model, train=False, continuous=True) 459 | else: 460 | theta_fn = theta_fn 461 | 462 | # initial sample 463 | if z is None: 464 | # If not represent, sample the latent code from the prior distribution of the SDE. 465 | x = sde.prior_sampling(shape).to(device) 466 | # Sample the number of nodes, if z is None 467 | n_nodes = torch.multinomial(n_nodes_pmf, shape[0], replacement=True) 468 | mask = torch.zeros((shape[0], shape[-1]), device=device) 469 | for i in range(shape[0]): 470 | mask[i][:n_nodes[i]] = 1. 471 | mask = (mask[:, None, :] * mask[:, :, None]).unsqueeze(1) 472 | 473 | else: 474 | x = z 475 | batch_size, _, max_num_nodes, _ = mask.shape 476 | node_mask = mask[:, 0, 0, :].clone() # without checking correctness 477 | node_mask[:, 0] = 1 478 | n_nodes = node_mask.sum(-1).long() 479 | 480 | # run solver func according to time schedule 481 | t_last = time_steps[0] 482 | for t_i in time_steps[1:]: 483 | x = dpm_solver_3(sde, theta_fn, x, t_last, t_i, mask) 484 | t_last = t_i 485 | 486 | if denoise: 487 | pass 488 | 489 | x = inverse_scaler(x) * mask 490 | return x, num_step * 3, n_nodes 491 | 492 | return sampler 493 | 494 | 495 | def get_mol_dpm3_twostage(atom_sde, bond_sde, atom_shape, bond_shape, inverse_scaler, 496 | time_step, eps=1e-3, denoise=False, device='cuda'): 497 | # arrange time schedule 498 | num_step = int(time_step // 3) 499 | 500 | def sampler(model, n_nodes_pmf, time_point, guided_theta_fn): 501 | 502 | start_lambda = atom_sde.log_snr_np(atom_sde.T) 503 | stop_lambda = atom_sde.log_snr_np(eps) 504 | lambda_sched = np.linspace(start=start_lambda, stop=stop_lambda, num=num_step + 1) 505 | time_steps = [atom_sde.lambda2t_np(lambda_ori) for lambda_ori in lambda_sched] 506 | 507 | with torch.no_grad(): 508 | # set up dpm theta func 509 | theta_fn = get_multi_theta_fn(atom_sde, bond_sde, model, train=False, continuous=True) 510 | 511 | # initial sample 512 | x_atom = atom_sde.prior_sampling(atom_shape).to(device) 513 | x_bond = bond_sde.prior_sampling(bond_shape).to(device) 514 | 515 | # Sample the number of nodes, if z is None 516 | n_nodes, atom_mask, bond_mask = sample_nodes(n_nodes_pmf, atom_shape, device) 517 | x_atom = x_atom * atom_mask.unsqueeze(-1) 518 | x_bond = x_bond * bond_mask 519 | 520 | # run solver func according to time schedule 521 | t_last = time_steps[0] 522 | for t_i in time_steps[1:]: 523 | if t_last > time_point: 524 | x_atom, x_bond = dpm_mol_solver_3(atom_sde, bond_sde, theta_fn, x_atom, x_bond, t_last, t_i, 525 | atom_mask, bond_mask) 526 | else: 527 | x_atom, x_bond = dpm_mol_solver_3(atom_sde, bond_sde, guided_theta_fn, x_atom, x_bond, t_last, t_i, 528 | atom_mask, bond_mask) 529 | t_last = t_i 530 | 531 | if denoise: 532 | pass 533 | 534 | x_atom = inverse_scaler(x_atom, atom=True) * atom_mask.unsqueeze(-1) 535 | x_bond = inverse_scaler(x_bond, atom=False) * bond_mask 536 | return x_atom, x_bond, num_step * 3, n_nodes 537 | 538 | return sampler 539 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import get_nspdk_eval 2 | from .mol_metrics import get_FCDMetric 3 | 4 | 5 | -------------------------------------------------------------------------------- /evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | 4 | 5 | ### code adapted from https://github.com/idea-iitd/graphgen/blob/master/metrics/mmd.py 6 | def compute_nspdk_mmd(samples1, samples2, metric, is_hist=True, n_jobs=None): 7 | from sklearn.metrics.pairwise import pairwise_kernels 8 | from eden.graph import vectorize 9 | 10 | def kernel_compute(X, Y=None, is_hist=True, metric='linear', n_jobs=None): 11 | X = vectorize(X, complexity=4, discrete=True) 12 | if Y is not None: 13 | Y = vectorize(Y, complexity=4, discrete=True) 14 | return pairwise_kernels(X, Y, metric='linear', n_jobs=n_jobs) 15 | 16 | X = kernel_compute(samples1, is_hist=is_hist, metric=metric, n_jobs=n_jobs) 17 | Y = kernel_compute(samples2, is_hist=is_hist, metric=metric, n_jobs=n_jobs) 18 | Z = kernel_compute(samples1, Y=samples2, is_hist=is_hist, metric=metric, n_jobs=n_jobs) 19 | 20 | return np.average(X) + np.average(Y) - 2 * np.average(Z) 21 | 22 | 23 | ##### code adapted from https://github.com/idea-iitd/graphgen/blob/master/metrics/stats.py 24 | def nspdk_stats(graph_ref_list, graph_pred_list): 25 | graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0] 26 | 27 | # prev = datetime.now() 28 | mmd_dist = compute_nspdk_mmd(graph_ref_list, graph_pred_list_remove_empty, metric='nspdk', is_hist=False, n_jobs=20) 29 | # elapsed = datetime.now() - prev 30 | # if PRINT_TIME: 31 | # print('Time computing degree mmd: ', elapsed) 32 | return mmd_dist 33 | 34 | 35 | def get_nspdk_eval(config): 36 | return nspdk_stats 37 | -------------------------------------------------------------------------------- /evaluation/mol_metrics.py: -------------------------------------------------------------------------------- 1 | from fcd_torch import FCD 2 | 3 | 4 | def compute_intermediate_FCD(smiles, n_jobs=1, device='cpu', batch_size=512): 5 | """ 6 | Precomputes statistics such as mean and variance for FCD. 7 | """ 8 | kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size} 9 | stats = FCD(**kwargs_fcd).precalc(smiles) 10 | return stats 11 | 12 | 13 | def get_FCDMetric(ref_smiles, n_jobs=1, device='cpu', batch_size=512): 14 | pref = compute_intermediate_FCD(ref_smiles, n_jobs, device, batch_size) 15 | 16 | def FCDMetric(gen_smiles): 17 | kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size} 18 | return FCD(**kwargs_fcd)(gen=gen_smiles, pref=pref) 19 | 20 | return FCDMetric 21 | -------------------------------------------------------------------------------- /exp/vpsde_qm9_cdgs/checkpoints/checkpoint_200.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GRAPH-0/CDGS/2d498aae8e6c0d56f875d5cc463a8e0ac22c197c/exp/vpsde_qm9_cdgs/checkpoints/checkpoint_200.pth -------------------------------------------------------------------------------- /fpscores.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GRAPH-0/CDGS/2d498aae8e6c0d56f875d5cc463a8e0ac22c197c/fpscores.pkl.gz -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | """All functions related to loss computation and optimization.""" 2 | 3 | import torch 4 | import torch.optim as optim 5 | import numpy as np 6 | from models import utils as mutils 7 | from sde_lib import VPSDE 8 | 9 | 10 | def get_optimizer(config, params): 11 | """Return a flax optimizer object based on `config`.""" 12 | if config.optim.optimizer == 'Adam': 13 | optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps, 14 | weight_decay=config.optim.weight_decay) 15 | else: 16 | raise NotImplementedError( 17 | f'Optimizer {config.optim.optimizer} not supported yet!' 18 | ) 19 | return optimizer 20 | 21 | 22 | def optimization_manager(config): 23 | """Return an optimize_fn based on `config`.""" 24 | 25 | def optimize_fn(optimizer, params, step, lr=config.optim.lr, 26 | warmup=config.optim.warmup, 27 | grad_clip=config.optim.grad_clip): 28 | """Optimize with warmup and gradient clipping (disabled if negative).""" 29 | if warmup > 0: 30 | for g in optimizer.param_groups: 31 | g['lr'] = lr * np.minimum(step / warmup, 1.0) 32 | if grad_clip >= 0: 33 | torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip) 34 | optimizer.step() 35 | 36 | return optimize_fn 37 | 38 | 39 | def get_multi_sde_loss_fn(atom_sde, bond_sde, train, reduce_mean=True, continuous=True, eps=1e-5): 40 | """ Create a loss function for training with arbitrary node SDE and edge SDE. 41 | 42 | Args: 43 | atom_sde, bond_sde: An `sde_lib.SDE` object that represents the forward SDE. 44 | train: `True` for training loss and `False` for evaluation loss. 45 | reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions. 46 | continuous: `True` indicates that the model is defined to take continuous time steps. 47 | Otherwise, it requires ad-hoc interpolation to take continuous time steps. 48 | eps: A `float` number. The smallest time step to sample from. 49 | 50 | Returns: 51 | A loss function. 52 | """ 53 | 54 | def loss_fn(model, batch): 55 | """Compute the loss function. 56 | 57 | Args: 58 | model: A score model. 59 | batch: A mini-batch of training data, including node_features, adjacency matrices, node mask and adj mask. 60 | 61 | Returns: 62 | loss: A scalar that represents the average loss value across the mini-batch. 63 | """ 64 | 65 | atom_feat, atom_mask, bond_feat, bond_mask = batch 66 | score_fn = mutils.get_multi_score_fn(atom_sde, bond_sde, model, train=train, continuous=continuous) 67 | t = torch.rand(atom_feat.shape[0], device=atom_feat.device) * (atom_sde.T - eps) + eps 68 | 69 | # perturbing atom 70 | z_atom = torch.randn_like(atom_feat) # [B, N, C] 71 | mean_atom, std_atom = atom_sde.marginal_prob(atom_feat, t) 72 | perturbed_atom = (mean_atom + std_atom[:, None, None] * z_atom) * atom_mask[:, :, None] 73 | 74 | # perturbing bond 75 | z_bond = torch.randn_like(bond_feat) # [B, C, N, N] 76 | z_bond = torch.tril(z_bond, -1) 77 | z_bond = z_bond + z_bond.transpose(-1, -2) 78 | mean_bond, std_bond = bond_sde.marginal_prob(bond_feat, t) 79 | perturbed_bond = (mean_bond + std_bond[:, None, None, None] * z_bond) * bond_mask 80 | 81 | atom_score, bond_score = score_fn((perturbed_atom, perturbed_bond), t, atom_mask=atom_mask, bond_mask=bond_mask) 82 | 83 | # atom loss 84 | atom_mask = atom_mask[:, :, None].repeat(1, 1, atom_feat.shape[-1]) 85 | atom_mask = atom_mask.reshape(atom_mask.shape[0], -1) 86 | losses_atom = torch.square(atom_score * std_atom[:, None, None] + z_atom) 87 | losses_atom = losses_atom.reshape(losses_atom.shape[0], -1) 88 | if reduce_mean: 89 | losses_atom = torch.sum(losses_atom * atom_mask, dim=-1) / torch.sum(atom_mask, dim=-1) 90 | else: 91 | losses_atom = 0.5 * torch.sum(losses_atom * atom_mask, dim=-1) 92 | loss_atom = losses_atom.mean() 93 | 94 | # bond loss 95 | bond_mask = bond_mask.repeat(1, bond_feat.shape[1], 1, 1) 96 | bond_mask = bond_mask.reshape(bond_mask.shape[0], -1) 97 | losses_bond = torch.square(bond_score * std_bond[:, None, None, None] + z_bond) 98 | losses_bond = losses_bond.reshape(losses_bond.shape[0], -1) 99 | if reduce_mean: 100 | losses_bond = torch.sum(losses_bond * bond_mask, dim=-1) / (torch.sum(bond_mask, dim=-1) + 1e-8) 101 | else: 102 | losses_bond = 0.5 * torch.sum(losses_bond * bond_mask, dim=-1) 103 | loss_bond = losses_bond.mean() 104 | 105 | return loss_atom + loss_bond 106 | 107 | return loss_fn 108 | 109 | 110 | def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False): 111 | """Create a one-step training/evaluation function. 112 | 113 | Args: 114 | sde: An `sde_lib.SDE` object that represents the forward SDE. 115 | Tuple (`sde_lib.SDE`, `sde_lib.SDE`) that represents the forward node SDE and edge SDE. 116 | optimize_fn: An optimization function. 117 | reduce_mean: If `True`, average the loss across data dimensions. 118 | Otherwise, sum the loss across data dimensions. 119 | continuous: `True` indicates that the model is defined to take continuous time steps. 120 | likelihood_weighting: If `True`, weight the mixture of score matching losses according to 121 | https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended by score-sde. 122 | 123 | Returns: 124 | A one-step function for training or evaluation. 125 | """ 126 | 127 | if continuous: 128 | if isinstance(sde, tuple): 129 | loss_fn = get_multi_sde_loss_fn(sde[0], sde[1], train, reduce_mean=reduce_mean, continuous=True) 130 | else: 131 | loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean, 132 | continuous=True, likelihood_weighting=likelihood_weighting) 133 | else: 134 | assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training." 135 | if isinstance(sde, VPSDE): 136 | loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean) 137 | elif isinstance(sde, tuple): 138 | raise ValueError("Discrete training for multi sde is not recommended.") 139 | else: 140 | raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.") 141 | 142 | def step_fn(state, batch): 143 | """Running one step of training or evaluation. 144 | 145 | For jax version: This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and 146 | jit-compiled together for faster execution. 147 | 148 | Args: 149 | state: A dictionary of training information, containing the score model, optimizer, 150 | EMA status, and number of optimization steps. 151 | batch: A mini-batch of training/evaluation data, including min-batch adjacency matrices and mask. 152 | 153 | Returns: 154 | loss: The average loss value of this state. 155 | """ 156 | model = state['model'] 157 | if train: 158 | optimizer = state['optimizer'] 159 | optimizer.zero_grad() 160 | loss = loss_fn(model, batch) 161 | loss.backward() 162 | optimize_fn(optimizer, model.parameters(), step=state['step']) 163 | state['step'] += 1 164 | state['ema'].update(model.parameters()) 165 | else: 166 | with torch.no_grad(): 167 | ema = state['ema'] 168 | ema.store(model.parameters()) 169 | ema.copy_to(model.parameters()) 170 | loss = loss_fn(model, batch) 171 | ema.restore(model.parameters()) 172 | 173 | return loss 174 | 175 | return step_fn 176 | 177 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Training and evaluation""" 2 | 3 | import run_lib 4 | from absl import app, flags 5 | from ml_collections.config_flags import config_flags 6 | import logging 7 | import os 8 | # import tensorflow as tf 9 | 10 | FLAGS = flags.FLAGS 11 | 12 | config_flags.DEFINE_config_file( 13 | 'config', None, 'Training configuration.', lock_config=True 14 | ) 15 | flags.DEFINE_string('workdir', None, 'Work directory.') 16 | flags.DEFINE_enum('mode', None, ['train', 'eval', 'train_regressor', 'const_opt', 'cond_sample'], 17 | 'Running mode: train or eval') 18 | flags.DEFINE_string('eval_folder', 'eval', 'The folder name for storing evaluation results') 19 | flags.mark_flags_as_required(['workdir', 'config', 'mode']) 20 | 21 | 22 | def main(argv): 23 | # Set random seed 24 | run_lib.set_random_seed(FLAGS.config) 25 | 26 | if FLAGS.mode == 'train': 27 | # Create the working directory 28 | # tf.io.gfile.makedirs(FLAGS.workdir) 29 | if not os.path.exists(FLAGS.workdir): 30 | os.makedirs(FLAGS.workdir) 31 | # Set logger so that it outputs to both console and file 32 | # Make logging work for both disk and Google Cloud Storage 33 | # gfile_stream = open(os.path.join(FLAGS.workdir, 'stdout.txt'), 'w') 34 | gfile_stream = open(os.path.join(FLAGS.workdir, 'stdout.txt'), 'a') 35 | handler = logging.StreamHandler(gfile_stream) 36 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 37 | handler.setFormatter(formatter) 38 | logger = logging.getLogger() 39 | logger.addHandler(handler) 40 | logger.setLevel('INFO') 41 | # Run the training pipeline 42 | run_lib.train(FLAGS.config, FLAGS.workdir) 43 | elif FLAGS.mode == 'eval': 44 | # Run the evaluation pipeline 45 | run_lib.evaluate(FLAGS.config, FLAGS.workdir, FLAGS.eval_folder) 46 | elif FLAGS.mode == 'train_regressor': 47 | # Run the noise graph regressor 48 | run_lib.train_regressor(FLAGS.config, FLAGS.workdir) 49 | elif FLAGS.mode == 'const_opt': 50 | run_lib.const_opt(FLAGS.config, FLAGS.workdir) 51 | elif FLAGS.mode == 'cond_sample': 52 | run_lib.mol_ode_cond_sample(FLAGS.config, FLAGS.workdir) 53 | else: 54 | raise ValueError(f"Mode {FLAGS.mode} not recognized.") 55 | 56 | 57 | if __name__ == '__main__': 58 | app.run(main) 59 | 60 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GRAPH-0/CDGS/2d498aae8e6c0d56f875d5cc463a8e0ac22c197c/models/__init__.py -------------------------------------------------------------------------------- /models/cdgs.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import functools 4 | from torch_geometric.utils import dense_to_sparse 5 | 6 | from . import utils, layers 7 | from .hmpb import HybridMPBlock 8 | 9 | get_act = layers.get_act 10 | conv1x1 = layers.conv1x1 11 | 12 | 13 | @utils.register_model(name='CDGS') 14 | class CDGS(nn.Module): 15 | """ 16 | Graph Noise Prediction Model. 17 | """ 18 | 19 | def __init__(self, config): 20 | super().__init__() 21 | 22 | self.config = config 23 | self.act = act = get_act(config) 24 | 25 | # get input channels(data.num_channels), hidden channels(model.nf), number of blocks(model.num_res_blocks) 26 | self.nf = nf = config.model.nf 27 | self.num_gnn_layers = num_gnn_layers = config.model.num_gnn_layers 28 | dropout = config.model.dropout 29 | self.embedding_type = embedding_type = config.model.embedding_type.lower() 30 | self.conditional = conditional = config.model.conditional 31 | self.edge_th = config.model.edge_th 32 | self.rw_depth = rw_depth = config.model.rw_depth 33 | 34 | modules = [] 35 | # timestep/noise_level embedding; only for continuous training 36 | if embedding_type == 'positional': 37 | embed_dim = nf 38 | else: 39 | raise ValueError(f'embedding type {embedding_type} unknown.') 40 | 41 | if conditional: 42 | modules.append(nn.Linear(embed_dim, nf * 2)) 43 | modules.append(nn.Linear(nf * 2, nf)) 44 | 45 | atom_ch = config.data.atom_channels 46 | bond_ch = config.data.bond_channels 47 | temb_dim = nf 48 | 49 | # project bond features 50 | assert bond_ch == 2 51 | bond_se_ch = int(nf * 0.4) 52 | bond_type_ch = int(0.5 * (nf - bond_se_ch)) 53 | modules.append(conv1x1(1, bond_type_ch)) 54 | modules.append(conv1x1(1, bond_type_ch)) 55 | modules.append(conv1x1(rw_depth + 1, bond_se_ch)) 56 | modules.append(nn.Linear(bond_se_ch + 2 * bond_type_ch, nf)) 57 | 58 | # project atom features 59 | atom_se_ch = int(nf * 0.2) 60 | atom_type_ch = nf - 2 * atom_se_ch 61 | modules.append(nn.Linear(bond_ch, atom_se_ch)) 62 | modules.append(nn.Linear(atom_ch, atom_type_ch)) 63 | modules.append(nn.Linear(rw_depth, atom_se_ch)) 64 | modules.append(nn.Linear(atom_type_ch + 2 * atom_se_ch, nf)) 65 | self.x_ch = nf 66 | 67 | # gnn network 68 | cat_dim = (nf * 2) // num_gnn_layers 69 | for _ in range(num_gnn_layers): 70 | modules.append(HybridMPBlock(nf, config.model.graph_layer, "FullTrans_1", config.model.heads, 71 | temb_dim=temb_dim, act=act, dropout=dropout, attn_dropout=dropout)) 72 | modules.append(nn.Linear(nf, cat_dim)) 73 | modules.append(nn.Linear(nf, cat_dim)) 74 | 75 | # atom output 76 | modules.append(nn.Linear(cat_dim * num_gnn_layers + atom_type_ch, nf)) 77 | modules.append(nn.Linear(nf, nf // 2)) 78 | modules.append(nn.Linear(nf // 2, atom_ch)) 79 | 80 | # bond output 81 | modules.append(conv1x1(cat_dim * num_gnn_layers + bond_type_ch, nf)) 82 | modules.append(conv1x1(nf, nf // 2)) 83 | modules.append(conv1x1(nf // 2, 1)) 84 | 85 | # structure output 86 | modules.append(conv1x1(cat_dim * num_gnn_layers + bond_type_ch, nf)) 87 | modules.append(conv1x1(nf, nf // 2)) 88 | modules.append(conv1x1(nf // 2, 1)) 89 | 90 | self.all_modules = nn.ModuleList(modules) 91 | 92 | def forward(self, x, time_cond, *args, **kwargs): 93 | 94 | atom_feat, bond_feat = x 95 | atom_mask = kwargs['atom_mask'] 96 | bond_mask = kwargs['bond_mask'] 97 | 98 | edge_exist = bond_feat[:, 1:, :, :] 99 | edge_cate = bond_feat[:, 0:1, :, :] 100 | 101 | # timestep/noise_level embedding; only for continuous training 102 | modules = self.all_modules 103 | m_idx = 0 104 | 105 | if self.embedding_type == 'positional': 106 | # Sinusoidal positional embeddings. 107 | timesteps = time_cond 108 | temb = layers.get_timestep_embedding(timesteps, self.nf) 109 | 110 | else: 111 | raise ValueError(f'embedding type {self.embedding_type} unknown.') 112 | 113 | if self.conditional: 114 | temb = modules[m_idx](temb) 115 | m_idx += 1 116 | temb = modules[m_idx](self.act(temb)) 117 | m_idx += 1 118 | else: 119 | temb = None 120 | 121 | if not self.config.data.centered: 122 | # rescale the input data to [-1, 1] 123 | atom_feat = atom_feat * 2. - 1. 124 | bond_feat = bond_feat * 2. - 1. 125 | 126 | # discretize dense adj 127 | with torch.no_grad(): 128 | adj = edge_exist.squeeze(1).clone() # [B, N, N] 129 | adj[adj >= 0.] = 1. 130 | adj[adj < 0.] = 0. 131 | adj = adj * bond_mask.squeeze(1) 132 | 133 | # extract RWSE and Shortest-Path Distance 134 | rw_landing, spd_onehot = utils.get_rw_feat(self.rw_depth, adj) 135 | 136 | # construct edge feature [B, N, N, F] 137 | adj_mask = bond_mask.permute(0, 2, 3, 1) 138 | dense_cate = modules[m_idx](edge_cate).permute(0, 2, 3, 1) * adj_mask 139 | m_idx += 1 140 | dense_exist = modules[m_idx](edge_exist).permute(0, 2, 3, 1) * adj_mask 141 | m_idx += 1 142 | dense_spd = modules[m_idx](spd_onehot).permute(0, 2, 3, 1) * adj_mask 143 | m_idx += 1 144 | dense_edge = modules[m_idx](torch.cat([dense_cate, dense_exist, dense_spd], dim=-1)) * adj_mask 145 | m_idx += 1 146 | 147 | # Use Degree as atom feature 148 | atom_degree = torch.sum(bond_feat, dim=-1).permute(0, 2, 1) # [B, N, C] 149 | atom_degree = modules[m_idx](atom_degree) # [B, N, nf] 150 | m_idx += 1 151 | atom_cate = modules[m_idx](atom_feat) 152 | m_idx += 1 153 | x_rwl = modules[m_idx](rw_landing) 154 | m_idx += 1 155 | x_atom = modules[m_idx](torch.cat([atom_degree, atom_cate, x_rwl], dim=-1)) 156 | m_idx += 1 157 | h_atom = x_atom.reshape(-1, self.x_ch) 158 | # Dense to sparse node [BxN, -1] 159 | 160 | dense_index = adj.nonzero(as_tuple=True) 161 | edge_index, _ = dense_to_sparse(adj) 162 | h_dense_edge = dense_edge 163 | 164 | # Run GNN layers 165 | atom_hids = [] 166 | bond_hids = [] 167 | for _ in range(self.num_gnn_layers): 168 | h_atom, h_dense_edge = modules[m_idx](h_atom, edge_index, h_dense_edge, dense_index, 169 | atom_mask, adj_mask, temb) 170 | m_idx += 1 171 | atom_hids.append(modules[m_idx](h_atom.reshape(x_atom.shape))) 172 | m_idx += 1 173 | bond_hids.append(modules[m_idx](h_dense_edge)) 174 | m_idx += 1 175 | 176 | atom_hids = torch.cat(atom_hids, dim=-1) 177 | bond_hids = torch.cat(bond_hids, dim=-1) 178 | 179 | # Output 180 | atom_score = self.act(modules[m_idx](torch.cat([atom_cate, atom_hids], dim=-1))) \ 181 | * atom_mask.unsqueeze(-1) 182 | m_idx += 1 183 | atom_score = self.act(modules[m_idx](atom_score)) 184 | m_idx += 1 185 | atom_score = modules[m_idx](atom_score) 186 | m_idx += 1 187 | 188 | bond_score = self.act(modules[m_idx](torch.cat([dense_cate, bond_hids], dim=-1).permute(0, 3, 1, 2))) \ 189 | * bond_mask 190 | m_idx += 1 191 | bond_score = self.act(modules[m_idx](bond_score)) 192 | m_idx += 1 193 | bond_score = modules[m_idx](bond_score) 194 | m_idx += 1 195 | 196 | exist_score = self.act(modules[m_idx](torch.cat([dense_exist, bond_hids], dim=-1).permute(0, 3, 1, 2))) \ 197 | * bond_mask 198 | m_idx += 1 199 | exist_score = self.act(modules[m_idx](exist_score)) 200 | m_idx += 1 201 | exist_score = modules[m_idx](exist_score) 202 | m_idx += 1 203 | 204 | # make score symmetric 205 | bond_score = torch.cat([bond_score, exist_score], dim=1) 206 | bond_score = (bond_score + bond_score.transpose(2, 3)) / 2. 207 | 208 | assert m_idx == len(modules) 209 | 210 | atom_score = atom_score * atom_mask.unsqueeze(-1) 211 | bond_score = bond_score * bond_mask 212 | 213 | return atom_score, bond_score 214 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ExponentialMovingAverage: 5 | """ 6 | Maintains (exponential) moving average of a set of parameters. 7 | """ 8 | 9 | def __init__(self, parameters, decay, use_num_updates=True): 10 | """ 11 | Args: 12 | parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`. 13 | decay: The exponential decay. 14 | use_num_updates: Whether to use number of updates when computing averages. 15 | """ 16 | if decay < 0.0 or decay > 1.0: 17 | raise ValueError('Decay must be between 0 and 1') 18 | self.decay = decay 19 | self.num_updates = 0 if use_num_updates else None 20 | self.shadow_params = [p.clone().detach() 21 | for p in parameters if p.requires_grad] 22 | self.collected_params = [] 23 | 24 | def update(self, parameters): 25 | """ 26 | Update currently maintained parameters. 27 | 28 | Call this every time the parameters are updated, such as the result of the `optimizer.step()` call. 29 | 30 | Args: 31 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to 32 | initialize this object. 33 | """ 34 | decay = self.decay 35 | if self.num_updates is not None: 36 | self.num_updates += 1 37 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 38 | one_minus_decay = 1.0 - decay 39 | with torch.no_grad(): 40 | parameters = [p for p in parameters if p.requires_grad] 41 | for s_param, param in zip(self.shadow_params, parameters): 42 | s_param.sub_(one_minus_decay * (s_param - param)) 43 | 44 | def copy_to(self, parameters): 45 | """ 46 | Copy current parameters into given collection of parameters. 47 | 48 | Args: 49 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 50 | updated with the stored moving averages. 51 | """ 52 | parameters = [p for p in parameters if p.requires_grad] 53 | for s_param, param in zip(self.shadow_params, parameters): 54 | if param.requires_grad: 55 | param.data.copy_(s_param.data) 56 | 57 | def store(self, parameters): 58 | """ 59 | Save the current parameters for restoring later. 60 | 61 | Args: 62 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. 63 | """ 64 | self.collected_params = [param.clone() for param in parameters] 65 | 66 | def restore(self, parameters): 67 | """ 68 | Restore the parameters stored with the `store` method. 69 | Useful to validate the model with EMA parameters without affecting the original optimization process. 70 | Store the parameters before the `copy_to` method. 71 | After validation (or model saving), use this to restore the former parameters. 72 | 73 | Args: 74 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. 75 | """ 76 | for c_param, param in zip(self.collected_params, parameters): 77 | param.data.copy_(c_param.data) 78 | 79 | def state_dict(self): 80 | return dict(decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params) 81 | 82 | def load_state_dict(self, state_dict): 83 | self.decay = state_dict['decay'] 84 | self.num_updates = state_dict['num_updates'] 85 | self.shadow_params = state_dict['shadow_params'] 86 | -------------------------------------------------------------------------------- /models/hmpb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import math 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch_geometric.nn as pygnn 7 | 8 | from torch_geometric.nn import Linear as Linear_pyg 9 | from torch_geometric.utils import dense_to_sparse 10 | from .transformer_layers import EdgeGateTransLayer 11 | 12 | 13 | class HybridMPBlock(nn.Module): 14 | """Local MPNN + fully-connected attention-based message passing layer. Inspired by GPSLayer.""" 15 | 16 | def __init__(self, dim_h, 17 | local_gnn_type, global_model_type, num_heads, 18 | temb_dim=None, act=None, dropout=0.0, attn_dropout=0.0): 19 | super().__init__() 20 | 21 | self.dim_h = dim_h 22 | self.num_heads = num_heads 23 | self.attn_dropout = attn_dropout 24 | self.local_gnn_type = local_gnn_type 25 | self.global_model_type = global_model_type 26 | if act is None: 27 | self.act = nn.ReLU() 28 | else: 29 | self.act = act 30 | 31 | # time embedding 32 | if temb_dim is not None: 33 | self.t_node = nn.Linear(temb_dim, dim_h) 34 | self.t_edge = nn.Linear(temb_dim, dim_h) 35 | 36 | # local message-passing model 37 | if local_gnn_type == 'None': 38 | self.local_model = None 39 | elif local_gnn_type == 'GINE': 40 | gin_nn = nn.Sequential(Linear_pyg(dim_h, dim_h), nn.ReLU(), Linear_pyg(dim_h, dim_h)) 41 | self.local_model = pygnn.GINEConv(gin_nn) 42 | elif local_gnn_type == 'GAT': 43 | self.local_model = pygnn.GATConv(in_channels=dim_h, 44 | out_channels=dim_h // num_heads, 45 | heads=num_heads, 46 | edge_dim=dim_h) 47 | elif local_gnn_type == 'LocalTrans_1': 48 | self.local_model = EdgeGateTransLayer(dim_h, dim_h // num_heads, num_heads, edge_dim=dim_h) 49 | else: 50 | raise ValueError(f"Unsupported local GNN model: {local_gnn_type}") 51 | 52 | # Global attention transformer-style model. 53 | if global_model_type == 'None': 54 | self.self_attn = None 55 | elif global_model_type == 'FullTrans_1': 56 | self.self_attn = EdgeGateTransLayer(dim_h, dim_h // num_heads, num_heads, edge_dim=dim_h) 57 | else: 58 | raise ValueError(f"Unsupported global x-former model: " 59 | f"{global_model_type}") 60 | 61 | # Normalization for MPNN and Self-Attention representations. 62 | self.norm1_local = nn.GroupNorm(num_groups=min(dim_h // 4, 32), num_channels=dim_h, eps=1e-6) 63 | self.norm1_attn = nn.GroupNorm(num_groups=min(dim_h // 4, 32), num_channels=dim_h, eps=1e-6) 64 | 65 | self.dropout = nn.Dropout(dropout) 66 | 67 | # Feed Forward block -> node. 68 | self.ff_linear1 = nn.Linear(dim_h, dim_h * 2) 69 | self.ff_linear2 = nn.Linear(dim_h * 2, dim_h) 70 | self.norm2_node = nn.GroupNorm(num_groups=min(dim_h // 4, 32), num_channels=dim_h, eps=1e-6) 71 | 72 | # Feed Forward block -> edge. 73 | self.ff_linear3 = nn.Linear(dim_h, dim_h * 2) 74 | self.ff_linear4 = nn.Linear(dim_h * 2, dim_h) 75 | self.norm2_edge = nn.GroupNorm(num_groups=min(dim_h // 4, 32), num_channels=dim_h, eps=1e-6) 76 | 77 | def _ff_block_node(self, x): 78 | """Feed Forward block. 79 | """ 80 | x = self.dropout(self.act(self.ff_linear1(x))) 81 | return self.dropout(self.ff_linear2(x)) 82 | 83 | def _ff_block_edge(self, x): 84 | """Feed Forward block. 85 | """ 86 | x = self.dropout(self.act(self.ff_linear3(x))) 87 | return self.dropout(self.ff_linear4(x)) 88 | 89 | def forward(self, x, edge_index, dense_edge, dense_index, node_mask, adj_mask, temb=None): 90 | """ 91 | Args: 92 | x: node feature [B*N, dim_h] 93 | edge_index: [2, edge_length] 94 | dense_edge: edge features in dense form [B, N, N, dim_h] 95 | dense_index: indices for valid edges [B, N, N, 1] 96 | node_mask: [B, N] 97 | adj_mask: [B, N, N, 1] 98 | temb: time conditional embedding [B, temb_dim] 99 | Returns: 100 | h 101 | edge 102 | """ 103 | 104 | B, N, _, _ = dense_edge.shape 105 | h_in1 = x 106 | h_in2 = dense_edge 107 | 108 | if temb is not None: 109 | h_edge = (dense_edge + self.t_edge(self.act(temb))[:, None, None, :]) * adj_mask 110 | temb = temb.unsqueeze(1).repeat(1, N, 1) 111 | temb = temb.reshape(-1, temb.size(-1)) 112 | h = (x + self.t_node(self.act(temb))) * node_mask.reshape(-1, 1) 113 | 114 | h_out_list = [] 115 | # Local MPNN with edge attributes 116 | if self.local_model is not None: 117 | edge_attr = h_edge[dense_index] 118 | h_local = self.local_model(h, edge_index, edge_attr) * node_mask.reshape(-1, 1) 119 | h_local = h_in1 + self.dropout(h_local) 120 | h_local = self.norm1_local(h_local) 121 | h_out_list.append(h_local) 122 | 123 | # Multi-head attention 124 | if self.self_attn is not None: 125 | if 'FullTrans' in self.global_model_type: 126 | # extract full connect edge_index and edge_attr 127 | dense_index_full = adj_mask.squeeze(-1).nonzero(as_tuple=True) 128 | edge_index_full, _ = dense_to_sparse(adj_mask.squeeze(-1)) 129 | edge_attr_full = h_edge[dense_index_full] 130 | h_attn = self.self_attn(h, edge_index_full, edge_attr_full) 131 | else: 132 | raise ValueError(f"Unsupported global transformer layer") 133 | h_attn = h_in1 + self.dropout(h_attn) 134 | h_attn = self.norm1_attn(h_attn) 135 | h_out_list.append(h_attn) 136 | 137 | # Combine local and global outputs 138 | assert len(h_out_list) > 0 139 | h = sum(h_out_list) * node_mask.reshape(-1, 1) 140 | h_dense = h.reshape(B, N, -1) 141 | h_edge = h_dense.unsqueeze(1) + h_dense.unsqueeze(2) 142 | 143 | # Feed Forward block 144 | h = h + self._ff_block_node(h) 145 | h = self.norm2_node(h) * node_mask.reshape(-1, 1) 146 | 147 | h_edge = h_in2 + self._ff_block_edge(h_edge) 148 | h_edge = self.norm2_edge(h_edge.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) * adj_mask 149 | 150 | return h, h_edge 151 | 152 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | """Common layers for defining score networks.""" 2 | 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import math 8 | import torch_geometric.nn as graph_nn 9 | 10 | 11 | def get_act(config): 12 | """Get actiuvation functions from the config file.""" 13 | 14 | if config.model.nonlinearity.lower() == 'elu': 15 | return nn.ELU() 16 | elif config.model.nonlinearity.lower() == 'relu': 17 | return nn.ReLU() 18 | elif config.model.nonlinearity.lower() == 'lrelu': 19 | return nn.LeakyReLU(negative_slope=0.2) 20 | elif config.model.nonlinearity.lower() == 'swish': 21 | return nn.SiLU() 22 | elif config.model.nonlinearity.lower() == 'tanh': 23 | return nn.Tanh() 24 | else: 25 | raise NotImplementedError('activation function does not exist!') 26 | 27 | 28 | def conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, padding=0): 29 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation, 30 | padding=padding) 31 | return conv 32 | 33 | 34 | # from DDPM 35 | def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): 36 | assert len(timesteps.shape) == 1 37 | half_dim = embedding_dim // 2 38 | # magic number 10000 is from transformers 39 | emb = math.log(max_positions) / (half_dim - 1) 40 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) 41 | emb = timesteps.float()[:, None] * emb[None, :] 42 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 43 | if embedding_dim % 2 == 1: # zero pad 44 | emb = F.pad(emb, (0, 1), mode='constant') 45 | assert emb.shape == (timesteps.shape[0], embedding_dim) 46 | return emb 47 | -------------------------------------------------------------------------------- /models/transformer_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Union, Tuple, Optional 3 | from torch_geometric.typing import PairTensor, Adj, OptTensor 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor 8 | import torch.nn.functional as F 9 | from torch.nn import Linear 10 | from torch_scatter import scatter 11 | from torch_geometric.nn.conv import MessagePassing 12 | from torch_geometric.utils import softmax 13 | 14 | 15 | class EdgeGateTransLayer(MessagePassing): 16 | """The version of edge feature gating.""" 17 | 18 | _alpha: OptTensor 19 | 20 | def __init__(self, x_channels: int, out_channels: int, 21 | heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None, 22 | bias: bool = True, **kwargs): 23 | kwargs.setdefault('aggr', 'add') 24 | super(EdgeGateTransLayer, self).__init__(node_dim=0, **kwargs) 25 | 26 | self.x_channels = x_channels 27 | self.in_channels = in_channels = x_channels 28 | self.out_channels = out_channels 29 | self.heads = heads 30 | self.dropout = dropout 31 | self.edge_dim = edge_dim 32 | 33 | self.lin_key = Linear(in_channels, heads * out_channels, bias=bias) 34 | self.lin_query = Linear(in_channels, heads * out_channels, bias=bias) 35 | self.lin_value = Linear(in_channels, heads * out_channels, bias=bias) 36 | 37 | self.lin_edge0 = Linear(edge_dim, heads * out_channels, bias=False) 38 | self.lin_edge1 = Linear(edge_dim, heads * out_channels, bias=False) 39 | 40 | self.reset_parameters() 41 | 42 | def reset_parameters(self): 43 | self.lin_key.reset_parameters() 44 | self.lin_query.reset_parameters() 45 | self.lin_value.reset_parameters() 46 | self.lin_edge0.reset_parameters() 47 | self.lin_edge1.reset_parameters() 48 | 49 | def forward(self, x: OptTensor, 50 | edge_index: Adj, 51 | edge_attr: OptTensor = None 52 | ) -> Tensor: 53 | """""" 54 | 55 | H, C = self.heads, self.out_channels 56 | 57 | x_feat = x 58 | query = self.lin_query(x_feat).view(-1, H, C) 59 | key = self.lin_key(x_feat).view(-1, H, C) 60 | value = self.lin_value(x_feat).view(-1, H, C) 61 | 62 | # propagate_type: (x: PairTensor, edge_attr: OptTensor) 63 | out_x = self.propagate(edge_index, query=query, key=key, value=value, edge_attr=edge_attr, size=None) 64 | 65 | out_x = out_x.view(-1, self.heads * self.out_channels) 66 | 67 | return out_x 68 | 69 | def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor, 70 | edge_attr: OptTensor, 71 | index: Tensor, ptr: OptTensor, 72 | size_i: Optional[int]) -> Tuple[Tensor, Tensor]: 73 | 74 | edge_attn = self.lin_edge0(edge_attr).view(-1, self.heads, self.out_channels) 75 | edge_attn = torch.tanh(edge_attn) 76 | alpha = (query_i * key_j * edge_attn).sum(dim=-1) / math.sqrt(self.out_channels) 77 | 78 | alpha = softmax(alpha, index, ptr, size_i) 79 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 80 | 81 | # node feature message 82 | msg = value_j 83 | msg = msg * torch.tanh(self.lin_edge1(edge_attr).view(-1, self.heads, self.out_channels)) 84 | msg = msg * alpha.view(-1, self.heads, 1) 85 | 86 | return msg 87 | 88 | def __repr__(self): 89 | return '{}({}, {}, heads={})'.format(self.__class__.__name__, 90 | self.in_channels, 91 | self.out_channels, self.heads) 92 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | """All functions and modules related to model definition. 2 | """ 3 | 4 | import torch 5 | import sde_lib 6 | import numpy as np 7 | from torch_scatter import scatter_min, scatter_max, scatter_mean, scatter_std 8 | 9 | 10 | _MODELS = {} 11 | 12 | 13 | def register_model(cls=None, *, name=None): 14 | """A decorator for registering model classes.""" 15 | 16 | def _register(cls): 17 | if name is None: 18 | local_name = cls.__name__ 19 | else: 20 | local_name = name 21 | if local_name in _MODELS: 22 | raise ValueError(f'Already registered model with name: {local_name}') 23 | _MODELS[local_name] = cls 24 | return cls 25 | 26 | if cls is None: 27 | return _register 28 | else: 29 | return _register(cls) 30 | 31 | 32 | def get_model(name): 33 | return _MODELS[name] 34 | 35 | 36 | def create_model(config): 37 | """Create the score model.""" 38 | model_name = config.model.name 39 | score_model = get_model(model_name)(config) 40 | score_model = score_model.to(config.device) 41 | score_model = torch.nn.DataParallel(score_model) 42 | return score_model 43 | 44 | 45 | def get_model_fn(model, train=False): 46 | """Create a function to give the output of the score-based model. 47 | 48 | Args: 49 | model: The score model. 50 | train: `True` for training and `False` for evaluation. 51 | 52 | Returns: 53 | A model function. 54 | """ 55 | 56 | def model_fn(x, labels, *args, **kwargs): 57 | """Compute the output of the score-based model. 58 | 59 | Args: 60 | x: A mini-batch of input data (Adjacency matrices). 61 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 62 | for different models. 63 | mask: Mask for adjacency matrices. 64 | 65 | Returns: 66 | A tuple of (model output, new mutable states) 67 | """ 68 | if not train: 69 | model.eval() 70 | return model(x, labels, *args, **kwargs) 71 | else: 72 | model.train() 73 | return model(x, labels, *args, **kwargs) 74 | 75 | return model_fn 76 | 77 | 78 | def get_multi_score_fn(atom_sde, bond_sde, model, train=False, continuous=False): 79 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 80 | 81 | Args: 82 | atom_sde: An `sde_lib.SDE` object that represents the forward SDE. 83 | bond_sde: An `sde_lib.SDE` object that represents the forward SDE. 84 | model: A score model. 85 | train: `True` for training and `False` for evaluation. 86 | continuous: If `True`, the score-based model is expected to directly take continuous time steps. 87 | 88 | Returns: 89 | A score function. 90 | """ 91 | model_fn = get_model_fn(model, train=train) 92 | 93 | if isinstance(atom_sde, sde_lib.VPSDE) or isinstance(atom_sde, sde_lib.subVPSDE): 94 | def score_fn(x, t, *args, **kwargs): 95 | # Scale neural network output by standard deviation and flip sign 96 | if continuous or isinstance(sde, sde_lib.subVPSDE): 97 | # For VP-trained models, t=0 corresponds to the lowest noise level 98 | # The maximum value of time embedding is assumed to 999 for continuously-trained models. 99 | labels = t * 999 100 | atom_score, bond_score = model_fn(x, labels, *args, **kwargs) 101 | atom_std = atom_sde.marginal_prob(torch.zeros_like(x[0]), t)[1] 102 | bond_std = bond_sde.marginal_prob(torch.zeros_like(x[1]), t)[1] 103 | else: 104 | # For VP-trained models, t=0 corresponds to the lowest noise level 105 | labels = t * (sde.N - 1) 106 | atom_score, bond_score = model_fn(x, labels, *args, **kwargs) 107 | atom_std = atom_sde.sqrt_1m_alpha_cumprod.to(labels.device)[labels.long()] 108 | bond_std = bond_sde.sqrt_1m_alpha_cumprod.to(labels.device)[labels.long()] 109 | 110 | atom_score = -atom_score / atom_std[:, None, None] 111 | bond_score = -bond_score / bond_std[:, None, None, None] 112 | return atom_score, bond_score 113 | 114 | else: 115 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 116 | 117 | return score_fn 118 | 119 | 120 | def get_multi_theta_fn(atom_sde, bond_sde, model, train=False, continuous=False): 121 | """Wraps `theta_fn` so that the model output corresponds to a real time-dependent score function. 122 | 123 | Args: 124 | atom_sde: An `sde_lib.SDE` object that represents the forward SDE. 125 | bond_sde: An `sde_lib.SDE` object that represents the forward SDE. 126 | model: A score model. 127 | train: `True` for training and `False` for evaluation. 128 | continuous: If `True`, the score-based model is expected to directly take continuous time steps. 129 | 130 | Returns: 131 | A theta function. 132 | """ 133 | model_fn = get_model_fn(model, train=train) 134 | 135 | if isinstance(atom_sde, sde_lib.VPSDE) or isinstance(atom_sde, sde_lib.subVPSDE): 136 | def theta_fn(x, t, *args, **kwargs): 137 | # Scale neural network output by standard deviation and flip sign 138 | if continuous or isinstance(sde, sde_lib.subVPSDE): 139 | # For VP-trained models, t=0 corresponds to the lowest noise level 140 | # The maximum value of time embedding is assumed to 999 for continuously-trained models. 141 | labels = t * 999 142 | atom_theta, bond_theta = model_fn(x, labels, *args, **kwargs) 143 | else: 144 | raise NotImplementedError() 145 | 146 | return atom_theta, bond_theta 147 | 148 | else: 149 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 150 | 151 | return theta_fn 152 | 153 | 154 | def get_mol_regressor_grad_fn(atom_sde, bond_sde, regressor_fn, norm=False): 155 | """Get the noise graph regressor gradient fn.""" 156 | N = atom_sde.N - 1 157 | 158 | def mol_regressor_grad_fn(x, t, only_grad=False, std=False, *args, **kwargs): 159 | label = t * N 160 | atom_std = atom_sde.marginal_prob(torch.zeros_like(x[0]), t)[1] 161 | bond_std = bond_sde.marginal_prob(torch.zeros_like(x[1]), t)[1] 162 | 163 | with torch.enable_grad(): 164 | atom_in, bond_in = x 165 | atom_in = atom_in.detach().requires_grad_(True) 166 | bond_in = bond_in.detach().requires_grad_(True) 167 | pred = regressor_fn((atom_in, bond_in), label, *args, **kwargs) 168 | try: 169 | atom_grad, bond_grad = torch.autograd.grad(pred.sum(), [atom_in, bond_in]) 170 | except: 171 | print('WARNING: grad error!') 172 | atom_grad = torch.zeros_like(atom_in) 173 | bond_grad = torch.zeros_like(bond_in) 174 | 175 | # multiply mask, std 176 | atom_grad = atom_grad * kwargs['atom_mask'].unsqueeze(-1) 177 | bond_grad = bond_grad * kwargs['bond_mask'] 178 | 179 | if only_grad: 180 | if std: 181 | return atom_grad, bond_grad, atom_std, bond_std 182 | return atom_grad, bond_grad 183 | 184 | atom_norm = torch.norm(atom_grad.reshape(atom_grad.shape[0], -1), dim=-1) 185 | bond_norm = torch.norm(bond_grad.reshape(bond_grad.shape[0], -1), dim=-1) 186 | 187 | if norm: 188 | atom_grad = atom_grad / (atom_norm + 1e-8)[:, None, None] 189 | bond_grad = bond_grad / (bond_norm + 1e-8)[:, None, None, None] 190 | 191 | atom_grad = - atom_std[:, None, None] * atom_grad 192 | bond_grad = - bond_std[:, None, None, None] * bond_grad 193 | return atom_grad, bond_grad 194 | 195 | return mol_regressor_grad_fn 196 | 197 | 198 | def get_guided_theta_fn(theta_fn, regressor_grad_fn, guidance_scale=1.0): 199 | """theta function with gradient guidance.""" 200 | def guided_theta_fn(x, t, *args, **kwargs): 201 | atom_theta, bond_theta = theta_fn(x, t, *args, **kwargs) 202 | atom_grad, bond_grad = regressor_grad_fn(x, t, *args, **kwargs) 203 | 204 | # atom_grad, bond_grad, atom_norm, bond_norm, atom_std, bond_std = regressor_grad_fn(x, t, True, *args, **kwargs) 205 | # atom_score = - atom_theta / atom_std[:, None, None] 206 | # atom_score_norm = torch.norm(atom_score.reshape(atom_score.shape[0], -1), dim=-1) 207 | # bond_score = - bond_theta / bond_std[:, None, None, None] 208 | # bond_score_norm = torch.norm(bond_score.reshape(bond_score.shape[0], -1), dim=-1) 209 | # atom_grad = - atom_std[:, None, None] * atom_grad * atom_score_norm[:, None, None] / (atom_norm + 1e-8)[:, None, None] 210 | # bond_grad = - bond_std[:, None, None, None] * bond_grad * bond_score_norm[:, None, None, None] / (bond_norm + 1e-8)[:, None, None, None] 211 | 212 | return atom_theta + atom_grad * guidance_scale, bond_theta + bond_grad * guidance_scale 213 | 214 | return guided_theta_fn 215 | 216 | 217 | def get_theta_fn(sde, model, train=False, continuous=False): 218 | model_fn = get_model_fn(model, train=train) 219 | 220 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 221 | def theta_fn(x, t, *args, **kwargs): 222 | # Scale neural network output by standard deviation and flip sign 223 | if continuous or isinstance(sde, sde_lib.subVPSDE): 224 | # For VP-trained models, t=0 corresponds to the lowest noise level 225 | # The maximum value of time embedding is assumed to 999 for continuously-trained models. 226 | labels = t * 999 227 | theta = model_fn(x, labels, *args, **kwargs) 228 | else: 229 | raise NotImplementedError() 230 | return theta 231 | else: 232 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 233 | 234 | return theta_fn 235 | 236 | @torch.no_grad() 237 | def get_rw_feat(k_step, dense_adj): 238 | """Compute k_step Random Walk for given dense adjacency matrix.""" 239 | 240 | rw_list = [] 241 | deg = dense_adj.sum(-1, keepdims=True) 242 | AD = dense_adj / (deg + 1e-8) 243 | rw_list.append(AD) 244 | 245 | for _ in range(k_step): 246 | rw = torch.bmm(rw_list[-1], AD) 247 | rw_list.append(rw) 248 | rw_map = torch.stack(rw_list[1:], dim=1) # [B, k_step, N, N] 249 | 250 | rw_landing = torch.diagonal(rw_map, offset=0, dim1=2, dim2=3) # [B, k_step, N] 251 | rw_landing = rw_landing.permute(0, 2, 1) # [B, N, rw_depth] 252 | 253 | # get the shortest path distance indices 254 | tmp_rw = rw_map.sort(dim=1)[0] 255 | spd_ind = (tmp_rw <= 0).sum(dim=1) # [B, N, N] 256 | 257 | spd_onehot = torch.nn.functional.one_hot(spd_ind, num_classes=k_step+1).to(torch.float) 258 | spd_onehot = spd_onehot.permute(0, 3, 1, 2) # [B, kstep, N, N] 259 | 260 | return rw_landing, spd_onehot 261 | -------------------------------------------------------------------------------- /mol_config.csv: -------------------------------------------------------------------------------- 1 | ,zinc250k,zinc_800_graphaf,zinc_800_jt,zinc250k_property,qm9_property,moses,qm9 2 | smile,smiles,smiles,smiles,smile,smile,smiles,SMILES1 3 | prop_list,['qed'],['penalized_logp'],['penalized_logp'],"['qed', 'penalized_logp']","['qed', 'penalized_logp']",[],[] 4 | url,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc250k.csv,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc_800_graphaf.csv,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc_800_jt.csv,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc250k_property.csv,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/qm9_property.csv,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/moses.csv,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/qm9.csv 5 | num_max_node,38,38,38,38,9,38,9 6 | atom_list,"[6, 7, 8, 9, 15, 16, 17, 35, 53]","[6, 7, 8, 9, 15, 16, 17, 35, 53]","[6, 7, 8, 9, 15, 16, 17, 35, 53]","[6, 7, 8, 9, 15, 16, 17, 35, 53]","[6, 7, 8, 9]","[6, 7, 8, 9, 15, 16, 17, 35, 53]","[6, 7, 8, 9]" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ml-collections 2 | absl-py 3 | pandas 4 | matplotlib 5 | tensorboard 6 | molsets 7 | fcd-torch 8 | rdkit==2022.3.5 -------------------------------------------------------------------------------- /run_lib.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import logging 6 | import time 7 | from absl import flags 8 | from torch.utils import tensorboard 9 | from torch_geometric.loader import DataLoader, DenseDataLoader 10 | import pickle 11 | from rdkit import RDLogger, Chem 12 | 13 | from models import cdgs 14 | import losses 15 | import sampling 16 | from models import utils as mutils 17 | from models.ema import ExponentialMovingAverage 18 | import datasets 19 | from evaluation import get_FCDMetric, get_nspdk_eval 20 | import sde_lib 21 | import visualize 22 | from utils import * 23 | from moses.metrics.metrics import get_all_metrics 24 | 25 | FLAGS = flags.FLAGS 26 | 27 | 28 | def set_random_seed(config): 29 | seed = config.seed 30 | os.environ['PYTHONHASHSEED'] = str(seed) 31 | 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | 36 | np.random.seed(seed) 37 | random.seed(seed) 38 | 39 | torch.backends.cudnn.deterministic = True 40 | torch.backends.cudnn.benchmark = False 41 | 42 | 43 | def mol_sde_train(config, workdir): 44 | """Runs the training pipeline of molecule generation. 45 | 46 | Args: 47 | config: Configuration to use. 48 | workdir: Working directory for checkpoints and TF summaries. 49 | If this contains checkpoint training will be resumed from the latest checkpoint. 50 | """ 51 | 52 | ### Ignore info output by RDKit 53 | RDLogger.DisableLog('rdApp.error') 54 | RDLogger.DisableLog('rdApp.warning') 55 | 56 | # Create directories for experimental logs 57 | sample_dir = os.path.join(workdir, "samples") 58 | if not os.path.exists(sample_dir): 59 | os.makedirs(sample_dir) 60 | 61 | tb_dir = os.path.join(workdir, "tensorboard") 62 | if not os.path.exists(tb_dir): 63 | os.makedirs(tb_dir) 64 | writer = tensorboard.SummaryWriter(tb_dir) 65 | 66 | # Initialize model. 67 | score_model = mutils.create_model(config) 68 | ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) 69 | optimizer = losses.get_optimizer(config, score_model.parameters()) 70 | state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) 71 | 72 | # Create checkpoints directly 73 | checkpoint_dir = os.path.join(workdir, "checkpoints") 74 | # Intermediate checkpoints to resume training 75 | checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth") 76 | if not os.path.exists(checkpoint_dir): 77 | os.makedirs(checkpoint_dir) 78 | if not os.path.exists(os.path.dirname(checkpoint_meta_dir)): 79 | os.makedirs(os.path.dirname(checkpoint_meta_dir)) 80 | # Resume training when intermediate checkpoints are detected 81 | state = restore_checkpoint(checkpoint_meta_dir, state, config.device) 82 | initial_step = int(state['step']) 83 | 84 | # Build dataloader and iterators 85 | train_ds, eval_ds, test_ds, n_node_pmf = datasets.get_dataset(config) 86 | 87 | train_loader = DenseDataLoader(train_ds, batch_size=config.training.batch_size, shuffle=True) 88 | eval_loader = DenseDataLoader(eval_ds, batch_size=config.training.eval_batch_size, shuffle=False) 89 | test_loader = DenseDataLoader(test_ds, batch_size=config.training.eval_batch_size, shuffle=False) 90 | n_node_pmf = torch.from_numpy(n_node_pmf).to(config.device) 91 | 92 | train_iter = iter(train_loader) 93 | # create data normalizer and its inverse 94 | scaler = datasets.get_data_scaler(config) 95 | inverse_scaler = datasets.get_data_inverse_scaler(config) 96 | 97 | # Setup SDEs 98 | if config.training.sde.lower() == 'vpsde': 99 | atom_sde = sde_lib.VPSDE(beta_min=config.model.node_beta_min, beta_max=config.model.node_beta_max, 100 | N=config.model.num_scales) 101 | bond_sde = sde_lib.VPSDE(beta_min=config.model.edge_beta_min, beta_max=config.model.edge_beta_max, 102 | N=config.model.num_scales) 103 | sampling_eps = 1e-3 104 | else: 105 | raise NotImplementedError(f"SDE {config.training.sde} unknown.") 106 | 107 | # Build one-step training and evaluation functions 108 | optimize_fn = losses.optimization_manager(config) 109 | continuous = config.training.continuous 110 | reduce_mean = config.training.reduce_mean 111 | likelihood_weighting = config.training.likelihood_weighting 112 | train_step_fn = losses.get_step_fn((atom_sde, bond_sde), train=True, optimize_fn=optimize_fn, 113 | reduce_mean=reduce_mean, continuous=continuous, 114 | likelihood_weighting=likelihood_weighting) 115 | eval_step_fn = losses.get_step_fn((atom_sde, bond_sde), train=False, optimize_fn=optimize_fn, 116 | reduce_mean=reduce_mean, continuous=continuous, 117 | likelihood_weighting=likelihood_weighting) 118 | 119 | test_FCDMetric = get_FCDMetric(test_ds.sub_smiles, device=config.device) 120 | eval_FCDMetric = get_FCDMetric(eval_ds.sub_smiles, device=config.device) 121 | 122 | # Build sampling functions 123 | if config.training.snapshot_sampling: 124 | sampling_atom_shape = (config.training.eval_batch_size, config.data.max_node, config.data.atom_channels) 125 | sampling_bond_shape = (config.training.eval_batch_size, config.data.bond_channels, 126 | config.data.max_node, config.data.max_node) 127 | sampling_fn = sampling.get_mol_sampling_fn(config, atom_sde, bond_sde, sampling_atom_shape, sampling_bond_shape, 128 | inverse_scaler, sampling_eps) 129 | 130 | num_train_steps = config.training.n_iters 131 | 132 | logging.info("Starting training loop at step %d." % (initial_step,)) 133 | 134 | for step in range(initial_step, num_train_steps + 1): 135 | try: 136 | graphs = next(train_iter) 137 | except StopIteration: 138 | train_iter = train_loader.__iter__() 139 | graphs = next(train_iter) 140 | 141 | batch = dense_mol(graphs, scaler, config.data.dequantization) 142 | 143 | # Execute one training step 144 | loss = train_step_fn(state, batch) 145 | if step % config.training.log_freq == 0: 146 | logging.info("step: %d, training_loss: %.5e" % (step, loss.item())) 147 | writer.add_scalar("training_loss", loss, step) 148 | 149 | # Save a temporary checkpoint to resume training after pre-emption periodically 150 | if step != 0 and step % config.training.snapshot_freq_for_preemption == 0: 151 | save_checkpoint(checkpoint_meta_dir, state) 152 | 153 | # Report the loss on evaluation dataset periodically 154 | if step % config.training.eval_freq == 0: 155 | for eval_graphs in eval_loader: 156 | eval_batch = dense_mol(eval_graphs, scaler) 157 | eval_loss = eval_step_fn(state, eval_batch) 158 | logging.info("step: %d, eval_loss: %.5e" % (step, eval_loss.item())) 159 | writer.add_scalar("eval_loss", eval_loss.item(), step) 160 | break 161 | for test_graphs in test_loader: 162 | test_batch = dense_mol(test_graphs, scaler) 163 | test_loss = eval_step_fn(state, test_batch) 164 | logging.info("step: %d, test_loss: %.5e" % (step, test_loss.item())) 165 | writer.add_scalar("test_loss", test_loss.item(), step) 166 | break 167 | 168 | # Save a checkpoint periodically and generate samples 169 | if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps: 170 | 171 | # Save the checkpoint. 172 | save_step = step // config.training.snapshot_freq 173 | save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{save_step}.pth'), state) 174 | 175 | # Generate and save samples 176 | if config.training.snapshot_sampling: 177 | ema.store(score_model.parameters()) 178 | ema.copy_to(score_model.parameters()) 179 | 180 | atom_sample, bond_sample, sample_steps, sample_nodes = sampling_fn(score_model, n_node_pmf) 181 | 182 | sample_list, valid_wd = tensor2mol(atom_sample, bond_sample, sample_nodes, config.data.atom_list, 183 | correct_validity=True, largest_connected_comp=True) 184 | ## fcd value 185 | smile_list = [Chem.MolToSmiles(mol) for mol in sample_list] 186 | fcd_test = test_FCDMetric(smile_list) 187 | fcd_eval = eval_FCDMetric(smile_list) 188 | 189 | ## log info 190 | valid_wd_rate = np.sum(valid_wd) / len(valid_wd) 191 | logging.info("step: %d, n_mol: %d, validity rate wd check: %.4f, fcd_val: %.4f, fcd_test: %.4f" % 192 | (step, len(sample_list), valid_wd_rate, fcd_eval, fcd_test)) 193 | 194 | ema.restore(score_model.parameters()) 195 | this_sample_dir = os.path.join(sample_dir, "iter_{}".format(step)) 196 | if not os.path.exists(this_sample_dir): 197 | os.makedirs(this_sample_dir) 198 | # graph visualization and save figs 199 | visualize.visualize_mols(sample_list[:16], this_sample_dir, config) 200 | 201 | 202 | def mol_sde_evaluate(config, workdir, eval_folder="eval"): 203 | """Evaluate trained models. 204 | 205 | Args: 206 | config: Configuration to use. 207 | workdir: Working directory for checkpoints. 208 | eval_folder: The subfolder for storing evaluation results. Default to "eval". 209 | """ 210 | 211 | ### Ignore info output by RDKit 212 | RDLogger.DisableLog('rdApp.error') 213 | RDLogger.DisableLog('rdApp.warning') 214 | 215 | # Create directory to eval_folder 216 | eval_dir = os.path.join(workdir, eval_folder) 217 | if not os.path.exists(eval_dir): 218 | os.makedirs(eval_dir) 219 | 220 | # Build data pipeline 221 | train_ds, _, test_ds, n_node_pmf = datasets.get_dataset(config) 222 | n_node_pmf = torch.from_numpy(n_node_pmf).to(config.device) 223 | # test_FCDMetric = get_FCDMetric(test_ds.sub_smiles, device=config.device) 224 | 225 | # Creat data normalizer and its inverse 226 | scaler = datasets.get_data_scaler(config) 227 | inverse_scaler = datasets.get_data_inverse_scaler(config) 228 | 229 | # Initialize model 230 | score_model = mutils.create_model(config) 231 | optimizer = losses.get_optimizer(config, score_model.parameters()) 232 | ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) 233 | state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) 234 | 235 | checkpoint_dir = os.path.join(workdir, "checkpoints") 236 | 237 | # Setup SDEs 238 | if config.training.sde.lower() == 'vpsde': 239 | atom_sde = sde_lib.VPSDE(beta_min=config.model.node_beta_min, beta_max=config.model.node_beta_max, 240 | N=config.model.num_scales) 241 | bond_sde = sde_lib.VPSDE(beta_min=config.model.edge_beta_min, beta_max=config.model.edge_beta_max, 242 | N=config.model.num_scales) 243 | sampling_eps = 1e-3 244 | elif config.training.sde.lower() == 'subvpsde': 245 | atom_sde = sde_lib.subVPSDE(beta_min=config.model.node_beta_min, beta_max=config.model.node_beta_max, 246 | N=config.model.num_scales) 247 | bond_sde = sde_lib.subVPSDE(beta_min=config.model.edge_beta_min, beta_max=config.model.edge_beta_nax, 248 | N=config.model.num_scales) 249 | sampling_eps = 1e-3 250 | else: 251 | raise NotImplementedError(f"SDE {config.training.sde} unknown.") 252 | 253 | 254 | if config.eval.enable_sampling: 255 | sampling_atom_shape = (config.eval.batch_size, config.data.max_node, config.data.atom_channels) 256 | sampling_bond_shape = (config.eval.batch_size, config.data.bond_channels, 257 | config.data.max_node, config.data.max_node) 258 | sampling_fn = sampling.get_mol_sampling_fn(config, atom_sde, bond_sde, sampling_atom_shape, sampling_bond_shape, 259 | inverse_scaler, sampling_eps) 260 | 261 | # Begin evaluation 262 | begin_ckpt = config.eval.begin_ckpt 263 | logging.info("begin checkpoint: %d" % (begin_ckpt,)) 264 | 265 | for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1): 266 | # Wait if the target checkpoint doesn't exist yet 267 | waiting_message_printed = False 268 | ckpt_filename = os.path.join(checkpoint_dir, "checkpoint_{}.pth".format(ckpt)) 269 | while not os.path.exists(ckpt_filename): 270 | if not waiting_message_printed: 271 | logging.warning("Waiting for the arrival of checkpoint_%d" % (ckpt,)) 272 | waiting_message_printed = True 273 | time.sleep(60) 274 | 275 | # Wait for 2 additional mins in case the file exists but is not ready for reading 276 | ckpt_path = os.path.join(checkpoint_dir, f'checkpoint_{ckpt}.pth') 277 | try: 278 | state = restore_checkpoint(ckpt_path, state, device=config.device) 279 | except: 280 | time.sleep(60) 281 | try: 282 | state = restore_checkpoint(ckpt_path, state, device=config.device) 283 | except: 284 | time.sleep(120) 285 | state = restore_checkpoint(ckpt_path, state, device=config.device) 286 | ema.copy_to(score_model.parameters()) 287 | 288 | # Generate samples and compute MMD stats 289 | if config.eval.enable_sampling: 290 | num_sampling_rounds = int(np.ceil(config.eval.num_samples / config.eval.batch_size)) 291 | all_samples = [] 292 | all_valid_wd = [] 293 | for r in range(num_sampling_rounds): 294 | logging.info("sampling -- ckpt: %d, round: %d" % (ckpt, r)) 295 | atom_sample, bond_sample, sample_steps, sample_nodes = sampling_fn(score_model, n_node_pmf) 296 | logging.info("sample steps: %d" % sample_steps) 297 | 298 | sample_list, valid_wd = tensor2mol(atom_sample, bond_sample, sample_nodes, config.data.atom_list, 299 | correct_validity=True, largest_connected_comp=True) 300 | 301 | all_samples += sample_list 302 | all_valid_wd += valid_wd 303 | 304 | all_samples = all_samples[:config.eval.num_samples] 305 | all_valid_wd = all_valid_wd[:config.eval.num_samples] 306 | smile_list = [] 307 | for mol in all_samples: 308 | if mol is not None: 309 | smile_list.append(Chem.MolToSmiles(mol)) 310 | 311 | # save the graphs 312 | sampler_name = config.sampling.method 313 | 314 | if config.eval.save_graph: 315 | # save the smile strings instead of rdkit mol object 316 | graph_file = os.path.join(eval_dir, sampler_name + "_ckpt_{}.pkl".format(ckpt)) 317 | with open(graph_file, "wb") as f: 318 | pickle.dump(smile_list, f) 319 | 320 | # evaluate 321 | logging.info('Number of molecules: %d' % len(all_samples)) 322 | ## valid, novelty, unique rate 323 | logging.info('sampling -- ckpt: {}, validity w/o correction: {:.6f}'. 324 | format(ckpt, np.sum(all_valid_wd) / len(all_valid_wd))) 325 | 326 | ## moses metric 327 | scores = get_all_metrics(gen=smile_list, k=len(smile_list), device=config.device, n_jobs=8, 328 | test=test_ds.sub_smiles, train=train_ds.sub_smiles) 329 | for metric in ['valid', f'unique@{len(smile_list)}', 'FCD/Test', 'Novelty']: 330 | logging.info(f'sampling -- ckpt: {ckpt}, {metric}: {scores[metric]}') 331 | 332 | ## NSPDK evaluation 333 | if config.eval.nspdk: 334 | nspdk_eval = get_nspdk_eval(config) 335 | test_smiles = test_ds.sub_smiles 336 | test_mols = [] 337 | for smile in test_smiles: 338 | mol = Chem.MolFromSmiles(smile) 339 | # Chem.Kekulize(mol) 340 | test_mols.append(mol) 341 | test_nx_graphs = mols_to_nx(test_mols) 342 | gen_nx_graphs = mols_to_nx(all_samples) 343 | nspdk_res = nspdk_eval(test_nx_graphs, gen_nx_graphs) 344 | logging.info('sampling -- ckpt: {}, NSPDK: {}'.format(ckpt, nspdk_res)) 345 | 346 | 347 | run_train_dict = { 348 | 'mol_sde': mol_sde_train 349 | } 350 | run_eval_dict = { 351 | 'mol_sde': mol_sde_evaluate, 352 | } 353 | 354 | 355 | def train(config, workdir): 356 | run_train_dict[config.model_type](config, workdir) 357 | 358 | 359 | def evaluate(config, workdir, eval_folder='eval'): 360 | run_eval_dict[config.model_type](config, workdir, eval_folder) 361 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | """Various sampling methods.""" 2 | 3 | import functools 4 | 5 | import torch 6 | import numpy as np 7 | import abc 8 | 9 | from models.utils import get_multi_score_fn 10 | from scipy import integrate 11 | # from torchdiffeq import odeint 12 | import sde_lib 13 | from models import utils as mutils 14 | from dpm_solvers import get_mol_sampler_dpm1, get_mol_sampler_dpm2, get_mol_sampler_dpm3, \ 15 | get_mol_sampler_dpm_mix, get_sampler_dpm3 16 | import time 17 | 18 | 19 | _CORRECTORS = {} 20 | _PREDICTORS = {} 21 | 22 | 23 | def register_predictor(cls=None, *, name=None): 24 | """A decorator for registering predictor classes.""" 25 | 26 | def _register(cls): 27 | if name is None: 28 | local_name = cls.__name__ 29 | else: 30 | local_name = name 31 | if local_name in _PREDICTORS: 32 | raise ValueError(f'Already registered predictor with name: {local_name}') 33 | _PREDICTORS[local_name] = cls 34 | return cls 35 | 36 | if cls is None: 37 | return _register 38 | else: 39 | return _register(cls) 40 | 41 | 42 | def register_corrector(cls=None, *, name=None): 43 | """A decorator for registering corrector classes.""" 44 | 45 | def _register(cls): 46 | if name is None: 47 | local_name = cls.__name__ 48 | else: 49 | local_name = name 50 | if local_name in _CORRECTORS: 51 | raise ValueError(f'Already registered corrector with name: {local_name}') 52 | _CORRECTORS[local_name] = cls 53 | return cls 54 | 55 | if cls is None: 56 | return _register 57 | else: 58 | return _register(cls) 59 | 60 | 61 | def get_predictor(name): 62 | return _PREDICTORS[name] 63 | 64 | 65 | def get_corrector(name): 66 | return _CORRECTORS[name] 67 | 68 | 69 | def get_mol_sampling_fn(config, atom_sde, bond_sde, atom_shape, bond_shape, inverse_scaler, eps): 70 | """Create a sampling function for molecule. 71 | 72 | Args: 73 | config: A `ml_collections.ConfigDict` object that contains all configuration information. 74 | atom_sde, bond_sde: A `sde_lib.SDE` object that represents the forward SDE. 75 | atom_shape, bond_shape: A sequence of integers representing the expected shape of a single sample. 76 | inverse_scaler: The inverse data normalizer function. 77 | eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability. 78 | 79 | Returns: 80 | A function that takes random states and a replicated training state and outputs samples with the 81 | trailing dimensions matching `shape`. 82 | """ 83 | 84 | sampler_name = config.sampling.method 85 | if sampler_name.lower() == 'dpm1': 86 | sampling_fn = get_mol_sampler_dpm1(atom_sde=atom_sde, 87 | bond_sde=bond_sde, 88 | atom_shape=atom_shape, 89 | bond_shape=bond_shape, 90 | inverse_scaler=inverse_scaler, 91 | time_step=config.sampling.ode_step, 92 | eps=eps, 93 | denoise=config.sampling.noise_removal, 94 | device=config.device) 95 | elif sampler_name.lower() == 'dpm2': 96 | sampling_fn = get_mol_sampler_dpm2(atom_sde=atom_sde, 97 | bond_sde=bond_sde, 98 | atom_shape=atom_shape, 99 | bond_shape=bond_shape, 100 | inverse_scaler=inverse_scaler, 101 | time_step=config.sampling.ode_step, 102 | eps=eps, 103 | denoise=config.sampling.noise_removal, 104 | device=config.device) 105 | elif sampler_name.lower() == 'dpm3': 106 | sampling_fn = get_mol_sampler_dpm3(atom_sde=atom_sde, 107 | bond_sde=bond_sde, 108 | atom_shape=atom_shape, 109 | bond_shape=bond_shape, 110 | inverse_scaler=inverse_scaler, 111 | time_step=config.sampling.ode_step, 112 | eps=eps, 113 | denoise=config.sampling.noise_removal, 114 | device=config.device) 115 | elif sampler_name.lower() == 'dpm_mix': 116 | sampling_fn = get_mol_sampler_dpm_mix(atom_sde=atom_sde, 117 | bond_sde=bond_sde, 118 | atom_shape=atom_shape, 119 | bond_shape=bond_shape, 120 | inverse_scaler=inverse_scaler, 121 | time_step=config.sampling.ode_step, 122 | eps=eps, 123 | denoise=config.sampling.noise_removal, 124 | device=config.device) 125 | # Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases. 126 | elif sampler_name.lower() == 'pc': 127 | 128 | predictor = get_predictor(config.sampling.predictor.lower()) 129 | corrector = get_corrector(config.sampling.corrector.lower()) 130 | 131 | sampling_fn = get_mol_pc_sampler(atom_sde=atom_sde, 132 | bond_sde=bond_sde, 133 | atom_shape=atom_shape, 134 | bond_shape=bond_shape, 135 | predictor=predictor, 136 | corrector=corrector, 137 | inverse_scaler=inverse_scaler, 138 | snr=(config.sampling.atom_snr, config.sampling.bond_snr), 139 | n_steps=config.sampling.n_steps_each, 140 | probability_flow=config.sampling.probability_flow, 141 | continuous=config.training.continuous, 142 | denoise=config.sampling.noise_removal, 143 | eps=eps, 144 | device=config.device) 145 | else: 146 | raise ValueError(f"Sampler name {sampler_name} unknown.") 147 | 148 | return sampling_fn 149 | 150 | 151 | class Predictor(abc.ABC): 152 | """The abstract class for a predictor algorithm.""" 153 | 154 | def __init__(self, sde, score_fn, probability_flow=False): 155 | super().__init__() 156 | self.sde = sde 157 | # Compute the reverse SDE/ODE 158 | if isinstance(sde, tuple): 159 | self.rsde = (sde[0].reverse(score_fn, probability_flow), sde[1].reverse(score_fn, probability_flow)) 160 | else: 161 | self.rsde = sde.reverse(score_fn, probability_flow) 162 | self.score_fn = score_fn 163 | 164 | @abc.abstractmethod 165 | def update_fn(self, x, t, *args, **kwargs): 166 | """One update of the predictor. 167 | 168 | Args: 169 | x: A PyTorch tensor representing the current state. 170 | t: A PyTorch tensor representing the current time step. 171 | 172 | Returns: 173 | x: A PyTorch tensor of the next state. 174 | x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. 175 | """ 176 | pass 177 | 178 | @abc.abstractmethod 179 | def update_mol_fn(self, x, t, *args, **kwargs): 180 | """One update of the predictor for molecule graphs. 181 | 182 | Args: 183 | x: A tuple of PyTorch tensor (x_atom, x_bond) representing the current state. 184 | t: A PyTorch tensor representing the current time step. 185 | 186 | Returns: 187 | x: A tuple of PyTorch tensor (x_atom, x_bond) of the next state. 188 | x_mean: A tuple of PyTorch tensor. The next state without random noise. Useful for denoising. 189 | """ 190 | pass 191 | 192 | 193 | class Corrector(abc.ABC): 194 | """The abstract class for a corrector algorithm.""" 195 | 196 | def __init__(self, sde, score_fn, snr, n_steps): 197 | super().__init__() 198 | self.sde = sde 199 | self.score_fn = score_fn 200 | self.snr = snr 201 | self.n_steps = n_steps 202 | 203 | @abc.abstractmethod 204 | def update_fn(self, x, t, *args, **kwargs): 205 | """One update of the corrector. 206 | 207 | Args: 208 | x: A PyTorch tensor representing the current state. 209 | t: A PyTorch tensor representing the current time step. 210 | 211 | Returns: 212 | x: A PyTorch tensor of the next state. 213 | x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. 214 | """ 215 | pass 216 | 217 | @abc.abstractmethod 218 | def update_mol_fn(self, x, t, *args, **kwargs): 219 | """One update of the corrector for molecule graphs. 220 | 221 | Args: 222 | x: A tuple of PyTorch tensor (x_atom, x_bond) representing the current state. 223 | t: A PyTorch tensor representing the current time step. 224 | 225 | Returns: 226 | x: A tuple of PyTorch tensor (x_atom, x_bond) of the next state. 227 | x_mean: A tuple of PyTorch tensor. The next state without random noise. Useful for denoising. 228 | """ 229 | pass 230 | 231 | 232 | @register_predictor(name='euler_maruyama') 233 | class EulerMaruyamaPredictor(Predictor): 234 | def __init__(self, sde, score_fn, probability_flow=False): 235 | super().__init__(sde, score_fn, probability_flow) 236 | 237 | def update_fn(self, x, t, *args, **kwargs): 238 | dt = -1. / self.rsde.N 239 | z = torch.randn_like(x) 240 | z = torch.tril(z, -1) 241 | z = z + z.transpose(-1, -2) 242 | drift, diffusion = self.rsde.sde(x, t, *args, **kwargs) 243 | drift = torch.tril(drift, -1) 244 | drift = drift + drift.transpose(-1, -2) 245 | x_mean = x + drift * dt 246 | x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z 247 | return x, x_mean 248 | 249 | def update_mol_fn(self, x, t, *args, **kwargs): 250 | atom_score, bond_score = self.score_fn(x, t, *args, **kwargs) 251 | # print('predictor atom norm: ', torch.norm(atom_score.reshape(atom_score.shape[0], -1), dim=-1).mean(), t[0]) 252 | 253 | x_atom, x_bond = x 254 | dt = -1. / self.rsde[0].N 255 | 256 | # atom update 257 | z_atom = torch.randn_like(x_atom) 258 | drift_atom, diffusion_atom = self.rsde[0].sde_score(x_atom, t, atom_score) 259 | x_atom_mean = x_atom + drift_atom * dt 260 | x_atom = x_atom_mean + diffusion_atom[:, None, None] * np.sqrt(-dt) * z_atom 261 | 262 | # bond update 263 | z_bond = torch.randn_like(x_bond) 264 | z_bond = torch.tril(z_bond, -1) 265 | z_bond = z_bond + z_bond.transpose(-1, -2) 266 | drift_bond, diffusion_bond = self.rsde[1].sde_score(x_bond, t, bond_score) 267 | 268 | x_bond_mean = x_bond + drift_bond * dt 269 | x_bond = x_bond_mean + diffusion_bond[:, None, None, None] * np.sqrt(-dt) * z_bond 270 | 271 | return (x_atom, x_bond), (x_atom_mean, x_bond_mean) 272 | 273 | 274 | @register_corrector(name='langevin') 275 | class LangevinCorrector(Corrector): 276 | def __init__(self, sde, score_fn, snr, n_steps): 277 | super().__init__(sde, score_fn, snr, n_steps) 278 | 279 | def update_fn(self, x, t, *args, **kwargs): 280 | sde = self.sde 281 | score_fn = self.score_fn 282 | n_steps = self.n_steps 283 | target_snr = self.snr 284 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 285 | timestep = (t * (sde.N - 1) / sde.T).long() 286 | # Note: it seems that subVPSDE doesn't set alphas 287 | alpha = sde.alphas.to(t.device)[timestep] 288 | else: 289 | alpha = torch.ones_like(t) 290 | 291 | for i in range(n_steps): 292 | 293 | grad = score_fn(x, t, *args, **kwargs) 294 | noise = torch.randn_like(x) 295 | 296 | noise = torch.tril(noise, -1) 297 | noise = noise + noise.transpose(-1, -2) 298 | 299 | mask = kwargs['mask'] 300 | 301 | # mask invalid elements and calculate norm 302 | mask_tmp = mask.reshape(mask.shape[0], -1) 303 | 304 | grad_norm = torch.norm(mask_tmp * grad.reshape(grad.shape[0], -1), dim=-1).mean() 305 | noise_norm = torch.norm(mask_tmp * noise.reshape(noise.shape[0], -1), dim=-1).mean() 306 | 307 | step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha 308 | x_mean = x + step_size[:, None, None, None] * grad 309 | x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise 310 | 311 | return x, x_mean 312 | 313 | 314 | def update_mol_fn(self, x, t, *args, **kwargs): 315 | x_atom, x_bond = x 316 | atom_sde, bond_sde = self.sde 317 | score_fn = self.score_fn 318 | n_steps = self.n_steps 319 | atom_snr, bond_snr = self.snr 320 | if isinstance(atom_sde, sde_lib.VPSDE) or isinstance(atom_sde, sde_lib.subVPSDE): 321 | timestep = (t * (atom_sde.N - 1) / atom_sde.T).long() 322 | # Note: it seems that subVPSDE doesn't set alphas 323 | alpha_atom = atom_sde.alphas.to(t.device)[timestep] 324 | alpha_bond = bond_sde.alphas.to(t.device)[timestep] 325 | else: 326 | alpha_atom = alpha_bond = torch.ones_like(t) 327 | 328 | for i in range(n_steps): 329 | grad_atom, grad_bond = score_fn(x, t, *args, **kwargs) 330 | 331 | # update atom 332 | noise_atom = torch.randn_like(x_atom) 333 | noise_atom = noise_atom * kwargs['atom_mask'].unsqueeze(-1) 334 | 335 | ## mask invalid elements and calculate norm 336 | # atom_mask = kwargs['atom_mask'].unsqueeze(-1) 337 | # atom_mask = atom_mask.repeat(1, 1, grad_atom.shape[-1]).reshape(grad_atom.shape[0], -1) 338 | 339 | # grad_norm_a = torch.norm(atom_mask * grad_atom.reshape(grad_atom.shape[0], -1), dim=-1).mean() 340 | # noise_norm_a = torch.norm(atom_mask * noise_atom.reshape(noise_atom.shape[0], -1), dim=-1).mean() 341 | grad_norm_a = torch.norm(grad_atom.reshape(grad_atom.shape[0], -1), dim=-1).mean() 342 | noise_norm_a = torch.norm(noise_atom.reshape(noise_atom.shape[0], -1), dim=-1).mean() 343 | 344 | # print('Corrector atom score norm:', grad_norm_a, t[0]) 345 | step_size_a = (atom_snr * noise_norm_a / grad_norm_a) ** 2 * 2 * alpha_atom 346 | x_atom_mean = x_atom + step_size_a[:, None, None] * grad_norm_a 347 | x_atom = x_atom_mean + torch.sqrt(step_size_a * 2)[:, None, None] * noise_atom 348 | 349 | # update bond 350 | noise_bond = torch.randn_like(x_bond) 351 | noise_bond = torch.tril(noise_bond, -1) 352 | noise_bond = noise_bond + noise_bond.transpose(-1, -2) 353 | noise_bond = noise_bond * kwargs['bond_mask'] 354 | 355 | # bond_mask = kwargs['bond_mask'].repeat(1, grad_bond.shape[1], 1, 1).reshape(grad_bond.shape[0], -1) 356 | # grad_norm_b = torch.norm(bond_mask * grad_bond.reshape(grad_bond.shape[0], -1), dim=-1).mean() 357 | # noise_norm_b = torch.norm(bond_mask * noise_bond.reshape(noise_bond.shape[0], -1), dim=-1).mean() 358 | grad_norm_b = torch.norm(grad_bond.reshape(grad_bond.shape[0], -1), dim=-1).mean() 359 | noise_norm_b = torch.norm(noise_bond.reshape(noise_bond.shape[0], -1), dim=-1).mean() 360 | 361 | step_size_b = (bond_snr * noise_norm_b / grad_norm_b) ** 2 * 2 * alpha_bond 362 | x_bond_mean = x_bond + step_size_b[:, None, None, None] * grad_norm_b 363 | x_bond = x_bond_mean + torch.sqrt(step_size_b * 2)[:, None, None, None] * noise_bond 364 | 365 | return (x_atom, x_bond), (x_atom_mean, x_bond_mean) 366 | 367 | 368 | @register_predictor(name='none') 369 | class NonePredictor(Predictor): 370 | """An empty predictor that does nothing.""" 371 | 372 | def __init__(self, sde, score_fn, probability_flow=False): 373 | pass 374 | 375 | def update_fn(self, x, t, *args, **kwargs): 376 | return x, x 377 | 378 | def update_mol_fn(self, x, t, *args, **kwargs): 379 | return x, x 380 | 381 | 382 | @register_corrector(name='none') 383 | class NoneCorrector(Corrector): 384 | """An empty corrector that does nothing.""" 385 | 386 | def __init__(self, sde, score_fn, snr, n_steps): 387 | pass 388 | 389 | def update_fn(self, x, t, *args, **kwargs): 390 | return x, x 391 | 392 | def update_atom_fn(self, x, t, *args, **kwargs): 393 | return x, x 394 | 395 | def update_bond_fn(self, x, t, *args, **kwargs): 396 | return x, x 397 | 398 | def update_mol_fn(self, x, t, *args, **kwargs): 399 | return x, x 400 | 401 | 402 | def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous, *args, **kwargs): 403 | """A wrapper that configures and returns the update function of predictors.""" 404 | if isinstance(sde, tuple): 405 | score_fn = mutils.get_multi_score_fn(sde[0], sde[1], model, train=False, continuous=continuous) 406 | else: 407 | # score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) 408 | raise ValueError('Score function error.') 409 | if predictor is None: 410 | # Corrector-only sampler 411 | predictor_obj = NonePredictor(sde, score_fn, probability_flow) 412 | else: 413 | predictor_obj = predictor(sde, score_fn, probability_flow) 414 | if isinstance(sde, tuple): 415 | return predictor_obj.update_mol_fn(x, t, *args, **kwargs) 416 | return predictor_obj.update_fn(x, t, *args, **kwargs) 417 | 418 | 419 | def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps, *args, **kwargs): 420 | """A wrapper that configures and returns the update function of correctors.""" 421 | if isinstance(sde, tuple): 422 | score_fn = mutils.get_multi_score_fn(sde[0], sde[1], model, train=False, continuous=continuous) 423 | else: 424 | # score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) 425 | raise ValueError('Score function error.') 426 | if corrector is None: 427 | # Predictor-only sampler 428 | corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps) 429 | else: 430 | corrector_obj = corrector(sde, score_fn, snr, n_steps) 431 | if isinstance(sde, tuple): 432 | return corrector_obj.update_mol_fn(x, t, *args, **kwargs) 433 | return corrector_obj.update_fn(x, t, *args, **kwargs) 434 | 435 | 436 | def get_mol_pc_sampler(atom_sde, bond_sde, atom_shape, bond_shape, predictor, corrector, inverse_scaler, snr, 437 | n_steps=1, probability_flow=False, continuous=False, 438 | denoise=True, eps=1e-3, device='cuda'): 439 | """Create a Predictor-Corrector (PC) sampler for molecule graph generation. 440 | 441 | Args: 442 | atom_sde, bond_sde: An `sde_lib.SDE` object representing the forward SDE. 443 | atom_shape, bond_shape: A sequence of integers. The expected shape of a single sample. 444 | predictor: A subclass of `sampling.Predictor` representing the predictor algorithm. 445 | corrector: A subclass of `sampling.Corrector` representing the corrector algorithm. 446 | inverse_scaler: The inverse data normalizer. 447 | snr: A `float` number. The signal-to-noise ratio for configuring correctors. 448 | n_steps: An integer. The number of corrector steps per predictor update. 449 | probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor. 450 | continuous: `True` indicates that the score model was continuously trained. 451 | denoise: If `True`, add one-step denoising to the final samples. 452 | eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues. 453 | device: PyTorch device. 454 | 455 | Returns: 456 | A sampling function that returns samples and the number of function evaluations during sampling. 457 | """ 458 | # Create predictor & corrector update functions 459 | predictor_update_fn = functools.partial(shared_predictor_update_fn, 460 | sde=(atom_sde, bond_sde), 461 | predictor=predictor, 462 | probability_flow=probability_flow, 463 | continuous=continuous) 464 | corrector_update_fn = functools.partial(shared_corrector_update_fn, 465 | sde=(atom_sde, bond_sde), 466 | corrector=corrector, 467 | continuous=continuous, 468 | snr=snr, 469 | n_steps=n_steps) 470 | 471 | 472 | def mol_pc_sampler(model, n_nodes_pmf): 473 | """The PC sampler function. 474 | 475 | Args: 476 | model: A score model. 477 | n_nodes_pmf: Probability mass function of graph nodes. 478 | 479 | Returns: 480 | Samples, number of function evaluations. 481 | """ 482 | with torch.no_grad(): 483 | # Initial sample 484 | 485 | x_atom = atom_sde.prior_sampling(atom_shape).to(device) 486 | x_bond = bond_sde.prior_sampling(bond_shape).to(device) 487 | 488 | timesteps = torch.linspace(atom_sde.T, eps, atom_sde.N, device=device) 489 | 490 | # Sample the number of nodes 491 | n_nodes = torch.multinomial(n_nodes_pmf, atom_shape[0], replacement=True) 492 | atom_mask = torch.zeros((atom_shape[0], atom_shape[1]), device=device) 493 | for i in range(atom_shape[0]): 494 | atom_mask[i][:n_nodes[i]] = 1. 495 | bond_mask = (atom_mask[:, None, :] * atom_mask[:, :, None]).unsqueeze(1) 496 | bond_mask = torch.tril(bond_mask, -1) 497 | bond_mask = bond_mask + bond_mask.transpose(-1, -2) 498 | 499 | x_atom = x_atom * atom_mask.unsqueeze(-1) 500 | x_bond = x_bond * bond_mask 501 | 502 | for i in range(atom_sde.N): 503 | t = timesteps[i] 504 | vec_t = torch.ones(atom_shape[0], device=t.device) * t 505 | 506 | (x_atom, x_bond), (x_atom_mean, x_bond_mean) = corrector_update_fn((x_atom, x_bond), vec_t, model=model, 507 | atom_mask=atom_mask, 508 | bond_mask=bond_mask) 509 | x_atom = x_atom * atom_mask.unsqueeze(-1) 510 | x_bond = x_bond * bond_mask 511 | 512 | (x_atom, x_bond), (x_atom_mean, x_bond_mean) = predictor_update_fn((x_atom, x_bond), vec_t, model=model, 513 | atom_mask=atom_mask, 514 | bond_mask=bond_mask) 515 | x_atom = x_atom * atom_mask.unsqueeze(-1) 516 | x_bond = x_bond * bond_mask 517 | 518 | return inverse_scaler(x_atom_mean if denoise else x_atom, atom=True) * atom_mask.unsqueeze(-1),\ 519 | inverse_scaler(x_bond_mean if denoise else x_bond, atom=False) * bond_mask,\ 520 | atom_sde.N * (n_steps + 1), n_nodes 521 | 522 | return mol_pc_sampler 523 | -------------------------------------------------------------------------------- /sascorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 5 | # Peter Ertl and Ansgar Schuffenhauer 6 | # Journal of Cheminformatics 1:8 (2009) 7 | # http://www.jcheminf.com/content/1/1/8 8 | # 9 | # several small modifications to the original paper are included 10 | # particularly slightly different formula for marocyclic penalty 11 | # and taking into account also molecule symmetry (fingerprint density) 12 | # 13 | # for a set of 10k diverse molecules the agreement between the original method 14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 15 | # 16 | # peter ertl & greg landrum, september 2013 17 | # 18 | from __future__ import print_function 19 | 20 | import math 21 | import pickle as cPickle 22 | import os.path as op 23 | from rdkit import Chem 24 | from rdkit.Chem import rdMolDescriptors 25 | # from rdkit.six.moves import cPickle 26 | from rdkit.six import iteritems 27 | 28 | _fscores = None 29 | 30 | 31 | def readFragmentScores(name='fpscores'): 32 | import gzip 33 | global _fscores 34 | # generate the full path filename: 35 | if name == "fpscores": 36 | name = op.join(op.dirname(__file__), name) 37 | _fscores = cPickle.load(gzip.open('%s.pkl.gz' % name)) 38 | outDict = {} 39 | for i in _fscores: 40 | for j in range(1, len(i)): 41 | outDict[i[j]] = float(i[0]) 42 | _fscores = outDict 43 | 44 | 45 | def numBridgeheadsAndSpiro(mol, ri=None): 46 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 47 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 48 | return nBridgehead, nSpiro 49 | 50 | 51 | def calculateScore(m): 52 | if _fscores is None: 53 | readFragmentScores() 54 | 55 | # fragment score 56 | fp = rdMolDescriptors.GetMorganFingerprint(m, 2) #<- 2 is the *radius* of the circular fingerprint 57 | fps = fp.GetNonzeroElements() 58 | score1 = 0. 59 | nf = 0 60 | for bitId, v in iteritems(fps): 61 | nf += v 62 | sfp = bitId 63 | score1 += _fscores.get(sfp, -4) * v 64 | score1 /= nf 65 | 66 | # features score 67 | nAtoms = m.GetNumAtoms() 68 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 69 | ri = m.GetRingInfo() 70 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 71 | nMacrocycles = 0 72 | for x in ri.AtomRings(): 73 | if len(x) > 8: 74 | nMacrocycles += 1 75 | 76 | sizePenalty = nAtoms**1.005 - nAtoms 77 | stereoPenalty = math.log10(nChiralCenters + 1) 78 | spiroPenalty = math.log10(nSpiro + 1) 79 | bridgePenalty = math.log10(nBridgeheads + 1) 80 | macrocyclePenalty = 0. 81 | # --------------------------------------- 82 | # This differs from the paper, which defines: 83 | # macrocyclePenalty = math.log10(nMacrocycles+1) 84 | # This form generates better results when 2 or more macrocycles are present 85 | if nMacrocycles > 0: 86 | macrocyclePenalty = math.log10(2) 87 | 88 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 89 | 90 | # correction for the fingerprint density 91 | # not in the original publication, added in version 1.1 92 | # to make highly symmetrical molecules easier to synthetise 93 | score3 = 0. 94 | if nAtoms > len(fps): 95 | score3 = math.log(float(nAtoms) / len(fps)) * .5 96 | 97 | sascore = score1 + score2 + score3 98 | 99 | # need to transform "raw" value into scale between 1 and 10 100 | min_score = -4.0 101 | max_score = 2.5 102 | sascore = 11. - (sascore - min_score + 1) / (max_score - min_score) * 9. 103 | # smooth the 10-end 104 | if sascore > 8.: 105 | sascore = 8. + math.log(sascore + 1. - 9.) 106 | if sascore > 10.: 107 | sascore = 10.0 108 | elif sascore < 1.: 109 | sascore = 1.0 110 | 111 | return sascore 112 | 113 | 114 | # def processMols(mols): 115 | # print('smiles\tName\tsa_score') 116 | # for m in mols: 117 | # if m is None: 118 | # continue 119 | 120 | # s = calculateScore(m) 121 | 122 | # smiles = Chem.MolToSmiles(m) 123 | # print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 124 | 125 | 126 | # if __name__ == '__main__': 127 | # import sys, time 128 | 129 | # t1 = time.time() 130 | # readFragmentScores("fpscores") 131 | # t2 = time.time() 132 | 133 | # suppl = Chem.SmilesMolSupplier(sys.argv[1]) 134 | # t3 = time.time() 135 | # processMols(suppl) 136 | # t4 = time.time() 137 | 138 | # print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), 139 | # file=sys.stderr) 140 | 141 | # 142 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 143 | # All rights reserved. 144 | # 145 | # Redistribution and use in source and binary forms, with or without 146 | # modification, are permitted provided that the following conditions are 147 | # met: 148 | # 149 | # * Redistributions of source code must retain the above copyright 150 | # notice, this list of conditions and the following disclaimer. 151 | # * Redistributions in binary form must reproduce the above 152 | # copyright notice, this list of conditions and the following 153 | # disclaimer in the documentation and/or other materials provided 154 | # with the distribution. 155 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 156 | # nor the names of its contributors may be used to endorse or promote 157 | # products derived from this software without specific prior written permission. 158 | # 159 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 160 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 161 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 162 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 163 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 164 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 165 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 166 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 167 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 168 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 169 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 170 | # -------------------------------------------------------------------------------- /sde_lib.py: -------------------------------------------------------------------------------- 1 | """Abstract SDE classes, Reverse SDE, and VE/VP SDEs.""" 2 | 3 | import abc 4 | import torch 5 | import numpy as np 6 | 7 | 8 | class SDE(abc.ABC): 9 | """SDE abstract class. Functions are designed for a mini-batch of inputs.""" 10 | 11 | def __init__(self, N): 12 | """Construct an SDE. 13 | 14 | Args: 15 | N: number of discretization time steps. 16 | """ 17 | super().__init__() 18 | self.N = N 19 | 20 | @property 21 | @abc.abstractmethod 22 | def T(self): 23 | """End time of the SDE.""" 24 | pass 25 | 26 | @abc.abstractmethod 27 | def sde(self, x, t): 28 | pass 29 | 30 | @abc.abstractmethod 31 | def marginal_prob(self, x, t): 32 | """Parameters to determine the marginal distribution of the SDE, $p_t(x)$""" 33 | pass 34 | 35 | @abc.abstractmethod 36 | def prior_sampling(self, shape): 37 | """Generate one sample from the prior distribution, $p_T(x)$.""" 38 | pass 39 | 40 | @abc.abstractmethod 41 | def prior_logp(self, z, mask): 42 | """Compute log-density of the prior distribution. 43 | 44 | Useful for computing the log-likelihood via probability flow ODE. 45 | 46 | Args: 47 | z: latent code 48 | Returns: 49 | log probability density 50 | """ 51 | pass 52 | 53 | def discretize(self, x, t): 54 | """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. 55 | 56 | Useful for reverse diffusion sampling and probability flow sampling. 57 | Defaults to Euler-Maruyama discretization. 58 | 59 | Args: 60 | x: a torch tensor 61 | t: a torch float representing the time step (from 0 to `self.T`) 62 | 63 | Returns: 64 | f, G 65 | """ 66 | dt = 1 / self.N 67 | drift, diffusion = self.sde(x, t) 68 | f = drift * dt 69 | G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) 70 | return f, G 71 | 72 | def reverse(self, score_fn, probability_flow=False): 73 | """Create the reverse-time SDE/ODE. 74 | 75 | Args: 76 | score_fn: A time-dependent score-based model that takes x and t and returns the score. 77 | probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. 78 | """ 79 | 80 | N = self.N 81 | T = self.T 82 | sde_fn = self.sde 83 | discretize_fn = self.discretize 84 | 85 | # Build the class for reverse-time SDE. 86 | class RSDE(self.__class__): 87 | def __init__(self): 88 | self.N = N 89 | self.probability_flow = probability_flow 90 | 91 | @property 92 | def T(self): 93 | return T 94 | 95 | def sde(self, x, t, *args, **kwargs): 96 | """Create the drift and diffusion functions for the reverse SDE/ODE.""" 97 | 98 | drift, diffusion = sde_fn(x, t) 99 | score = score_fn(x, t, *args, **kwargs) 100 | drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) 101 | # Set the diffusion function to zero for ODEs. 102 | diffusion = 0. if self.probability_flow else diffusion 103 | return drift, diffusion 104 | 105 | def sde_score(self, x, t, score): 106 | """Create the drift and diffusion functions for the reverse SDE/ODE, given score values.""" 107 | drift, diffusion = sde_fn(x, t) 108 | if len(score.shape) == 4: 109 | drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) 110 | elif len(score.shape) == 3: 111 | drift = drift - diffusion[:, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) 112 | else: 113 | raise ValueError 114 | diffusion = 0. if self.probability_flow else diffusion 115 | return drift, diffusion 116 | 117 | def discretize(self, x, t, *args, **kwargs): 118 | """Create discretized iteration rules for the reverse diffusion sampler.""" 119 | f, G = discretize_fn(x, t) 120 | rev_f = f - G[:, None, None, None] ** 2 * score_fn(x, t, *args, **kwargs) * \ 121 | (0.5 if self.probability_flow else 1.) 122 | rev_G = torch.zeros_like(G) if self.probability_flow else G 123 | return rev_f, rev_G 124 | 125 | def discretize_score(self, x, t, score): 126 | """Create discretized iteration rules for the reverse diffusion sampler, given score values.""" 127 | f, G = discretize_fn(x, t) 128 | if len(score.shape) == 4: 129 | rev_f = f - G[:, None, None, None] ** 2 * score * \ 130 | (0.5 if self.probability_flow else 1.) 131 | elif len(score.shape) == 3: 132 | rev_f = f - G[:, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) 133 | else: 134 | raise ValueError 135 | rev_G = torch.zeros_like(G) if self.probability_flow else G 136 | return rev_f, rev_G 137 | 138 | return RSDE() 139 | 140 | 141 | class VPSDE(SDE): 142 | def __init__(self, beta_min=0.1, beta_max=20, N=1000): 143 | """Construct a Variance Preserving SDE. 144 | 145 | Args: 146 | beta_min: value of beta(0) 147 | beta_max: value of beta(1) 148 | N: number of discretization steps 149 | """ 150 | super().__init__(N) 151 | self.beta_0 = beta_min 152 | self.beta_1 = beta_max 153 | self.N = N 154 | self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N) 155 | self.alphas = 1. - self.discrete_betas 156 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 157 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 158 | self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) 159 | 160 | @property 161 | def T(self): 162 | return 1 163 | 164 | def sde(self, x, t): 165 | beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) 166 | if len(x.shape) == 4: 167 | drift = -0.5 * beta_t[:, None, None, None] * x 168 | elif len(x.shape) == 3: 169 | drift = -0.5 * beta_t[:, None, None] * x 170 | else: 171 | raise NotImplementedError 172 | diffusion = torch.sqrt(beta_t) 173 | return drift, diffusion 174 | 175 | def marginal_prob(self, x, t): 176 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 177 | if len(x.shape) == 4: 178 | mean = torch.exp(log_mean_coeff[:, None, None, None]) * x 179 | elif len(x.shape) == 3: 180 | mean = torch.exp(log_mean_coeff[:, None, None]) * x 181 | else: 182 | raise ValueError("The shape of x in marginal_prob is not correct.") 183 | std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) 184 | return mean, std 185 | 186 | def log_snr(self, t): 187 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 188 | mean = torch.exp(log_mean_coeff) 189 | std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) 190 | log_snr = torch.log(mean / std) 191 | return log_snr, mean, std 192 | 193 | def log_snr_np(self, t): 194 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 195 | mean = np.exp(log_mean_coeff) 196 | std = np.sqrt(1. - np.exp(2. * log_mean_coeff)) 197 | log_snr = np.log(mean / std) 198 | return log_snr 199 | 200 | def lambda2t(self, lambda_ori): 201 | log_val = torch.log(torch.exp(-2. * lambda_ori) + 1.) 202 | t = 2. * log_val / (torch.sqrt(self.beta_0 ** 2 + 2. * (self.beta_1 - self.beta_0) * log_val) + self.beta_0) 203 | return t 204 | 205 | def lambda2t_np(self, lambda_ori): 206 | log_val = np.log(np.exp(-2. * lambda_ori) + 1.) 207 | t = 2. * log_val / (np.sqrt(self.beta_0 ** 2 + 2. * (self.beta_1 - self.beta_0) * log_val) + self.beta_0) 208 | return t 209 | 210 | def prior_sampling(self, shape): 211 | sample = torch.randn(*shape) 212 | if len(shape) == 4: 213 | sample = torch.tril(sample, -1) 214 | sample = sample + sample.transpose(-1, -2) 215 | 216 | return sample 217 | 218 | def prior_logp(self, z, mask): 219 | N = torch.sum(mask, dim=tuple(range(1, len(mask.shape)))) 220 | logps = -N / 2. * np.log(2 * np.pi) - torch.sum((z * mask) ** 2, dim=(1, 2, 3)) / 2. 221 | return logps 222 | 223 | def discretize(self, x, t): 224 | """DDPM discretization.""" 225 | timestep = (t * (self.N - 1) / self.T).long() 226 | beta = self.discrete_betas.to(x.device)[timestep] 227 | alpha = self.alphas.to(x.device)[timestep] 228 | sqrt_beta = torch.sqrt(beta) 229 | if len(x.shape) == 4: 230 | f = torch.sqrt(alpha)[:, None, None, None] * x - x 231 | elif len(x.shape) == 3: 232 | f = torch.sqrt(alpha)[:, None, None] * x - x 233 | else: 234 | NotImplementedError 235 | G = sqrt_beta 236 | return f, G 237 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import logging 4 | import re 5 | import copy 6 | import numpy as np 7 | import torch.nn.functional as F 8 | import networkx as nx 9 | from rdkit import Chem, DataStructs 10 | from rdkit.Chem import AllChem, rdMolDescriptors 11 | from rdkit.Chem.Descriptors import MolLogP, qed 12 | from sascorer import calculateScore 13 | 14 | ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1} 15 | bond_decoder_m = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE, 3: Chem.rdchem.BondType.TRIPLE} 16 | 17 | 18 | def restore_checkpoint(ckpt_dir, state, device): 19 | if not os.path.exists(ckpt_dir): 20 | if not os.path.exists(os.path.dirname(ckpt_dir)): 21 | os.makedirs(os.path.dirname(ckpt_dir)) 22 | logging.warning(f"No checkpoint found at {ckpt_dir}. " 23 | f"Returned the same state as input") 24 | return state 25 | else: 26 | loaded_state = torch.load(ckpt_dir, map_location=device) 27 | state['optimizer'].load_state_dict(loaded_state['optimizer']) 28 | state['model'].load_state_dict(loaded_state['model'], strict=False) 29 | state['ema'].load_state_dict(loaded_state['ema']) 30 | state['step'] = loaded_state['step'] 31 | return state 32 | 33 | 34 | def save_checkpoint(ckpt_dir, state): 35 | saved_state = { 36 | 'optimizer': state['optimizer'].state_dict(), 37 | 'model': state['model'].state_dict(), 38 | 'ema': state['ema'].state_dict(), 39 | 'step': state['step'] 40 | } 41 | torch.save(saved_state, ckpt_dir) 42 | 43 | 44 | @torch.no_grad() 45 | def dense_mol(graph_data, scaler=None, dequantization=False): 46 | """Extract features and masks from PyG Dense DataBatch. 47 | 48 | Args: 49 | graph_data: DataBatch object. 50 | y: [B, 1] graph property values. 51 | num_atom: [B, 1] number of atoms in graphs. 52 | smile: [B] smile sequences. 53 | x: [B, max_node, channel1] atom type features. 54 | adj: [B, channel2, max_node, max_node] bond type features. 55 | atom_mask: [B, max_node] 56 | 57 | Returns: 58 | atom_feat: [B, max_node, channel1] 59 | atom_mask: [B, max_node] 60 | bond_feat: [B, channel2, max_node, max_node] 61 | bond_mask: [B, 1, max_node, max_node] 62 | """ 63 | 64 | atom_feat = graph_data.x 65 | bond_feat = graph_data.adj 66 | atom_mask = graph_data.atom_mask 67 | if len(atom_mask.shape) == 1: 68 | atom_mask = atom_mask.unsqueeze(0) 69 | bond_mask = (atom_mask[:, None, :] * atom_mask[:, :, None]).unsqueeze(1) 70 | bond_mask = torch.tril(bond_mask, -1) 71 | bond_mask = bond_mask + bond_mask.transpose(-1, -2) 72 | 73 | if dequantization: 74 | atom_noise = torch.rand_like(atom_feat) 75 | atom_feat = (atom_feat + atom_noise) / 2. * atom_mask[:, :, None] 76 | bond_noise = torch.rand_like(bond_feat) 77 | bond_noise = torch.tril(bond_noise, -1) 78 | bond_noise = bond_noise + bond_noise.transpose(1, 2) 79 | bond_feat = (bond_feat + bond_noise) / 2. * bond_mask 80 | 81 | atom_feat = scaler(atom_feat, atom=True) 82 | bond_feat = scaler(bond_feat, atom=False) 83 | 84 | return atom_feat * atom_mask.unsqueeze(-1), atom_mask, bond_feat * bond_mask, bond_mask 85 | 86 | 87 | def adj2graph(adj, sample_nodes): 88 | """Covert the PyTorch tensor adjacency matrices to numpy array. 89 | 90 | Args: 91 | adj: [Batch_size, channel, Max_node, Max_node], assume channel=1 92 | sample_nodes: [Batch_size] 93 | """ 94 | adj_list = [] 95 | # discretization 96 | adj[adj >= 0.5] = 1. 97 | adj[adj < 0.5] = 0. 98 | for i in range(adj.shape[0]): 99 | adj_tmp = adj[i, 0] 100 | # symmetric 101 | adj_tmp = torch.tril(adj_tmp, -1) 102 | adj_tmp = adj_tmp + adj_tmp.transpose(0, 1) 103 | # truncate 104 | adj_tmp = adj_tmp.cpu().numpy()[:sample_nodes[i], :sample_nodes[i]] 105 | adj_list.append(adj_tmp) 106 | 107 | return adj_list 108 | 109 | 110 | def quantize_mol(adjs): 111 | # Quantize generated molecules [B, 1, N, N] 112 | adjs = adjs.squeeze(1) 113 | if type(adjs).__name__ == 'Tensor': 114 | adjs = adjs.detach().cpu() 115 | else: 116 | adjs = torch.tensor(adjs) 117 | adjs = adjs * 3 118 | adjs[adjs >= 2.5] = 3 119 | adjs[torch.bitwise_and(adjs >= 1.5, adjs < 2.5)] = 2 120 | adjs[torch.bitwise_and(adjs >= 0.5, adjs < 1.5)] = 1 121 | adjs[adjs < 0.5] = 0 122 | return np.array(adjs.to(torch.int64)) 123 | 124 | 125 | def quantize_mol_2(adjs): 126 | # Quantize generated molecules [B, 2, N, N] 127 | # The 2nd channel: 0 -> edge type; 1 -> edge existence 128 | if type(adjs).__name__ == 'Tensor': 129 | adjs = adjs.detach().cpu() 130 | else: 131 | adjs = torch.tensor(adjs) 132 | 133 | adj_0 = adjs[:, 0, :, :] 134 | adj_1 = adjs[:, 1, :, :] 135 | 136 | adj_0 = adj_0 * 3 137 | adj_0[adj_0 >= 2.5] = 3 138 | adj_0[torch.bitwise_and(adj_0 >= 1.5, adj_0 < 2.5)] = 2 139 | adj_0[torch.bitwise_and(adj_0 >= 0.5, adj_0 < 1.5)] = 1 140 | adj_0[adj_0 < 0.5] = 0 141 | 142 | adj_1[adj_1 < 0.5] = 0 143 | adj_1[adj_1 >= 0.5] = 1 144 | 145 | adjs = adj_0 * adj_1 146 | return np.array(adjs.to(torch.int64)) 147 | 148 | 149 | def construct_mol(x, A, num_node, atomic_num_list): 150 | mol = Chem.RWMol() 151 | atoms = np.argmax(x, axis=1) 152 | atoms = atoms[:num_node] 153 | 154 | for atom in atoms: 155 | mol.AddAtom(Chem.Atom(int(atomic_num_list[atom]))) 156 | 157 | if len(A.shape) == 2: 158 | adj = A[:num_node, :num_node] 159 | elif A.shape[0] == 4: 160 | # A (edge_type, max_num_node, max_num_node) 161 | adj = np.argmax(A, axis=0) 162 | adj = np.array(adj) 163 | adj = adj[:num_node, :num_node] 164 | 165 | # Note. 3 means no existing edge (when constructing adj matrices) 166 | adj[adj == 3] = -1 167 | adj += 1 168 | adj = adj - np.eye(num_node) 169 | else: 170 | raise ValueError('Wrong Adj shape.') 171 | 172 | for start, end in zip(*np.nonzero(adj)): 173 | if start > end: 174 | mol.AddBond(int(start), int(end), bond_decoder_m[adj[start, end]]) 175 | # remove formal charge for fair comparison with GraphAF, GraphDF, GraphCNF 176 | 177 | # add formal charge to atom: e.g. [O+], [N+], [S+], not support [O-], [N-], [NH+] etc. 178 | flag, atomid_valence = check_valency(mol) 179 | if flag: 180 | continue 181 | else: 182 | assert len(atomid_valence) == 2 183 | idx = atomid_valence[0] 184 | v = atomid_valence[1] 185 | an = mol.GetAtomWithIdx(idx).GetAtomicNum() 186 | if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1: 187 | mol.GetAtomWithIdx(idx).SetFormalCharge(1) 188 | 189 | return mol 190 | 191 | 192 | def check_valency(mol): 193 | """ 194 | Checks that no atoms in the mol have exceeded their possible valency 195 | 196 | Return: 197 | True if no valency issues, False otherwise 198 | """ 199 | try: 200 | Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) 201 | return True, None 202 | except ValueError as e: 203 | e = str(e) 204 | p = e.find('#') 205 | e_sub = e[p:] 206 | atomid_valence = list(map(int, re.findall(r'\d+', e_sub))) 207 | return False, atomid_valence 208 | 209 | 210 | def correct_mol(mol): 211 | no_correct = False 212 | flag, _ = check_valency(mol) 213 | if flag: 214 | no_correct = True 215 | 216 | while True: 217 | flag, atomid_valence = check_valency(mol) 218 | if flag: 219 | break 220 | else: 221 | assert len(atomid_valence) == 2 222 | idx = atomid_valence[0] 223 | queue = [] 224 | 225 | for b in mol.GetAtomWithIdx(idx).GetBonds(): 226 | queue.append( 227 | (b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx()) 228 | ) 229 | queue.sort(key=lambda tup: tup[1], reverse=True) 230 | 231 | if len(queue) > 0: 232 | start = queue[0][2] 233 | end = queue[0][3] 234 | t = queue[0][1] - 1 235 | mol.RemoveBond(start, end) 236 | if t >= 1: 237 | mol.AddBond(start, end, bond_decoder_m[t]) 238 | 239 | return mol, no_correct 240 | 241 | 242 | def valid_mol_can_with_seg(x, largest_connected_comp=True): 243 | if x is None: 244 | return None 245 | sm = Chem.MolToSmiles(x, isomericSmiles=True) 246 | mol = Chem.MolFromSmiles(sm) 247 | if largest_connected_comp and '.' in sm: 248 | vsm = [(s, len(s)) for s in sm.split('.')] # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.') 249 | vsm.sort(key=lambda tup: tup[1], reverse=True) 250 | mol = Chem.MolFromSmiles(vsm[0][0]) 251 | return mol 252 | 253 | 254 | def check_chemical_validity(mol): 255 | """ 256 | Check the chemical validity of the mol object. Existing mol object is not modified. 257 | 258 | Args: mol: Rdkit mol object 259 | 260 | Return: 261 | True if chemically valid, False otherwise 262 | """ 263 | 264 | s = Chem.MolToSmiles(mol, isomericSmiles=True) 265 | m = Chem.MolFromSmiles(s) # implicitly performs sanitization 266 | if m: 267 | return True 268 | else: 269 | return False 270 | 271 | 272 | def tensor2mol(x_atom, x_bond, num_atoms, atomic_num_list, correct_validity=True, largest_connected_comp=True): 273 | """Construct molecules from the atom and bond tensors. 274 | 275 | Args: 276 | x_atom: The node tensor [`number of samples`, `maximum number of atoms`, `number of possible atom types`]. 277 | x_bond: The adjacency tensor [`number of samples`, `number of possible bond type`, `maximum number of atoms`, 278 | `maximum number of atoms`] 279 | num_atoms: The number of nodes for every sample [`number of samples`] 280 | atomic_num_list: A list to specify what each atom channel corresponds to. 281 | correct_validity: Whether to use the validity correction introduced by `MoFlow`. 282 | largest_connected_comp: Whether to use the largest connected component as the final molecule in the validity 283 | correction. 284 | 285 | Return: 286 | The list of Rdkit mol object. The check_chemical_validity rate without check. 287 | """ 288 | if x_bond.shape[1] == 1: 289 | x_bond = quantize_mol(x_bond) 290 | elif x_bond.shape[1] == 2: 291 | x_bond = quantize_mol_2(x_bond) 292 | else: 293 | x_bond = x_bond.cpu().detach().numpy() 294 | 295 | x_atom = x_atom.cpu().detach().numpy() 296 | num_nodes = num_atoms.cpu().detach().numpy() 297 | 298 | gen_mols = [] 299 | valid_cum = [] 300 | 301 | for atom_elem, bond_elem, num_node in zip(x_atom, x_bond, num_nodes): 302 | mol = construct_mol(atom_elem, bond_elem, num_node, atomic_num_list) 303 | 304 | if correct_validity: 305 | # correct the invalid molecule 306 | cmol, no_correct = correct_mol(mol) 307 | if no_correct: 308 | valid_cum.append(1) 309 | else: 310 | valid_cum.append(0) 311 | vcmol = valid_mol_can_with_seg(cmol, largest_connected_comp=largest_connected_comp) 312 | gen_mols.append(vcmol) 313 | else: 314 | gen_mols.append(mol) 315 | 316 | return gen_mols, valid_cum 317 | 318 | 319 | def penalized_logp(mol): 320 | """ 321 | Calculate the reward that consists of log p penalized by SA and # long cycles, 322 | as described in (Kusner et al. 2017). Scores are normalized based on the 323 | statistics of 250k_rndm_zinc_drugs_clean.smi dataset. 324 | 325 | Args: 326 | mol: Rdkit mol object 327 | 328 | Returns: 329 | :class:`float` 330 | """ 331 | 332 | # normalization constants, statistics from 250k_rndm_zinc_drugs_clean.smi 333 | logP_mean = 2.4570953396190123 334 | logP_std = 1.434324401111988 335 | SA_mean = -3.0525811293166134 336 | SA_std = 0.8335207024513095 337 | cycle_mean = -0.0485696876403053 338 | cycle_std = 0.2860212110245455 339 | 340 | log_p = MolLogP(mol) 341 | SA = -calculateScore(mol) 342 | 343 | # cycle score 344 | cycle_list = nx.cycle_basis(nx.Graph( 345 | Chem.rdmolops.GetAdjacencyMatrix(mol))) 346 | if len(cycle_list) == 0: 347 | cycle_length = 0 348 | else: 349 | cycle_length = max([len(j) for j in cycle_list]) 350 | if cycle_length <= 6: 351 | cycle_length = 0 352 | else: 353 | cycle_length = cycle_length - 6 354 | cycle_score = -cycle_length 355 | 356 | normalized_log_p = (log_p - logP_mean) / logP_std 357 | normalized_SA = (SA - SA_mean) / SA_std 358 | normalized_cycle = (cycle_score - cycle_mean) / cycle_std 359 | 360 | return normalized_log_p + normalized_SA + normalized_cycle 361 | 362 | 363 | def get_mol_qed(mol): 364 | return qed(mol) 365 | 366 | 367 | def calculate_min_plogp(mol): 368 | """ 369 | Calculate the reward that consists of log p penalized by SA and # long cycles, 370 | as described in (Kusner et al. 2017). Scores are normalized based on the 371 | statistics of 250k_rndm_zinc_drugs_clean.smi dataset. 372 | 373 | Args: 374 | mol: Rdkit mol object 375 | 376 | :rtype: 377 | :class:`float` 378 | """ 379 | 380 | p1 = penalized_logp(mol) 381 | s1 = Chem.MolToSmiles(mol, isomericSmiles=True) 382 | s2 = Chem.MolToSmiles(mol, isomericSmiles=False) 383 | mol1 = Chem.MolFromSmiles(s1) 384 | mol2 = Chem.MolFromSmiles(s2) 385 | p2 = penalized_logp(mol1) 386 | p3 = penalized_logp(mol2) 387 | final_p = min(p1, p2) 388 | final_p = min(final_p, p3) 389 | return final_p 390 | 391 | 392 | def reward_target_molecule_similarity(mol, target, radius=2, nBits=2048, useChirality=True): 393 | """ 394 | Calculate the similarity, based on tanimoto similarity 395 | between the ECFP fingerprints of the x molecule and target molecule. 396 | 397 | Args: 398 | mol: Rdkit mol object 399 | target: Rdkit mol object 400 | 401 | Returns: 402 | :class:`float`, [0.0, 1.0] 403 | """ 404 | x = rdMolDescriptors.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=nBits, useChirality=useChirality) 405 | target = rdMolDescriptors.GetMorganFingerprintAsBitVect(target, radius=radius, nBits=nBits, 406 | useChirality=useChirality) 407 | return DataStructs.TanimotoSimilarity(x, target) 408 | 409 | 410 | def convert_radical_electrons_to_hydrogens(mol): 411 | """ 412 | Convert radical electrons in a molecule into bonds to hydrogens. Only 413 | use this if molecule is valid. Return a new mol object. 414 | 415 | Args: 416 | mol: Rdkit mol object 417 | 418 | :rtype: 419 | Rdkit mol object 420 | """ 421 | 422 | m = copy.deepcopy(mol) 423 | if Chem.Descriptors.NumRadicalElectrons(m) == 0: # not a radical 424 | return m 425 | else: # a radical 426 | print('converting radical electrons to H') 427 | for a in m.GetAtoms(): 428 | num_radical_e = a.GetNumRadicalElectrons() 429 | if num_radical_e > 0: 430 | a.SetNumRadicalElectrons(0) 431 | a.SetNumExplicitHs(num_radical_e) 432 | return m 433 | 434 | 435 | def get_final_smiles(mol): 436 | """ 437 | Returns a SMILES of the final molecule. Converts any radical 438 | electrons into hydrogens. Works only if molecule is valid 439 | :return: SMILES 440 | """ 441 | m = convert_radical_electrons_to_hydrogens(mol) 442 | return Chem.MolToSmiles(m, isomericSmiles=True) 443 | 444 | 445 | def mols_to_nx(mols): 446 | nx_graphs = [] 447 | for mol in mols: 448 | G = nx.Graph() 449 | 450 | for atom in mol.GetAtoms(): 451 | G.add_node(atom.GetIdx(), 452 | label=atom.GetSymbol()) 453 | # atomic_num=atom.GetAtomicNum(), 454 | # formal_charge=atom.GetFormalCharge(), 455 | # chiral_tag=atom.GetChiralTag(), 456 | # hybridization=atom.GetHybridization(), 457 | # num_explicit_hs=atom.GetNumExplicitHs(), 458 | # is_aromatic=atom.GetIsAromatic()) 459 | 460 | for bond in mol.GetBonds(): 461 | G.add_edge(bond.GetBeginAtomIdx(), 462 | bond.GetEndAtomIdx(), 463 | label=int(bond.GetBondTypeAsDouble())) 464 | # bond_type=bond.GetBondType()) 465 | 466 | nx_graphs.append(G) 467 | return nx_graphs 468 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import networkx as nx 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from rdkit.Chem import Draw 7 | import matplotlib 8 | 9 | 10 | def draw_graph_list(graph_list, row, col, f_path, iterations=100, layout='spring', is_single=False, k=1, 11 | node_size=55, alpha=1, width=1.3, remove=True): 12 | 13 | G_list = [nx.to_networkx_graph(graph_list[i]) for i in range(len(graph_list))] 14 | 15 | # remove isolate nodes in graphs 16 | if remove: 17 | for gg in G_list: 18 | gg.remove_nodes_from(list(nx.isolates(gg))) 19 | 20 | plt.switch_backend('agg') 21 | for i, G in enumerate(G_list): 22 | plt.subplot(row, col, i+1) 23 | plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0) 24 | # plt.axis("off") 25 | 26 | # turn off axis label 27 | plt.xticks([]) 28 | plt.yticks([]) 29 | 30 | if layout == 'spring': 31 | pos = nx.spring_layout(G, k=k / np.sqrt(G.number_of_nodes()), iterations=iterations) 32 | elif layout == 'spectral': 33 | pos = nx.spectral_layout(G) 34 | else: 35 | raise ValueError(f'{layout} not recognized.') 36 | 37 | if is_single: 38 | nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='#336699', alpha=1, linewidths=0, 39 | ) 40 | nx.draw_networkx_edges(G, pos, alpha=alpha, width=width) 41 | else: 42 | nx.draw_networkx_nodes(G, pos, node_size=1.5, node_color='#336699', alpha=1, linewidths=0.2) 43 | # nx.draw_networkx_nodes(G, pos, node_size=2.0, node_color='#336699', alpha=1, linewidths=1.0) 44 | nx.draw_networkx_edges(G, pos, alpha=0.3, width=0.2) 45 | # nx.draw_networkx_edges(G, pos, alpha=0.3, width=0.5) 46 | 47 | plt.tight_layout() 48 | plt.savefig(f_path, dpi=1600) 49 | plt.close() 50 | 51 | 52 | def visualize_graphs(graph_list, dir_path, config, remove=True): 53 | row = config.sampling.vis_row 54 | col = config.sampling.vis_col 55 | n_graph = row * col 56 | 57 | n_fig = int(np.ceil(len(graph_list) / n_graph)) 58 | for i in range(n_fig): 59 | draw_graph_list(graph_list[i*n_graph:(i+1)*n_graph], row, col, 60 | f_path=os.path.join(dir_path, "sample"+str(i)+".png"), remove=remove) 61 | 62 | 63 | def visualize_diff_graphs(graph_list, dir_path, sample_name, remove=True): 64 | draw_graph_list(graph_list, 4, 4, f_path=os.path.join(dir_path, sample_name+'.png'), remove=remove) 65 | 66 | 67 | def draw_adjacency_matrix(adjs, times, fname='diff_test.png'): 68 | 69 | n_graphs = len(adjs) 70 | plt.figure(figsize=(10, 3), dpi=300) 71 | 72 | for i, adj in enumerate(adjs): 73 | adj = np.clip(adj, a_min=0., a_max=1.) 74 | plt.subplot(1, n_graphs, i+1) 75 | plt.axis('off') 76 | plt.title(str(times[i])) 77 | #plt.imshow(adj, cmap='binary', interpolation="none") 78 | #plt.imshow(adj, cmap='YlGn', interpolation="none") 79 | plt.imshow(adj, cmap='Reds', interpolation="none") 80 | 81 | # plt.savefig(fname, dpi=1600) 82 | plt.savefig(fname, dpi=300) 83 | 84 | 85 | def draw_matrix(adjs, times, fname='diff_test.png'): 86 | 87 | n_graphs = len(adjs) 88 | plt.figure(figsize=(10, 3), dpi=300) 89 | norm = matplotlib.colors.Normalize(vmin=0, vmax=1.0) 90 | 91 | for i, adj in enumerate(adjs): 92 | adj = np.clip(adj, a_min=0., a_max=1.) 93 | plt.subplot(1, n_graphs, i+1) 94 | plt.axis('off') 95 | plt.title(str(times[i])) 96 | #plt.imshow(adj, cmap='binary', interpolation="none") 97 | #plt.imshow(adj, cmap='YlGn', interpolation="none") 98 | plt.imshow(adj, cmap='Reds', interpolation="none", norm=norm) 99 | # plt.colorbar() 100 | 101 | # plt.savefig(fname, dpi=1600) 102 | plt.savefig(fname, dpi=300) 103 | 104 | 105 | def visualize_mols(mol_list, dir_path, config): 106 | # from rdkit.Chem.Draw import IPythonConsole 107 | 108 | row = config.sampling.vis_row 109 | col = config.sampling.vis_col 110 | n_mol = row * col 111 | 112 | img = Draw.MolsToGridImage(mol_list[:n_mol], subImgSize=(400, 400), molsPerRow=row) 113 | img.save(os.path.join(dir_path, 'mol.png')) 114 | 115 | --------------------------------------------------------------------------------