├── .gitignore ├── requirements.txt ├── data └── README.md ├── README.md ├── config.py ├── train.py ├── dataset.py ├── gvae.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/preprocessed/ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torch_geometric==1.7.0 3 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Dataset downloaded from: https://moleculenet.org/ 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Graph Variational Autoencoder with Pytorch Geometric for Molecule Generation. 2 | 3 | This model is based on this paper: 4 | https://jcheminf.biomedcentral.com/track/pdf/10.1186/s13321-019-0396-x.pdf 5 | 6 | Note that the model is not fully functioning yet. 7 | 8 | Please have a look at the original implementation in tensorflow for further directions: 9 | https://github.com/seokhokang/graphvae_approx/ 10 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | #DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 4 | DEVICE = "cpu" 5 | 6 | # Supported edge types 7 | SUPPORTED_EDGES = ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC"] 8 | 9 | # Supported atoms 10 | SUPPORTED_ATOMS = ["C", "N", "O", "F", "P", "S", "Cl", "Br", "I"] 11 | ATOMIC_NUMBERS = [6, 7, 8, 9, 15, 16, 17, 35, 53] 12 | 13 | # Dataset (if you change this, delete the processed files to run again) 14 | MAX_MOLECULE_SIZE = 20 15 | 16 | # To remove valence errors ect. 17 | DISABLE_RDKIT_WARNINGS = True -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import DataLoader 3 | from dataset import MoleculeDataset 4 | from tqdm import tqdm 5 | import numpy as np 6 | import mlflow.pytorch 7 | from utils import (count_parameters, gvae_loss, 8 | slice_edge_type_from_edge_feats, slice_atom_type_from_node_feats) 9 | from gvae import GVAE 10 | from config import DEVICE as device 11 | 12 | # Load data 13 | train_dataset = MoleculeDataset(root="data/", filename="HIV_train_oversampled.csv")[:10000] 14 | test_dataset = MoleculeDataset(root="data/", filename="HIV_test.csv", test=True)[:1000] 15 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) 16 | test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True) 17 | 18 | # Load model 19 | model = GVAE(feature_size=train_dataset[0].x.shape[1]) 20 | model = model.to(device) 21 | print("Model parameters: ", count_parameters(model)) 22 | 23 | # Define loss and optimizer 24 | loss_fn = gvae_loss 25 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 26 | kl_beta = 0.5 27 | 28 | # Train function 29 | def run_one_epoch(data_loader, type, epoch, kl_beta): 30 | # Store per batch loss and accuracy 31 | all_losses = [] 32 | all_kldivs = [] 33 | 34 | # Iterate over data loader 35 | for _, batch in enumerate(tqdm(data_loader)): 36 | # Some of the data points have invalid adjacency matrices 37 | try: 38 | # Use GPU 39 | batch.to(device) 40 | # Reset gradients 41 | optimizer.zero_grad() 42 | # Call model 43 | triu_logits, node_logits, mu, logvar = model(batch.x.float(), 44 | batch.edge_attr.float(), 45 | batch.edge_index, 46 | batch.batch) 47 | # Calculate loss and backpropagate 48 | edge_targets = slice_edge_type_from_edge_feats(batch.edge_attr.float()) 49 | node_targets = slice_atom_type_from_node_feats(batch.x.float(), as_index=True) 50 | loss, kl_div = loss_fn(triu_logits, node_logits, 51 | batch.edge_index, edge_targets, 52 | node_targets, mu, logvar, 53 | batch.batch, kl_beta) 54 | if type == "Train": 55 | loss.backward() 56 | optimizer.step() 57 | # Store loss and metrics 58 | all_losses.append(loss.detach().cpu().numpy()) 59 | #all_accs.append(acc) 60 | all_kldivs.append(kl_div.detach().cpu().numpy()) 61 | except IndexError as error: 62 | # For a few graphs the edge information is not correct 63 | # Simply skip the batch containing those 64 | print("Error: ", error) 65 | 66 | # Perform sampling 67 | if type == "Test": 68 | generated_mols = model.sample_mols(num=10000) 69 | print(f"Generated {generated_mols} molecules.") 70 | mlflow.log_metric(key=f"Sampled molecules", value=float(generated_mols), step=epoch) 71 | 72 | print(f"{type} epoch {epoch} loss: ", np.array(all_losses).mean()) 73 | mlflow.log_metric(key=f"{type} Epoch Loss", value=float(np.array(all_losses).mean()), step=epoch) 74 | mlflow.log_metric(key=f"{type} KL Divergence", value=float(np.array(all_kldivs).mean()), step=epoch) 75 | mlflow.pytorch.log_model(model, "model") 76 | 77 | # Run training 78 | with mlflow.start_run() as run: 79 | for epoch in range(100): 80 | model.train() 81 | run_one_epoch(train_loader, type="Train", epoch=epoch, kl_beta=kl_beta) 82 | if epoch % 5 == 0: 83 | print("Start test epoch...") 84 | model.eval() 85 | run_one_epoch(test_loader, type="Test", epoch=epoch, kl_beta=kl_beta) -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import torch_geometric 4 | from torch_geometric.data import Dataset 5 | import numpy as np 6 | import os 7 | from tqdm import tqdm 8 | import deepchem as dc 9 | from config import MAX_MOLECULE_SIZE 10 | from utils import slice_atom_type_from_node_feats 11 | import re 12 | 13 | print(f"Torch version: {torch.__version__}") 14 | print(f"Cuda available: {torch.cuda.is_available()}") 15 | print(f"Torch geometric version: {torch_geometric.__version__}") 16 | 17 | class MoleculeDataset(Dataset): 18 | def __init__(self, root, filename, test=False, transform=None, pre_transform=None, length=0): 19 | """ 20 | root = Where the dataset should be stored. This folder is split 21 | into raw_dir (downloaded dataset) and processed_dir (processed data). 22 | """ 23 | self.test = test 24 | self.filename = filename 25 | self.length = length 26 | super(MoleculeDataset, self).__init__(root, transform, pre_transform) 27 | 28 | @property 29 | def raw_file_names(self): 30 | """ If this file exists in raw_dir, the download is not triggered. 31 | (The download func. is not implemented here) 32 | """ 33 | return self.filename 34 | 35 | @property 36 | def processed_file_names(self): 37 | """ If these files are found in raw_dir, processing is skipped """ 38 | processed_files = [f for f in os.listdir(self.processed_dir) if not f.startswith("pre")] 39 | 40 | if self.test: 41 | processed_files = [file for file in processed_files if "test" in file] 42 | if len(processed_files) == 0: 43 | return ["no_files.dummy"] 44 | last_file = sorted(processed_files)[-1] 45 | index = int(re.search(r'\d+', last_file).group()) 46 | self.length = index 47 | return [f'data_test_{i}.pt' for i in list(range(0, index))] 48 | else: 49 | processed_files = [file for file in processed_files if not "test" in file] 50 | if len(processed_files) == 0: 51 | return ["no_files.dummy"] 52 | last_file = sorted(processed_files)[-1] 53 | index = int(re.search(r'\d+', last_file).group()) 54 | self.length = index 55 | return [f'data_{i}.pt' for i in list(range(0, index))] 56 | 57 | 58 | def download(self): 59 | pass 60 | 61 | def process(self): 62 | self.data = pd.read_csv(self.raw_paths[0]).reset_index() 63 | featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True) 64 | for _, mol in tqdm(self.data.iterrows(), total=self.data.shape[0]): 65 | # Featurize molecule 66 | f = featurizer.featurize(mol["smiles"]) 67 | data = f[0].to_pyg_graph() 68 | data.y = self._get_label(mol["HIV_active"]) 69 | data.smiles = mol["smiles"] 70 | 71 | # Get the molecule's atom types 72 | atom_types = slice_atom_type_from_node_feats(data.x) 73 | 74 | # Only save if molecule is in permitted size 75 | if (data.x.shape[0] < MAX_MOLECULE_SIZE) and -1 not in atom_types: 76 | if self.test: 77 | torch.save(data, 78 | os.path.join(self.processed_dir, 79 | f'data_test_{self.length}.pt')) 80 | else: 81 | torch.save(data, 82 | os.path.join(self.processed_dir, 83 | f'data_{self.length}.pt')) 84 | self.length += 1 85 | else: 86 | print("Skipping invalid mol (too big/unknown atoms): ", data.smiles) 87 | print(f"Done. Stored {self.length} preprocessed molecules.") 88 | 89 | def _get_label(self, label): 90 | label = np.asarray([label]) 91 | return torch.tensor(label, dtype=torch.int64) 92 | 93 | def len(self): 94 | return self.length 95 | 96 | def get(self, idx): 97 | """ 98 | - Equivalent to __getitem__ in pytorch 99 | - Is not needed for PyG's InMemoryDataset 100 | """ 101 | if self.test: 102 | data = torch.load(os.path.join(self.processed_dir, 103 | f'data_test_{idx}.pt')) 104 | else: 105 | data = torch.load(os.path.join(self.processed_dir, 106 | f'data_{idx}.pt')) 107 | return data -------------------------------------------------------------------------------- /gvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Linear 4 | from torch_geometric.nn.conv import TransformerConv 5 | from torch_geometric.nn import Set2Set 6 | from torch_geometric.nn import BatchNorm 7 | from config import SUPPORTED_ATOMS, SUPPORTED_EDGES, MAX_MOLECULE_SIZE, ATOMIC_NUMBERS 8 | from utils import graph_representation_to_molecule, to_one_hot 9 | from tqdm import tqdm 10 | 11 | class GVAE(nn.Module): 12 | def __init__(self, feature_size): 13 | super(GVAE, self).__init__() 14 | self.encoder_embedding_size = 64 15 | self.edge_dim = 11 16 | self.latent_embedding_size = 128 17 | self.num_edge_types = len(SUPPORTED_EDGES) 18 | self.num_atom_types = len(SUPPORTED_ATOMS) 19 | self.max_num_atoms = MAX_MOLECULE_SIZE 20 | self.decoder_hidden_neurons = 512 21 | 22 | # Encoder layers 23 | self.conv1 = TransformerConv(feature_size, 24 | self.encoder_embedding_size, 25 | heads=4, 26 | concat=False, 27 | beta=True, 28 | edge_dim=self.edge_dim) 29 | self.bn1 = BatchNorm(self.encoder_embedding_size) 30 | self.conv2 = TransformerConv(self.encoder_embedding_size, 31 | self.encoder_embedding_size, 32 | heads=4, 33 | concat=False, 34 | beta=True, 35 | edge_dim=self.edge_dim) 36 | self.bn2 = BatchNorm(self.encoder_embedding_size) 37 | self.conv3 = TransformerConv(self.encoder_embedding_size, 38 | self.encoder_embedding_size, 39 | heads=4, 40 | concat=False, 41 | beta=True, 42 | edge_dim=self.edge_dim) 43 | self.bn3 = BatchNorm(self.encoder_embedding_size) 44 | self.conv4 = TransformerConv(self.encoder_embedding_size, 45 | self.encoder_embedding_size, 46 | heads=4, 47 | concat=False, 48 | beta=True, 49 | edge_dim=self.edge_dim) 50 | 51 | # Pooling layers 52 | self.pooling = Set2Set(self.encoder_embedding_size, processing_steps=4) 53 | 54 | # Latent transform layers 55 | self.mu_transform = Linear(self.encoder_embedding_size*2, 56 | self.latent_embedding_size) 57 | self.logvar_transform = Linear(self.encoder_embedding_size*2, 58 | self.latent_embedding_size) 59 | 60 | # Decoder layers 61 | # --- Shared layers 62 | self.linear_1 = Linear(self.latent_embedding_size, self.decoder_hidden_neurons) 63 | self.linear_2 = Linear(self.decoder_hidden_neurons, self.decoder_hidden_neurons) 64 | 65 | # --- Atom decoding (outputs a matrix: (max_num_atoms) * (# atom_types + "none"-type)) 66 | atom_output_dim = self.max_num_atoms*(self.num_atom_types + 1) 67 | self.atom_decode = Linear(self.decoder_hidden_neurons, atom_output_dim) 68 | 69 | # --- Edge decoding (outputs a triu tensor: (max_num_atoms*(max_num_atoms-1)/2*(#edge_types + 1) )) 70 | edge_output_dim = int(((self.max_num_atoms * (self.max_num_atoms - 1)) / 2) * (self.num_edge_types + 1)) 71 | self.edge_decode = Linear(self.decoder_hidden_neurons, edge_output_dim) 72 | 73 | 74 | def encode(self, x, edge_attr, edge_index, batch_index): 75 | # GNN layers 76 | x = self.conv1(x, edge_index, edge_attr).relu() 77 | x = self.bn1(x) 78 | x = self.conv2(x, edge_index, edge_attr).relu() 79 | x = self.bn2(x) 80 | x = self.conv3(x, edge_index, edge_attr).relu() 81 | x = self.bn3(x) 82 | x = self.conv4(x, edge_index, edge_attr).relu() 83 | 84 | # Pool to global representation 85 | x = self.pooling(x, batch_index) 86 | 87 | # Latent transform layers 88 | mu = self.mu_transform(x) 89 | logvar = self.logvar_transform(x) 90 | return mu, logvar 91 | 92 | def decode_graph(self, graph_z): 93 | """ 94 | Decodes a latent vector into a continuous graph representation 95 | consisting of node types and edge types. 96 | """ 97 | # Pass through shared layers 98 | z = self.linear_1(graph_z).relu() 99 | z = self.linear_2(z).relu() 100 | # Decode atom types 101 | atom_logits = self.atom_decode(z) 102 | # Decode edge types 103 | edge_logits = self.edge_decode(z) 104 | 105 | return atom_logits, edge_logits 106 | 107 | 108 | def decode(self, z, batch_index): 109 | node_logits = [] 110 | triu_logits = [] 111 | # Iterate over molecules in batch 112 | for graph_id in torch.unique(batch_index): 113 | # Get latent vector for this graph 114 | graph_z = z[graph_id] 115 | 116 | # Recover graph from latent vector 117 | atom_logits, edge_logits = self.decode_graph(graph_z) 118 | 119 | # Store per graph results 120 | node_logits.append(atom_logits) 121 | triu_logits.append(edge_logits) 122 | 123 | # Concatenate all outputs of the batch 124 | node_logits = torch.cat(node_logits) 125 | triu_logits = torch.cat(triu_logits) 126 | return triu_logits, node_logits 127 | 128 | 129 | def reparameterize(self, mu, logvar): 130 | if self.training: 131 | # Get standard deviation 132 | std = torch.exp(logvar) 133 | # Returns random numbers from a normal distribution 134 | eps = torch.randn_like(std) 135 | # Return sampled values 136 | return eps.mul(std).add_(mu) 137 | else: 138 | return mu 139 | 140 | def forward(self, x, edge_attr, edge_index, batch_index): 141 | # Encode the molecule 142 | mu, logvar = self.encode(x, edge_attr, edge_index, batch_index) 143 | # Sample latent vector (per atom) 144 | z = self.reparameterize(mu, logvar) 145 | # Decode latent vector into original molecule 146 | triu_logits, node_logits = self.decode(z, batch_index) 147 | 148 | return triu_logits, node_logits, mu, logvar 149 | 150 | 151 | def sample_mols(self, num=10000): 152 | print("Sampling molecules ... ") 153 | 154 | n_valid = 0 155 | # Sample molecules and check if they are valid 156 | for _ in tqdm(range(num)): 157 | # Sample latent space 158 | z = torch.randn(1, self.latent_embedding_size) 159 | 160 | # Get model output (this could also be batched) 161 | dummy_batch_index = torch.Tensor([0]).int() 162 | triu_logits, node_logits = self.decode(z, dummy_batch_index) 163 | 164 | # Reshape triu predictions 165 | edge_matrix_shape = (int((MAX_MOLECULE_SIZE * (MAX_MOLECULE_SIZE - 1))/2), len(SUPPORTED_EDGES) + 1) 166 | triu_preds_matrix = triu_logits.reshape(edge_matrix_shape) 167 | triu_preds = torch.argmax(triu_preds_matrix, dim=1) 168 | 169 | # Reshape node predictions 170 | node_matrix_shape = (MAX_MOLECULE_SIZE, (len(SUPPORTED_ATOMS) + 1)) 171 | node_preds_matrix = node_logits.reshape(node_matrix_shape) 172 | node_preds = torch.argmax(node_preds_matrix[:, :9], dim=1) 173 | 174 | # Get atomic numbers 175 | node_preds_one_hot = to_one_hot(node_preds, options=ATOMIC_NUMBERS) 176 | atom_numbers_dummy = torch.Tensor(ATOMIC_NUMBERS).repeat(node_preds_one_hot.shape[0], 1) 177 | atom_types = torch.masked_select(atom_numbers_dummy, node_preds_one_hot.bool()) 178 | 179 | # Attempt to create valid molecule 180 | smiles, _ = graph_representation_to_molecule(atom_types, triu_preds.float()) 181 | 182 | # A dot means disconnected 183 | if smiles and "." not in smiles: 184 | print("Successfully generated: ", smiles) 185 | n_valid += 1 186 | return n_valid -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.utils import to_dense_adj 2 | import torch 3 | from rdkit import Chem 4 | from rdkit import RDLogger 5 | from config import DEVICE as device 6 | from config import (SUPPORTED_ATOMS, SUPPORTED_EDGES, MAX_MOLECULE_SIZE, ATOMIC_NUMBERS, 7 | DISABLE_RDKIT_WARNINGS) 8 | 9 | # Disable rdkit warnings 10 | if DISABLE_RDKIT_WARNINGS: 11 | RDLogger.DisableLog('rdApp.*') 12 | 13 | def count_parameters(model): 14 | """ 15 | Counts the number of parameters for a Pytorch model 16 | """ 17 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 18 | 19 | def kl_loss(mu=None, logstd=None): 20 | """ 21 | Closed formula of the KL divergence for normal distributions 22 | """ 23 | MAX_LOGSTD = 10 24 | logstd = logstd.clamp(max=MAX_LOGSTD) 25 | kl_div = -0.5 * torch.mean(torch.sum(1 + 2 * logstd - mu**2 - logstd.exp()**2, dim=1)) 26 | 27 | # Limit numeric errors 28 | kl_div = kl_div.clamp(max=1000) 29 | return kl_div 30 | 31 | def slice_graph_targets(graph_id, edge_targets, node_targets, batch_index): 32 | """ 33 | Slices out the upper triangular part of an adjacency matrix for 34 | a single graph from a large adjacency matrix for a full batch. 35 | For the node features the corresponding section in the batch is sliced out. 36 | -------- 37 | graph_id: The ID of the graph (in the batch index) to slice 38 | edge_targets: A dense adjacency matrix for the whole batch 39 | node_targets: A tensor of node labels for the whole batch 40 | batch_index: The node to graph map for the batch 41 | """ 42 | # Create mask for nodes of this graph id 43 | graph_mask = torch.eq(batch_index, graph_id) 44 | # Row slice and column slice batch targets to get graph edge targets 45 | graph_edge_targets = edge_targets[graph_mask][:, graph_mask] 46 | # Get triangular upper part of adjacency matrix for targets 47 | size = graph_edge_targets.shape[0] 48 | triu_indices = torch.triu_indices(size, size, offset=1) 49 | triu_mask = torch.squeeze(to_dense_adj(triu_indices)).bool() 50 | graph_edge_targets = graph_edge_targets[triu_mask] 51 | # Slice node targets 52 | graph_node_targets = node_targets[graph_mask] 53 | return graph_edge_targets, graph_node_targets 54 | 55 | def slice_graph_predictions(triu_logits, node_logits, graph_triu_size, triu_start_point, graph_size, node_start_point): 56 | """ 57 | Slices out the corresponding section from a list of batch triu values. 58 | Given a start point and the size of a graph's triu, simply slices 59 | the section from the batch list. 60 | ------- 61 | triu_logits: A batch of triu predictions of different graphs 62 | node_logits: A batch of node predictions with fixed size MAX_GRAPH_SIZE 63 | graph_triu_size: Size of the triu of the graph to slice 64 | triu_start_point: Index of the first node of this graph in the triu batch 65 | graph_size: Max graph size 66 | node_start_point: Index of the first node of this graph in the nodes batch 67 | """ 68 | # Slice edge logits 69 | graph_logits_triu = torch.squeeze( 70 | triu_logits[triu_start_point:triu_start_point + graph_triu_size] 71 | ) 72 | # Slice node logits 73 | graph_node_logits = torch.squeeze( 74 | node_logits[node_start_point:node_start_point + graph_size] 75 | ) 76 | return graph_logits_triu, graph_node_logits 77 | 78 | 79 | def slice_edge_type_from_edge_feats(edge_feats): 80 | """ 81 | This function only works for the MolGraphConvFeaturizer used in the dataset. 82 | It slices the one-hot encoded edge type from the edge feature matrix. 83 | The first 4 values stand for ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC"]. 84 | """ 85 | edge_types_one_hot = edge_feats[:, :4] 86 | edge_types = edge_types_one_hot.nonzero(as_tuple=False) 87 | # Start index at 1, zero will be no edge 88 | edge_types[:, 1] = edge_types[:, 1] + 1 89 | return edge_types 90 | 91 | 92 | def slice_atom_type_from_node_feats(node_features, as_index=False): 93 | """ 94 | This function only works for the MolGraphConvFeaturizer used in the dataset. 95 | It slices the one-hot encoded atom type from the node feature matrix. 96 | Unknown atom types are not considered and not expected in the datset. 97 | """ 98 | supported_atoms = SUPPORTED_ATOMS 99 | atomic_numbers = ATOMIC_NUMBERS 100 | 101 | # Slice first X entries from the node feature matrix 102 | atom_types_one_hot = node_features[:, :len(supported_atoms)] 103 | if not as_index: 104 | # Map the index to the atomic number 105 | atom_numbers_dummy = torch.Tensor(atomic_numbers).repeat(atom_types_one_hot.shape[0], 1) 106 | atom_types = torch.masked_select(atom_numbers_dummy, atom_types_one_hot.bool()) 107 | else: 108 | atom_types = torch.argmax(atom_types_one_hot, dim=1) 109 | return atom_types 110 | 111 | def to_one_hot(x, options): 112 | """ 113 | Converts a tensor of values to a one-hot vector 114 | based on the entries in options. 115 | """ 116 | return torch.nn.functional.one_hot(x.long(), len(options)) 117 | 118 | def squared_difference(input, target): 119 | return (input - target) ** 2 120 | 121 | 122 | def triu_to_dense(triu_values, num_nodes): 123 | """ 124 | Converts a triangular upper part of a matrix as flat vector 125 | to a squared adjacency matrix with a specific size (num_nodes). 126 | """ 127 | dense_adj = torch.zeros((num_nodes, num_nodes)).to(device).float() 128 | triu_indices = torch.triu_indices(num_nodes, num_nodes, offset=1) 129 | tril_indices = torch.tril_indices(num_nodes, num_nodes, offset=-1) 130 | dense_adj[triu_indices[0], triu_indices[1]] = triu_values 131 | dense_adj[tril_indices[0], tril_indices[1]] = triu_values 132 | return dense_adj 133 | 134 | 135 | def triu_to_3d_dense(triu_values, num_nodes, depth=len(SUPPORTED_EDGES)): 136 | """ 137 | Converts the triangular upper part of a matrix 138 | for several dimensions into a 3d tensor. 139 | """ 140 | # Create placeholder for 3d matrix 141 | adj_matrix_3d = torch.empty((num_nodes, num_nodes, depth), dtype=torch.float, device=device) 142 | for edge_type in range(len(SUPPORTED_EDGES)): 143 | adj_mat_edge_type = triu_to_dense(triu_values[:, edge_type].float(), num_nodes) 144 | adj_matrix_3d[:, :, edge_type] = adj_mat_edge_type 145 | return adj_matrix_3d 146 | 147 | def calculate_node_edge_pair_loss(node_tar, edge_tar, node_pred, edge_pred): 148 | """ 149 | Calculates a loss based on the sum of node-edge pairs. 150 | node_tar: [nodes, supported atoms] 151 | node_pred: [max nodes, supported atoms + 1] 152 | edge_tar: [triu values for target nodes, supported edges] 153 | edge_pred: [triu values for predicted nodes, supported edges] 154 | 155 | """ 156 | # Recover full 3d adjacency matrix for edge predictions 157 | edge_pred_3d = triu_to_3d_dense(edge_pred, node_pred.shape[0]) # [num nodes, num nodes, edge types] 158 | 159 | # Recover full 3d adjacency matrix for edge targets 160 | edge_tar_3d = triu_to_3d_dense(edge_tar, node_tar.shape[0]) # [num nodes, num nodes, edge types] 161 | 162 | # --- The two output matrices tell us how many edges are connected with each of the atom types 163 | # Multiply each of the edge types with the atom types for the predictions 164 | node_edge_preds = torch.empty((MAX_MOLECULE_SIZE, len(SUPPORTED_ATOMS), len(SUPPORTED_EDGES)), dtype=torch.float, device=device) 165 | for edge in range(len(SUPPORTED_EDGES)): 166 | node_edge_preds[:, :, edge] = torch.matmul(edge_pred_3d[:, :, edge], node_pred[:, :9]) 167 | 168 | # Multiply each of the edge types with the atom types for the targets 169 | node_edge_tar = torch.empty((node_tar.shape[0], len(SUPPORTED_ATOMS), len(SUPPORTED_EDGES)), dtype=torch.float, device=device) 170 | for edge in range(len(SUPPORTED_EDGES)): 171 | node_edge_tar[:, :, edge] = torch.matmul(edge_tar_3d[:, :, edge], node_tar.float()) 172 | 173 | # Reduce to matrix with [num atom types, num edge types] 174 | node_edge_pred_matrix = torch.sum(node_edge_preds, dim=0) 175 | node_edge_tar_matrix = torch.sum(node_edge_tar, dim=0) 176 | 177 | if torch.equal(node_edge_pred_matrix.int(), node_edge_tar_matrix.int()): 178 | print("Reconstructed node-edge pairs: ", node_edge_pred_matrix.int()) 179 | 180 | node_edge_loss = torch.mean(sum(squared_difference(node_edge_pred_matrix, node_edge_tar_matrix.float()))) 181 | 182 | # Calculate node-edge-node for preds 183 | node_edge_node_preds = torch.empty((MAX_MOLECULE_SIZE, MAX_MOLECULE_SIZE, len(SUPPORTED_EDGES)), dtype=torch.float, device=device) 184 | for edge in range(len(SUPPORTED_EDGES)): 185 | node_edge_node_preds[:, :, edge] = torch.matmul(node_edge_preds[:, :, edge], node_pred[:, :9].t()) 186 | 187 | # Calculate node-edge-node for targets 188 | node_edge_node_tar = torch.empty((node_tar.shape[0], node_tar.shape[0], len(SUPPORTED_EDGES)), dtype=torch.float, device=device) 189 | for edge in range(len(SUPPORTED_EDGES)): 190 | node_edge_node_tar[:, :, edge] = torch.matmul(node_edge_tar[:, :, edge], node_tar.float().t()) 191 | 192 | # Node edge node loss 193 | node_edge_node_loss = sum(squared_difference(torch.sum(node_edge_node_preds, [0,1]), 194 | torch.sum(node_edge_node_tar, [0,1]))) 195 | 196 | # TODO: Improve loss 197 | return node_edge_loss # * node_edge_node_loss 198 | 199 | 200 | def approximate_recon_loss(node_targets, node_preds, triu_targets, triu_preds): 201 | """ 202 | See: https://github.com/seokhokang/graphvae_approx/ 203 | TODO: Improve loss function 204 | """ 205 | # Convert targets to one hot 206 | onehot_node_targets = to_one_hot(node_targets, SUPPORTED_ATOMS ) #+ ["None"] 207 | onehot_triu_targets = to_one_hot(triu_targets, ["None"] + SUPPORTED_EDGES) 208 | 209 | # Reshape node predictions 210 | node_matrix_shape = (MAX_MOLECULE_SIZE, (len(SUPPORTED_ATOMS) + 1)) 211 | node_preds_matrix = node_preds.reshape(node_matrix_shape) 212 | 213 | # Reshape triu predictions 214 | edge_matrix_shape = (int((MAX_MOLECULE_SIZE * (MAX_MOLECULE_SIZE - 1))/2), len(SUPPORTED_EDGES) + 1) 215 | triu_preds_matrix = triu_preds.reshape(edge_matrix_shape) 216 | 217 | # Apply sum on labels per (node/edge) type and discard "none" types 218 | node_preds_reduced = torch.sum(node_preds_matrix[:, :9], 0) 219 | node_targets_reduced = torch.sum(onehot_node_targets, 0) 220 | triu_preds_reduced = torch.sum(triu_preds_matrix[:, 1:], 0) 221 | triu_targets_reduced = torch.sum(onehot_triu_targets[:, 1:], 0) 222 | 223 | # Calculate node-sum loss and edge-sum loss 224 | node_loss = sum(squared_difference(node_preds_reduced, node_targets_reduced.float())) 225 | edge_loss = sum(squared_difference(triu_preds_reduced, triu_targets_reduced.float())) 226 | 227 | # Calculate node-edge-sum loss 228 | # Forces the model to properly arrange the matrices 229 | node_edge_loss = calculate_node_edge_pair_loss(onehot_node_targets, 230 | onehot_triu_targets, 231 | node_preds_matrix, 232 | triu_preds_matrix) 233 | 234 | approx_loss = node_loss + edge_loss + node_edge_loss 235 | 236 | if all(node_targets_reduced == node_preds_reduced.int()) and \ 237 | all(triu_targets_reduced == triu_preds_reduced.int()): 238 | print("Reconstructed all edges: ", node_targets_reduced) 239 | print("and all nodes: ", node_targets_reduced) 240 | return approx_loss 241 | 242 | 243 | def gvae_loss(triu_logits, node_logits, edge_index, edge_types, node_types, \ 244 | mu, logvar, batch_index, kl_beta): 245 | """ 246 | Calculates the loss for the graph variational autoencoder, 247 | consiting of a node loss, an edge loss and the KL divergence. 248 | """ 249 | # Convert target edge index to dense adjacency matrix 250 | batch_edge_targets = torch.squeeze(to_dense_adj(edge_index)) 251 | 252 | # Add edge types to adjacency targets 253 | batch_edge_targets[edge_index[0], edge_index[1]] = edge_types[:, 1].float() 254 | 255 | # For this model we always have the same (fixed) output dimension 256 | graph_size = MAX_MOLECULE_SIZE*(len(SUPPORTED_ATOMS) + 1) 257 | graph_triu_size = int((MAX_MOLECULE_SIZE * (MAX_MOLECULE_SIZE - 1)) / 2) * (len(SUPPORTED_EDGES) + 1) 258 | 259 | # Reconstruction loss per graph 260 | batch_recon_loss = [] 261 | triu_indices_counter = 0 262 | graph_size_counter = 0 263 | 264 | # Loop over graphs in this batch 265 | for graph_id in torch.unique(batch_index): 266 | # Get upper triangular targets for this graph from the whole batch 267 | triu_targets, node_targets = slice_graph_targets(graph_id, 268 | batch_edge_targets, 269 | node_types, 270 | batch_index) 271 | 272 | # Get upper triangular predictions for this graph from the whole batch 273 | triu_preds, node_preds = slice_graph_predictions(triu_logits, 274 | node_logits, 275 | graph_triu_size, 276 | triu_indices_counter, 277 | graph_size, 278 | graph_size_counter) 279 | 280 | # Update counter to the index of the next (upper-triu) graph 281 | triu_indices_counter = triu_indices_counter + graph_triu_size 282 | graph_size_counter = graph_size_counter + graph_size 283 | 284 | # Calculate losses 285 | recon_loss = approximate_recon_loss(node_targets, 286 | node_preds, 287 | triu_targets, 288 | triu_preds) 289 | batch_recon_loss.append(recon_loss) 290 | 291 | # Take average of all losses 292 | num_graphs = torch.unique(batch_index).shape[0] 293 | batch_recon_loss = torch.true_divide(sum(batch_recon_loss), num_graphs) 294 | 295 | # KL Divergence 296 | kl_divergence = kl_loss(mu, logvar) 297 | 298 | return batch_recon_loss + kl_beta * kl_divergence, kl_divergence 299 | 300 | 301 | 302 | def graph_representation_to_molecule(node_types, adjacency_triu): 303 | """ 304 | Converts the predicted graph to a molecule and validates it 305 | using RDKit. 306 | """ 307 | # Create empty mol 308 | mol = Chem.RWMol() 309 | 310 | # Add atoms to mol and store their index 311 | node_to_idx = {} 312 | for i in range(len(node_types)): 313 | a = Chem.Atom(int(node_types[i])) 314 | molIdx = mol.AddAtom(a) 315 | node_to_idx[i] = molIdx 316 | 317 | # Add edges to mol 318 | num_nodes = len(node_types) 319 | adjacency_matrix = triu_to_dense(adjacency_triu, num_nodes) 320 | for ix, row in enumerate(adjacency_matrix): 321 | for iy, bond in enumerate(row): 322 | # only traverse half the matrix 323 | if iy <= ix: 324 | continue 325 | 326 | # add bonds 327 | if bond == 0: 328 | continue 329 | else: 330 | if bond == 1: 331 | bond_type = Chem.rdchem.BondType.SINGLE 332 | elif bond == 2: 333 | bond_type = Chem.rdchem.BondType.DOUBLE 334 | elif bond == 3: 335 | bond_type = Chem.rdchem.BondType.TRIPLE 336 | elif bond == 4: 337 | bond_type = Chem.rdchem.BondType.AROMATIC 338 | mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type) 339 | # Convert RWMol to mol and Smiles 340 | mol = mol.GetMol() 341 | smiles = Chem.MolToSmiles(mol) 342 | 343 | # Sanitize molecule (make sure it is valid) 344 | try: 345 | Chem.SanitizeMol(mol) 346 | except: 347 | smiles = None 348 | 349 | # TODO: Visualize and save (use deepchem smiles_to_image) 350 | return smiles, mol 351 | --------------------------------------------------------------------------------