├── LICENSE ├── README.md └── genmol ├── JTVAE ├── model.py ├── preprocess.py ├── sample.py ├── savedmodel.pth ├── train.py └── train.txt ├── ORGAN ├── Data.py ├── Metrics_Reward.py ├── Model.py ├── NP_Score │ ├── README │ ├── __pycache__ │ │ └── npscorer.cpython-37.pyc │ ├── npscorer.py │ └── publicnp.model.gz ├── RewardMetrics.py ├── Run.py ├── SA_Score │ ├── README │ ├── UnitTestSAScore.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── sascorer.cpython-37.pyc │ ├── data │ │ └── zim.100.txt │ ├── fpscores.pkl.gz │ └── sascorer.py ├── Trainer.py ├── mcf.csv ├── test.py └── wehi_pains.csv ├── aae ├── data.py ├── model.py ├── run.py ├── sample.py └── train.py ├── models.txt └── vae ├── data.py ├── run.py ├── samples.py ├── trainer.py └── vae_model.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 bayeslabs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GenMol ( Molecular Structure Generation) 2 | This is a library which curates different molecular generation methods with machine learning. You can use this library to advance your research in Drug discovery and Material Discovery. 3 | 4 | We implemented following algorithms using Pytorch for Molecular generations. 5 | 15 | -------------------------------------------------------------------------------- /genmol/JTVAE/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math, random, sys 4 | from optparse import OptionParser 5 | import pickle 6 | import rdkit 7 | import json 8 | import rdkit.Chem as Chem 9 | from scipy.sparse import csr_matrix 10 | from scipy.sparse.csgraph import minimum_spanning_tree 11 | from collections import defaultdict 12 | import copy 13 | import torch.optim as optim 14 | import torch.optim.lr_scheduler as lr_scheduler 15 | from torch.utils.data import Dataset, DataLoader 16 | from torch.autograd import Variable 17 | import numpy as np 18 | from collections import deque 19 | import os, random 20 | import torch.nn.functional as F 21 | import pdb 22 | 23 | from jvae_preprocess import * 24 | 25 | def get_slots(smiles): 26 | mol = Chem.MolFromSmiles(smiles) 27 | return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()] 28 | 29 | 30 | def get_molecule(node): 31 | return Chem.MolFromSmiles(node.smiles) 32 | 33 | 34 | class Vocab(object): 35 | benzynes = ['C1=CC=CC=C1', 'C1=CC=NC=C1', 'C1=CC=NN=C1', 'C1=CN=CC=N1', 'C1=CN=CN=C1', 'C1=CN=NC=N1', 'C1=CN=NN=C1', 'C1=NC=NC=N1', 'C1=NN=CN=N1'] 36 | penzynes = ['C1=C[NH]C=C1', 'C1=C[NH]C=N1', 'C1=C[NH]N=C1', 'C1=C[NH]N=N1', 'C1=COC=C1', 'C1=COC=N1', 'C1=CON=C1', 'C1=CSC=C1', 'C1=CSC=N1', 'C1=CSN=C1', 'C1=CSN=N1', 'C1=NN=C[NH]1', 'C1=NN=CO1', 'C1=NN=CS1', 'C1=N[NH]C=N1', 'C1=N[NH]N=C1', 'C1=N[NH]N=N1', 'C1=NN=N[NH]1', 'C1=NN=NS1', 'C1=NOC=N1', 'C1=NON=C1', 'C1=NSC=N1', 'C1=NSN=C1'] 37 | 38 | def __init__(self, smiles_list,all_trees): 39 | list_d=[] 40 | 41 | for j in range(0,len(all_trees)): 42 | x=[] 43 | x=all_trees[j].nodes 44 | 45 | for i in range(0,len(x)): 46 | m=get_molecule(x[i]) 47 | m1=Chem.MolToSmiles(m,kekuleSmiles=False) 48 | list_d.append(m1) 49 | 50 | list_f=list(dict.fromkeys(list_d)) 51 | smiles_f=smiles_list+list_f 52 | 53 | self.vocab = smiles_f 54 | self.vmap = {x:i for i,x in enumerate(self.vocab)} 55 | self.slots = [get_slots(smiles) for smiles in self.vocab] 56 | Vocab.benzynes = [s for s in smiles_list if s.count('=') >= 2 and Chem.MolFromSmiles(s).GetNumAtoms() == 6] + ['C1=CCNCC1'] 57 | Vocab.penzynes = [s for s in smiles_list if s.count('=') >= 2 and Chem.MolFromSmiles(s).GetNumAtoms() == 5] + ['C1=NCCN1','C1=NNCC1'] 58 | 59 | 60 | def get_index(self, smiles): 61 | return self.vmap[smiles] 62 | 63 | def get_smiles(self, idx): 64 | return self.vocab[idx] 65 | 66 | def get_slots(self, idx): 67 | return copy.deepcopy(self.slots[idx]) 68 | 69 | def size(self): 70 | return len(self.vocab) 71 | 72 | 73 | 74 | def create_variable(tensor, requires_grad=None): 75 | if requires_grad is None: 76 | return Variable(tensor) 77 | else: 78 | return Variable(tensor, requires_grad=requires_grad) 79 | 80 | def index_select_ND(source, dim, index): 81 | index_size = index.size() 82 | suffix_dim = source.size()[1:] 83 | final_size = index_size + suffix_dim 84 | target = source.index_select(dim, index.view(-1)) 85 | return target.view(final_size) 86 | 87 | def GRU(x, h_nei, W_z, W_r, U_r, W_h): 88 | hidden_size = x.size()[-1] 89 | sum_h = h_nei.sum(dim=1) 90 | z_input = torch.cat([x,sum_h], dim=1) 91 | z = torch.sigmoid(W_z(z_input)) 92 | 93 | r_1 = W_r(x).view(-1,1,hidden_size) 94 | r_2 = U_r(h_nei) 95 | r = torch.sigmoid(r_1 + r_2) 96 | 97 | gated_h = r * h_nei 98 | sum_gated_h = gated_h.sum(dim=1) 99 | h_input = torch.cat([x,sum_gated_h], dim=1) 100 | pre_h = torch.tanh(W_h(h_input)) 101 | new_h = (1.0 - z) * sum_h + z * pre_h 102 | return new_h 103 | 104 | 105 | 106 | ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] 107 | 108 | ATOM_FDIM1 = len(ELEM_LIST) + 6 + 5 + 4 + 1 109 | BOND_FDIM1 = 5 + 6 110 | MAX_NB1 = 6 111 | 112 | def onek_encoding_unk1(x, allowable_set): 113 | if x not in allowable_set: 114 | x = allowable_set[-1] 115 | return list(map(lambda s: x == s, allowable_set)) 116 | 117 | def atom_features1(atom): 118 | return torch.Tensor(onek_encoding_unk1(atom.GetSymbol(), ELEM_LIST) 119 | + onek_encoding_unk1(atom.GetDegree(), [0,1,2,3,4,5]) 120 | + onek_encoding_unk1(atom.GetFormalCharge(), [-1,-2,1,2,0]) 121 | + onek_encoding_unk1(int(atom.GetChiralTag()), [0,1,2,3]) 122 | + [atom.GetIsAromatic()]) 123 | 124 | def bond_features1(bond): 125 | bt = bond.GetBondType() 126 | stereo = int(bond.GetStereo()) 127 | fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()] 128 | fstereo = onek_encoding_unk1(stereo, [0,1,2,3,4,5]) 129 | return torch.Tensor(fbond + fstereo) 130 | 131 | class MPN(nn.Module): 132 | 133 | def __init__(self, hidden_size, depth): 134 | super(MPN, self).__init__() 135 | self.hidden_size = int(hidden_size) 136 | self.depth = depth 137 | 138 | self.W_i = nn.Linear(ATOM_FDIM1 + BOND_FDIM1, hidden_size, bias=False) 139 | self.W_h = nn.Linear(hidden_size, hidden_size, bias=False) 140 | self.W_o = nn.Linear(ATOM_FDIM1 + hidden_size, hidden_size) 141 | 142 | def forward(self, fatoms, fbonds, agraph, bgraph, scope): 143 | fatoms = create_variable(fatoms) 144 | fbonds = create_variable(fbonds) 145 | agraph = create_variable(agraph) 146 | bgraph = create_variable(bgraph) 147 | 148 | binput = self.W_i(fbonds) 149 | message = F.relu(binput) 150 | 151 | for i in range(self.depth - 1): 152 | nei_message = index_select_ND(message, 0, bgraph) 153 | nei_message = nei_message.sum(dim=1) 154 | nei_message = self.W_h(nei_message) 155 | message = F.relu(binput + nei_message) 156 | 157 | nei_message = index_select_ND(message, 0, agraph) 158 | nei_message = nei_message.sum(dim=1) 159 | ainput = torch.cat([fatoms, nei_message], dim=1) 160 | atom_hiddens = F.relu(self.W_o(ainput)) 161 | 162 | max_len = max([x for _,x in scope]) 163 | batch_vecs = [] 164 | for st,le in scope: 165 | cur_vecs = atom_hiddens[st : st + le].mean(dim=0) 166 | batch_vecs.append( cur_vecs ) 167 | 168 | mol_vecs = torch.stack(batch_vecs, dim=0) 169 | return mol_vecs 170 | 171 | @staticmethod 172 | def tensorize(mol_batch): 173 | padding = torch.zeros(ATOM_FDIM1 + BOND_FDIM1) 174 | fatoms,fbonds = [],[padding] #Ensure bond is 1-indexed 175 | in_bonds,all_bonds = [],[(-1,-1)] #Ensure bond is 1-indexed 176 | scope = [] 177 | total_atoms = 0 178 | 179 | for smiles in mol_batch: 180 | mol = get_mol(smiles) 181 | n_atoms = mol.GetNumAtoms() 182 | 183 | for atom in mol.GetAtoms(): 184 | fatoms.append( atom_features1(atom) ) 185 | in_bonds.append([]) 186 | 187 | for bond in mol.GetBonds(): 188 | a1 = bond.GetBeginAtom() 189 | a2 = bond.GetEndAtom() 190 | x = a1.GetIdx() + total_atoms 191 | y = a2.GetIdx() + total_atoms 192 | 193 | b = len(all_bonds) 194 | all_bonds.append((x,y)) 195 | fbonds.append( torch.cat([fatoms[x], bond_features1(bond)], 0) ) 196 | in_bonds[y].append(b) 197 | 198 | b = len(all_bonds) 199 | all_bonds.append((y,x)) 200 | fbonds.append( torch.cat([fatoms[y], bond_features1(bond)], 0) ) 201 | in_bonds[x].append(b) 202 | 203 | scope.append((total_atoms,n_atoms)) 204 | total_atoms += n_atoms 205 | 206 | total_bonds = len(all_bonds) 207 | fatoms = torch.stack(fatoms, 0) 208 | fbonds = torch.stack(fbonds, 0) 209 | agraph = torch.zeros(total_atoms,MAX_NB1).long() 210 | bgraph = torch.zeros(total_bonds,MAX_NB1).long() 211 | 212 | for a in range(total_atoms): 213 | for i,b in enumerate(in_bonds[a]): 214 | agraph[a,i] = b 215 | 216 | for b1 in range(1, total_bonds): 217 | x,y = all_bonds[b1] 218 | for i,b2 in enumerate(in_bonds[x]): 219 | if all_bonds[b2][0] != y: 220 | bgraph[b1,i] = b2 221 | 222 | return (fatoms, fbonds, agraph, bgraph, scope) 223 | 224 | 225 | 226 | ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1 227 | BOND_FDIM = 5 228 | MAX_NB = 15 229 | 230 | def onek_encoding_unk(x, allowable_set): 231 | if x not in allowable_set: 232 | x = allowable_set[-1] 233 | return list(map(lambda s: x == s, allowable_set)) 234 | 235 | def atom_features(atom): 236 | return torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) 237 | + onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) 238 | + onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0]) 239 | + [atom.GetIsAromatic()]) 240 | 241 | def bond_features(bond): 242 | bt = bond.GetBondType() 243 | return torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]) 244 | 245 | class JTMPN(nn.Module): 246 | 247 | def __init__(self, hidden_size, depth): 248 | super(JTMPN, self).__init__() 249 | self.hidden_size = int(hidden_size) 250 | self.depth = depth 251 | 252 | self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False) 253 | self.W_h = nn.Linear(hidden_size, hidden_size, bias=False) 254 | self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size) 255 | 256 | def forward(self, fatoms, fbonds, agraph, bgraph, scope, tree_message): #tree_message[0] == vec(0) 257 | fatoms = create_variable(fatoms) 258 | fbonds = create_variable(fbonds) 259 | agraph = create_variable(agraph) 260 | bgraph = create_variable(bgraph) 261 | 262 | binput = self.W_i(fbonds) 263 | graph_message = F.relu(binput) 264 | 265 | for i in range(self.depth - 1): 266 | message = torch.cat([tree_message,graph_message], dim=0) 267 | nei_message = index_select_ND(message, 0, bgraph) 268 | nei_message = nei_message.sum(dim=1) #assuming tree_message[0] == vec(0) 269 | nei_message = self.W_h(nei_message) 270 | graph_message = F.relu(binput + nei_message) 271 | 272 | message = torch.cat([tree_message,graph_message], dim=0) 273 | nei_message = index_select_ND(message, 0, agraph) 274 | nei_message = nei_message.sum(dim=1) 275 | ainput = torch.cat([fatoms, nei_message], dim=1) 276 | atom_hiddens = F.relu(self.W_o(ainput)) 277 | 278 | mol_vecs = [] 279 | for st,le in scope: 280 | mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le 281 | mol_vecs.append(mol_vec) 282 | 283 | mol_vecs = torch.stack(mol_vecs, dim=0) 284 | return mol_vecs 285 | 286 | @staticmethod 287 | def tensorize(cand_batch, mess_dict): 288 | fatoms,fbonds = [],[] 289 | in_bonds,all_bonds = [],[] 290 | total_atoms = 0 291 | total_mess = len(mess_dict) + 1 #must include vec(0) padding 292 | scope = [] 293 | 294 | for smiles,all_nodes,ctr_node in cand_batch: 295 | mol = Chem.MolFromSmiles(smiles) 296 | Chem.Kekulize(mol) #The original jtnn version kekulizes. Need to revisit why it is necessary 297 | n_atoms = mol.GetNumAtoms() 298 | ctr_bid = ctr_node.idx 299 | for atom in mol.GetAtoms(): 300 | 301 | fatoms.append( atom_features(atom) ) 302 | in_bonds.append([]) 303 | 304 | for bond in mol.GetBonds(): 305 | 306 | a1 = bond.GetBeginAtom() 307 | a2 = bond.GetEndAtom() 308 | x = a1.GetIdx() + total_atoms 309 | y = a2.GetIdx() + total_atoms 310 | #Here x_nid,y_nid could be 0 311 | x_nid,y_nid = a1.GetAtomMapNum(),a2.GetAtomMapNum() 312 | x_bid = all_nodes[x_nid - 1].idx if x_nid > 0 else -1 313 | y_bid = all_nodes[y_nid - 1].idx if y_nid > 0 else -1 314 | 315 | bfeature = bond_features(bond) 316 | 317 | b = total_mess + len(all_bonds) #bond idx offseted by total_mess 318 | all_bonds.append((x,y)) 319 | fbonds.append( torch.cat([fatoms[x], bfeature], 0) ) 320 | in_bonds[y].append(b) 321 | 322 | b = total_mess + len(all_bonds) 323 | all_bonds.append((y,x)) 324 | fbonds.append( torch.cat([fatoms[y], bfeature], 0) ) 325 | in_bonds[x].append(b) 326 | 327 | if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid: 328 | if (x_bid,y_bid) in mess_dict: 329 | mess_idx = mess_dict[(x_bid,y_bid)] 330 | in_bonds[y].append(mess_idx) 331 | if (y_bid,x_bid) in mess_dict: 332 | mess_idx = mess_dict[(y_bid,x_bid)] 333 | in_bonds[x].append(mess_idx) 334 | 335 | scope.append((total_atoms,n_atoms)) 336 | total_atoms += n_atoms 337 | 338 | total_bonds = len(all_bonds) 339 | fatoms = torch.stack(fatoms, 0) 340 | fbonds = torch.stack(fbonds, 0) 341 | agraph = torch.zeros(total_atoms,MAX_NB).long() 342 | bgraph = torch.zeros(total_bonds,MAX_NB).long() 343 | 344 | for a in range(total_atoms): 345 | for i,b in enumerate(in_bonds[a]): 346 | agraph[a,i] = b 347 | 348 | for b1 in range(total_bonds): 349 | x,y = all_bonds[b1] 350 | for i,b2 in enumerate(in_bonds[x]): #b2 is offseted by total_mess 351 | if b2 < total_mess or all_bonds[b2-total_mess][0] != y: 352 | bgraph[b1,i] = b2 353 | 354 | return (fatoms, fbonds, agraph, bgraph, scope) 355 | 356 | 357 | 358 | def dfs(stack, x, fa_idx): 359 | for y in x.neighbors: 360 | if y.idx == fa_idx: continue 361 | stack.append( (x,y,1) ) 362 | dfs(stack, y, x.idx) 363 | stack.append( (y,x,0) ) 364 | 365 | def have_slots(fa_slots, ch_slots): 366 | if len(fa_slots) > 2 and len(ch_slots) > 2: 367 | return True 368 | matches = [] 369 | for i,s1 in enumerate(fa_slots): 370 | a1,c1,h1 = s1 371 | for j,s2 in enumerate(ch_slots): 372 | a2,c2,h2 = s2 373 | if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4): 374 | matches.append( (i,j) ) 375 | 376 | if len(matches) == 0: return False 377 | 378 | fa_match,ch_match = zip(*matches) 379 | if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2: #never remove atom from ring 380 | fa_slots.pop(fa_match[0]) 381 | if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2: #never remove atom from ring 382 | ch_slots.pop(ch_match[0]) 383 | 384 | return True 385 | 386 | def can_assemble(node_x, node_y): 387 | node_x.nid = 1 388 | node_x.is_leaf = False 389 | set_atommap(node_x.mol, node_x.nid) 390 | 391 | neis = node_x.neighbors + [node_y] 392 | for i,nei in enumerate(neis): 393 | nei.nid = i + 2 394 | nei.is_leaf = (len(nei.neighbors) <= 1) 395 | if nei.is_leaf: 396 | set_atommap(nei.mol, 0) 397 | else: 398 | set_atommap(nei.mol, nei.nid) 399 | 400 | neighbors = [nei for nei in neis if nei.mol.GetNumAtoms() > 1] 401 | neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True) 402 | singletons = [nei for nei in neis if nei.mol.GetNumAtoms() == 1] 403 | neighbors = singletons + neighbors 404 | cands,aroma_scores = enum_assemble(node_x, neighbors) 405 | return len(cands) > 0# and sum(aroma_scores) >= 0 406 | 407 | 408 | 409 | MAX_NB0 = 15 410 | MAX_DECODE_LEN0 = 100 411 | 412 | class JTNNDecoder(nn.Module): 413 | 414 | def __init__(self, vocab, hidden_size, latent_size, embedding): 415 | super(JTNNDecoder, self).__init__() 416 | self.hidden_size = int(hidden_size) 417 | self.vocab_size = vocab.size() 418 | self.vocab = vocab 419 | self.embedding = embedding 420 | latent_size=int(latent_size) 421 | #GRU Weights 422 | self.W_z = nn.Linear(2 * hidden_size, hidden_size) 423 | self.U_r = nn.Linear(hidden_size, hidden_size, bias=False) 424 | self.W_r = nn.Linear(hidden_size, hidden_size) 425 | self.W_h = nn.Linear(2 * hidden_size, hidden_size) 426 | 427 | #Word Prediction Weights 428 | self.W = nn.Linear(hidden_size + latent_size, hidden_size) 429 | 430 | #Stop Prediction Weights 431 | self.U = nn.Linear(hidden_size + latent_size, hidden_size) 432 | self.U_i = nn.Linear(2 * hidden_size, hidden_size) 433 | 434 | #Output Weights 435 | self.W_o = nn.Linear(hidden_size, self.vocab_size) 436 | self.U_o = nn.Linear(hidden_size, 1) 437 | 438 | 439 | #Loss Functions 440 | self.pred_loss = nn.CrossEntropyLoss(size_average=False) 441 | self.stop_loss = nn.BCEWithLogitsLoss(size_average=False) 442 | 443 | def aggregate(self, hiddens, contexts, x_tree_vecs, mode): 444 | if mode == 'word': 445 | V, V_o = self.W, self.W_o 446 | elif mode == 'stop': 447 | V, V_o = self.U, self.U_o 448 | else: 449 | raise ValueError('aggregate mode is wrong') 450 | 451 | tree_contexts = x_tree_vecs.index_select(0, contexts) 452 | input_vec = torch.cat([hiddens, tree_contexts], dim=-1) 453 | output_vec = F.relu( V(input_vec) ) 454 | return V_o(output_vec) 455 | 456 | def forward(self, mol_batch, x_tree_vecs): 457 | pred_hiddens,pred_contexts,pred_targets = [],[],[] 458 | stop_hiddens,stop_contexts,stop_targets = [],[],[] 459 | traces = [] 460 | for mol_tree in mol_batch: 461 | s = [] 462 | dfs(s, mol_tree.nodes[0], -1) 463 | traces.append(s) 464 | for node in mol_tree.nodes: 465 | node.neighbors = [] 466 | 467 | #Predict Root 468 | batch_size = len(mol_batch) 469 | pred_hiddens.append(create_variable(torch.zeros(len(mol_batch),self.hidden_size))) 470 | pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch]) 471 | 472 | pred_contexts.append( create_variable( torch.LongTensor(range(batch_size)) ) ) 473 | 474 | max_iter = max([len(tr) for tr in traces]) 475 | padding = create_variable(torch.zeros(self.hidden_size), False) 476 | h = {} 477 | 478 | for t in range(max_iter): 479 | prop_list = [] 480 | batch_list = [] 481 | for i,plist in enumerate(traces): 482 | if t < len(plist): 483 | prop_list.append(plist[t]) 484 | batch_list.append(i) 485 | 486 | cur_x = [] 487 | cur_h_nei,cur_o_nei = [],[] 488 | 489 | for node_x, real_y, _ in prop_list: 490 | #Neighbors for message passing (target not included) 491 | cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx] 492 | pad_len = MAX_NB0 - len(cur_nei) 493 | cur_h_nei.extend(cur_nei) 494 | cur_h_nei.extend([padding] * pad_len) 495 | 496 | #Neighbors for stop prediction (all neighbors) 497 | cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors] 498 | pad_len = MAX_NB0 - len(cur_nei) 499 | cur_o_nei.extend(cur_nei) 500 | cur_o_nei.extend([padding] * pad_len) 501 | 502 | #Current clique embedding 503 | cur_x.append(node_x.wid) 504 | 505 | 506 | #Clique embedding 507 | cur_x = create_variable(torch.LongTensor(cur_x)) 508 | cur_x = self.embedding(cur_x) 509 | 510 | #Message passing 511 | cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB0,self.hidden_size) 512 | new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) 513 | 514 | #Node Aggregate 515 | cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB0,self.hidden_size) 516 | cur_o = cur_o_nei.sum(dim=1) 517 | 518 | #Gather targets 519 | pred_target,pred_list = [],[] 520 | stop_target = [] 521 | for i,m in enumerate(prop_list): 522 | node_x,node_y,direction = m 523 | x,y = node_x.idx,node_y.idx 524 | h[(x,y)] = new_h[i] 525 | node_y.neighbors.append(node_x) 526 | if direction == 1: 527 | pred_target.append(node_y.wid) 528 | pred_list.append(i) 529 | stop_target.append(direction) 530 | 531 | #Hidden states for stop prediction 532 | cur_batch = create_variable(torch.LongTensor(batch_list)) 533 | stop_hidden = torch.cat([cur_x,cur_o], dim=1) 534 | stop_hiddens.append( stop_hidden ) 535 | stop_contexts.append( cur_batch ) 536 | stop_targets.extend( stop_target ) 537 | 538 | #Hidden states for clique prediction 539 | if len(pred_list) > 0: 540 | batch_list = [batch_list[i] for i in pred_list] 541 | cur_batch = create_variable(torch.LongTensor(batch_list)) 542 | pred_contexts.append( cur_batch ) 543 | 544 | cur_pred = create_variable(torch.LongTensor(pred_list)) 545 | pred_hiddens.append( new_h.index_select(0, cur_pred) ) 546 | pred_targets.extend( pred_target ) 547 | 548 | #Last stop at root 549 | cur_x,cur_o_nei = [],[] 550 | for mol_tree in mol_batch: 551 | node_x = mol_tree.nodes[0] 552 | cur_x.append(node_x.wid) 553 | cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors] 554 | pad_len = MAX_NB0 - len(cur_nei) 555 | cur_o_nei.extend(cur_nei) 556 | cur_o_nei.extend([padding] * pad_len) 557 | 558 | cur_x = create_variable(torch.LongTensor(cur_x)) 559 | cur_x = self.embedding(cur_x) 560 | cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB0,self.hidden_size) 561 | cur_o = cur_o_nei.sum(dim=1) 562 | 563 | stop_hidden = torch.cat([cur_x,cur_o], dim=1) 564 | stop_hiddens.append( stop_hidden ) 565 | stop_contexts.append( create_variable( torch.LongTensor(range(batch_size)) ) ) 566 | stop_targets.extend( [0] * len(mol_batch) ) 567 | 568 | #Predict next clique 569 | pred_contexts = torch.cat(pred_contexts, dim=0) 570 | pred_hiddens = torch.cat(pred_hiddens, dim=0) 571 | pred_scores = self.aggregate(pred_hiddens, pred_contexts, x_tree_vecs, 'word') 572 | pred_targets = create_variable(torch.LongTensor(pred_targets)) 573 | pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch) 574 | _,preds = torch.max(pred_scores, dim=1) 575 | pred_acc = torch.eq(preds, pred_targets).float() 576 | pred_acc = torch.sum(pred_acc) / pred_targets.nelement() 577 | 578 | #Predict stop 579 | stop_contexts = torch.cat(stop_contexts, dim=0) 580 | stop_hiddens = torch.cat(stop_hiddens, dim=0) 581 | stop_hiddens = F.relu( self.U_i(stop_hiddens) ) 582 | stop_scores = self.aggregate(stop_hiddens, stop_contexts, x_tree_vecs, 'stop') 583 | stop_scores = stop_scores.squeeze(-1) 584 | stop_targets = create_variable(torch.Tensor(stop_targets)) 585 | 586 | stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch) 587 | stops = torch.ge(stop_scores, 0).float() 588 | stop_acc = torch.eq(stops, stop_targets).float() 589 | stop_acc = torch.sum(stop_acc) / stop_targets.nelement() 590 | 591 | return pred_loss, stop_loss, pred_acc.item(), stop_acc.item() 592 | 593 | def decode(self, x_tree_vecs, prob_decode): 594 | assert x_tree_vecs.size(0) == 1 595 | 596 | stack = [] 597 | init_hiddens = create_variable( torch.zeros(1, self.hidden_size) ) 598 | zero_pad = create_variable(torch.zeros(1,1,self.hidden_size)) 599 | contexts = create_variable( torch.LongTensor(1).zero_() ) 600 | 601 | #Root Prediction 602 | root_score = self.aggregate(init_hiddens, contexts, x_tree_vecs, 'word') 603 | _,root_wid = torch.max(root_score, dim=1) 604 | root_wid = root_wid.item() 605 | 606 | root = MolTreeNode(self.vocab.get_smiles(root_wid)) 607 | root.wid = root_wid 608 | root.idx = 0 609 | stack.append( (root, self.vocab.get_slots(root.wid)) ) 610 | 611 | all_nodes = [root] 612 | h = {} 613 | for step in range(MAX_DECODE_LEN0): 614 | node_x,fa_slot = stack[-1] 615 | cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors ] 616 | if len(cur_h_nei) > 0: 617 | cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size) 618 | else: 619 | cur_h_nei = zero_pad 620 | 621 | cur_x = create_variable(torch.LongTensor([node_x.wid])) 622 | cur_x = self.embedding(cur_x) 623 | 624 | #Predict stop 625 | cur_h = cur_h_nei.sum(dim=1) 626 | stop_hiddens = torch.cat([cur_x,cur_h], dim=1) 627 | stop_hiddens = F.relu( self.U_i(stop_hiddens) ) 628 | stop_score = self.aggregate(stop_hiddens, contexts, x_tree_vecs, 'stop') 629 | 630 | if prob_decode: 631 | backtrack = (torch.bernoulli( torch.sigmoid(stop_score) ).item() == 0) 632 | else: 633 | backtrack = (stop_score.item() < 0) 634 | 635 | if not backtrack: #Forward: Predict next clique 636 | new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) 637 | pred_score = self.aggregate(new_h, contexts, x_tree_vecs, 'word') 638 | 639 | if prob_decode: 640 | sort_wid = torch.multinomial(F.softmax(pred_score, dim=1).squeeze(), 5) 641 | else: 642 | _,sort_wid = torch.sort(pred_score, dim=1, descending=True) 643 | sort_wid = sort_wid.data.squeeze() 644 | 645 | next_wid = None 646 | for wid in sort_wid[:5]: 647 | slots = self.vocab.get_slots(wid) 648 | node_y = MolTreeNode(self.vocab.get_smiles(wid)) 649 | if have_slots(fa_slot, slots) and can_assemble(node_x, node_y): 650 | next_wid = wid 651 | next_slots = slots 652 | break 653 | 654 | if next_wid is None: 655 | backtrack = True #No more children can be added 656 | else: 657 | node_y = MolTreeNode(self.vocab.get_smiles(next_wid)) 658 | node_y.wid = next_wid 659 | node_y.idx = len(all_nodes) 660 | node_y.neighbors.append(node_x) 661 | h[(node_x.idx,node_y.idx)] = new_h[0] 662 | stack.append( (node_y,next_slots) ) 663 | all_nodes.append(node_y) 664 | 665 | if backtrack: #Backtrack, use if instead of else 666 | if len(stack) == 1: 667 | break #At root, terminate 668 | 669 | node_fa,_ = stack[-2] 670 | cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ] 671 | if len(cur_h_nei) > 0: 672 | cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size) 673 | else: 674 | cur_h_nei = zero_pad 675 | 676 | new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) 677 | h[(node_x.idx,node_fa.idx)] = new_h[0] 678 | node_fa.neighbors.append(node_x) 679 | stack.pop() 680 | 681 | return root, all_nodes 682 | 683 | 684 | 685 | 686 | class JTNNEncoder(nn.Module): 687 | 688 | def __init__(self, hidden_size, depth, embedding): 689 | super(JTNNEncoder, self).__init__() 690 | self.hidden_size = int(hidden_size) 691 | self.depth = depth 692 | 693 | self.embedding = embedding 694 | self.outputNN = nn.Sequential( 695 | nn.Linear(2 * hidden_size, hidden_size), 696 | nn.ReLU() 697 | ) 698 | self.GRU = GraphGRU(hidden_size, hidden_size, depth=depth) 699 | 700 | def forward(self, fnode, fmess, node_graph, mess_graph, scope): 701 | fnode = create_variable(fnode) 702 | fmess = create_variable(fmess) 703 | node_graph = create_variable(node_graph) 704 | mess_graph = create_variable(mess_graph) 705 | messages = create_variable(torch.zeros(mess_graph.size(0), self.hidden_size)) 706 | 707 | fnode = self.embedding(fnode) 708 | fmess = index_select_ND(fnode, 0, fmess) 709 | messages = self.GRU(messages, fmess, mess_graph) 710 | 711 | mess_nei = index_select_ND(messages, 0, node_graph) 712 | node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1) 713 | node_vecs = self.outputNN(node_vecs) 714 | 715 | max_len = max([x for _,x in scope]) 716 | batch_vecs = [] 717 | for st,le in scope: 718 | cur_vecs = node_vecs[st] #Root is the first node 719 | batch_vecs.append( cur_vecs ) 720 | 721 | tree_vecs = torch.stack(batch_vecs, dim=0) 722 | return tree_vecs, messages 723 | 724 | @staticmethod 725 | def tensorize(tree_batch): 726 | 727 | node_batch = [] 728 | scope = [] 729 | for tree in tree_batch: 730 | scope.append( (len(node_batch), len(tree.nodes)) ) 731 | node_batch.extend(tree.nodes) 732 | 733 | return JTNNEncoder.tensorize_nodes(node_batch, scope) 734 | 735 | @staticmethod 736 | def tensorize_nodes(node_batch, scope): 737 | 738 | messages,mess_dict = [None],{} 739 | fnode = [] 740 | for x in node_batch: 741 | fnode.append(x.wid) 742 | for y in x.neighbors: 743 | mess_dict[(x.idx,y.idx)] = len(messages) 744 | messages.append( (x,y) ) 745 | 746 | node_graph = [[] for i in range(len(node_batch))] 747 | mess_graph = [[] for i in range(len(messages))] 748 | fmess = [0] * len(messages) 749 | 750 | for x,y in messages[1:]: 751 | mid1 = mess_dict[(x.idx,y.idx)] 752 | fmess[mid1] = x.idx 753 | node_graph[y.idx].append(mid1) 754 | for z in y.neighbors: 755 | if z.idx == x.idx: continue 756 | mid2 = mess_dict[(y.idx,z.idx)] 757 | mess_graph[mid2].append(mid1) 758 | 759 | max_len = max([len(t) for t in node_graph] + [1]) 760 | for t in node_graph: 761 | pad_len = max_len - len(t) 762 | t.extend([0] * pad_len) 763 | 764 | max_len = max([len(t) for t in mess_graph] + [1]) 765 | for t in mess_graph: 766 | pad_len = max_len - len(t) 767 | t.extend([0] * pad_len) 768 | 769 | mess_graph = torch.LongTensor(mess_graph) 770 | node_graph = torch.LongTensor(node_graph) 771 | fmess = torch.LongTensor(fmess) 772 | fnode = torch.LongTensor(fnode) 773 | return (fnode, fmess, node_graph, mess_graph, scope), mess_dict 774 | 775 | class GraphGRU(nn.Module): 776 | 777 | def __init__(self, input_size, hidden_size, depth): 778 | super(GraphGRU, self).__init__() 779 | self.hidden_size = int(hidden_size) 780 | self.input_size = input_size 781 | self.depth = depth 782 | 783 | self.W_z = nn.Linear(input_size + hidden_size, hidden_size) 784 | self.W_r = nn.Linear(input_size, hidden_size, bias=False) 785 | self.U_r = nn.Linear(hidden_size, hidden_size) 786 | self.W_h = nn.Linear(input_size + hidden_size, hidden_size) 787 | 788 | def forward(self, h, x, mess_graph): 789 | mask = torch.ones(h.size(0), 1) 790 | mask[0] = 0 #first vector is padding 791 | mask = create_variable(mask) 792 | for it in range(self.depth): 793 | h_nei = index_select_ND(h, 0, mess_graph) 794 | sum_h = h_nei.sum(dim=1) 795 | z_input = torch.cat([x, sum_h], dim=1) 796 | z = torch.sigmoid(self.W_z(z_input)) 797 | 798 | r_1 = self.W_r(x).view(-1, 1, self.hidden_size) 799 | r_2 = self.U_r(h_nei) 800 | r = torch.sigmoid(r_1 + r_2) 801 | 802 | gated_h = r * h_nei 803 | sum_gated_h = gated_h.sum(dim=1) 804 | h_input = torch.cat([x, sum_gated_h], dim=1) 805 | pre_h = torch.tanh(self.W_h(h_input)) 806 | h = (1.0 - z) * sum_h + z * pre_h 807 | h = h * mask 808 | 809 | return h 810 | 811 | 812 | 813 | 814 | 815 | 816 | 817 | class JTNNVAE(nn.Module): 818 | 819 | def __init__(self, vocab, hidden_size, latent_size, depthT, depthG): 820 | super(JTNNVAE, self).__init__() 821 | self.vocab = vocab 822 | 823 | self.hidden_size = int(hidden_size) 824 | self.latent_size = latent_size = latent_size / 2 #Tree and Mol has two vectors 825 | self.latent_size=int(self.latent_size) 826 | self.jtnn = JTNNEncoder(int(hidden_size),int(depthT), nn.Embedding(780,450)) 827 | self.decoder = JTNNDecoder(vocab, int(hidden_size), int(latent_size), nn.Embedding(780,450)) 828 | 829 | self.jtmpn = JTMPN(int(hidden_size), int(depthG)) 830 | self.mpn = MPN(int(hidden_size), int(depthG)) 831 | 832 | self.A_assm = nn.Linear(int(latent_size), int(hidden_size), bias=False) 833 | self.assm_loss = nn.CrossEntropyLoss(size_average=False) 834 | 835 | self.T_mean = nn.Linear(int(hidden_size), int(latent_size)) 836 | self.T_var = nn.Linear(int(hidden_size), int(latent_size)) 837 | self.G_mean = nn.Linear(int(hidden_size), int(latent_size)) 838 | self.G_var = nn.Linear(int(hidden_size), int(latent_size)) 839 | 840 | def encode(self, jtenc_holder, mpn_holder): 841 | tree_vecs, tree_mess = self.jtnn(*jtenc_holder) 842 | mol_vecs = self.mpn(*mpn_holder) 843 | return tree_vecs, tree_mess, mol_vecs 844 | 845 | def encode_latent(self, jtenc_holder, mpn_holder): 846 | tree_vecs, _ = self.jtnn(*jtenc_holder) 847 | mol_vecs = self.mpn(*mpn_holder) 848 | tree_mean = self.T_mean(tree_vecs) 849 | mol_mean = self.G_mean(mol_vecs) 850 | tree_var = -torch.abs(self.T_var(tree_vecs)) 851 | mol_var = -torch.abs(self.G_var(mol_vecs)) 852 | return torch.cat([tree_mean, mol_mean], dim=1), torch.cat([tree_var, mol_var], dim=1) 853 | 854 | def rsample(self, z_vecs, W_mean, W_var): 855 | batch_size = z_vecs.size(0) 856 | z_mean = W_mean(z_vecs) 857 | z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al. 858 | kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size 859 | epsilon = create_variable(torch.randn_like(z_mean)) 860 | z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon 861 | return z_vecs, kl_loss 862 | 863 | def sample_prior(self, prob_decode=False): 864 | z_tree = torch.randn(1, self.latent_size) 865 | z_mol = torch.randn(1, self.latent_size) 866 | return self.decode(z_tree, z_mol, prob_decode) 867 | 868 | def forward(self, x_batch, beta): 869 | x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder= x_batch 870 | #ncoding the graph and tree 871 | x_tree_vecs, x_tree_mess, x_mol_vecs = self.encode(x_jtenc_holder, x_mpn_holder) 872 | 873 | z_tree_vecs,tree_kl = self.rsample(x_tree_vecs, self.T_mean, self.T_var) 874 | 875 | z_mol_vecs,mol_kl = self.rsample(x_mol_vecs, self.G_mean, self.G_var) 876 | 877 | kl_div = tree_kl + mol_kl 878 | #Decoding the tree 879 | word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs) 880 | #Decoding the graph and assembling the graph 881 | assm_loss, assm_acc = self.assm(x_batch, x_jtmpn_holder, z_mol_vecs, x_tree_mess) 882 | 883 | return word_loss + topo_loss + assm_loss + beta * kl_div, kl_div.item(), word_acc, topo_acc, assm_acc 884 | 885 | def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, x_tree_mess): 886 | jtmpn_holder,batch_idx = jtmpn_holder 887 | fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder 888 | batch_idx = create_variable(batch_idx) 889 | 890 | cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, x_tree_mess) 891 | 892 | x_mol_vecs = x_mol_vecs.index_select(0, batch_idx) 893 | x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear 894 | scores = torch.bmm( 895 | x_mol_vecs.unsqueeze(1), 896 | cand_vecs.unsqueeze(-1) 897 | ).squeeze() 898 | 899 | cnt,tot,acc = 0,0,0 900 | all_loss = [] 901 | for i,mol_tree in enumerate(mol_batch): 902 | comp_nodes = [node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf] 903 | cnt += len(comp_nodes) 904 | for node in comp_nodes: 905 | label = node.cands.index(node.label) 906 | ncand = len(node.cands) 907 | cur_score = scores.narrow(0, tot, ncand) 908 | tot += ncand 909 | 910 | if cur_score.data[label] >= cur_score.max().item(): 911 | acc += 1 912 | 913 | label = create_variable(torch.LongTensor([label])) 914 | all_loss.append( self.assm_loss(cur_score.view(1,-1), label) ) 915 | 916 | all_loss = sum(all_loss) / len(mol_batch) 917 | return all_loss, acc * 1.0 / cnt 918 | 919 | def decode(self, x_tree_vecs, x_mol_vecs, prob_decode): 920 | 921 | assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1 922 | 923 | pred_root,pred_nodes = self.decoder.decode(x_tree_vecs, prob_decode) 924 | if len(pred_nodes) == 0: return None 925 | elif len(pred_nodes) == 1: return pred_root.smiles 926 | 927 | #Mark nid & is_leaf & atommap 928 | for i,node in enumerate(pred_nodes): 929 | node.nid = i + 1 930 | node.is_leaf = (len(node.neighbors) == 1) 931 | if len(node.neighbors) > 1: 932 | set_atommap(node.mol, node.nid) 933 | 934 | scope = [(0, len(pred_nodes))] 935 | jtenc_holder,mess_dict = JTNNEncoder.tensorize_nodes(pred_nodes, scope) 936 | _,tree_mess = self.jtnn(*jtenc_holder) 937 | tree_mess = (tree_mess, mess_dict) #Important: tree_mess is a matrix, mess_dict is a python dict 938 | 939 | x_mol_vecs = self.A_assm(x_mol_vecs).squeeze() #bilinear 940 | 941 | cur_mol = copy_edit_mol(pred_root.mol) 942 | global_amap = [{}] + [{} for node in pred_nodes] 943 | global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()} 944 | 945 | cur_mol,_ = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=True) 946 | if cur_mol is None: 947 | cur_mol = copy_edit_mol(pred_root.mol) 948 | global_amap = [{}] + [{} for node in pred_nodes] 949 | global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()} 950 | cur_mol,pre_mol = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=False) 951 | if cur_mol is None: cur_mol = pre_mol 952 | 953 | if cur_mol is None: 954 | return None 955 | 956 | 957 | 958 | cur_mol = cur_mol.GetMol() 959 | set_atommap(cur_mol) 960 | cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol)) 961 | return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None 962 | 963 | def dfs_assemble(self, y_tree_mess, x_mol_vecs, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node, prob_decode, check_aroma): 964 | fa_nid = fa_node.nid if fa_node is not None else -1 965 | prev_nodes = [fa_node] if fa_node is not None else [] 966 | 967 | children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid] 968 | neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1] 969 | neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True) 970 | singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1] 971 | neighbors = singletons + neighbors 972 | 973 | cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid] 974 | cands,aroma_score = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap) 975 | if len(cands) == 0 or (sum(aroma_score) < 0 and check_aroma): 976 | return None, cur_mol 977 | 978 | cand_smiles,cand_amap = zip(*cands) 979 | aroma_score = torch.Tensor(aroma_score) 980 | cands = [(smiles, all_nodes, cur_node) for smiles in cand_smiles] 981 | 982 | if len(cands) > 1: 983 | jtmpn_holder = JTMPN.tensorize(cands, y_tree_mess[1]) 984 | fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder 985 | cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess[0]) 986 | scores = torch.mv(cand_vecs, x_mol_vecs) + aroma_score 987 | else: 988 | scores = torch.Tensor([1.0]) 989 | 990 | if prob_decode: 991 | probs = F.softmax(scores.view(1,-1), dim=1).squeeze() + 1e-7 #prevent prob = 0 992 | cand_idx = torch.multinomial(probs, probs.numel()) 993 | else: 994 | _,cand_idx = torch.sort(scores, descending=True) 995 | 996 | backup_mol = Chem.RWMol(cur_mol) 997 | pre_mol = cur_mol 998 | for i in range(cand_idx.numel()): 999 | cur_mol = Chem.RWMol(backup_mol) 1000 | pred_amap = cand_amap[cand_idx[i].item()] 1001 | new_global_amap = copy.deepcopy(global_amap) 1002 | 1003 | for nei_id,ctr_atom,nei_atom in pred_amap: 1004 | if nei_id == fa_nid: 1005 | continue 1006 | new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node.nid][ctr_atom] 1007 | 1008 | cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached 1009 | new_mol = cur_mol.GetMol() 1010 | new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol)) 1011 | 1012 | if new_mol is None: continue 1013 | 1014 | has_error = False 1015 | for nei_node in children: 1016 | if nei_node.is_leaf: continue 1017 | tmp_mol, tmp_mol2 = self.dfs_assemble(y_tree_mess, x_mol_vecs, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node, prob_decode, check_aroma) 1018 | if tmp_mol is None: 1019 | has_error = True 1020 | if i == 0: pre_mol = tmp_mol2 1021 | break 1022 | cur_mol = tmp_mol 1023 | 1024 | if not has_error: return cur_mol, cur_mol 1025 | 1026 | return None, pre_mol 1027 | 1028 | #Reading the input 1029 | vocab = [x.strip("\r\n ") for x in open('train.txt')] 1030 | #Building the vocabulary 1031 | vocab = Vocab(vocab,mol_trees) 1032 | 1033 | #Defining the model 1034 | model = JTNNVAE(vocab, int(450), int(56), int(20), int(3)) 1035 | 1036 | print("Model") 1037 | print(model) 1038 | 1039 | 1040 | 1041 | 1042 | -------------------------------------------------------------------------------- /genmol/JTVAE/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math, random, sys 4 | from optparse import OptionParser 5 | import pickle 6 | import rdkit 7 | import json 8 | import rdkit.Chem as Chem 9 | from scipy.sparse import csr_matrix 10 | from scipy.sparse.csgraph import minimum_spanning_tree 11 | from collections import defaultdict 12 | import copy 13 | import torch.optim as optim 14 | import torch.optim.lr_scheduler as lr_scheduler 15 | from torch.utils.data import Dataset, DataLoader 16 | from torch.autograd import Variable 17 | import numpy as np 18 | from collections import deque 19 | import os, random 20 | import torch.nn.functional as F 21 | import pdb 22 | 23 | benzynes_i = ['C1=CC=CC=C1', 'C1=CC=NC=C1', 'C1=CC=NN=C1', 'C1=CN=CC=N1', 'C1=CN=CN=C1', 'C1=CN=NC=N1', 'C1=CN=NN=C1', 'C1=NC=NC=N1', 'C1=NN=CN=N1'] 24 | penzynes_i = ['C1=C[NH]C=C1', 'C1=C[NH]C=N1', 'C1=C[NH]N=C1', 'C1=C[NH]N=N1', 'C1=COC=C1', 'C1=COC=N1', 'C1=CON=C1', 'C1=CSC=C1', 'C1=CSC=N1', 'C1=CSN=C1', 'C1=CSN=N1', 'C1=NN=C[NH]1', 'C1=NN=CO1', 'C1=NN=CS1', 'C1=N[NH]C=N1', 'C1=N[NH]N=C1', 'C1=N[NH]N=N1', 'C1=NN=N[NH]1', 'C1=NN=NS1', 'C1=NOC=N1', 'C1=NON=C1', 'C1=NSC=N1', 'C1=NSN=C1'] 25 | 26 | 27 | MST_MAX_WEiGHT_10 = 100 28 | MAX_NCAND_10 = 2000 29 | 30 | def set_atommap(mol, num=0): 31 | for atom in mol.GetAtoms(): 32 | atom.SetAtomMapNum(num) 33 | 34 | def get_mol(smiles): 35 | mol = Chem.MolFromSmiles(smiles) 36 | if mol is None: 37 | return None 38 | Chem.Kekulize(mol) 39 | return mol 40 | 41 | def get_smiles(mol): 42 | return Chem.MolToSmiles(mol, kekuleSmiles=True) 43 | 44 | def sanitize(mol): 45 | try: 46 | smiles = get_smiles(mol) 47 | mol = get_mol(smiles) 48 | except Exception as e: 49 | return None 50 | return mol 51 | 52 | def copy_atom(atom): 53 | new_atom = Chem.Atom(atom.GetSymbol()) 54 | new_atom.SetFormalCharge(atom.GetFormalCharge()) 55 | new_atom.SetAtomMapNum(atom.GetAtomMapNum()) 56 | return new_atom 57 | 58 | def copy_edit_mol(mol): 59 | new_mol = Chem.RWMol(Chem.MolFromSmiles('')) 60 | for atom in mol.GetAtoms(): 61 | new_atom = copy_atom(atom) 62 | new_mol.AddAtom(new_atom) 63 | for bond in mol.GetBonds(): 64 | a1 = bond.GetBeginAtom().GetIdx() 65 | a2 = bond.GetEndAtom().GetIdx() 66 | bt = bond.GetBondType() 67 | new_mol.AddBond(a1, a2, bt) 68 | return new_mol 69 | 70 | def get_clique_mol(mol, atoms): 71 | smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True) 72 | new_mol = Chem.MolFromSmiles(smiles, sanitize=False) 73 | new_mol = copy_edit_mol(new_mol).GetMol() 74 | new_mol = sanitize(new_mol) #We assume this is not None 75 | return new_mol 76 | 77 | def tree_decomp(mol): 78 | 79 | n_atoms = mol.GetNumAtoms() 80 | if n_atoms == 1: #special case 81 | return [[0]], [] 82 | 83 | cliques = [] 84 | for bond in mol.GetBonds(): 85 | a1 = bond.GetBeginAtom().GetIdx() 86 | a2 = bond.GetEndAtom().GetIdx() 87 | if not bond.IsInRing(): 88 | cliques.append([a1,a2]) 89 | 90 | ssr = [list(x) for x in Chem.GetSymmSSSR(mol)] 91 | cliques.extend(ssr) 92 | 93 | nei_list = [[] for i in range(n_atoms)] 94 | for i in range(len(cliques)): 95 | for atom in cliques[i]: 96 | nei_list[atom].append(i) 97 | 98 | #Merge Rings with intersection > 2 atoms 99 | for i in range(len(cliques)): 100 | if len(cliques[i]) <= 2: continue 101 | for atom in cliques[i]: 102 | for j in nei_list[atom]: 103 | if i >= j or len(cliques[j]) <= 2: continue 104 | inter = set(cliques[i]) & set(cliques[j]) 105 | if len(inter) > 2: 106 | cliques[i].extend(cliques[j]) 107 | cliques[i] = list(set(cliques[i])) 108 | cliques[j] = [] 109 | 110 | cliques = [c for c in cliques if len(c) > 0] 111 | nei_list = [[] for i in range(n_atoms)] 112 | for i in range(len(cliques)): 113 | for atom in cliques[i]: 114 | nei_list[atom].append(i) 115 | 116 | #Build edges and add singleton cliques 117 | edges = defaultdict(int) 118 | for atom in range(n_atoms): 119 | if len(nei_list[atom]) <= 1: 120 | continue 121 | cnei = nei_list[atom] 122 | bonds = [c for c in cnei if len(cliques[c]) == 2] 123 | rings = [c for c in cnei if len(cliques[c]) > 4] 124 | if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): #In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with. 125 | cliques.append([atom]) 126 | c2 = len(cliques) - 1 127 | for c1 in cnei: 128 | edges[(c1,c2)] = 1 129 | elif len(rings) > 2: #Multiple (n>2) complex rings 130 | cliques.append([atom]) 131 | c2 = len(cliques) - 1 132 | for c1 in cnei: 133 | edges[(c1,c2)] = MST_MAX_WEiGHT_10 - 1 134 | else: 135 | for i in range(len(cnei)): 136 | for j in range(i + 1, len(cnei)): 137 | c1,c2 = cnei[i],cnei[j] 138 | inter = set(cliques[c1]) & set(cliques[c2]) 139 | if edges[(c1,c2)] < len(inter): 140 | edges[(c1,c2)] = len(inter) #cnei[i] < cnei[j] by construction 141 | 142 | edges = [u + (MST_MAX_WEiGHT_10-v,) for u,v in edges.items()] 143 | if len(edges) == 0: 144 | return cliques, edges 145 | 146 | #Compute Maximum Spanning Tree 147 | row,col,data = zip(*edges) 148 | n_clique = len(cliques) 149 | clique_graph = csr_matrix( (data,(row,col)), shape=(n_clique,n_clique) ) 150 | junc_tree = minimum_spanning_tree(clique_graph) 151 | row,col = junc_tree.nonzero() 152 | edges = [(row[i],col[i]) for i in range(len(row))] 153 | return (cliques, edges) 154 | 155 | 156 | 157 | def atom_equal(a1, a2): 158 | return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge() 159 | 160 | #Bond type not considered because all aromatic (so SINGLE matches DOUBLE) 161 | def ring_bond_equal(b1, b2, reverse=False): 162 | b1 = (b1.GetBeginAtom(), b1.GetEndAtom()) 163 | if reverse: 164 | b2 = (b2.GetEndAtom(), b2.GetBeginAtom()) 165 | else: 166 | b2 = (b2.GetBeginAtom(), b2.GetEndAtom()) 167 | return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1]) 168 | 169 | def attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap): 170 | prev_nids = [node.nid for node in prev_nodes] 171 | for nei_node in prev_nodes + neighbors: 172 | nei_id,nei_mol = nei_node.nid,nei_node.mol 173 | amap = nei_amap[nei_id] 174 | for atom in nei_mol.GetAtoms(): 175 | if atom.GetIdx() not in amap: 176 | new_atom = copy_atom(atom) 177 | amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom) 178 | 179 | if nei_mol.GetNumBonds() == 0: 180 | nei_atom = nei_mol.GetAtomWithIdx(0) 181 | ctr_atom = ctr_mol.GetAtomWithIdx(amap[0]) 182 | ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum()) 183 | else: 184 | for bond in nei_mol.GetBonds(): 185 | a1 = amap[bond.GetBeginAtom().GetIdx()] 186 | a2 = amap[bond.GetEndAtom().GetIdx()] 187 | if ctr_mol.GetBondBetweenAtoms(a1, a2) is None: 188 | ctr_mol.AddBond(a1, a2, bond.GetBondType()) 189 | elif nei_id in prev_nids: #father node overrides 190 | ctr_mol.RemoveBond(a1, a2) 191 | ctr_mol.AddBond(a1, a2, bond.GetBondType()) 192 | return ctr_mol 193 | 194 | def local_attach(ctr_mol, neighbors, prev_nodes, amap_list): 195 | ctr_mol = copy_edit_mol(ctr_mol) 196 | nei_amap = {nei.nid:{} for nei in prev_nodes + neighbors} 197 | 198 | for nei_id,ctr_atom,nei_atom in amap_list: 199 | nei_amap[nei_id][nei_atom] = ctr_atom 200 | 201 | ctr_mol = attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap) 202 | return ctr_mol.GetMol() 203 | 204 | #This version records idx mapping between ctr_mol and nei_mol 205 | def enum_attach(ctr_mol, nei_node, amap, singletons): 206 | nei_mol,nei_idx = nei_node.mol,nei_node.nid 207 | att_confs = [] 208 | black_list = [atom_idx for nei_id,atom_idx,_ in amap if nei_id in singletons] 209 | ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list] 210 | ctr_bonds = [bond for bond in ctr_mol.GetBonds()] 211 | 212 | if nei_mol.GetNumBonds() == 0: #neighbor singleton 213 | nei_atom = nei_mol.GetAtomWithIdx(0) 214 | used_list = [atom_idx for _,atom_idx,_ in amap] 215 | for atom in ctr_atoms: 216 | if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list: 217 | new_amap = amap + [(nei_idx, atom.GetIdx(), 0)] 218 | att_confs.append( new_amap ) 219 | 220 | elif nei_mol.GetNumBonds() == 1: #neighbor is a bond 221 | bond = nei_mol.GetBondWithIdx(0) 222 | bond_val = int(bond.GetBondTypeAsDouble()) 223 | b1,b2 = bond.GetBeginAtom(), bond.GetEndAtom() 224 | 225 | for atom in ctr_atoms: 226 | #Optimize if atom is carbon (other atoms may change valence) 227 | if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val: 228 | continue 229 | if atom_equal(atom, b1): 230 | new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())] 231 | att_confs.append( new_amap ) 232 | elif atom_equal(atom, b2): 233 | new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())] 234 | att_confs.append( new_amap ) 235 | else: 236 | #intersection is an atom 237 | for a1 in ctr_atoms: 238 | for a2 in nei_mol.GetAtoms(): 239 | if atom_equal(a1, a2): 240 | #Optimize if atom is carbon (other atoms may change valence) 241 | if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4: 242 | continue 243 | new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())] 244 | att_confs.append( new_amap ) 245 | 246 | #intersection is an bond 247 | if ctr_mol.GetNumBonds() > 1: 248 | for b1 in ctr_bonds: 249 | for b2 in nei_mol.GetBonds(): 250 | if ring_bond_equal(b1, b2): 251 | new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())] 252 | att_confs.append( new_amap ) 253 | 254 | if ring_bond_equal(b1, b2, reverse=True): 255 | new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())] 256 | att_confs.append( new_amap ) 257 | return att_confs 258 | 259 | #Try rings first: Speed-Up 260 | def enum_assemble(node, neighbors, prev_nodes=[], prev_amap=[]): 261 | all_attach_confs = [] 262 | singletons = [nei_node.nid for nei_node in neighbors + prev_nodes if nei_node.mol.GetNumAtoms() == 1] 263 | 264 | def search(cur_amap, depth): 265 | if len(all_attach_confs) > MAX_NCAND_10: 266 | return 267 | if depth == len(neighbors): 268 | all_attach_confs.append(cur_amap) 269 | return 270 | 271 | nei_node = neighbors[depth] 272 | cand_amap = enum_attach(node.mol, nei_node, cur_amap, singletons) 273 | cand_smiles = set() 274 | candidates = [] 275 | for amap in cand_amap: 276 | cand_mol = local_attach(node.mol, neighbors[:depth+1], prev_nodes, amap) 277 | cand_mol = sanitize(cand_mol) 278 | if cand_mol is None: 279 | continue 280 | smiles = get_smiles(cand_mol) 281 | if smiles in cand_smiles: 282 | continue 283 | cand_smiles.add(smiles) 284 | candidates.append(amap) 285 | 286 | if len(candidates) == 0: 287 | return 288 | 289 | for new_amap in candidates: 290 | search(new_amap, depth + 1) 291 | 292 | search(prev_amap, 0) 293 | cand_smiles = set() 294 | candidates = [] 295 | aroma_score = [] 296 | for amap in all_attach_confs: 297 | cand_mol = local_attach(node.mol, neighbors, prev_nodes, amap) 298 | cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol)) 299 | smiles = Chem.MolToSmiles(cand_mol) 300 | if smiles in cand_smiles or check_singleton(cand_mol, node, neighbors) == False: 301 | continue 302 | cand_smiles.add(smiles) 303 | candidates.append( (smiles,amap) ) 304 | aroma_score.append( check_aroma(cand_mol, node, neighbors) ) 305 | 306 | return candidates, aroma_score 307 | 308 | def check_singleton(cand_mol, ctr_node, nei_nodes): 309 | rings = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() > 2] 310 | singletons = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() == 1] 311 | if len(singletons) > 0 or len(rings) == 0: return True 312 | 313 | n_leaf2_atoms = 0 314 | for atom in cand_mol.GetAtoms(): 315 | nei_leaf_atoms = [a for a in atom.GetNeighbors() if not a.IsInRing()] #a.GetDegree() == 1] 316 | if len(nei_leaf_atoms) > 1: 317 | n_leaf2_atoms += 1 318 | 319 | return n_leaf2_atoms == 0 320 | 321 | def check_aroma(cand_mol, ctr_node, nei_nodes): 322 | rings = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() >= 3] 323 | if len(rings) < 2: return 0 #Only multi-ring system needs to be checked 324 | 325 | get_nid = lambda x: 0 if x.is_leaf else x.nid 326 | benzynes = [get_nid(node) for node in nei_nodes + [ctr_node] if node.smiles in benzynes_i] 327 | penzynes = [get_nid(node) for node in nei_nodes + [ctr_node] if node.smiles in penzynes_i] 328 | if len(benzynes) + len(penzynes) == 0: 329 | return 0 #No specific aromatic rings 330 | 331 | n_aroma_atoms = 0 332 | for atom in cand_mol.GetAtoms(): 333 | if atom.GetAtomMapNum() in benzynes+penzynes and atom.GetIsAromatic(): 334 | n_aroma_atoms += 1 335 | 336 | if n_aroma_atoms >= len(benzynes) * 4 + len(penzynes) * 3: 337 | return 1000 338 | else: 339 | return -0.001 340 | 341 | #Only used for debugging purpose 342 | def dfs_assemble(cur_mol, global_amap, fa_amap, cur_node, fa_node): 343 | fa_nid = fa_node.nid if fa_node is not None else -1 344 | prev_nodes = [fa_node] if fa_node is not None else [] 345 | 346 | children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid] 347 | neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1] 348 | neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True) 349 | singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1] 350 | neighbors = singletons + neighbors 351 | 352 | cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid] 353 | cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap) 354 | 355 | cand_smiles,cand_amap = zip(*cands) 356 | label_idx = cand_smiles.index(cur_node.label) 357 | label_amap = cand_amap[label_idx] 358 | 359 | for nei_id,ctr_atom,nei_atom in label_amap: 360 | if nei_id == fa_nid: 361 | continue 362 | global_amap[nei_id][nei_atom] = global_amap[cur_node.nid][ctr_atom] 363 | 364 | cur_mol = attach_mols(cur_mol, children, [], global_amap) #father is already attached 365 | for nei_node in children: 366 | if not nei_node.is_leaf: 367 | dfs_assemble(cur_mol, global_amap, label_amap, nei_node, cur_node) 368 | 369 | 370 | 371 | class MolTreeNode(object): 372 | 373 | def __init__(self, smiles, clique=[]): 374 | self.smiles = smiles 375 | self.mol = get_mol(self.smiles) 376 | 377 | self.clique = [x for x in clique] #copy 378 | self.neighbors = [] 379 | 380 | def add_neighbor(self, nei_node): 381 | self.neighbors.append(nei_node) 382 | 383 | def recover(self, original_mol): 384 | clique = [] 385 | clique.extend(self.clique) 386 | if not self.is_leaf: 387 | for cidx in self.clique: 388 | original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid) 389 | 390 | for nei_node in self.neighbors: 391 | clique.extend(nei_node.clique) 392 | if nei_node.is_leaf: #Leaf node, no need to mark 393 | continue 394 | for cidx in nei_node.clique: 395 | #allow singleton node override the atom mapping 396 | if cidx not in self.clique or len(nei_node.clique) == 1: 397 | atom = original_mol.GetAtomWithIdx(cidx) 398 | atom.SetAtomMapNum(nei_node.nid) 399 | 400 | clique = list(set(clique)) 401 | label_mol = get_clique_mol(original_mol, clique) 402 | self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol))) 403 | 404 | for cidx in clique: 405 | original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0) 406 | 407 | return self.label 408 | 409 | def assemble(self): 410 | neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1] 411 | neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True) 412 | singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1] 413 | neighbors = singletons + neighbors 414 | 415 | cands,aroma = enum_assemble(self, neighbors) 416 | new_cands = [cand for i,cand in enumerate(cands) if aroma[i] >= 0] 417 | if len(new_cands) > 0: cands = new_cands 418 | 419 | if len(cands) > 0: 420 | self.cands, _ = zip(*cands) 421 | self.cands = list(self.cands) 422 | else: 423 | self.cands = [] 424 | 425 | class MolTree(object): 426 | 427 | def __init__(self, smiles): 428 | self.smiles = smiles 429 | self.mol = get_mol(smiles) 430 | 431 | 432 | cliques, edges = tree_decomp(self.mol) 433 | self.nodes = [] 434 | root = 0 435 | for i,c in enumerate(cliques): 436 | cmol = get_clique_mol(self.mol, c) 437 | node = MolTreeNode(get_smiles(cmol), c) 438 | self.nodes.append(node) 439 | if min(c) == 0: root = i 440 | 441 | for x,y in edges: 442 | self.nodes[x].add_neighbor(self.nodes[y]) 443 | self.nodes[y].add_neighbor(self.nodes[x]) 444 | 445 | if root > 0: 446 | self.nodes[0],self.nodes[root] = self.nodes[root],self.nodes[0] 447 | 448 | for i,node in enumerate(self.nodes): 449 | node.nid = i + 1 450 | if len(node.neighbors) > 1: 451 | set_atommap(node.mol, node.nid) 452 | node.is_leaf = (len(node.neighbors) == 1) 453 | 454 | def size(self): 455 | return len(self.nodes) 456 | 457 | def recover(self): 458 | for node in self.nodes: 459 | node.recover(self.mol) 460 | 461 | def assemble(self): 462 | for node in self.nodes: 463 | node.assemble() 464 | 465 | 466 | def tensorize_trees(smiles, assm=True): 467 | mol_tree = MolTree(smiles) 468 | mol_tree.recover() 469 | if assm: 470 | mol_tree.assemble() 471 | for node in mol_tree.nodes: 472 | if node.label not in node.cands: 473 | node.cands.append(node.label) 474 | 475 | del mol_tree.mol 476 | for node in mol_tree.nodes: 477 | del node.mol 478 | 479 | return mol_tree 480 | 481 | 482 | 483 | splits = 4 484 | 485 | with open('train.txt') as f: 486 | data = [line.strip("\r\n ").split()[0] for line in f] 487 | 488 | mol_trees=[] 489 | for i in range(0,len(data)): 490 | #Generating the molecular tree for each molecule and appending them to a list 491 | mol_trees.append(tensorize_trees(data[i])) 492 | 493 | print("Molecular trees") 494 | print(mol_trees) 495 | 496 | 497 | trees_data=[] 498 | l = (len(mol_trees) + splits - 1) / splits 499 | #Making the batches of mol trees 500 | for i in range(splits): 501 | s = i * l 502 | sub_data = mol_trees[int(s) : int(s + l)] 503 | trees_data.append(sub_data) 504 | -------------------------------------------------------------------------------- /genmol/JTVAE/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math, random, sys 4 | from optparse import OptionParser 5 | import pickle 6 | import rdkit 7 | import json 8 | import rdkit.Chem as Chem 9 | from scipy.sparse import csr_matrix 10 | from scipy.sparse.csgraph import minimum_spanning_tree 11 | from collections import defaultdict 12 | import copy 13 | import torch.optim as optim 14 | import torch.optim.lr_scheduler as lr_scheduler 15 | from torch.utils.data import Dataset, DataLoader 16 | from torch.autograd import Variable 17 | import numpy as np 18 | from collections import deque 19 | import os, random 20 | import torch.nn.functional as F 21 | import pdb 22 | 23 | from jvae_model import * 24 | 25 | path = "savedmodel.pth" 26 | model=JTNNVAE(vocab, int(450), int(56), int(20), int(3)) 27 | model.load_state_dict(torch.load(path)) 28 | torch.manual_seed(0) 29 | print("Molecules generated") 30 | for i in range(10): 31 | print(model.sample_prior()) 32 | 33 | 34 | -------------------------------------------------------------------------------- /genmol/JTVAE/savedmodel.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/JTVAE/savedmodel.pth -------------------------------------------------------------------------------- /genmol/JTVAE/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math, random, sys 4 | from optparse import OptionParser 5 | import pickle 6 | import rdkit 7 | import json 8 | import rdkit.Chem as Chem 9 | from scipy.sparse import csr_matrix 10 | from scipy.sparse.csgraph import minimum_spanning_tree 11 | from collections import defaultdict 12 | import copy 13 | import torch.optim as optim 14 | import torch.optim.lr_scheduler as lr_scheduler 15 | from torch.utils.data import Dataset, DataLoader 16 | from torch.autograd import Variable 17 | import numpy as np 18 | from collections import deque 19 | import os, random 20 | import torch.nn.functional as F 21 | import pdb 22 | 23 | 24 | from jvae_preprocess import * 25 | from jvae_model import * 26 | 27 | 28 | def set_batch_nodeID(mol_batch, vocab): 29 | tot = 0 30 | 31 | for mol_tree in mol_batch: 32 | 33 | for node in mol_tree.nodes: 34 | node.idx = tot 35 | 36 | s_to_m=Chem.MolFromSmiles(node.smiles) 37 | m_to_s=Chem.MolToSmiles(s_to_m,kekuleSmiles=False) 38 | node.wid = vocab.get_index(m_to_s) 39 | 40 | tot += 1 41 | 42 | 43 | def tensorize_x(tree_batch, vocab,assm=True): 44 | set_batch_nodeID(tree_batch, vocab) 45 | smiles_batch = [tree.smiles for tree in tree_batch] 46 | jtenc_holder,mess_dict = JTNNEncoder.tensorize(tree_batch) 47 | jtenc_holder = jtenc_holder 48 | mpn_holder = MPN.tensorize(smiles_batch) 49 | 50 | if assm is False: 51 | return tree_batch, jtenc_holder, mpn_holder 52 | 53 | cands = [] 54 | batch_idx = [] 55 | for i,mol_tree in enumerate(tree_batch): 56 | for node in mol_tree.nodes: 57 | if node.is_leaf or len(node.cands) == 1: continue 58 | cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] ) 59 | batch_idx.extend([i] * len(node.cands)) 60 | 61 | jtmpn_holder = JTMPN.tensorize(cands, mess_dict) 62 | batch_idx = torch.LongTensor(batch_idx) 63 | 64 | return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx) 65 | 66 | 67 | class MolTreeDataset(Dataset): 68 | 69 | def __init__(self, data, vocab, assm=True): 70 | self.data = data 71 | self.vocab = vocab 72 | self.assm = assm 73 | 74 | def __len__(self): 75 | return len(self.data) 76 | 77 | def __getitem__(self, idx): 78 | return tensorize_x(self.data[idx], self.vocab,assm=self.assm) 79 | 80 | 81 | 82 | def get_loader(data_1,vocab): 83 | 84 | for i in range(0,len(data_1)): 85 | 86 | if True: 87 | random.shuffle(data_1[i]) 88 | 89 | batches=[] 90 | for j in range(0,len(data_1[i])): 91 | batches.append([]) 92 | 93 | for j in range(0,len(data_1[i])): 94 | 95 | batches[j].append(data_1[i][j]) 96 | 97 | dataset = MolTreeDataset(batches, vocab,True) 98 | 99 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lambda x:x[0]) 100 | 101 | for b in dataloader: 102 | yield b 103 | 104 | del batches, dataset, dataloader 105 | 106 | 107 | 108 | 109 | for param in model.parameters(): 110 | if param.dim() == 1: 111 | nn.init.constant_(param, 0) 112 | else: 113 | nn.init.xavier_normal_(param) 114 | 115 | 116 | print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)) 117 | 118 | 119 | 120 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 121 | scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9) 122 | scheduler.step() 123 | 124 | param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()])) 125 | grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None])) 126 | 127 | total_step = 0 128 | beta = 0.0 129 | meters = np.zeros(4) 130 | path = "savedmodel.pth" 131 | print("Training") 132 | #Training starts here... 133 | for epoch in range(10): 134 | 135 | #Loading the data 136 | loader=get_loader(trees_data,vocab) 137 | 138 | for batch in loader: 139 | total_step += 1 140 | try: 141 | model.zero_grad() 142 | #Send the batch to the model 143 | loss, kl_div, wacc, tacc, sacc = model(batch, beta) 144 | #Backward propagation 145 | loss.backward() 146 | nn.utils.clip_grad_norm_(model.parameters(),50.0) 147 | optimizer.step() 148 | except Exception as e: 149 | print(e) 150 | continue 151 | 152 | meters = meters + np.array([kl_div, wacc * 100, tacc * 100, sacc * 100]) 153 | 154 | 155 | torch.save(model.state_dict(), path) 156 | 157 | 158 | scheduler.step() 159 | #print("learning rate: %.6f" % scheduler.get_lr()[0]) 160 | 161 | beta = min(1.0, beta + 0.002) 162 | 163 | print("Epoch :" + str(epoch)) 164 | 165 | -------------------------------------------------------------------------------- /genmol/JTVAE/train.txt: -------------------------------------------------------------------------------- 1 | CCC(NC(=O)c1scnc1C1CC1)C(=O)N1CCOCC1 2 | O=C1OCCC1Sc1nnc(-c2c[nH]c3ccccc23)n1C1CC1 3 | CCN(C)S(=O)(=O)N1CCC(Nc2cccc(OC)c2)CC1 4 | CC(=O)Nc1cccc(NC(C)c2ccccn2)c1 5 | Cc1cc(-c2nc3sc(C4CC4)nn3c2C#N)ccc1Cl 6 | CCOCCCNC(=O)c1cc(OC)ccc1Br 7 | Cc1nc(-c2ccncc2)[nH]c(=O)c1CC(=O)NC1CCCC1 8 | C#CCN(CC#C)C(=O)c1cc2ccccc2cc1OC(F)F 9 | CCOc1ccc(CN2c3ccccc3NCC2C)cc1N 10 | NC(=O)C1CCC(CNc2cc(-c3ccccc3)nc3ccnn23)CC1 11 | Cc1csc(Sc2cc(C)nc3ncnn23)n1 12 | COCCN1CCN(C(=O)CCc2nc(C(C)(C)C)no2)CC1=O 13 | CCN(CC(C)C#N)C(=O)CN1CCN(C(=O)CC(F)(F)F)CC1 14 | Cc1ccc(C(=O)Nc2ccc(S(N)(=O)=O)cc2F)cc1C 15 | O=C(c1cccs1)N1CCc2[nH]nc(-c3cccc(F)c3)c2C1 16 | O=C(C1CC1)N1CCc2ccc(NS(=O)(=O)c3ccccc3)cc21 17 | O=C(CCNC(=O)c1ccccc1)Nc1cc(Cl)cc(Cl)c1 18 | CCN1CC(c2nc3ccccc3n2CC(=O)OC(C)C)CC1=O 19 | CCC(NC(=O)COc1ccccc1O)c1ccc(OC)cc1 20 | C=CCn1c(SCCc2c(C)noc2C)n[nH]c1=O 21 | CNC(=O)c1cc(NC(=O)N2C[C@H]3CCC[NH+]3C[C@H]2C)ccc1F 22 | CCCc1ccc([C@@H]([NH3+])C(OC)OC)cc1 23 | [NH3+]C[C@H](c1cc(Cl)cs1)N1CC[C@@H]2CCCC[C@@H]2C1 24 | CC(C)CN(CC(C)C)C(=O)NCC(C)(C)N1CCOCC1 25 | COc1cc(C(=O)N[C@H]2C[C@H]2c2cccc(Cl)c2)cc(OC)c1OC 26 | COCCNC(=O)c1cccc(N2CC[NH2+][C@@H](c3ccccc3)C2)n1 27 | Cc1ccc(-c2cc(OCC(=O)[O-])nc(NCc3ccc4c(c3)OCO4)n2)cc1 28 | COc1cccc(NC2CC[NH+](C[C@@H](O)CN3C[C@H](C)O[C@@H](C)C3)CC2)c1 29 | CC(C)[C@H](NC(=O)Nc1ccccc1)C(=O)NCc1ccco1 30 | O=C(Nc1ccccn1)N1CCC(n2cc[nH+]c2)CC1 31 | Cc1cc(C)cc([C@H]2OCC[C@H]2C[NH2+]C(C)C)c1 32 | C[C@H]1CSC(NC[C@]2(C)CCCO2)=[NH+]1 33 | Cc1ccc(O)c(Cc2ccccc2Cl)c1 34 | O=C1CCCC2=C1[C@H](c1ccc([N+](=O)[O-])o1)n1nnnc1N2 35 | C=CCS/C(N)=C(C#N)/C(C#N)=C(\N)SCC=C 36 | Cc1sc(=O)n(CCC(=O)N(C)[C@@H]2CCCC[C@H]2S(C)(=O)=O)c1C 37 | COc1cc([N+](=O)[O-])cc(/C=N/Nc2ccccc2C)c1O 38 | CCOC(=O)[C@@H]1CCCN(Cc2nn3c(=O)cc(C)nc3s2)C1 39 | Cc1sc2nc(C[NH+](C(C)C)C3CCCC3)nc(N)c2c1C 40 | O=C(NCCCc1nc2ccccc2s1)N1CCc2ccccc2C1 41 | CCOc1ccc(Br)cc1C(=O)Nc1cccc(OC[C@@H]2CCCO2)c1 42 | CCn1nc(C(=O)N2CCN(c3ccccc3)CC2)c2c1CC[NH+](Cc1cc(C)c(OC)cc1C)C2 43 | O=C(N[C@@H]1[C@@H]2CCO[C@@H]2C12CCC2)c1cnc([C@H]2CCCO2)s1 44 | CCOc1ccc(C(=O)Nc2ccc(F)cc2F)cc1 45 | O=C(C1CCC(F)(F)CC1)N1CCN(c2ncccc2F)CC1 46 | O=C(NCc1ccc(F)c(Cl)c1)[C@@H]1CSCN1C(=O)c1c[nH]c2ccccc12 47 | Cc1cc(CCNC(=O)Nc2ccccc2F)on1 48 | COCCOC(=O)C1=C(N)Oc2c(oc(CO)cc2=O)[C@H]1c1cccc(F)c1 49 | CCOC(=O)C(=O)Nc1cc(C)nn1-c1nc([O-])c2c(n1)CCC2 50 | CN(C(=O)[C@H]1CCCN(c2ncnc3onc(-c4ccc(F)cc4)c23)C1)C1CCCCC1 51 | -------------------------------------------------------------------------------- /genmol/ORGAN/Data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | 4 | Data = pd.read_csv('C:/Users/haroon_03/Desktop/smiles.csv') 5 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 6 | chars = set() 7 | for string in Data['SMILES']: 8 | chars.update(string) 9 | train_data = Data[Data['SPLIT'] == 'train'] 10 | test_data = Data[Data['SPLIT'] == 'test'] 11 | test_scaffold = Data[Data['SPLIT'] == 'test_scaffolds'] 12 | 13 | all_syms = sorted(list(chars) + ['', '', '', '']) 14 | vocabulary = all_syms 15 | 16 | c2i = {c: i for i, c in enumerate(all_syms)} 17 | i2c = {i: c for i, c in enumerate(all_syms)} 18 | 19 | train_data = (train_data['SMILES'].squeeze()).astype(str).tolist() 20 | test_scaffold = (test_scaffold['SMILES'].squeeze()).astype(str).tolist() 21 | test_data = (test_data['SMILES'].squeeze()).astype(str).tolist() 22 | -------------------------------------------------------------------------------- /genmol/ORGAN/Metrics_Reward.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | from scipy.spatial.distance import cosine as cos_distance 4 | from fcd_torch import FCD as FCDMetric 5 | from fcd_torch import calculate_frechet_distance 6 | from RewardMetrics import * 7 | from rdkit import rdBase 8 | import random 9 | 10 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 11 | 12 | def get_all_metrics(test, gen, k=None, n_jobs=1, device=device, 13 | batch_size=512, test_scaffolds=None, 14 | ptest=None, ptest_scaffolds=None, 15 | pool=None, gpu=None): 16 | """ 17 | Computes all available metrics between test (scaffold test) 18 | and generated sets of SMILES. 19 | Parameters: 20 | test: list of test SMILES 21 | gen: list of generated SMILES 22 | k: list with values for unique@k. Will calculate number of 23 | unique molecules in the first k molecules. Default [1000, 10000] 24 | n_jobs: number of workers for parallel processing 25 | device: 'cpu' or 'cuda:n', where n is GPU device number 26 | batch_size: batch size for FCD metric 27 | test_scaffolds: list of scaffold test SMILES 28 | Will compute only on the general test set if not specified 29 | ptest: dict with precalculated statistics of the test set 30 | ptest_scaffolds: dict with precalculated statistics 31 | of the scaffold test set 32 | pool: optional multiprocessing pool to use for parallelization 33 | gpu: deprecated, use `device` 34 | 35 | Available metrics: 36 | * %valid 37 | * %unique@k 38 | * Frechet ChemNet Distance (FCD) 39 | * Fragment similarity (Frag) 40 | * Scaffold similarity (Scaf) 41 | * Similarity to nearest neighbour (SNN) 42 | * Internal diversity (IntDiv) 43 | * Internal diversity 2: using square root of mean squared 44 | Tanimoto similarity (IntDiv2) 45 | * %passes filters (Filters) 46 | * Distribution difference for logP, SA, QED, NP, weight 47 | """ 48 | if k is None: 49 | k = [1000, 10000] 50 | rdBase 51 | metrics = {} 52 | if gpu is not None: 53 | warnings.warn( 54 | "parameter `gpu` is deprecated. Use `device`", 55 | DeprecationWarning 56 | ) 57 | if gpu == -1: 58 | device = 'cpu' 59 | else: 60 | device = 'cuda:{}'.format(gpu) 61 | close_pool = False 62 | if pool is None: 63 | if n_jobs != 1: 64 | pool = Pool(n_jobs) 65 | close_pool = True 66 | else: 67 | pool = 1 68 | metrics['valid'] = fraction_valid(gen) 69 | gen = remove_invalid(gen, canonize=True) 70 | if not isinstance(k, (list, tuple)): 71 | k = [k] 72 | for _k in k: 73 | metrics['unique@{}'.format(_k)] = fraction_unique(gen, _k) 74 | 75 | if ptest is None: 76 | ptest = compute_intermediate_statistics(test, n_jobs=n_jobs, 77 | device=device, 78 | batch_size=batch_size, 79 | pool=pool) 80 | if test_scaffolds is not None and ptest_scaffolds is None: 81 | ptest_scaffolds = compute_intermediate_statistics( 82 | test_scaffolds, n_jobs=n_jobs, 83 | device=device, batch_size=batch_size, 84 | pool=pool 85 | ) 86 | mols = mapper(pool)(get_mol, gen) 87 | kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size} 88 | kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size} 89 | metrics['FCD/Test'] = FCDMetric(**kwargs_fcd)(gen=gen, pref=ptest['FCD']) 90 | metrics['SNN/Test'] = SNNMetric(**kwargs)(gen=mols, pref=ptest['SNN']) 91 | metrics['Frag/Test'] = FragMetric(**kwargs)(gen=mols, pref=ptest['Frag']) 92 | metrics['Scaf/Test'] = ScafMetric(**kwargs)(gen=mols, pref=ptest['Scaf']) 93 | if ptest_scaffolds is not None: 94 | metrics['FCD/TestSF'] = FCDMetric(**kwargs_fcd)( 95 | gen=gen, pref=ptest_scaffolds['FCD'] 96 | ) 97 | metrics['SNN/TestSF'] = SNNMetric(**kwargs)( 98 | gen=mols, pref=ptest_scaffolds['SNN'] 99 | ) 100 | metrics['Frag/TestSF'] = FragMetric(**kwargs)( 101 | gen=mols, pref=ptest_scaffolds['Frag'] 102 | ) 103 | metrics['Scaf/TestSF'] = ScafMetric(**kwargs)( 104 | gen=mols, pref=ptest_scaffolds['Scaf'] 105 | ) 106 | 107 | metrics['IntDiv'] = internal_diversity(mols, pool, device=device) 108 | metrics['IntDiv2'] = internal_diversity(mols, pool, device=device, p=2) 109 | metrics['Filters'] = fraction_passes_filters(mols, pool) 110 | 111 | # Properties 112 | for name, func in [('logP', logP), ('SA', SA), 113 | ('QED', QED), ('NP', NP), 114 | ('weight', weight)]: 115 | metrics[name] = FrechetMetric(func, **kwargs)(gen=mols, 116 | pref=ptest[name]) 117 | enable_rdkit_log() 118 | if close_pool: 119 | pool.terminate() 120 | return metrics 121 | 122 | 123 | def compute_intermediate_statistics(smiles, n_jobs=1, device='cpu', 124 | batch_size=512, pool=None): 125 | """ 126 | The function precomputes statistics such as mean and variance for FCD, etc. 127 | It is useful to compute the statistics for test and scaffold test sets to 128 | speedup metrics calculation. 129 | """ 130 | close_pool = False 131 | if pool is None: 132 | if n_jobs != 1: 133 | pool = Pool(n_jobs) 134 | close_pool = True 135 | else: 136 | pool = 1 137 | statistics = {} 138 | mols = mapper(pool)(get_mol, smiles) 139 | kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size} 140 | kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size} 141 | statistics['FCD'] = FCDMetric(**kwargs_fcd).precalc(smiles) 142 | statistics['SNN'] = SNNMetric(**kwargs).precalc(mols) 143 | statistics['Frag'] = FragMetric(**kwargs).precalc(mols) 144 | statistics['Scaf'] = ScafMetric(**kwargs).precalc(mols) 145 | for name, func in [('logP', logP), ('SA', SA), 146 | ('QED', QED), ('NP', NP), 147 | ('weight', weight)]: 148 | statistics[name] = FrechetMetric(func, **kwargs).precalc(mols) 149 | if close_pool: 150 | pool.terminate() 151 | return statistics 152 | 153 | 154 | def fraction_passes_filters(gen, n_jobs=1): 155 | """ 156 | Computes the fraction of molecules that pass filters: 157 | * MCF 158 | * PAINS 159 | * Only allowed atoms ('C','N','S','O','F','Cl','Br','H') 160 | * No charges 161 | """ 162 | passes = mapper(n_jobs)(mol_passes_filters, gen) 163 | return np.mean(passes) 164 | 165 | 166 | def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan', 167 | gen_fps=None, p=1): 168 | """ 169 | Computes internal diversity as: 170 | 1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y)) 171 | """ 172 | if gen_fps is None: 173 | gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs) 174 | return 1 - (average_agg_tanimoto(gen_fps, gen_fps, 175 | agg='mean', device=device, p=p)).mean() 176 | 177 | 178 | def fraction_unique(gen, k=None, n_jobs=1, check_validity=True): 179 | """ 180 | Computes a number of unique molecules 181 | Parameters: 182 | gen: list of SMILES 183 | k: compute unique@k 184 | n_jobs: number of threads for calculation 185 | check_validity: raises ValueError if invalid molecules are present 186 | """ 187 | if k is not None: 188 | if len(gen) < k: 189 | warnings.warn( 190 | "Can't compute unique@{}.".format(k) + 191 | "gen contains only {} molecules".format(len(gen)) 192 | ) 193 | gen = gen[:k] 194 | canonic = set(mapper(n_jobs)(canonic_smiles, gen)) 195 | if None in canonic and check_validity: 196 | raise ValueError("Invalid molecule passed to unique@k") 197 | return len(canonic) / len(gen) 198 | 199 | 200 | def fraction_valid(gen, n_jobs=1): 201 | """ 202 | Computes a number of valid molecules 203 | Parameters: 204 | gen: list of SMILES 205 | n_jobs: number of threads for calculation 206 | """ 207 | gen = mapper(n_jobs)(get_mol, gen) 208 | return 1 - gen.count(None) / len(gen) 209 | 210 | 211 | def remove_invalid(gen, canonize=True, n_jobs=1): 212 | """ 213 | Removes invalid molecules from the dataset 214 | """ 215 | if not canonize: 216 | mols = mapper(n_jobs)(get_mol, gen) 217 | return [gen_ for gen_, mol in zip(gen, mols) if mol is not None] 218 | else: 219 | return [x for x in mapper(n_jobs)(canonic_smiles, gen) if x is not None] 220 | 221 | 222 | class Metric: 223 | def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs): 224 | self.n_jobs = n_jobs 225 | self.device = device 226 | self.batch_size = batch_size 227 | for k, v in kwargs.values(): 228 | setattr(self, k, v) 229 | 230 | def __call__(self, ref=None, gen=None, pref=None, pgen=None): 231 | assert (ref is None) != (pref is None), "specify ref xor pref" 232 | assert (gen is None) != (pgen is None), "specify gen xor pgen" 233 | if pref is None: 234 | pref = self.precalc(ref) 235 | if pgen is None: 236 | pgen = self.precalc(gen) 237 | return self.metric(pref, pgen) 238 | 239 | def precalc(self, moleclues): 240 | raise NotImplementedError 241 | 242 | def metric(self, pref, pgen): 243 | raise NotImplementedError 244 | 245 | 246 | class SNNMetric(Metric): 247 | """ 248 | Computes average max similarities of gen SMILES to ref SMILES 249 | """ 250 | 251 | def __init__(self, fp_type='morgan', **kwargs): 252 | self.fp_type = fp_type 253 | super().__init__(**kwargs) 254 | 255 | def precalc(self, mols): 256 | return {'fps': fingerprints(mols, n_jobs=self.n_jobs, fp_type=self.fp_type)} 257 | 258 | def metric(self, pref, pgen): 259 | return average_agg_tanimoto(pref['fps'], pgen['fps'], 260 | device=self.device) 261 | 262 | 263 | def cos_similarity(ref_counts, gen_counts): 264 | """ 265 | Computes cosine similarity between 266 | dictionaries of form {name: count}. Non-present 267 | elements are considered zero: 268 | 269 | sim = / ||r|| / ||g|| 270 | """ 271 | if len(ref_counts) == 0 or len(gen_counts) == 0: 272 | return np.nan 273 | keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys())) 274 | ref_vec = np.array([ref_counts.get(k, 0) for k in keys]) 275 | gen_vec = np.array([gen_counts.get(k, 0) for k in keys]) 276 | return 1 - cos_distance(ref_vec, gen_vec) 277 | 278 | 279 | class FragMetric(Metric): 280 | def precalc(self, mols): 281 | return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)} 282 | 283 | def metric(self, pref, pgen): 284 | return cos_similarity(pref['frag'], pgen['frag']) 285 | 286 | 287 | class ScafMetric(Metric): 288 | def precalc(self, mols): 289 | return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)} 290 | 291 | def metric(self, pref, pgen): 292 | return cos_similarity(pref['scaf'], pgen['scaf']) 293 | 294 | 295 | class FrechetMetric(Metric): 296 | def __init__(self, func=None, **kwargs): 297 | self.func = func 298 | super().__init__(**kwargs) 299 | 300 | def precalc(self, mols): 301 | if self.func is not None: 302 | values = mapper(self.n_jobs)(self.func, mols) 303 | else: 304 | values = mols 305 | return {'mu': np.mean(values), 'var': np.var(values)} 306 | 307 | def metric(self, pref, pgen): 308 | return calculate_frechet_distance( 309 | pref['mu'], pref['var'], pgen['mu'], pgen['var'] 310 | ) 311 | 312 | 313 | class MetricsReward: 314 | supported_metrics = ['fcd', 'snn', 'fragments', 'scaffolds', 315 | 'internal_diversity', 'filters', 316 | 'logp', 'sa', 'qed', 'np', 'weight'] 317 | 318 | @staticmethod 319 | def _nan2zero(value): 320 | if value == np.nan: 321 | return 0 322 | 323 | return value 324 | 325 | def __init__(self, n_ref_subsample, n_rollouts, n_jobs, metrics=[]): 326 | assert all([m in MetricsReward.supported_metrics for m in metrics]) 327 | 328 | self.n_ref_subsample = n_ref_subsample 329 | self.n_rollouts = n_rollouts 330 | # TODO: profile this. Pool works too slow. 331 | n_jobs = n_jobs if False else 1 332 | self.n_jobs = n_jobs 333 | self.metrics = metrics 334 | 335 | def get_reference_data(self, data): 336 | ref_smiles = remove_invalid(data, canonize=True, n_jobs=self.n_jobs) 337 | ref_mols = mapper(self.n_jobs)(get_mol, ref_smiles) 338 | return ref_smiles, ref_mols 339 | 340 | def _get_metrics(self, ref, ref_mols, rollout): 341 | rollout_mols = mapper(self.n_jobs)(get_mol, rollout) 342 | result = [[0 if m is None else 1] for m in rollout_mols] 343 | 344 | if sum([r[0] for r in result], 0) == 0: 345 | return result 346 | 347 | rollout = remove_invalid(rollout, canonize=True, n_jobs=self.n_jobs) 348 | rollout_mols = mapper(self.n_jobs)(get_mol, rollout) 349 | if len(rollout) < 2: 350 | return result 351 | 352 | if len(self.metrics): 353 | for metric_name in self.metrics: 354 | if metric_name == 'fcd': 355 | m = FCDMetric(n_jobs=self.n_jobs)(ref, rollout) 356 | elif metric_name == 'morgan': 357 | m = SNNMetric(n_jobs=self.n_jobs)(ref_mols, rollout_mols) 358 | elif metric_name == 'fragments': 359 | m = FragMetric(n_jobs=self.n_jobs)(ref_mols, rollout_mols) 360 | elif metric_name == 'scaffolds': 361 | m = ScafMetric(n_jobs=self.n_jobs)(ref_mols, rollout_mols) 362 | elif metric_name == 'internal_diversity': 363 | m = internal_diversity(rollout_mols, n_jobs=self.n_jobs) 364 | elif metric_name == 'filters': 365 | m = fraction_passes_filters( 366 | rollout_mols, n_jobs=self.n_jobs 367 | ) 368 | elif metric_name == 'logp': 369 | m = -FrechetMetric(func=logP, n_jobs=self.n_jobs)( 370 | ref_mols, rollout_mols 371 | ) 372 | elif metric_name == 'sa': 373 | m = -FrechetMetric(func=SA, n_jobs=self.n_jobs)( 374 | ref_mols, rollout_mols 375 | ) 376 | elif metric_name == 'qed': 377 | m = -FrechetMetric(func=QED, n_jobs=self.n_jobs)( 378 | ref_mols, rollout_mols 379 | ) 380 | elif metric_name == 'np': 381 | m = -FrechetMetric(func=NP, n_jobs=self.n_jobs)( 382 | ref_mols, rollout_mols 383 | ) 384 | elif metric_name == 'weight': 385 | m = -FrechetMetric(func=weight, n_jobs=self.n_jobs)( 386 | ref_mols, rollout_mols 387 | ) 388 | 389 | m = MetricsReward._nan2zero(m) 390 | for i in range(len(rollout)): 391 | result[i].append(m) 392 | 393 | return result 394 | 395 | def __call__(self, gen, ref, ref_mols): 396 | 397 | idxs = random.sample(range(len(ref)), self.n_ref_subsample) 398 | ref_subsample = [ref[idx] for idx in idxs] 399 | ref_mols_subsample = [ref_mols[idx] for idx in idxs] 400 | 401 | gen_counter = Counter(gen) 402 | gen_counts = [gen_counter[g] for g in gen] 403 | 404 | n = len(gen) // self.n_rollouts 405 | rollouts = [gen[i::n] for i in range(n)] 406 | 407 | metrics_values = [self._get_metrics( 408 | ref_subsample, ref_mols_subsample, rollout 409 | ) for rollout in rollouts] 410 | metrics_values = map( 411 | lambda rm: [ 412 | sum(r, 0) / len(r) 413 | for r in rm 414 | ], metrics_values) 415 | reward_values = sum(zip(*metrics_values), ()) 416 | reward_values = [v / c for v, c in zip(reward_values, gen_counts)] 417 | 418 | return reward_values 419 | -------------------------------------------------------------------------------- /genmol/ORGAN/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 5 | from Data import * 6 | from Metrics_Reward import MetricsReward 7 | 8 | class Generator(nn.Module): 9 | def __init__(self, embedding_layer, hidden_size, num_layers, dropout): 10 | super(Generator, self).__init__() 11 | 12 | self.embedding_layer = embedding_layer 13 | self.lstm_layer = nn.LSTM(embedding_layer.embedding_dim, 14 | hidden_size, num_layers, 15 | batch_first=True, dropout=dropout) 16 | self.linear_layer = nn.Linear(hidden_size, 17 | embedding_layer.num_embeddings) 18 | 19 | def forward(self, x, lengths, states=None): 20 | x = self.embedding_layer(x) 21 | x = pack_padded_sequence(x, lengths, batch_first=True) 22 | x, states = self.lstm_layer(x, states) 23 | x, _ = pad_packed_sequence(x, batch_first=True) 24 | x = self.linear_layer(x) 25 | 26 | return x, lengths, states 27 | 28 | 29 | class Discriminator(nn.Module): 30 | def __init__(self, desc_embedding_layer, convs, dropout=0): 31 | super(Discriminator, self).__init__() 32 | 33 | self.embedding_layer = desc_embedding_layer 34 | self.conv_layers = nn.ModuleList( 35 | [nn.Conv2d(1, f, kernel_size=( 36 | n, self.embedding_layer.embedding_dim) 37 | ) for f, n in convs]) 38 | sum_filters = sum([f for f, _ in convs]) 39 | self.highway_layer = nn.Linear(sum_filters, sum_filters) 40 | self.dropout_layer = nn.Dropout(p=dropout) 41 | self.output_layer = nn.Linear(sum_filters, 1) 42 | 43 | def forward(self, x): 44 | x = self.embedding_layer(x) 45 | x = x.unsqueeze(1) 46 | convs = [F.elu(conv_layer(x)).squeeze(3) 47 | for conv_layer in self.conv_layers] 48 | x = [F.max_pool1d(c, c.shape[2]).squeeze(2) for c in convs] 49 | x = torch.cat(x, dim=1) 50 | 51 | h = self.highway_layer(x) 52 | t = torch.sigmoid(h) 53 | x = t * F.elu(h) + (1 - t) * x 54 | x = self.dropout_layer(x) 55 | out = self.output_layer(x) 56 | 57 | return out 58 | 59 | 60 | class ORGAN(nn.Module): 61 | def __init__(self): 62 | super(ORGAN, self).__init__() 63 | 64 | self.metrics_reward = MetricsReward(n_ref_subsample=100, n_rollouts=16, n_jobs=1, metrics=[]) 65 | 66 | self.reward_weight = 0.7 67 | 68 | self.convs = [(100, 1), (200, 2), (200, 3), 69 | (200, 4), (200, 5), (100, 6), 70 | (100, 7), (100, 8), (100, 9), 71 | (100, 10)] 72 | 73 | self.embedding_layer = nn.Embedding( 74 | len(vocabulary), embedding_dim=32, padding_idx=c2i['']) 75 | 76 | self.desc_embedding_layer = nn.Embedding( 77 | len(vocabulary), embedding_dim=32, padding_idx=c2i['']) 78 | 79 | self.generator = Generator(self.embedding_layer, hidden_size=512, num_layers=2, dropout=0) 80 | 81 | self.discriminator = Discriminator(self.desc_embedding_layer, self.convs, dropout=0) 82 | 83 | def device(self): 84 | return next(self.parameters()).device 85 | 86 | def generator_forward(self, *args, **kwargs): 87 | return self.generator(*args, **kwargs) 88 | 89 | def discriminator_forward(self, *args, **kwargs): 90 | return self.discriminator(*args, **kwargs) 91 | 92 | def forward(self, *args, **kwargs): 93 | return self.sample(*args, **kwargs) 94 | 95 | def char2id(self, c): 96 | if c not in c2i: 97 | return c2i[''] 98 | 99 | return c2i[c] 100 | 101 | def id2char(self, id): 102 | if id not in i2c: 103 | return i2c[14] 104 | 105 | return i2c[id] 106 | 107 | def string2id(self, string, add_bos=False, add_eos=False): 108 | ids = [self.char2id(c) for c in string] 109 | 110 | if add_bos: 111 | ids = [c2i['']] + ids 112 | 113 | if add_eos: 114 | ids = ids + [c2i['']] 115 | 116 | return ids 117 | 118 | def ids2string(self, ids, rem_bos=True, rem_eos=True): 119 | if len(ids) == 0: 120 | return '' 121 | if rem_bos and ids[0] == c2i['']: 122 | ids = ids[1:] 123 | if rem_eos and ids[-1] == c2i['']: 124 | ids = ids[:-1] 125 | 126 | string = ''.join([self.id2char(id) for id in ids]) 127 | 128 | return string 129 | 130 | def string2tensor(self, string): 131 | ids = self.string2id(string, add_bos=True, add_eos=True) 132 | tensor = torch.tensor(ids, dtype=torch.long, device=device) 133 | 134 | return tensor 135 | 136 | def tensor2string(self, tensor): 137 | ids = tensor.tolist() 138 | string = self.ids2string(ids, rem_bos=True, rem_eos=True) 139 | 140 | return string 141 | 142 | def sample_tensor(self, n, max_length=100): 143 | prevs = torch.empty(n, 1, 144 | dtype=torch.long).fill_(c2i['']) 145 | samples, lengths = self._proceed_sequences(prevs, None, max_length) 146 | 147 | samples = torch.cat([prevs, samples], dim=-1) 148 | lengths += 1 149 | 150 | return samples, lengths 151 | 152 | def sample(self, batch_n=64, max_length=100): 153 | samples, lengths = self.sample_tensor(batch_n, max_length) 154 | samples = [t[:l] for t, l in zip(samples, lengths)] 155 | 156 | return [self.tensor2string(t) for t in samples] 157 | 158 | def _proceed_sequences(self, prevs, states, max_length): 159 | with torch.no_grad(): 160 | n_sequences = prevs.shape[0] 161 | 162 | sequences = [] 163 | lengths = torch.zeros(n_sequences, 164 | dtype=torch.long, device=device) 165 | 166 | one_lens = torch.ones(n_sequences, 167 | dtype=torch.long, device=device) 168 | is_end = prevs.eq(c2i['']).view(-1) 169 | 170 | for _ in range(max_length): 171 | outputs, _, states = self.generator(prevs, one_lens, states) 172 | probs = F.softmax(outputs, dim=-1).view(n_sequences, -1) 173 | currents = torch.multinomial(probs, 1) 174 | 175 | currents[is_end, :] = c2i[''] 176 | sequences.append(currents) 177 | lengths[~is_end] += 1 178 | 179 | is_end[currents.view(-1) == c2i['']] = 1 180 | if is_end.sum() == n_sequences: 181 | break 182 | 183 | prevs = currents 184 | 185 | sequences = torch.cat(sequences, dim=-1) 186 | 187 | return sequences, lengths 188 | 189 | def rollout(self, ref_smiles, ref_mols, n_samples, n_rollouts, max_length=100): 190 | with torch.no_grad(): 191 | sequences = [] 192 | rewards = [] 193 | ref_smiles = ref_smiles 194 | ref_mols = ref_mols 195 | lengths = torch.zeros(n_samples, dtype=torch.long, device=device) 196 | 197 | one_lens = torch.ones(n_samples, dtype=torch.long, ) 198 | prevs = torch.empty(n_samples, 1, dtype=torch.long, device=device).fill_(c2i['']) 199 | is_end = torch.zeros(n_samples, dtype=torch.uint8, device=device) 200 | states = None 201 | 202 | sequences.append(prevs) 203 | lengths += 1 204 | 205 | for current_len in range(10): 206 | print(current_len) 207 | outputs, _, states = self.generator(prevs, one_lens, states) 208 | 209 | probs = F.softmax(outputs, dim=-1).view(n_samples, -1) 210 | currents = torch.multinomial(probs, 1) 211 | 212 | currents[is_end, :] = c2i[''] 213 | sequences.append(currents) 214 | lengths[~is_end] += 1 215 | 216 | rollout_prevs = currents[~is_end, :].repeat(n_rollouts, 1) 217 | rollout_states = ( 218 | states[0][:, ~is_end, :].repeat(1, n_rollouts, 1), 219 | states[1][:, ~is_end, :].repeat(1, n_rollouts, 1) 220 | ) 221 | rollout_sequences, rollout_lengths = self._proceed_sequences( 222 | rollout_prevs, rollout_states, max_length - current_len 223 | ) 224 | 225 | rollout_sequences = torch.cat( 226 | [s[~is_end, :].repeat(n_rollouts, 1) for s in sequences] + [rollout_sequences], dim=-1) 227 | rollout_lengths += lengths[~is_end].repeat(n_rollouts) 228 | 229 | rollout_rewards = torch.sigmoid( 230 | self.discriminator(rollout_sequences).detach() 231 | ) 232 | 233 | if self.metrics_reward is not None and self.reward_weight > 0: 234 | strings = [ 235 | self.tensor2string(t[:l]) 236 | for t, l in zip(rollout_sequences, rollout_lengths) 237 | ] 238 | 239 | obj_rewards = torch.tensor( 240 | self.metrics_reward(strings, ref_smiles, ref_mols)).view(-1, 1) 241 | rollout_rewards = (rollout_rewards * (1 - self.reward_weight) + 242 | obj_rewards * self.reward_weight 243 | ) 244 | print('Metrics Rewards = ', obj_rewards) 245 | current_rewards = torch.zeros(n_samples, device=device) 246 | 247 | current_rewards[~is_end] = rollout_rewards.view( 248 | n_rollouts, -1 249 | ).mean(dim=0) 250 | rewards.append(current_rewards.view(-1, 1)) 251 | 252 | is_end[currents.view(-1) == c2i['']] = 1 253 | if is_end.sum() >= 10: 254 | break 255 | prevs = currents 256 | 257 | sequences = torch.cat(sequences, dim=1) 258 | rewards = torch.cat(rewards, dim=1) 259 | 260 | return sequences, rewards, lengths 261 | -------------------------------------------------------------------------------- /genmol/ORGAN/NP_Score/README: -------------------------------------------------------------------------------- 1 | RDKit-based implementation of the method described in: 2 | 3 | Natural Product-likeness Score and Its Application for Prioritization of Compound Libraries 4 | Peter Ertl, Silvio Roggo, and Ansgar Schuffenhauer 5 | Journal of Chemical Information and Modeling, 48, 68-74 (2008) 6 | http://pubs.acs.org/doi/abs/10.1021/ci700286x 7 | 8 | Contribution from Peter Ertl 9 | 10 | -------------------------------------------------------------------------------- /genmol/ORGAN/NP_Score/__pycache__/npscorer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/ORGAN/NP_Score/__pycache__/npscorer.cpython-37.pyc -------------------------------------------------------------------------------- /genmol/ORGAN/NP_Score/npscorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of natural product-likeness as described in: 3 | # 4 | # Natural Product-likeness Score and Its Application for Prioritization of 5 | # Compound Libraries 6 | # Peter Ertl, Silvio Roggo, and Ansgar Schuffenhauer 7 | # Journal of Chemical Information and Modeling, 48, 68-74 (2008) 8 | # http://pubs.acs.org/doi/abs/10.1021/ci700286x 9 | # 10 | # for the training of this model only openly available data have been used 11 | # ~50,000 natural products collected from various open databases 12 | # ~1 million drug-like molecules from ZINC as a "non-NP background" 13 | # 14 | # peter ertl, august 2015 15 | # 16 | 17 | from __future__ import print_function 18 | from rdkit import Chem 19 | from rdkit.Chem import rdMolDescriptors 20 | import sys 21 | import math 22 | import gzip 23 | import pickle 24 | import os.path 25 | from collections import namedtuple 26 | 27 | 28 | _fscores = None 29 | 30 | 31 | def readNPModel(filename=os.path.join(os.path.dirname(__file__),'publicnp.model.gz')): 32 | """Reads and returns the scoring model, 33 | which has to be passed to the scoring functions.""" 34 | global _fscores 35 | _fscores = pickle.load(gzip.open(filename)) 36 | return _fscores 37 | 38 | 39 | def scoreMolWConfidence(mol, fscore): 40 | """Next to the NP Likeness Score, this function outputs a confidence value 41 | between 0..1 that descibes how many fragments of the tested molecule 42 | were found in the model data set (1: all fragments were found). 43 | 44 | Returns namedtuple NPLikeness(nplikeness, confidence)""" 45 | 46 | if mol is None: 47 | raise ValueError('invalid molecule') 48 | fp = rdMolDescriptors.GetMorganFingerprint(mol, 2) 49 | bits = fp.GetNonzeroElements() 50 | 51 | # calculating the score 52 | score = 0.0 53 | bits_found = 0 54 | for bit in bits: 55 | if bit in fscore: 56 | bits_found += 1 57 | score += fscore[bit] 58 | 59 | score /= float(mol.GetNumAtoms()) 60 | confidence = float(bits_found / len(bits)) 61 | 62 | # preventing score explosion for exotic molecules 63 | if score > 4: 64 | score = 4. + math.log10(score - 4. + 1.) 65 | elif score < -4: 66 | score = -4. - math.log10(-4. - score + 1.) 67 | NPLikeness = namedtuple("NPLikeness", "nplikeness,confidence") 68 | return NPLikeness(score, confidence) 69 | 70 | 71 | def scoreMol(mol, fscore=None): 72 | """Calculates the Natural Product Likeness of a molecule. 73 | 74 | Returns the score as float in the range -5..5.""" 75 | if _fscores is None: 76 | readNPModel() 77 | fscore = fscore or _fscores 78 | return scoreMolWConfidence(mol, fscore).nplikeness 79 | 80 | 81 | def processMols(fscore, suppl): 82 | print("calculating ...", file=sys.stderr) 83 | n = 0 84 | for i, m in enumerate(suppl): 85 | if m is None: 86 | continue 87 | 88 | n += 1 89 | score = "%.3f" % scoreMol(m, fscore) 90 | 91 | smiles = Chem.MolToSmiles(m, True) 92 | name = m.GetProp('_Name') 93 | print(smiles + "\t" + name + "\t" + score) 94 | 95 | print("finished, " + str(n) + " molecules processed", file=sys.stderr) 96 | 97 | 98 | if __name__ == '__main__': 99 | fscore = readNPModel() # fills fscore 100 | 101 | suppl = Chem.SmilesMolSupplier( 102 | sys.argv[1], smilesColumn=0, nameColumn=1, titleLine=False 103 | ) 104 | processMols(fscore, suppl) 105 | 106 | # 107 | # Copyright (c) 2015, Novartis Institutes for BioMedical Research Inc. 108 | # All rights reserved. 109 | # 110 | # Redistribution and use in source and binary forms, with or without 111 | # modification, are permitted provided that the following conditions are 112 | # met: 113 | # 114 | # * Redistributions of source code must retain the above copyright 115 | # notice, this list of conditions and the following disclaimer. 116 | # * Redistributions in binary form must reproduce the above 117 | # copyright notice, this list of conditions and the following 118 | # disclaimer in the documentation and/or other materials provided 119 | # with the distribution. 120 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 121 | # nor the names of its contributors may be used to endorse or promote 122 | # products derived from this software without specific prior written 123 | # permission. 124 | # 125 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 126 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 127 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 128 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 129 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 130 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 131 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 132 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 133 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 134 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 135 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 136 | # 137 | -------------------------------------------------------------------------------- /genmol/ORGAN/NP_Score/publicnp.model.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/ORGAN/NP_Score/publicnp.model.gz -------------------------------------------------------------------------------- /genmol/ORGAN/RewardMetrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter 3 | from functools import partial 4 | import numpy as np 5 | import pandas as pd 6 | import scipy.sparse 7 | import torch 8 | from rdkit import Chem 9 | from rdkit.Chem import AllChem 10 | from rdkit.Chem import MACCSkeys 11 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan 12 | from rdkit.Chem.QED import qed 13 | from rdkit.Chem.Scaffolds import MurckoScaffold 14 | from rdkit.Chem import Descriptors 15 | from multiprocessing import Pool 16 | from SA_Score import sascorer 17 | 18 | from NP_Score import npscorer 19 | _base_dir = os.path.split(__file__)[0] 20 | _mcf = pd.read_csv(os.path.join(_base_dir, 'mcf.csv')) 21 | _pains = pd.read_csv(os.path.join(_base_dir, 'wehi_pains.csv'), 22 | names=['smarts', 'names']) 23 | _filters = [Chem.MolFromSmarts(x) for x in 24 | _mcf.append(_pains, sort=True)['smarts'].values] 25 | 26 | def mapper(n_jobs): 27 | # n_jobs = 8 28 | ''' 29 | Returns function for map call. 30 | If n_jobs == 1, will use standard map 31 | If n_jobs > 1, will use multiprocessing pool 32 | If n_jobs is a pool object, will return its map function 33 | ''' 34 | if n_jobs == 1: 35 | def _mapper(*args, **kwargs): 36 | return list(map(*args, **kwargs)) 37 | 38 | return _mapper 39 | elif isinstance(n_jobs, int): 40 | pool = Pool(n_jobs) 41 | 42 | def _mapper(*args, **kwargs): 43 | try: 44 | result = pool.map(*args, **kwargs) 45 | finally: 46 | pool.terminate() 47 | return result 48 | 49 | return _mapper 50 | else: 51 | return n_jobs.map 52 | 53 | 54 | def get_mol(smiles_or_mol): 55 | '''' 56 | Loads SMILES/molecule into RDKit's object 57 | ''' 58 | if isinstance(smiles_or_mol, str): 59 | if len(smiles_or_mol) == 0: 60 | return None 61 | mol = Chem.MolFromSmiles(smiles_or_mol) 62 | if mol is None: 63 | return None 64 | try: 65 | Chem.SanitizeMol(mol) 66 | except ValueError: 67 | return None 68 | return mol 69 | else: 70 | return smiles_or_mol 71 | 72 | 73 | def canonic_smiles(smiles_or_mol): 74 | mol = get_mol(smiles_or_mol) 75 | if mol is None: 76 | return None 77 | return Chem.MolToSmiles(mol) 78 | 79 | 80 | def logP(mol): 81 | """ 82 | Computes RDKit's logP 83 | """ 84 | return Chem.Crippen.MolLogP(mol) 85 | 86 | 87 | def SA(mol): 88 | """ 89 | Computes RDKit's Synthetic Accessibility score 90 | """ 91 | return sascorer.calculateScore(mol) 92 | 93 | 94 | def NP(mol): 95 | """ 96 | Computes RDKit's Natural Product-likeness score 97 | """ 98 | return npscorer.scoreMol(mol) 99 | 100 | 101 | def QED(mol): 102 | """ 103 | Computes RDKit's QED score 104 | """ 105 | return qed(mol) 106 | 107 | 108 | def weight(mol): 109 | """ 110 | Computes molecular weight for given molecule. 111 | Returns float, 112 | """ 113 | return Descriptors.MolWt(mol) 114 | 115 | 116 | def get_n_rings(mol): 117 | """ 118 | Computes the number of rings in a molecule 119 | """ 120 | return mol.GetRingInfo().NumRings() 121 | 122 | 123 | def fragmenter(mol): 124 | """ 125 | fragment mol using BRICS and return smiles list 126 | """ 127 | fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol)) 128 | fgs_smi = Chem.MolToSmiles(fgs).split(".") 129 | return fgs_smi 130 | 131 | 132 | def compute_fragments(mol_list, n_jobs=1): 133 | """ 134 | fragment list of mols using BRICS and return smiles list 135 | """ 136 | fragments = Counter() 137 | for mol_frag in mapper(n_jobs)(fragmenter, mol_list): 138 | fragments.update(mol_frag) 139 | return fragments 140 | 141 | 142 | def compute_scaffolds(mol_list, n_jobs=1, min_rings=2): 143 | """ 144 | Extracts a scafold from a molecule in a form of a canonic SMILES 145 | """ 146 | scaffolds = Counter() 147 | map_ = mapper(n_jobs) 148 | scaffolds = Counter( 149 | map_(partial(compute_scaffold, min_rings=min_rings), mol_list)) 150 | if None in scaffolds: 151 | scaffolds.pop(None) 152 | return scaffolds 153 | 154 | 155 | def compute_scaffold(mol, min_rings=2): 156 | mol = get_mol(mol) 157 | scaffold = MurckoScaffold.GetScaffoldForMol(mol) 158 | n_rings = get_n_rings(scaffold) 159 | scaffold_smiles = Chem.MolToSmiles(scaffold) 160 | if scaffold_smiles == '' or n_rings < min_rings: 161 | return None 162 | else: 163 | return scaffold_smiles 164 | 165 | 166 | def average_agg_tanimoto(stock_vecs, gen_vecs, 167 | batch_size=5000, agg='max', 168 | device='cpu', p=1): 169 | """ 170 | For each molecule in gen_vecs finds closest molecule in stock_vecs. 171 | Returns average tanimoto score for between these molecules 172 | 173 | Parameters: 174 | stock_vecs: numpy array 175 | gen_vecs: numpy array 176 | agg: max or mean 177 | p: power for averaging: (mean x^p)^(1/p) 178 | """ 179 | assert agg in ['max', 'mean'], "Can aggregate only max or mean" 180 | agg_tanimoto = np.zeros(len(gen_vecs)) 181 | total = np.zeros(len(gen_vecs)) 182 | for j in range(0, stock_vecs.shape[0], batch_size): 183 | x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() 184 | for i in range(0, gen_vecs.shape[0], batch_size): 185 | y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() 186 | y_gen = y_gen.transpose(0, 1) 187 | tp = torch.mm(x_stock, y_gen) 188 | jac = (tp / (x_stock.sum(1, keepdim=True) + 189 | y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() 190 | jac[np.isnan(jac)] = 1 191 | if p != 1: 192 | jac = jac ** p 193 | if agg == 'max': 194 | agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum( 195 | agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) 196 | elif agg == 'mean': 197 | agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) 198 | total[i:i + y_gen.shape[1]] += jac.shape[0] 199 | if agg == 'mean': 200 | agg_tanimoto /= total 201 | if p != 1: 202 | agg_tanimoto = (agg_tanimoto) ** (1 / p) 203 | return np.mean(agg_tanimoto) 204 | 205 | 206 | def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2, 207 | morgan__n=1024, *args, **kwargs): 208 | """ 209 | Generates fingerprint for SMILES 210 | If smiles is invalid, returns None 211 | Returns numpy array of fingerprint bits 212 | 213 | Parameters: 214 | smiles: SMILES string 215 | type: type of fingerprint: [MACCS|morgan] 216 | dtype: if not None, specifies the dtype of returned array 217 | """ 218 | fp_type = fp_type.lower() 219 | molecule = get_mol(smiles_or_mol, *args, **kwargs) 220 | if molecule is None: 221 | return None 222 | if fp_type == 'maccs': 223 | keys = MACCSkeys.GenMACCSKeys(molecule) 224 | keys = np.array(keys.GetOnBits()) 225 | fingerprint = np.zeros(166, dtype='uint8') 226 | if len(keys) != 0: 227 | fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero 228 | elif fp_type == 'morgan': 229 | fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n), 230 | dtype='uint8') 231 | else: 232 | raise ValueError("Unknown fingerprint type {}".format(fp_type)) 233 | if dtype is not None: 234 | fingerprint = fingerprint.astype(dtype) 235 | return fingerprint 236 | 237 | 238 | def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args, 239 | 240 | **kwargs): 241 | ''' 242 | Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers 243 | e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10) 244 | Inserts np.NaN to rows corresponding to incorrect smiles. 245 | IMPORTANT: if there is at least one np.NaN, the dtype would be float 246 | Parameters: 247 | smiles_mols_array: list/array/pd.Series of smiles or already computed 248 | RDKit molecules 249 | n_jobs: number of parralel workers to execute 250 | already_unique: flag for performance reasons, if smiles array is big 251 | and already unique. Its value is set to True if smiles_mols_array 252 | contain RDKit molecules already. 253 | ''' 254 | if isinstance(smiles_mols_array, pd.Series): 255 | smiles_mols_array = smiles_mols_array.values 256 | else: 257 | smiles_mols_array = np.asarray(smiles_mols_array) 258 | if not isinstance(smiles_mols_array[0], str): 259 | already_unique = True 260 | 261 | if not already_unique: 262 | smiles_mols_array, inv_index = np.unique(smiles_mols_array, 263 | return_inverse=True) 264 | 265 | fps = mapper(n_jobs)( 266 | partial(fingerprint, *args, **kwargs), smiles_mols_array 267 | ) 268 | 269 | length = 1 270 | for fp in fps: 271 | if fp is not None: 272 | length = fp.shape[-1] 273 | first_fp = fp 274 | break 275 | fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :] 276 | for fp in fps] 277 | if scipy.sparse.issparse(first_fp): 278 | fps = scipy.sparse.vstack(fps).tocsr() 279 | else: 280 | fps = np.vstack(fps) 281 | if not already_unique: 282 | return fps[inv_index] 283 | else: 284 | return fps 285 | 286 | 287 | def mol_passes_filters(mol, 288 | allowed=None, 289 | isomericSmiles=False): 290 | """ 291 | Checks if mol 292 | * passes MCF and PAINS filters, 293 | * has only allowed atoms 294 | * is not charged 295 | """ 296 | allowed = allowed or {'C', 'N', 'S', 'O', 'F', 'Cl', 'Br', 'H'} 297 | mol = get_mol(mol) 298 | if mol is None: 299 | return False 300 | ring_info = mol.GetRingInfo() 301 | if ring_info.NumRings() != 0 and any( 302 | len(x) >= 8 for x in ring_info.AtomRings() 303 | ): 304 | return False 305 | h_mol = Chem.AddHs(mol) 306 | if any(atom.GetFormalCharge() != 0 for atom in mol.GetAtoms()): 307 | return False 308 | if any(atom.GetSymbol() not in allowed for atom in mol.GetAtoms()): 309 | return False 310 | if any(h_mol.HasSubstructMatch(smarts) for smarts in _filters): 311 | return False 312 | smiles = Chem.MolToSmiles(mol, isomericSmiles=isomericSmiles) 313 | if smiles is None or len(smiles) == 0: 314 | return False 315 | if Chem.MolFromSmiles(smiles) is None: 316 | return False 317 | return True 318 | -------------------------------------------------------------------------------- /genmol/ORGAN/Run.py: -------------------------------------------------------------------------------- 1 | import tqdm as tqdm 2 | from Metrics_Reward import * 3 | from Data import * 4 | from Trainer import fit 5 | from Model import ORGAN 6 | 7 | 8 | def sampler(model): 9 | n_samples = 100 10 | samples = [] 11 | with tqdm(total=n_samples, desc='Generating Samples')as T: 12 | while n_samples > 0: 13 | current_samples = model.sample(min(n_samples, 64), max_length=100) 14 | samples.extend(current_samples) 15 | n_samples -= len(current_samples) 16 | T.update(len(current_samples)) 17 | 18 | return samples 19 | 20 | 21 | def evaluate(test, samples, test_scaffolds=None, ptest=None, ptest_scaffolds=None): 22 | gen = samples 23 | metrics = get_all_metrics(test, gen, k=[1000, 1000], n_jobs=1, 24 | device=device, 25 | test_scaffolds=test_scaffolds, 26 | ptest=ptest, ptest_scaffolds=ptest_scaffolds) 27 | for name, value in metrics.items(): 28 | print('{}, {}'.format(name, value)) 29 | 30 | 31 | model = ORGAN() 32 | fit(model, train_data) 33 | samples = sampler(model) 34 | evaluate(test_data, samples, test_scaffold) 35 | -------------------------------------------------------------------------------- /genmol/ORGAN/SA_Score/README: -------------------------------------------------------------------------------- 1 | RDKit-based implementation of the method described in: 2 | 3 | Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 4 | Peter Ertl and Ansgar Schuffenhauer 5 | Journal of Cheminformatics 1:8 (2009) 6 | http://www.jcheminf.com/content/1/1/8 7 | 8 | Contribution from Peter Ertl and Greg Landrum 9 | 10 | -------------------------------------------------------------------------------- /genmol/ORGAN/SA_Score/UnitTestSAScore.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import unittest 4 | 5 | import sascorer 6 | from rdkit import Chem 7 | 8 | print(sascorer.__file__) 9 | 10 | 11 | class TestCase(unittest.TestCase): 12 | 13 | def test1(self): 14 | with open('data/zim.100.txt') as f: 15 | testData = [x.strip().split('\t') for x in f] 16 | testData.pop(0) 17 | for row in testData: 18 | smi = row[0] 19 | m = Chem.MolFromSmiles(smi) 20 | tgt = float(row[2]) 21 | val = sascorer.calculateScore(m) 22 | self.assertAlmostEqual(tgt, val, 3) 23 | 24 | 25 | if __name__ == '__main__': 26 | import sys 27 | import getopt 28 | import re 29 | 30 | doLong = 0 31 | if len(sys.argv) > 1: 32 | args, extras = getopt.getopt(sys.argv[1:], 'l') 33 | for arg, val in args: 34 | if arg == '-l': 35 | doLong = 1 36 | sys.argv.remove('-l') 37 | if doLong: 38 | for methName in dir(TestCase): 39 | if re.match('_test', methName): 40 | newName = re.sub('_test', 'test', methName) 41 | exec('TestCase.%s = TestCase.%s' % (newName, methName)) 42 | 43 | unittest.main() 44 | -------------------------------------------------------------------------------- /genmol/ORGAN/SA_Score/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/ORGAN/SA_Score/__init__.py -------------------------------------------------------------------------------- /genmol/ORGAN/SA_Score/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/ORGAN/SA_Score/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /genmol/ORGAN/SA_Score/__pycache__/sascorer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/ORGAN/SA_Score/__pycache__/sascorer.cpython-37.pyc -------------------------------------------------------------------------------- /genmol/ORGAN/SA_Score/data/zim.100.txt: -------------------------------------------------------------------------------- 1 | smiles Name sa_score 2 | Cc1c(C(=O)NCCO)[n+](=O)c2ccccc2n1[O-] ZINC21984717 3.166 3 | Cn1cc(NC=O)cc1C(=O)Nc1cc(C(=O)Nc2cc(C(=O)NCCC(N)=[NH2+])n(C)c2)n(C)c1 ZINC03872327 3.328 4 | OC(c1ccncc1)c1ccc(OCC[NH+]2CCCC2)cc1 ZINC34421620 3.822 5 | CC(C(=O)[O-])c1ccc(-c2ccccc2)cc1 ZINC00000361 2.462 6 | C[NH+](C)CC(O)Cn1c2ccc(Br)cc2c2cc(Br)ccc21 ZINC00626529 3.577 7 | NC(=[NH2+])NCC1COc2ccccc2O1 ZINC00000357 3.290 8 | CCC(C)(C)[NH2+]CC(O)COc1ccccc1C#N ZINC04214111 3.698 9 | C[NH+](C)CC(O)Cn1c2ccc(Br)cc2c2cc(Br)ccc21 ZINC00626528 3.577 10 | CC12CCC3C(CCC4CC(=O)CCC43C)C1CCC2=O ZINC04081985 3.912 11 | COc1ccc(OC(=O)N(CC(=O)[O-])Cc2ccc(OCCc3nc(-c4ccccc4)oc3C)cc2)cc1 ZINC03935839 2.644 12 | COc1ccccc1OC(=O)c1ccccc1 ZINC00000349 1.342 13 | CC(C)CC[NH2+]CC1COc2ccccc2O1 ZINC04214115 3.701 14 | CN1CCN(C(=O)OC2c3nccnc3C(=O)N2c2ccc(Cl)cn2)CC1 ZINC19632834 3.196 15 | CCC1(c2ccccc2)C(=O)N(COC)C(=O)N(COC)C1=O ZINC02986592 2.759 16 | Nc1ccc(S(=O)(=O)Nc2ccccc2)cc1 ZINC00141883 1.529 17 | O=C([O-])CCCNC(=O)NC1CCCCC1 ZINC08754389 2.493 18 | CCC(C)C(C(=O)OC1CC[N+](C)(C)CC1)c1ccccc1 ZINC00000595 3.399 19 | CCC(C)SSc1ncc[nH]1 ZINC13209429 3.983 20 | CC[N+](C)(CC)CCOC(=O)C(O)(c1cccs1)C1CCCC1 ZINC01690860 3.471 21 | CC12CCC3C(CCC4CC(=O)CCC43C)C1CCC2O ZINC03814360 3.994 22 | CC12CCC3C4CCC(=O)C=C4CCC3C1CCC2O ZINC03814379 4.056 23 | OCC1OC(OC2C(CO)OC(O)C(O)C2O)C(O)C(O)C1O ZINC04095762 4.282 24 | CC(C)CC(CC[NH+](C(C)C)C(C)C)(C(N)=O)c1ccccn1 ZINC02016048 4.092 25 | C=CC1(C)CC(=O)C2(O)C(C)(O1)C(OC(C)=O)C(OC(=O)CC[NH+](C)C)C1C(C)(C)CCC(O)C12C ZINC38595287 5.519 26 | C=CC[NH+]1CCCC1CNC(=O)c1cc(S(N)(=O)=O)cc(OC)c1OC ZINC00601278 4.286 27 | CC(=O)OC1C[NH+]2CCC1CC2 ZINC00492792 5.711 28 | CC12CCC3C(CCC4CC(=O)CCC43C)C1CCC2O ZINC03814418 3.994 29 | CC1(O)CCC2C3CCC4=CC(=O)CCC4(C)C3CCC21C ZINC03814422 4.022 30 | CC(=O)OC1(C(C)=O)CCC2C3C=C(Cl)C4=CC(=O)C5CC5C4(C)C3CCC21C ZINC03814423 4.827 31 | C#CC1(O)CCC2C3CCc4cc(OC)ccc4C3CCC21C ZINC03815424 3.810 32 | C=CC1(C)CC(OC(=O)CSCC[NH+](CC)CC)C2(C)C3C(=O)CCC3(CCC2C)C(C)C1O ZINC25757051 6.200 33 | O=C([O-])C(=O)Nc1nc(-c2ccc3c(c2)OCCO3)cs1 ZINC03623428 2.594 34 | CC[NH+]1CCCC1CNC(=O)C(O)(c1ccccc1)c1ccccc1 ZINC00900569 3.950 35 | CC(C)(OCc1nn(Cc2ccccc2)c2ccccc12)C(=O)[O-] ZINC00004594 2.573 36 | Cc1nnc(C(C)C)n1C1CC2CCC(C1)[NH+]2CCC(NC(=O)C1CCC(F)(F)CC1)c1ccccc1 ZINC03817234 5.316 37 | Nc1ncnc2c1ncn2C1OC(COP(=O)([O-])OP(=O)([O-])OP(=O)([O-])[O-])C(O)C1O ZINC03871612 5.290 38 | O=C([O-])CNC(=O)c1ccccc1 ZINC00097685 2.097 39 | Nc1ncnc2c1ncn2C1OC(COP(=O)([O-])OP(=O)([O-])OP(=O)([O-])[O-])C(O)C1O ZINC03871613 5.290 40 | Nc1ncnc2c1ncn2C1OC(COP(=O)([O-])OP(=O)([O-])OP(=O)([O-])[O-])C(O)C1O ZINC03871614 5.290 41 | c1ccc(OCc2ccc(CCCN3CCOCC3)cc2)cc1 ZINC19865692 1.702 42 | CC=CC1=C(C(=O)[O-])N2C(=O)C(NC(=O)C(N)c3ccc(O)cc3)C2SC1 ZINC20444132 4.042 43 | C[NH+]1CCCC1COc1cccnc1 ZINC03805141 4.510 44 | O=C([O-])C(O)CC(O)C(O)CO ZINC04803503 4.398 45 | O=C([O-])C(O)CC(O)C(O)CO ZINC01696607 4.398 46 | C[NH+]1CCCC1Cc1c[nH]c2ccc(CCS(=O)(=O)c3ccccc3)cc12 ZINC03823475 3.921 47 | C(=Cc1ccccc1)C[NH+]1CCN(C(c2ccccc2)c2ccccc2)CC1 ZINC19632891 2.973 48 | Nc1ncnc2c1ncn2C1OC(COP(=O)([O-])OP(=O)([O-])OP(=O)([O-])[O-])C(O)C1O ZINC03871615 5.290 49 | CC(c1ccccc1)N(C)C=O ZINC06932229 2.562 50 | CC(=O)C1CCC2C3CCC4CC(C)(O)CCC4(C)C3CCC12C ZINC03824281 4.279 51 | O=C([O-])C(O)CC(O)C(O)CO ZINC04803506 4.398 52 | COc1cc(O)c(C(=O)c2ccccc2)c(O)c1 ZINC00000187 1.868 53 | O=C([O-])C(O)CC(O)C(O)CO ZINC04803507 4.398 54 | COc1c2c(cc3c1C(O)N(C)CC3)OCO2 ZINC00000186 3.183 55 | CCC(C(=O)[O-])c1ccc(CC(C)C)cc1 ZINC00015537 2.827 56 | O=C([O-])C1[NH+]=C(c2ccccc2)c2cc(Cl)ccc2NC1(O)O ZINC38611850 4.011 57 | O=C([O-])C1[NH+]=C(c2ccccc2)c2cc(Cl)ccc2NC1(O)O ZINC38611851 4.011 58 | OCC(O)COc1ccc(Cl)cc1 ZINC00000135 2.102 59 | NC(=O)NC(=O)C(Cl)c1ccccc1 ZINC00000134 2.455 60 | OC(c1ccccc1)(c1ccccc1)C1C[NH+]2CCC1CC2 ZINC01298963 4.530 61 | C[NH2+]CC(C)c1ccccc1 ZINC04298801 3.471 62 | Clc1cccc(Cl)c1N=C1NCCO1 ZINC13835972 3.267 63 | [NH3+]C(Cc1ccccc1)C(=O)CCl ZINC02504633 3.251 64 | CC(C)Cn1cnc2c1c1ccccc1nc2N ZINC19632912 2.230 65 | CC(O)CN(C)c1ccc(NN)nn1 ZINC00000624 3.193 66 | CC1(O)CCC2C3CCC4=CC(=O)CCC4=C3C=CC21C ZINC00001727 4.461 67 | CCC(C(=O)[O-])c1ccc(-c2ccccc2)cc1 ZINC00000111 2.505 68 | CC(=O)OCC1OC(n2ncc(=O)[nH]c2=O)C(OC(C)=O)C1OC(C)=O ZINC03830255 3.832 69 | CC(=O)OCC1OC(n2ncc(=O)[nH]c2=O)C(OC(C)=O)C1OC(C)=O ZINC03830256 3.832 70 | Cn1cc(C(=O)c2cccc3ccccc32)cc1C(=O)[O-] ZINC00001783 2.456 71 | CC(=O)OCC1OC(n2ncc(=O)[nH]c2=O)C(OC(C)=O)C1OC(C)=O ZINC03830257 3.832 72 | Cc1cccc(-c2nc3ccccc3c(Nc3ccc4[nH]ncc4c3)n2)n1 ZINC39279791 2.358 73 | O=C([O-])C1CC2CCCCC2[NH2+]1 ZINC04899687 5.422 74 | CC(=O)OCC(=O)C1CCC2C3CC=C4CC(O)CCC4(C)C3CCC12C ZINC00538219 4.187 75 | O=C([O-])C1CC2CCCCC2[NH2+]1 ZINC04899686 5.422 76 | O=C(OCc1ccccc1)C(O)c1ccccc1 ZINC00000078 2.038 77 | CC(=O)OCC(=O)C1(O)CCC2C3CCC4=CC(=O)C=CC4(C)C3C(O)CC21C ZINC00608041 4.394 78 | Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2)cc1 ZINC02570895 2.144 79 | COCc1cccc(CC(O)C=CC2C(O)CC(=O)C2CCSCCCC(=O)OC)c1 ZINC03940680 3.934 80 | CCC(=O)N(c1ccccc1)C1CC[NH+](C(C)Cc2ccccc2)CC1 ZINC01664586 3.582 81 | CCC(=O)N(c1ccccc1)C1CC[NH+](C(C)Cc2ccccc2)CC1 ZINC01664587 3.582 82 | CCOC(=O)Nc1ccc2c(c1)N(C(=O)CCN1CCOCC1)c1ccccc1S2 ZINC19340795 2.446 83 | O=C([O-])Cc1cc(=O)[nH]c(=O)[nH]1 ZINC00403617 3.258 84 | NC(=O)C([NH3+])Cc1c[nH]c2ccccc12 ZINC04899521 3.224 85 | NC(=O)C([NH3+])Cc1ccc(O)cc1 ZINC04899513 3.280 86 | O=C(c1cc2ccccc2o1)N1CCN(Cc2ccccc2)CC1 ZINC19632922 1.799 87 | O=C(CO)C(O)C(O)CO ZINC00902219 3.473 88 | CC(Cc1ccccc1)NC(=O)C([NH3+])CCCC[NH3+] ZINC11680943 3.967 89 | C[NH+]1CCC(c2c(O)cc(=O)c3c(O)cc(-c4ccccc4Cl)oc2-3)C(O)C1 ZINC05966679 4.616 90 | CN(C)c1ccc(O)c2c1CC1CC3C([NH+](C)C)C(=O)C(C(N)=O)=C(O)C3(O)C(=O)C1=C2O ZINC04019704 4.713 91 | Cc1cc2nc3c(=O)[nH]c(=O)nc-3n(CC(O)C(O)C(O)CO)c2cc1C ZINC03650334 3.791 92 | C[NH+]1C2CCC1CC(OC(=O)c1c[nH]c3ccccc13)C2 ZINC18130447 4.892 93 | Cc1ccccc1NC(=O)C(C)[NH+]1CCCC1 ZINC00000051 3.809 94 | O=S(=O)([O-])CCN1CCOCC1 ZINC19419111 2.776 95 | C[NH+]1CCN(CC(=O)N2c3ccccc3C(=O)Nc3cccnc32)CC1 ZINC19632927 3.379 96 | CCCCCC=CCC=CCCCCCCCC(=O)[O-] ZINC03802188 2.805 97 | CC(CC([NH3+])C(=O)[O-])C(=O)[O-] ZINC01747048 5.690 98 | CC1c2cccc(O)c2C(=O)C2=C(O)C3(O)C(O)=C(C(N)=O)C(=O)C([NH+](C)C)C3C(O)C21 ZINC04019706 5.069 99 | Cc1cc2nc3nc([O-])[nH]c(=O)c3nc2cc1C ZINC12446789 3.079 100 | CC1=CC(C)C2(CO)COC(c3ccc(O)cc3)C1C2C ZINC38190856 4.749 101 | CC[NH+]1CCC(=C2c3ccccc3CCc3ccccc32)C1C ZINC02020004 3.925 102 | -------------------------------------------------------------------------------- /genmol/ORGAN/SA_Score/fpscores.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/ORGAN/SA_Score/fpscores.pkl.gz -------------------------------------------------------------------------------- /genmol/ORGAN/SA_Score/sascorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on 5 | # Molecular Complexity and Fragment Contributions 6 | # Peter Ertl and Ansgar Schuffenhauer 7 | # Journal of Cheminformatics 1:8 (2009) 8 | # http://www.jcheminf.com/content/1/1/8 9 | # 10 | # several small modifications to the original paper are included 11 | # particularly slightly different formula for marocyclic penalty 12 | # and taking into account also molecule symmetry (fingerprint density) 13 | # 14 | # for a set of 10k diverse molecules the agreement between the original method 15 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 16 | # 17 | # peter ertl & greg landrum, september 2013 18 | # 19 | from __future__ import print_function 20 | 21 | import math 22 | import os.path as op 23 | 24 | from rdkit import Chem 25 | from rdkit.Chem import rdMolDescriptors 26 | from rdkit.six import iteritems 27 | import pickle 28 | 29 | _fscores = None 30 | 31 | 32 | def readFragmentScores(name='fpscores'): 33 | import gzip 34 | global _fscores 35 | # generate the full path filename: 36 | if name == "fpscores": 37 | name = op.join(op.dirname(__file__), name) 38 | _fscores = pickle.load(gzip.open('%s.pkl.gz' % name)) 39 | outDict = {} 40 | for i in _fscores: 41 | for j in range(1, len(i)): 42 | outDict[i[j]] = float(i[0]) 43 | _fscores = outDict 44 | 45 | 46 | def numBridgeheadsAndSpiro(mol, ri=None): 47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 49 | return nBridgehead, nSpiro 50 | 51 | 52 | def calculateScore(m): 53 | if _fscores is None: 54 | readFragmentScores() 55 | 56 | # fragment score 57 | fp = rdMolDescriptors.GetMorganFingerprint( 58 | m, 2 # <- 2 is the *radius* of the circular fingerprint 59 | ) 60 | fps = fp.GetNonzeroElements() 61 | score1 = 0. 62 | nf = 0 63 | for bitId, v in iteritems(fps): 64 | nf += v 65 | sfp = bitId 66 | score1 += _fscores.get(sfp, -4) * v 67 | score1 /= nf 68 | 69 | # features score 70 | nAtoms = m.GetNumAtoms() 71 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 72 | ri = m.GetRingInfo() 73 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 74 | nMacrocycles = 0 75 | for x in ri.AtomRings(): 76 | if len(x) > 8: 77 | nMacrocycles += 1 78 | 79 | sizePenalty = nAtoms ** 1.005 - nAtoms 80 | stereoPenalty = math.log10(nChiralCenters + 1) 81 | spiroPenalty = math.log10(nSpiro + 1) 82 | bridgePenalty = math.log10(nBridgeheads + 1) 83 | macrocyclePenalty = 0. 84 | # --------------------------------------- 85 | # This differs from the paper, which defines: 86 | # macrocyclePenalty = math.log10(nMacrocycles+1) 87 | # This form generates better results when 2 or more macrocycles are present 88 | if nMacrocycles > 0: 89 | macrocyclePenalty = math.log10(2) 90 | 91 | score2 = (0. - sizePenalty - stereoPenalty - 92 | spiroPenalty - bridgePenalty - macrocyclePenalty) 93 | 94 | # correction for the fingerprint density 95 | # not in the original publication, added in version 1.1 96 | # to make highly symmetrical molecules easier to synthetise 97 | score3 = 0. 98 | if nAtoms > len(fps): 99 | score3 = math.log(float(nAtoms) / len(fps)) * .5 100 | 101 | sascore = score1 + score2 + score3 102 | 103 | # need to transform "raw" value into scale between 1 and 10 104 | min = -4.0 105 | max = 2.5 106 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 107 | # smooth the 10-end 108 | if sascore > 8.: 109 | sascore = 8. + math.log(sascore + 1. - 9.) 110 | if sascore > 10.: 111 | sascore = 10.0 112 | elif sascore < 1.: 113 | sascore = 1.0 114 | 115 | return sascore 116 | 117 | 118 | def processMols(mols): 119 | print('smiles\tName\tsa_score') 120 | for i, m in enumerate(mols): 121 | if m is None: 122 | continue 123 | 124 | s = calculateScore(m) 125 | 126 | smiles = Chem.MolToSmiles(m) 127 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 128 | 129 | 130 | if __name__ == '__main__': 131 | import sys 132 | import time 133 | 134 | t1 = time.time() 135 | readFragmentScores("fpscores") 136 | t2 = time.time() 137 | 138 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 139 | t3 = time.time() 140 | processMols(suppl) 141 | t4 = time.time() 142 | 143 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ( 144 | (t2 - t1), (t4 - t3)), 145 | file=sys.stderr) 146 | 147 | # 148 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 149 | # All rights reserved. 150 | # 151 | # Redistribution and use in source and binary forms, with or without 152 | # modification, are permitted provided that the following conditions are 153 | # met: 154 | # 155 | # * Redistributions of source code must retain the above copyright 156 | # notice, this list of conditions and the following disclaimer. 157 | # * Redistributions in binary form must reproduce the above 158 | # copyright notice, this list of conditions and the following 159 | # disclaimer in the documentation and/or other materials provided 160 | # with the distribution. 161 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 162 | # nor the names of its contributors may be used to endorse or promote 163 | # products derived from this software without specific prior written 164 | # permission. 165 | # 166 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 167 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 168 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 169 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 170 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 171 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 172 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 173 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 174 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 175 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 176 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 177 | # 178 | -------------------------------------------------------------------------------- /genmol/ORGAN/Trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from tqdm import tqdm 5 | from torch.nn.utils.rnn import pad_sequence 6 | from torch.utils.data import DataLoader 7 | from torch.optim import Adam 8 | import random 9 | 10 | from Data import * 11 | 12 | n_batch = 64 13 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | discriminator_pretrain_epochs = 50 15 | discriminator_epochs = 10 16 | generator_pretrain_epochs = 50 17 | max_length = 100 18 | save_frequency = 25 19 | generator_updates = 1 20 | discriminator_updates = 1 21 | n_samples = 64 22 | n_rollouts = 16 23 | pg_iters = 10 24 | 25 | class PolicyGradientLoss(nn.Module): 26 | def forward(self, outputs, targets, rewards, lengths): 27 | log_probs = F.log_softmax(outputs, dim=2) 28 | items = torch.gather( 29 | log_probs, 2, targets.unsqueeze(2) 30 | ) * rewards.unsqueeze(2) 31 | loss = -sum( 32 | [t[:l].sum() for t, l in zip(items, lengths)] 33 | ) / lengths.sum().float() 34 | return loss 35 | 36 | 37 | def generator_collate_fn(model): 38 | def collate(data): 39 | data.sort(key=len, reverse=True) 40 | tensors = [model.string2tensor(string) 41 | for string in data] 42 | 43 | prevs = pad_sequence( 44 | [t[:-1] for t in tensors], 45 | batch_first=True, padding_value=c2i[''] 46 | ) 47 | nexts = pad_sequence( 48 | [t[1:] for t in tensors], 49 | batch_first=True, padding_value=c2i[''] 50 | ) 51 | lens = torch.tensor( 52 | [len(t) - 1 for t in tensors], 53 | dtype=torch.long, device=device) 54 | return prevs, nexts, lens 55 | 56 | return collate 57 | 58 | 59 | def get_dataloader(training_data, collate_fn): 60 | return DataLoader(training_data, batch_size=n_batch, 61 | shuffle=True, num_workers=8, collate_fn=collate_fn, worker_init_fn=None) 62 | 63 | 64 | def _pretrain_generator_epoch(model, tqdm_data, criterion, optimizer): 65 | model.discriminator.eval() 66 | if optimizer is None: 67 | model.eval() 68 | else: 69 | model.train() 70 | 71 | postfix = {'loss': 0, 'running_loss': 0} 72 | 73 | for i, batch in enumerate(tqdm_data): 74 | (prevs, nexts, lens) = (data.to(device) for data in batch) 75 | outputs, _, _, = model.generator_forward(prevs, lens) 76 | 77 | loss = criterion(outputs.view(-1, outputs.shape[-1]), 78 | nexts.view(-1)) 79 | 80 | if optimizer is not None: 81 | optimizer.zero_grad() 82 | loss.backward() 83 | optimizer.step() 84 | 85 | postfix['loss'] = loss.item() 86 | postfix['running_loss'] += ( 87 | loss.item() - postfix['running_loss'] 88 | ) / (i + 1) 89 | tqdm_data.set_postfix(postfix) 90 | 91 | postfix['mode'] = ('Pretrain: eval generator' 92 | if optimizer is None 93 | else 'Pretrain: train generator') 94 | return postfix 95 | 96 | 97 | def _pretrain_generator(model, train_loader): 98 | generator = model.generator 99 | criterion = nn.CrossEntropyLoss(ignore_index=c2i['']) 100 | optimizer = torch.optim.Adam(model.generator.parameters(), lr=1e-4) 101 | 102 | model.zero_grad() 103 | for epoch in range(generator_pretrain_epochs): 104 | tqdm_data = tqdm(train_loader, desc='Generator training (epoch #{})'.format(epoch)) 105 | postfix = _pretrain_generator_epoch(model, tqdm_data, criterion, optimizer) 106 | if epoch % save_frequency == 0: 107 | generator = generator.to('cpu') 108 | torch.save(generator.state_dict(), 'model.csv'[:-4] + 109 | '_generator_{0:03d}.csv'.format(epoch)) 110 | generator = generator.to(device) 111 | 112 | 113 | def discriminator_collate_fn(model): 114 | def collate(data): 115 | data.sort(key=len, reverse=True) 116 | tensors = [model.string2tensor(string) for string in data] 117 | inputs = pad_sequence(tensors, batch_first=True, padding_value=c2i['']) 118 | 119 | return inputs 120 | 121 | return collate 122 | 123 | 124 | def _pretrain_discriminator_epoch(model, tqdm_data, 125 | criterion, optimizer=None): 126 | model.eval() 127 | if optimizer is None: 128 | model.eval() 129 | else: 130 | model.train() 131 | 132 | postfix = {'loss': 0, 133 | 'running_loss': 0} 134 | for i, inputs_from_data in enumerate(tqdm_data): 135 | inputs_from_data = inputs_from_data.to(device) 136 | inputs_from_model, _ = model.sample_tensor(n_batch, 100) 137 | 138 | targets = torch.zeros(n_batch, 1, device=device) 139 | outputs = model.discriminator_forward(inputs_from_model) 140 | loss = criterion(outputs, targets) / 2 141 | 142 | targets = torch.ones(inputs_from_data.shape[0], 1, device=device) 143 | outputs = model.discriminator_forward(inputs_from_data) 144 | loss += criterion(outputs, targets) / 2 145 | 146 | if optimizer is not None: 147 | optimizer.zero_grad() 148 | loss.backward() 149 | optimizer.step() 150 | 151 | postfix['loss'] = loss.item() 152 | postfix['running_loss'] += (loss.item() - 153 | postfix['running_loss']) / (i + 1) 154 | tqdm_data.set_postfix(postfix) 155 | 156 | postfix['mode'] = ('Pretrain: eval discriminator' 157 | if optimizer is None 158 | else 'Pretrain: train discriminator') 159 | return postfix 160 | 161 | 162 | def _pretrain_discriminator(model, train_loader): 163 | discriminator = model.discriminator 164 | criterion = nn.BCEWithLogitsLoss() 165 | optimizer = torch.optim.Adam(model.discriminator.parameters(), 166 | lr=1e-4) 167 | 168 | model.zero_grad() 169 | for epoch in range(discriminator_pretrain_epochs): 170 | tqdm_data = tqdm( 171 | train_loader, 172 | desc='Discriminator training (epoch #{})'.format(epoch) 173 | ) 174 | postfix = _pretrain_discriminator_epoch( 175 | model, tqdm_data, criterion, optimizer 176 | ) 177 | if epoch % save_frequency == 0: 178 | discriminator = discriminator.to('cpu') 179 | torch.save(discriminator.state_dict(), 'model.csv'[:-4] + '_discriminator_{0:03d}.csv'.format(epoch)) 180 | discriminator = discriminator.to(device) 181 | 182 | 183 | def _policy_gradient_iter(model, train_loader, criterion, optimizer, iter_, ref_smiles, ref_mols): 184 | smooth = 0.1 185 | 186 | # Generator 187 | gen_postfix = {'generator_loss': 0, 188 | 'smoothed_reward': 0} 189 | 190 | gen_tqdm = tqdm(range(generator_updates), 191 | desc='PG generator training (iter #{})'.format(iter_)) 192 | for _ in gen_tqdm: 193 | model.eval() 194 | sequences, rewards, lengths = model.rollout(ref_smiles, ref_mols, n_samples=n_samples, 195 | n_rollouts=n_rollouts, max_len=max_length) 196 | model.train() 197 | 198 | lengths, indices = torch.sort(lengths, descending=True) 199 | sequences = sequences[indices, ...] 200 | rewards = rewards[indices, ...] 201 | 202 | generator_outputs, lengths, _ = model.generator_forward( 203 | sequences[:, :-1], lengths - 1 204 | ) 205 | generator_loss = criterion['generator']( 206 | generator_outputs, sequences[:, 1:], rewards, lengths 207 | ) 208 | 209 | optimizer['generator'].zero_grad() 210 | generator_loss.backward() 211 | nn.utils.clip_grad_value_(model.generator.parameters(), clip_value=5) 212 | optimizer['generator'].step() 213 | 214 | gen_postfix['generator_loss'] += ( 215 | generator_loss.item() - 216 | gen_postfix['generator_loss'] 217 | ) * smooth 218 | mean_episode_reward = torch.cat( 219 | [t[:l] for t, l in zip(rewards, lengths)] 220 | ).mean().item() 221 | gen_postfix['smoothed_reward'] += ( 222 | mean_episode_reward - gen_postfix['smoothed_reward'] 223 | ) * smooth 224 | gen_tqdm.set_postfix(gen_postfix) 225 | 226 | # Discriminator 227 | discrim_postfix = {'discrim-r_loss': 0} 228 | discrim_tqdm = tqdm( 229 | range(discriminator_updates), 230 | desc='PG discrim-r training (iter #{})'.format(iter_) 231 | ) 232 | for _ in discrim_tqdm: 233 | model.generator.eval() 234 | n_batches = ( 235 | len(train_loader) + n_batch - 1 236 | ) // n_batch 237 | sampled_batches = [ 238 | model.sample_tensor(n_batch, 239 | max_length=max_length)[0] 240 | for _ in range(n_batches) 241 | ] 242 | 243 | for _ in range(discriminator_epochs): 244 | random.shuffle(sampled_batches) 245 | 246 | for inputs_from_model, inputs_from_data in zip( 247 | sampled_batches, train_loader 248 | ): 249 | # print(inputs_from_model) 250 | inputs_from_data = inputs_from_data.to(device) 251 | print(inputs_from_data) 252 | 253 | discrim_outputs = model.discriminator_forward( 254 | inputs_from_model 255 | ) 256 | discrim_targets = torch.zeros(len(discrim_outputs), 257 | 1, device=device) 258 | discrim_loss = criterion['discriminator']( 259 | discrim_outputs, discrim_targets 260 | ) / 2 261 | 262 | discrim_outputs = model.discriminator.forward( 263 | inputs_from_data) 264 | discrim_targets = torch.ones( 265 | len(discrim_outputs), 1, device=device) 266 | discrim_loss += criterion['discriminator']( 267 | discrim_outputs, discrim_targets 268 | ) / 2 269 | optimizer['discriminator'].zero_grad() 270 | discrim_loss.backward() 271 | optimizer['discriminator'].step() 272 | 273 | discrim_postfix['discrim-r_loss'] += ( 274 | discrim_loss.item() - 275 | discrim_postfix['discrim-r_loss'] 276 | ) * smooth 277 | 278 | discrim_tqdm.set_postfix(discrim_postfix) 279 | 280 | postfix = {**gen_postfix, **discrim_postfix} 281 | postfix['mode'] = 'Policy Gradient (iter #{})'.format(iter_) 282 | return postfix 283 | 284 | 285 | def _train_policy_gradient(model, pg_train_loader, ref_smiles, ref_mols): 286 | criterion = { 287 | 'generator': PolicyGradientLoss(), 288 | 'discriminator': nn.BCEWithLogitsLoss(), 289 | } 290 | 291 | optimizer = { 292 | 'generator': torch.optim.Adam(model.generator.parameters(), 293 | lr=1e-4), 294 | 'discriminator': torch.optim.Adam( 295 | model.discriminator.parameters(), lr=1e-4) 296 | } 297 | ref_smiles = ref_smiles 298 | ref_mols = ref_mols 299 | model.zero_grad() 300 | for iter_ in range(pg_iters): 301 | postfix = _policy_gradient_iter(model, pg_train_loader, criterion, optimizer, iter_, ref_smiles, ref_mols) 302 | 303 | 304 | def fit(model, train_data): 305 | # Generator 306 | gen_collate_fn = generator_collate_fn(model) 307 | gen_train_loader = get_dataloader(train_data, gen_collate_fn) 308 | _pretrain_generator(model, gen_train_loader) 309 | 310 | # Discriminator 311 | dsc_collate_fn = discriminator_collate_fn(model) 312 | desc_train_loader = get_dataloader(train_data, dsc_collate_fn) 313 | _pretrain_discriminator(model, desc_train_loader) 314 | 315 | # Policy Gradient 316 | if model.metrics_reward is not None: 317 | (ref_smiles, ref_mols) = model.metrics_reward.get_reference_data(train_data) 318 | 319 | pg_train_loader = desc_train_loader 320 | _train_policy_gradient(model, pg_train_loader, ref_smiles, ref_mols) 321 | 322 | del ref_smiles 323 | del ref_mols 324 | # 325 | return model 326 | -------------------------------------------------------------------------------- /genmol/ORGAN/mcf.csv: -------------------------------------------------------------------------------- 1 | names,smarts 2 | MCF1,[#6]=&!@[#6]-[#6]#[#7] 3 | MCF2,[#6]=&!@[#6]-[#16](=[#8])=[#8] 4 | MCF3,[#6]=&!@[#6&!H0]-&!@[#6](=[#8])-&!@[#7] 5 | MCF4,"[H]C([H])([#6])[F,Cl,Br,I]" 6 | MCF5,[#6]1-[#8]-[#6]-1 7 | MCF6,[#6]-[#7]=[#6]=[#8] 8 | MCF7,[#6&!H0]=[#8] 9 | MCF8,"[#6](=&!@[#7&!H0])-&!@[#6,#7,#8,#16]" 10 | MCF9,[#6]1-[#7]-[#6]-1 11 | MCF10,[#6]~&!@[#7]~&!@[#7]~&!@[#6] 12 | MCF11,[#7]=&!@[#7] 13 | MCF12,[H][#6]-1=[#6]([H])-[#6]=[#6](-*)-[#8]-1 14 | MCF13,[H][#6]-1=[#6]([H])-[#6]=[#6](-*)-[#16]-1 15 | MCF14,"[#17,#35,#53]-c(:*):[!#1!#6]:*" 16 | MCF15,[H][#7]([H])-[#6]-1=[#6]-[#6]=[#6]-[#6]=[#6]-1 17 | MCF16,[#16]~[#16] 18 | MCF17,[#7]~&!@[#7]~&!@[#7] 19 | MCF18,[#7]-&!@[#6&!H0&!H1]-&!@[#7] 20 | MCF19,[#6&!H0](-&!@[#8])-&!@[#8] 21 | MCF20,[#35].[#35].[#35] 22 | MCF21,[#17].[#17].[#17].[#17] 23 | MCF22,[#9].[#9].[#9].[#9].[#9].[#9].[#9] 24 | -------------------------------------------------------------------------------- /genmol/ORGAN/test.py: -------------------------------------------------------------------------------- 1 | from Data import * 2 | import unittest 3 | import tqdm as tqdm 4 | from Metrics_Reward import * 5 | from Model import ORGAN 6 | 7 | class test_metrics(unittest.TestCase): 8 | test = ['Oc1ccccc1-c1cccc2cnccc12', 9 | 'COc1cccc(NC(=O)Cc2coc3ccc(OC)cc23)c1'] 10 | test_sf = ['COCc1nnc(NC(=O)COc2ccc(C(C)(C)C)cc2)s1', 11 | 'O=C(C1CC2C=CC1C2)N1CCOc2ccccc21', 12 | 'Nc1c(Br)cccc1C(=O)Nc1ccncn1'] 13 | gen = ['CNC', 'Oc1ccccc1-c1cccc2cnccc12', 14 | 'INVALID', 'CCCP', 15 | 'Cc1noc(C)c1CN(C)C(=O)Nc1cc(F)cc(F)c1', 16 | 'Cc1nc(NCc2ccccc2)no1-c1ccccc1'] 17 | target = {'valid': 2 / 3, 18 | 'unique@3': 1.0, 19 | 'FCD/Test': 52.58371754126664, 20 | 'SNN/Test': 0.3152585653588176, 21 | 'Frag/Test': 0.3, 22 | 'Scaf/Test': 0.5, 23 | 'IntDiv': 0.7189187309761661, 24 | 'Filters': 0.75, 25 | 'logP': 4.9581881764518005, 26 | 'SA': 0.5086898026154574, 27 | 'QED': 0.045033731661603064, 28 | 'NP': 0.2902816615644048, 29 | 'weight': 14761.927533455337} 30 | 31 | def test_get_all_metrics_multiprocess(self): 32 | metrics = get_all_metrics(test_data, samples, k=3) 33 | fail = set() 34 | for metric in self.target: 35 | if not np.allclose(metrics[metric], self.target[metric]): 36 | warnings.warn( 37 | "Metric `{}` value does not match expected " 38 | "value. Got {}, expected {}".format(metric, 39 | metrics[metric], 40 | self.target[metric]) 41 | ) 42 | fail.add(metric) 43 | assert len(fail) == 0, f"Some metrics didn't pass tests: {fail}" 44 | 45 | def test_get_all_metrics_scaffold(self): 46 | get_all_metrics(self.test, self.gen, 47 | test_scaffolds=self.test_sf, 48 | k=3, n_jobs=2) 49 | mols = ['CCNC', 'CCC', 'INVALID', 'CCC'] 50 | assert np.allclose(fraction_valid(mols), 3 / 4), "Failed valid" 51 | assert np.allclose(fraction_unique(mols, check_validity=False), 52 | 3 / 4), "Failed unique" 53 | assert np.allclose(fraction_unique(mols, k=2), 1), "Failed unique" 54 | mols = [Chem.MolFromSmiles(x) for x in mols] 55 | assert np.allclose(fraction_valid(mols), 3 / 4), "Failed valid" 56 | assert np.allclose(fraction_unique(mols, check_validity=False), 57 | 3 / 4), "Failed unique" 58 | assert np.allclose(fraction_unique(mols, k=2), 1), "Failed unique" 59 | 60 | def sampler(model): 61 | n_samples = 100000 62 | samples = [] 63 | with tqdm(total=n_samples, desc='Generating Samples')as T: 64 | while n_samples > 0: 65 | current_samples = model.sample(min(n_samples, 64), max_length=100) 66 | samples.extend(current_samples) 67 | n_samples -= len(current_samples) 68 | T.update(len(current_samples)) 69 | 70 | return samples 71 | 72 | 73 | def evaluate(test, samples, test_scaffolds=None, ptest=None, ptest_scaffolds=None): 74 | gen = samples 75 | k = [50, 99] 76 | n_jobs = 1 77 | batch_size = 20 78 | ptest = ptest 79 | ptest_scaffolds = 20 80 | pool = None 81 | gpu = None 82 | metrics = get_all_metrics(test, gen, k, n_jobs, device, batch_size, test_scaffolds, ptest, ptest_scaffolds) 83 | for name, value in metrics.items(): 84 | print('{}, {}'.format(name, value)) 85 | 86 | 87 | model = ORGAN() 88 | samples = sampler(model) 89 | model = model.to(device) 90 | evaluate(test_data, samples) 91 | -------------------------------------------------------------------------------- /genmol/aae/data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | 4 | data = pd.read_csv('C:/Users/ASUS\Desktop/intern things/dataset_iso_v1.csv') 5 | train_data1 = data[data['SPLIT'] == 'train'] 6 | train_data_smiles2 = (train_data1["SMILES"].squeeze()).astype(str).tolist() 7 | train_data = train_data_smiles2 8 | 9 | chars = set() 10 | for string in train_data: 11 | chars.update(string) 12 | all_sys = sorted(list(chars)) + ['', '', '', ''] 13 | vocab = all_sys 14 | c2i = {c: i for i, c in enumerate(all_sys)} 15 | i2c = {i: c for i, c in enumerate(all_sys)} 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | vector = torch.eye(len(c2i)) 18 | 19 | 20 | def char2id(char): 21 | if char not in c2i: 22 | return c2i[''] 23 | else: 24 | return c2i[char] 25 | 26 | 27 | def id2char(id): 28 | if id not in i2c: 29 | return i2c[32] 30 | else: 31 | return i2c[id] 32 | 33 | def string2ids(string,add_bos=False, add_eos=False): 34 | ids = [char2id(c) for c in string] 35 | if add_bos: 36 | ids = [c2i['']] + ids 37 | if add_eos: 38 | ids = ids + [c2i['']] 39 | return ids 40 | def ids2string(ids, rem_bos=True, rem_eos=True): 41 | if len(ids) == 0: 42 | return '' 43 | if rem_bos and ids[0] == c2i['']: 44 | ids = ids[1:] 45 | if rem_eos and ids[-1] == c2i['']: 46 | ids = ids[:-1] 47 | string = ''.join([id2char(id) for id in ids]) 48 | return string 49 | def string2tensor(string, device='model'): 50 | ids = string2ids(string, add_bos=True, add_eos=True) 51 | tensor = torch.tensor(ids, dtype=torch.long,device=device if device == 'model' else device) 52 | return tensor 53 | 54 | vector = torch.eye(len(c2i)) -------------------------------------------------------------------------------- /genmol/aae/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import pandas as pd 4 | from torch.nn.utils.rnn import pad_sequence 5 | import torch.nn as nn 6 | 7 | 8 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 9 | 10 | 11 | from data import * 12 | 13 | 14 | emb_dim = 30 15 | hidden_dim = 64 16 | latent_dim = 4 17 | disc_input = 64 18 | disc_output = 84 19 | batch_size = 50 20 | 21 | 22 | class encoder(nn.Module): 23 | def __init__(self, vocab, emb_dim, hidden_dim, latent_dim): 24 | super(encoder, self).__init__() 25 | self.hidden_dim = hidden_dim 26 | self.latent_dim = latent_dim 27 | self.emb_dim = emb_dim 28 | self.vocab = vocab 29 | 30 | self.embeddings_layer = nn.Embedding(len(vocab), emb_dim, padding_idx=c2i['']) 31 | 32 | self.rnn = nn.LSTM(emb_dim, hidden_dim) 33 | self.fc = nn.Linear(hidden_dim, latent_dim) 34 | self.relu = nn.ReLU() 35 | nn.Drop = nn.Dropout(p=0.25) 36 | 37 | def forward(self, x, lengths): 38 | batch_size = x.shape[0] 39 | 40 | x = self.embeddings_layer(x) 41 | x = pack_padded_sequence(x, lengths, batch_first=True) 42 | output, (_, x) = self.rnn(x) 43 | 44 | x = x.permute(1, 2, 0).view(batch_size, -1) 45 | x = self.fc(x) 46 | state = self.relu(x) 47 | return state 48 | 49 | 50 | class decoder(nn.Module): 51 | def __init__(self, vocab, emb_dim, latent_dim, hidden_dim): 52 | super(decoder, self).__init__() 53 | self.latent_dim = latent_dim 54 | self.hidden_dim = hidden_dim 55 | self.emb_dim = emb_dim 56 | self.vocab = vocab 57 | 58 | self.latent = nn.Linear(latent_dim, hidden_dim) 59 | self.embeddings_layer = nn.Embedding(len(vocab), emb_dim, padding_idx=c2i['']) 60 | self.rnn = nn.LSTM(emb_dim, hidden_dim, batch_first=True) 61 | self.fc = nn.Linear(hidden_dim, len(vocab)) 62 | 63 | def forward(self, x, lengths, state, is_latent_state=False): 64 | if is_latent_state: 65 | c0 = self.latent(state) 66 | 67 | c0 = c0.unsqueeze(0) 68 | h0 = torch.zeros_like(c0) 69 | 70 | state = (h0, c0) 71 | 72 | x = self.embeddings_layer(x) 73 | 74 | x = pack_padded_sequence(x, lengths, batch_first=True) 75 | 76 | x, state = self.rnn(x, state) 77 | 78 | x, lengths = pad_packed_sequence(x, batch_first=True) 79 | x = self.fc(x) 80 | 81 | return x, lengths, state 82 | 83 | 84 | class Discriminator(nn.Module): 85 | def __init__(self, latent_dim, disc_input, disc_output): 86 | super(Discriminator, self).__init__() 87 | self.latent_dim = latent_dim 88 | self.disc_input = disc_input 89 | self.disc_output = disc_output 90 | 91 | self.lin1 = nn.Linear(latent_dim, disc_input) 92 | self.lin2 = nn.Linear(disc_input, disc_output) 93 | self.lin3 = nn.Linear(disc_output, 1) 94 | self.sig = nn.Sigmoid() 95 | 96 | def forward(self, x): 97 | x = self.lin1(x) 98 | x = self.lin2(x) 99 | x = self.lin3(x) 100 | 101 | x = self.sig(x) 102 | return x 103 | class AAE(nn.Module): 104 | def __init__(self): 105 | super(AAE,self).__init__() 106 | self.encoder = encoder(vocab,emb_dim,hidden_dim,latent_dim) 107 | self.decoder = decoder(vocab,emb_dim,latent_dim,hidden_dim) 108 | self.discriminator = Discriminator(latent_dim,disc_input,disc_output) 109 | -------------------------------------------------------------------------------- /genmol/aae/run.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | from model import * 3 | from train import * 4 | from sample import * 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | model = AAE().to(device) 8 | fit(model,train_data) 9 | model.eval() 10 | get_samples(model) 11 | 12 | -------------------------------------------------------------------------------- /genmol/aae/sample.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from tqdm import tqdm 3 | from data import * 4 | def sample(model,n_batch, max_len=100): 5 | with torch.no_grad(): 6 | samples = [] 7 | lengths = torch.zeros(n_batch, dtype=torch.long, device=device) 8 | state = sample_latent(n_batch) 9 | prevs = torch.empty(n_batch, 1, dtype=torch.long, device=device).fill_(c2i[""]) 10 | one_lens = torch.ones(n_batch, dtype=torch.long, device=device) 11 | is_end = torch.zeros(n_batch, dtype=torch.uint8, device=device) 12 | for i in range(max_len): 13 | logits, _, state = model.decoder(prevs, one_lens, state, i == 0) 14 | currents = torch.argmax(logits, dim=-1) 15 | is_end[currents.view(-1) == c2i[""]] = 1 16 | if is_end.sum() == max_len: 17 | break 18 | 19 | currents[is_end, :] = c2i[""] 20 | samples.append(currents) 21 | lengths[~is_end] += 1 22 | prevs = currents 23 | if len(samples): 24 | samples = torch.cat(samples, dim=-1) 25 | samples = [tensor2string(t[:l]) for t, l in zip(samples, lengths)] 26 | else: 27 | samples = ['' for _ in range(n_batch)] 28 | return samples 29 | 30 | 31 | def get_samples(model): 32 | samples = [] 33 | n = 300 34 | max_len = 100 35 | with tqdm(total=300, desc='Generating samples') as T: 36 | while n > 0: 37 | current_samples = sample(model,min(n, batch_size), max_len) 38 | samples.extend(current_samples) 39 | n -= len(current_samples) 40 | T.update(len(current_samples)) 41 | print(samples) 42 | 43 | def tensor2string(tensor): 44 | ids = tensor.tolist() 45 | string = ids2string(ids, rem_bos=True, rem_eos=True) 46 | return string 47 | def sample_latent(n): 48 | return torch.randn(n,latent_dim) -------------------------------------------------------------------------------- /genmol/aae/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | 6 | import torch.nn.functional as F 7 | from model import * 8 | 9 | def pretrain(model, train_loader): 10 | criterion = nn.CrossEntropyLoss() 11 | optimizer = torch.optim.Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=0.001) 12 | model.zero_grad() 13 | for epoch in range(4): 14 | if optimizer is None: 15 | model.train() 16 | else: 17 | model.eval() 18 | for i, (encoder_inputs, decoder_inputs, decoder_targets) in enumerate(train_loader): 19 | encoder_inputs = (data.to(device) for data in encoder_inputs) 20 | decoder_inputs = (data.to(device) for data in decoder_inputs) 21 | decoder_targets = (data.to(device) for data in decoder_targets) 22 | 23 | latent_code = model.encoder(*encoder_inputs) 24 | decoder_output, decoder_output_lengths, states = model.decoder(*decoder_inputs, latent_code, 25 | is_latent_state=True) 26 | 27 | decoder_outputs = torch.cat([t[:l] for t, l in zip(decoder_output, decoder_output_lengths)], dim=0) 28 | decoder_targets = torch.cat([t[:l] for t, l in zip(*decoder_targets)], dim=0) 29 | loss = criterion(decoder_outputs, decoder_targets) 30 | 31 | if optimizer is not None: 32 | optimizer.zero_grad() 33 | loss.backward() 34 | optimizer.step() 35 | 36 | 37 | def train(model, train_loader): 38 | criterion = {"enc": nn.CrossEntropyLoss(), "gen": lambda t: -torch.mean(F.logsigmoid(t)),"disc": nn.BCEWithLogitsLoss()} 39 | 40 | optimizers = {'auto': torch.optim.Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=0.001), 41 | 'gen': torch.optim.Adam(model.encoder.parameters(), lr=0.001), 42 | 'disc': torch.optim.Adam(model.discriminator.parameters(), lr=0.001)} 43 | 44 | model.zero_grad() 45 | for epoch in range(10): 46 | if optimizers is None: 47 | model.train() 48 | else: 49 | model.eval() 50 | 51 | for i, (encoder_inputs, decoder_inputs, decoder_targets) in enumerate(train_loader): 52 | encoder_inputs = (data.to(device) for data in encoder_inputs) 53 | decoder_inputs = (data.to(device) for data in decoder_inputs) 54 | decoder_targets = (data.to(device) for data in decoder_targets) 55 | 56 | latent_code = model.encoder(*encoder_inputs) 57 | decoder_output, decoder_output_lengths, states = model.decoder(*decoder_inputs, latent_code, 58 | is_latent_state=True) 59 | discriminator_output = model.discriminator(latent_code) 60 | 61 | decoder_outputs = torch.cat([t[:l] for t, l in zip(decoder_output, decoder_output_lengths)], dim=0) 62 | decoder_targets = torch.cat([t[:l] for t, l in zip(*decoder_targets)], dim=0) 63 | 64 | autoencoder_loss = criterion["enc"](decoder_outputs, decoder_targets) 65 | generation_loss = criterion["gen"](discriminator_output) 66 | 67 | if i % 2 == 0: 68 | discriminator_input = torch.randn(batch_size, latent_dim) 69 | discriminator_output = model.discriminator(discriminator_input) 70 | discriminator_targets = torch.ones(batch_size, 1) 71 | else: 72 | discriminator_targets = torch.zeros(batch_size, 1) 73 | discriminator_loss = criterion["disc"](discriminator_output, discriminator_targets) 74 | 75 | if optimizers is not None: 76 | optimizers["auto"].zero_grad() 77 | autoencoder_loss.backward(retain_graph=True) 78 | optimizers["auto"].step() 79 | 80 | optimizers["gen"].zero_grad() 81 | autoencoder_loss.backward(retain_graph=True) 82 | optimizers["gen"].step() 83 | 84 | optimizers["disc"].zero_grad() 85 | autoencoder_loss.backward(retain_graph=True) 86 | optimizers["disc"].step() 87 | 88 | def fit(model,train_data): 89 | train_loader = get_dataloader(model, train_data, collate_fn=None, shuffle=True) 90 | pretrain(model,train_loader) 91 | train(model,train_loader) 92 | 93 | def get_collate_device(model): 94 | return device 95 | def get_dataloader(model, data, collate_fn=None, shuffle=True): 96 | if collate_fn is None: 97 | collate_fn = get_collate_fn(model) 98 | return DataLoader(data, batch_size= batch_size,shuffle=shuffle,collate_fn=collate_fn) 99 | 100 | 101 | def get_collate_fn(model): 102 | device = get_collate_device(model) 103 | 104 | def collate(data): 105 | data.sort(key=lambda x: len(x), reverse=True) 106 | 107 | tensors = [string2tensor(string, device=device) for string in data] 108 | lengths = torch.tensor([len(t) for t in tensors], dtype=torch.long, device=device) 109 | 110 | encoder_inputs = pad_sequence(tensors, batch_first=True, padding_value=c2i[""]) 111 | encoder_input_lengths = lengths - 2 112 | 113 | decoder_inputs = pad_sequence([t[:-1] for t in tensors], batch_first=True, padding_value=c2i[""]) 114 | decoder_input_lengths = lengths - 1 115 | decoder_targets = pad_sequence([t[1:] for t in tensors], batch_first=True, padding_value=c2i[""]) 116 | decoder_target_lengths = lengths - 1 117 | return (encoder_inputs, encoder_input_lengths), (decoder_inputs, decoder_input_lengths), (decoder_targets, decoder_target_lengths) 118 | 119 | return collate -------------------------------------------------------------------------------- /genmol/models.txt: -------------------------------------------------------------------------------- 1 | Char-Rnn 2 | ORGAN 3 | MOLGAN 4 | VAE 5 | AAE 6 | JTVAE 7 | Release 8 | -------------------------------------------------------------------------------- /genmol/vae/data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | 4 | 5 | data = pd.read_csv('C:/Users/ASUS\Desktop/intern things/dataset_iso_v1.csv') 6 | train_data1 = data[data['SPLIT'] == 'train'] 7 | train_data_smiles2 = (train_data1["SMILES"].squeeze()).astype(str).tolist() 8 | train_data = train_data_smiles2 9 | 10 | chars = set() 11 | for string in train_data: 12 | chars.update(string) 13 | all_sys = sorted(list(chars)) + ['', '', '', ''] 14 | vocab = all_sys 15 | c2i = {c: i for i, c in enumerate(all_sys)} 16 | i2c = {i: c for i, c in enumerate(all_sys)} 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | vector = torch.eye(len(c2i)) 19 | 20 | 21 | def char2id(char): 22 | if char not in c2i: 23 | return c2i[''] 24 | else: 25 | return c2i[char] 26 | 27 | 28 | def id2char(id): 29 | if id not in i2c: 30 | return i2c[32] 31 | else: 32 | return i2c[id] 33 | 34 | def string2ids(string,add_bos=False, add_eos=False): 35 | ids = [char2id(c) for c in string] 36 | if add_bos: 37 | ids = [c2i['']] + ids 38 | if add_eos: 39 | ids = ids + [c2i['']] 40 | return ids 41 | def ids2string(ids, rem_bos=True, rem_eos=True): 42 | if len(ids) == 0: 43 | return '' 44 | if rem_bos and ids[0] == c2i['']: 45 | ids = ids[1:] 46 | if rem_eos and ids[-1] == c2i['']: 47 | ids = ids[:-1] 48 | string = ''.join([id2char(id) for id in ids]) 49 | return string 50 | def string2tensor(string, device='model'): 51 | ids = string2ids(string, add_bos=True, add_eos=True) 52 | tensor = torch.tensor(ids, dtype=torch.long,device=device if device == 'model' else device) 53 | return tensor 54 | tensor = [string2tensor(string, device=device) for string in train_data] 55 | vector = torch.eye(len(c2i)) 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /genmol/vae/run.py: -------------------------------------------------------------------------------- 1 | 2 | from trainer import * 3 | from vae_model import VAE 4 | from data import * 5 | from samples import * 6 | 7 | model = VAE(vocab,vector).to(device) 8 | fit(model, train_data) 9 | model.eval() 10 | sample = sample.take_samples(model,n_batch) 11 | print(sample) -------------------------------------------------------------------------------- /genmol/vae/samples.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tqdm import tqdm 3 | import pandas as pd 4 | n_samples = 3000 5 | n_jobs = 1 6 | max_len = 100 7 | 8 | class sample(): 9 | def take_samples(model,n_batch): 10 | n = n_samples 11 | samples = [] 12 | with tqdm(total=n_samples, desc='Generating samples') as T: 13 | while n > 0: 14 | current_samples = model.sample(min(n, n_batch), max_len) 15 | samples.extend(current_samples) 16 | n -= len(current_samples) 17 | T.update(len(current_samples)) 18 | samples = pd.DataFrame(samples, columns=['SMILES']) 19 | return samples -------------------------------------------------------------------------------- /genmol/vae/trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | from torch.nn.utils import clip_grad_norm_ 6 | import math 7 | import numpy as np 8 | from collections import UserList, defaultdict 9 | n_last = 1000 10 | n_batch = 32 11 | kl_start = 0 12 | kl_w_start = 0.0 13 | kl_w_end = 1.0 14 | n_epoch = 50 15 | n_workers = 0 16 | 17 | clip_grad = 50 18 | lr_start = 0.003 19 | lr_n_period = 10 20 | lr_n_mult = 1 21 | lr_end = 3 * 1e-4 22 | lr_n_restarts = 6 23 | from data import * 24 | 25 | def _n_epoch(): 26 | return sum(lr_n_period * (lr_n_mult ** i) for i in range(lr_n_restarts)) 27 | 28 | def _train_epoch(model, epoch, train_loader, kl_weight, optimizer=None): 29 | if optimizer is None: 30 | model.eval() 31 | else: 32 | model.train() 33 | 34 | kl_loss_values = CircularBuffer(n_last) 35 | recon_loss_values = CircularBuffer(n_last) 36 | loss_values = CircularBuffer(n_last) 37 | for i, input_batch in enumerate(train_loader): 38 | input_batch = tuple(data.to(device) for data in input_batch) 39 | 40 | #forward 41 | kl_loss, recon_loss = model(input_batch) 42 | loss = kl_weight * kl_loss + recon_loss 43 | #backward 44 | if optimizer is not None: 45 | optimizer.zero_grad() 46 | loss.backward() 47 | clip_grad_norm_(get_optim_params(model),clip_grad) 48 | optimizer.step() 49 | 50 | kl_loss_values.add(kl_loss.item()) 51 | recon_loss_values.add(recon_loss.item()) 52 | loss_values.add(loss.item()) 53 | lr = (optimizer.param_groups[0]['lr'] if optimizer is not None else None) 54 | 55 | #update train_loader 56 | kl_loss_value = kl_loss_values.mean() 57 | recon_loss_value = recon_loss_values.mean() 58 | loss_value = loss_values.mean() 59 | postfix = [f'loss={loss_value:.5f}',f'(kl={kl_loss_value:.5f}',f'recon={recon_loss_value:.5f})',f'klw={kl_weight:.5f} lr={lr:.5f}'] 60 | postfix = {'epoch': epoch,'kl_weight': kl_weight,'lr': lr,'kl_loss': kl_loss_value,'recon_loss': recon_loss_value,'loss': loss_value,'mode': 'Eval' if optimizer is None else 'Train'} 61 | return postfix 62 | 63 | def _train(model, train_loader, val_loader=None, logger=None): 64 | optimizer = optim.Adam(get_optim_params(model),lr= lr_start) 65 | 66 | lr_annealer = CosineAnnealingLRWithRestart(optimizer) 67 | 68 | model.zero_grad() 69 | for epoch in range(n_epoch): 70 | 71 | kl_annealer = KLAnnealer(n_epoch) 72 | kl_weight = kl_annealer(epoch) 73 | postfix = _train_epoch(model, epoch,train_loader, kl_weight, optimizer) 74 | lr_annealer.step() 75 | def fit(model, train_data, val_data=None): 76 | logger = Logger() if False is not None else None 77 | train_loader = get_dataloader(model,train_data,shuffle=True) 78 | 79 | 80 | 81 | val_loader = None if val_data is None else get_dataloader(model, val_data, shuffle=False) 82 | _train(model, train_loader, val_loader, logger) 83 | return model 84 | def get_collate_device(model): 85 | return model.device 86 | def get_dataloader(model, train_data, collate_fn=None, shuffle=True): 87 | if collate_fn is None: 88 | collate_fn = get_collate_fn(model) 89 | print(collate_fn) 90 | return DataLoader(train_data, batch_size=n_batch, shuffle=shuffle, num_workers=n_workers, collate_fn=collate_fn) 91 | 92 | def get_collate_fn(model): 93 | device = get_collate_device(model) 94 | 95 | def collate(train_data): 96 | train_data.sort(key=len, reverse=True) 97 | tensors = [string2tensor(string, device=device) for string in train_data] 98 | return tensors 99 | 100 | return collate 101 | 102 | def get_optim_params(model): 103 | return (p for p in model.parameters() if p.requires_grad) 104 | 105 | class KLAnnealer: 106 | def __init__(self,n_epoch): 107 | self.i_start = kl_start 108 | self.w_start = kl_w_start 109 | self.w_max = kl_w_end 110 | self.n_epoch = n_epoch 111 | 112 | 113 | self.inc = (self.w_max - self.w_start) / (self.n_epoch - self.i_start) 114 | 115 | def __call__(self, i): 116 | k = (i - self.i_start) if i >= self.i_start else 0 117 | return self.w_start + k * self.inc 118 | 119 | 120 | 121 | class CosineAnnealingLRWithRestart(_LRScheduler): 122 | def __init__(self , optimizer): 123 | self.n_period = lr_n_period 124 | self.n_mult = lr_n_mult 125 | self.lr_end = lr_end 126 | 127 | self.current_epoch = 0 128 | self.t_end = self.n_period 129 | 130 | # Also calls first epoch 131 | super().__init__(optimizer, -1) 132 | 133 | def get_lr(self): 134 | return [self.lr_end + (base_lr - self.lr_end) * 135 | (1 + math.cos(math.pi * self.current_epoch / self.t_end)) / 2 136 | for base_lr in self.base_lrs] 137 | 138 | def step(self, epoch=None): 139 | if epoch is None: 140 | epoch = self.last_epoch + 1 141 | self.last_epoch = epoch 142 | self.current_epoch += 1 143 | 144 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 145 | param_group['lr'] = lr 146 | 147 | if self.current_epoch == self.t_end: 148 | self.current_epoch = 0 149 | self.t_end = self.n_mult * self.t_end 150 | 151 | 152 | 153 | 154 | class CircularBuffer: 155 | def __init__(self, size): 156 | self.max_size = size 157 | self.data = np.zeros(self.max_size) 158 | self.size = 0 159 | self.pointer = -1 160 | 161 | def add(self, element): 162 | self.size = min(self.size + 1, self.max_size) 163 | self.pointer = (self.pointer + 1) % self.max_size 164 | self.data[self.pointer] = element 165 | return element 166 | 167 | def last(self): 168 | assert self.pointer != -1, "Can't get an element from an empty buffer!" 169 | return self.data[self.pointer] 170 | 171 | def mean(self): 172 | return self.data.mean() 173 | 174 | 175 | class Logger(UserList): 176 | def __init__(self, data=None): 177 | super().__init__() 178 | self.sdata = defaultdict(list) 179 | for step in (data or []): 180 | self.append(step) 181 | 182 | def __getitem__(self, key): 183 | if isinstance(key, int): 184 | return self.data[key] 185 | elif isinstance(key, slice): 186 | return Logger(self.data[key]) 187 | else: 188 | ldata = self.sdata[key] 189 | if isinstance(ldata[0], dict): 190 | return Logger(ldata) 191 | else: 192 | return ldata 193 | 194 | def append(self, step_dict): 195 | super().append(step_dict) 196 | for k, v in step_dict.items(): 197 | self.sdata[k].append(v) 198 | 199 | 200 | 201 | 202 | -------------------------------------------------------------------------------- /genmol/vae/vae_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | q_bidir = True 7 | q_d_h = 256 8 | q_n_layers = 1 9 | q_dropout = 0.5 10 | d_n_layers = 3 11 | d_dropout = 0 12 | d_z = 128 13 | d_d_h = 512 14 | from data import * 15 | class VAE(nn.Module): 16 | def __init__(self,vocab,vector): 17 | super().__init__() 18 | self.vocabulary = vocab 19 | self.vector = vector 20 | 21 | n_vocab, d_emb = len(vocab), vector.size(1) 22 | self.x_emb = nn.Embedding(n_vocab, d_emb, c2i['']) 23 | self.x_emb.weight.data.copy_(vector) 24 | 25 | #ENCODER 26 | 27 | self.encoder_rnn = nn.GRU(d_emb,q_d_h,num_layers=q_n_layers,batch_first=True,dropout=q_dropout if q_n_layers > 1 else 0,bidirectional=q_bidir) 28 | q_d_last = q_d_h * (2 if q_bidir else 1) 29 | self.q_mu = nn.Linear(q_d_last, d_z) 30 | self.q_logvar = nn.Linear(q_d_last, d_z) 31 | 32 | 33 | 34 | # Decoder 35 | self.decoder_rnn = nn.GRU(d_emb + d_z,d_d_h,num_layers=d_n_layers,batch_first=True,dropout=d_dropout if d_n_layers > 1 else 0) 36 | self.decoder_latent = nn.Linear(d_z, d_d_h) 37 | self.decoder_fullyc = nn.Linear(d_d_h, n_vocab) 38 | 39 | 40 | 41 | # Grouping the model's parameters 42 | self.encoder = nn.ModuleList([self.encoder_rnn,self.q_mu,self.q_logvar]) 43 | self.decoder = nn.ModuleList([self.decoder_rnn,self.decoder_latent,self.decoder_fullyc]) 44 | self.vae = nn.ModuleList([self.x_emb,self.encoder,self.decoder]) 45 | 46 | 47 | 48 | @property 49 | def device(self): 50 | return next(self.parameters()).device 51 | 52 | def string2tensor(self, string, device='model'): 53 | ids = string2ids(string, add_bos=True, add_eos=True) 54 | tensor = torch.tensor(ids, dtype=torch.long,device=self.device if device == 'model' else device) 55 | return tensor 56 | 57 | def tensor2string(self, tensor): 58 | ids = tensor.tolist() 59 | string = ids2string(ids, rem_bos=True, rem_eos=True) 60 | return string 61 | 62 | def forward(self,x): 63 | z, kl_loss = self.forward_encoder(x) 64 | recon_loss = self.forward_decoder(x, z) 65 | print("forward") 66 | return kl_loss, recon_loss 67 | 68 | def forward_encoder(self,x): 69 | x = [self.x_emb(i_x) for i_x in x] 70 | x = nn.utils.rnn.pack_sequence(x) 71 | _, h = self.encoder_rnn(x, None) 72 | h = h[-(1 + int(self.encoder_rnn.bidirectional)):] 73 | h = torch.cat(h.split(1), dim=-1).squeeze(0) 74 | mu, logvar = self.q_mu(h), self.q_logvar(h) 75 | eps = torch.randn_like(mu) 76 | z = mu + (logvar / 2).exp() * eps 77 | kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean() 78 | return z, kl_loss 79 | 80 | def forward_decoder(self,x, z): 81 | lengths = [len(i_x) for i_x in x] 82 | x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value= c2i['']) 83 | x_emb = self.x_emb(x) 84 | z_0 = z.unsqueeze(1).repeat(1, x_emb.size(1), 1) 85 | x_input = torch.cat([x_emb, z_0], dim=-1) 86 | x_input = nn.utils.rnn.pack_padded_sequence(x_input, lengths, batch_first=True) 87 | h_0 = self.decoder_latent(z) 88 | h_0 = h_0.unsqueeze(0).repeat(self.decoder_rnn.num_layers, 1, 1) 89 | output, _ = self.decoder_rnn(x_input, h_0) 90 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) 91 | y = self.decoder_fullyc(output) 92 | 93 | recon_loss = F.cross_entropy(y[:, :-1].contiguous().view(-1, y.size(-1)),x[:, 1:].contiguous().view(-1),ignore_index= c2i['']) 94 | return recon_loss 95 | 96 | 97 | def sample_z_prior(self,n_batch): 98 | return torch.randn(n_batch,self.q_mu.out_features,device= self.x_emb.weight.device) 99 | def sample(self,n_batch, max_len=100, z=None, temp=1.0): 100 | with torch.no_grad(): 101 | if z is None: 102 | z = self.sample_z_prior(n_batch) 103 | z = z.to(self.device) 104 | z_0 = z.unsqueeze(1) 105 | h = self.decoder_latent(z) 106 | h = h.unsqueeze(0).repeat(self.decoder_rnn.num_layers, 1, 1) 107 | w = torch.tensor(c2i[''], device=self.device).repeat(n_batch) 108 | x = torch.tensor([c2i['']], device=device).repeat(n_batch, max_len) 109 | x[:, 0] = c2i[''] 110 | end_pads = torch.tensor([max_len], device=self.device).repeat(n_batch) 111 | eos_mask = torch.zeros(n_batch, dtype=torch.uint8, device=self.device) 112 | 113 | 114 | for i in range(1, max_len): 115 | x_emb = self.x_emb(w).unsqueeze(1) 116 | x_input = torch.cat([x_emb, z_0], dim=-1) 117 | 118 | o, h = self.decoder_rnn(x_input, h) 119 | y = self.decoder_fullyc(o.squeeze(1)) 120 | y = F.softmax(y / temp, dim=-1) 121 | 122 | w = torch.multinomial(y, 1)[:, 0] 123 | x[~eos_mask, i] = w[~eos_mask] 124 | i_eos_mask = ~eos_mask & (w == c2i['']) 125 | end_pads[i_eos_mask] = i + 1 126 | eos_mask = eos_mask | i_eos_mask 127 | 128 | 129 | new_x = [] 130 | for i in range(x.size(0)): 131 | new_x.append(x[i, :end_pads[i]]) 132 | 133 | 134 | return [self.tensor2string(i_x) for i_x in new_x] --------------------------------------------------------------------------------