├── LICENSE
├── README.md
└── genmol
├── JTVAE
├── model.py
├── preprocess.py
├── sample.py
├── savedmodel.pth
├── train.py
└── train.txt
├── ORGAN
├── Data.py
├── Metrics_Reward.py
├── Model.py
├── NP_Score
│ ├── README
│ ├── __pycache__
│ │ └── npscorer.cpython-37.pyc
│ ├── npscorer.py
│ └── publicnp.model.gz
├── RewardMetrics.py
├── Run.py
├── SA_Score
│ ├── README
│ ├── UnitTestSAScore.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── sascorer.cpython-37.pyc
│ ├── data
│ │ └── zim.100.txt
│ ├── fpscores.pkl.gz
│ └── sascorer.py
├── Trainer.py
├── mcf.csv
├── test.py
└── wehi_pains.csv
├── aae
├── data.py
├── model.py
├── run.py
├── sample.py
└── train.py
├── models.txt
└── vae
├── data.py
├── run.py
├── samples.py
├── trainer.py
└── vae_model.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 bayeslabs
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GenMol ( Molecular Structure Generation)
2 | This is a library which curates different molecular generation methods with machine learning. You can use this library to advance your research in Drug discovery and Material Discovery.
3 |
4 | We implemented following algorithms using Pytorch for Molecular generations.
5 |
15 |
--------------------------------------------------------------------------------
/genmol/JTVAE/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math, random, sys
4 | from optparse import OptionParser
5 | import pickle
6 | import rdkit
7 | import json
8 | import rdkit.Chem as Chem
9 | from scipy.sparse import csr_matrix
10 | from scipy.sparse.csgraph import minimum_spanning_tree
11 | from collections import defaultdict
12 | import copy
13 | import torch.optim as optim
14 | import torch.optim.lr_scheduler as lr_scheduler
15 | from torch.utils.data import Dataset, DataLoader
16 | from torch.autograd import Variable
17 | import numpy as np
18 | from collections import deque
19 | import os, random
20 | import torch.nn.functional as F
21 | import pdb
22 |
23 | from jvae_preprocess import *
24 |
25 | def get_slots(smiles):
26 | mol = Chem.MolFromSmiles(smiles)
27 | return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
28 |
29 |
30 | def get_molecule(node):
31 | return Chem.MolFromSmiles(node.smiles)
32 |
33 |
34 | class Vocab(object):
35 | benzynes = ['C1=CC=CC=C1', 'C1=CC=NC=C1', 'C1=CC=NN=C1', 'C1=CN=CC=N1', 'C1=CN=CN=C1', 'C1=CN=NC=N1', 'C1=CN=NN=C1', 'C1=NC=NC=N1', 'C1=NN=CN=N1']
36 | penzynes = ['C1=C[NH]C=C1', 'C1=C[NH]C=N1', 'C1=C[NH]N=C1', 'C1=C[NH]N=N1', 'C1=COC=C1', 'C1=COC=N1', 'C1=CON=C1', 'C1=CSC=C1', 'C1=CSC=N1', 'C1=CSN=C1', 'C1=CSN=N1', 'C1=NN=C[NH]1', 'C1=NN=CO1', 'C1=NN=CS1', 'C1=N[NH]C=N1', 'C1=N[NH]N=C1', 'C1=N[NH]N=N1', 'C1=NN=N[NH]1', 'C1=NN=NS1', 'C1=NOC=N1', 'C1=NON=C1', 'C1=NSC=N1', 'C1=NSN=C1']
37 |
38 | def __init__(self, smiles_list,all_trees):
39 | list_d=[]
40 |
41 | for j in range(0,len(all_trees)):
42 | x=[]
43 | x=all_trees[j].nodes
44 |
45 | for i in range(0,len(x)):
46 | m=get_molecule(x[i])
47 | m1=Chem.MolToSmiles(m,kekuleSmiles=False)
48 | list_d.append(m1)
49 |
50 | list_f=list(dict.fromkeys(list_d))
51 | smiles_f=smiles_list+list_f
52 |
53 | self.vocab = smiles_f
54 | self.vmap = {x:i for i,x in enumerate(self.vocab)}
55 | self.slots = [get_slots(smiles) for smiles in self.vocab]
56 | Vocab.benzynes = [s for s in smiles_list if s.count('=') >= 2 and Chem.MolFromSmiles(s).GetNumAtoms() == 6] + ['C1=CCNCC1']
57 | Vocab.penzynes = [s for s in smiles_list if s.count('=') >= 2 and Chem.MolFromSmiles(s).GetNumAtoms() == 5] + ['C1=NCCN1','C1=NNCC1']
58 |
59 |
60 | def get_index(self, smiles):
61 | return self.vmap[smiles]
62 |
63 | def get_smiles(self, idx):
64 | return self.vocab[idx]
65 |
66 | def get_slots(self, idx):
67 | return copy.deepcopy(self.slots[idx])
68 |
69 | def size(self):
70 | return len(self.vocab)
71 |
72 |
73 |
74 | def create_variable(tensor, requires_grad=None):
75 | if requires_grad is None:
76 | return Variable(tensor)
77 | else:
78 | return Variable(tensor, requires_grad=requires_grad)
79 |
80 | def index_select_ND(source, dim, index):
81 | index_size = index.size()
82 | suffix_dim = source.size()[1:]
83 | final_size = index_size + suffix_dim
84 | target = source.index_select(dim, index.view(-1))
85 | return target.view(final_size)
86 |
87 | def GRU(x, h_nei, W_z, W_r, U_r, W_h):
88 | hidden_size = x.size()[-1]
89 | sum_h = h_nei.sum(dim=1)
90 | z_input = torch.cat([x,sum_h], dim=1)
91 | z = torch.sigmoid(W_z(z_input))
92 |
93 | r_1 = W_r(x).view(-1,1,hidden_size)
94 | r_2 = U_r(h_nei)
95 | r = torch.sigmoid(r_1 + r_2)
96 |
97 | gated_h = r * h_nei
98 | sum_gated_h = gated_h.sum(dim=1)
99 | h_input = torch.cat([x,sum_gated_h], dim=1)
100 | pre_h = torch.tanh(W_h(h_input))
101 | new_h = (1.0 - z) * sum_h + z * pre_h
102 | return new_h
103 |
104 |
105 |
106 | ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
107 |
108 | ATOM_FDIM1 = len(ELEM_LIST) + 6 + 5 + 4 + 1
109 | BOND_FDIM1 = 5 + 6
110 | MAX_NB1 = 6
111 |
112 | def onek_encoding_unk1(x, allowable_set):
113 | if x not in allowable_set:
114 | x = allowable_set[-1]
115 | return list(map(lambda s: x == s, allowable_set))
116 |
117 | def atom_features1(atom):
118 | return torch.Tensor(onek_encoding_unk1(atom.GetSymbol(), ELEM_LIST)
119 | + onek_encoding_unk1(atom.GetDegree(), [0,1,2,3,4,5])
120 | + onek_encoding_unk1(atom.GetFormalCharge(), [-1,-2,1,2,0])
121 | + onek_encoding_unk1(int(atom.GetChiralTag()), [0,1,2,3])
122 | + [atom.GetIsAromatic()])
123 |
124 | def bond_features1(bond):
125 | bt = bond.GetBondType()
126 | stereo = int(bond.GetStereo())
127 | fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
128 | fstereo = onek_encoding_unk1(stereo, [0,1,2,3,4,5])
129 | return torch.Tensor(fbond + fstereo)
130 |
131 | class MPN(nn.Module):
132 |
133 | def __init__(self, hidden_size, depth):
134 | super(MPN, self).__init__()
135 | self.hidden_size = int(hidden_size)
136 | self.depth = depth
137 |
138 | self.W_i = nn.Linear(ATOM_FDIM1 + BOND_FDIM1, hidden_size, bias=False)
139 | self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
140 | self.W_o = nn.Linear(ATOM_FDIM1 + hidden_size, hidden_size)
141 |
142 | def forward(self, fatoms, fbonds, agraph, bgraph, scope):
143 | fatoms = create_variable(fatoms)
144 | fbonds = create_variable(fbonds)
145 | agraph = create_variable(agraph)
146 | bgraph = create_variable(bgraph)
147 |
148 | binput = self.W_i(fbonds)
149 | message = F.relu(binput)
150 |
151 | for i in range(self.depth - 1):
152 | nei_message = index_select_ND(message, 0, bgraph)
153 | nei_message = nei_message.sum(dim=1)
154 | nei_message = self.W_h(nei_message)
155 | message = F.relu(binput + nei_message)
156 |
157 | nei_message = index_select_ND(message, 0, agraph)
158 | nei_message = nei_message.sum(dim=1)
159 | ainput = torch.cat([fatoms, nei_message], dim=1)
160 | atom_hiddens = F.relu(self.W_o(ainput))
161 |
162 | max_len = max([x for _,x in scope])
163 | batch_vecs = []
164 | for st,le in scope:
165 | cur_vecs = atom_hiddens[st : st + le].mean(dim=0)
166 | batch_vecs.append( cur_vecs )
167 |
168 | mol_vecs = torch.stack(batch_vecs, dim=0)
169 | return mol_vecs
170 |
171 | @staticmethod
172 | def tensorize(mol_batch):
173 | padding = torch.zeros(ATOM_FDIM1 + BOND_FDIM1)
174 | fatoms,fbonds = [],[padding] #Ensure bond is 1-indexed
175 | in_bonds,all_bonds = [],[(-1,-1)] #Ensure bond is 1-indexed
176 | scope = []
177 | total_atoms = 0
178 |
179 | for smiles in mol_batch:
180 | mol = get_mol(smiles)
181 | n_atoms = mol.GetNumAtoms()
182 |
183 | for atom in mol.GetAtoms():
184 | fatoms.append( atom_features1(atom) )
185 | in_bonds.append([])
186 |
187 | for bond in mol.GetBonds():
188 | a1 = bond.GetBeginAtom()
189 | a2 = bond.GetEndAtom()
190 | x = a1.GetIdx() + total_atoms
191 | y = a2.GetIdx() + total_atoms
192 |
193 | b = len(all_bonds)
194 | all_bonds.append((x,y))
195 | fbonds.append( torch.cat([fatoms[x], bond_features1(bond)], 0) )
196 | in_bonds[y].append(b)
197 |
198 | b = len(all_bonds)
199 | all_bonds.append((y,x))
200 | fbonds.append( torch.cat([fatoms[y], bond_features1(bond)], 0) )
201 | in_bonds[x].append(b)
202 |
203 | scope.append((total_atoms,n_atoms))
204 | total_atoms += n_atoms
205 |
206 | total_bonds = len(all_bonds)
207 | fatoms = torch.stack(fatoms, 0)
208 | fbonds = torch.stack(fbonds, 0)
209 | agraph = torch.zeros(total_atoms,MAX_NB1).long()
210 | bgraph = torch.zeros(total_bonds,MAX_NB1).long()
211 |
212 | for a in range(total_atoms):
213 | for i,b in enumerate(in_bonds[a]):
214 | agraph[a,i] = b
215 |
216 | for b1 in range(1, total_bonds):
217 | x,y = all_bonds[b1]
218 | for i,b2 in enumerate(in_bonds[x]):
219 | if all_bonds[b2][0] != y:
220 | bgraph[b1,i] = b2
221 |
222 | return (fatoms, fbonds, agraph, bgraph, scope)
223 |
224 |
225 |
226 | ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
227 | BOND_FDIM = 5
228 | MAX_NB = 15
229 |
230 | def onek_encoding_unk(x, allowable_set):
231 | if x not in allowable_set:
232 | x = allowable_set[-1]
233 | return list(map(lambda s: x == s, allowable_set))
234 |
235 | def atom_features(atom):
236 | return torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
237 | + onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
238 | + onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
239 | + [atom.GetIsAromatic()])
240 |
241 | def bond_features(bond):
242 | bt = bond.GetBondType()
243 | return torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()])
244 |
245 | class JTMPN(nn.Module):
246 |
247 | def __init__(self, hidden_size, depth):
248 | super(JTMPN, self).__init__()
249 | self.hidden_size = int(hidden_size)
250 | self.depth = depth
251 |
252 | self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
253 | self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
254 | self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
255 |
256 | def forward(self, fatoms, fbonds, agraph, bgraph, scope, tree_message): #tree_message[0] == vec(0)
257 | fatoms = create_variable(fatoms)
258 | fbonds = create_variable(fbonds)
259 | agraph = create_variable(agraph)
260 | bgraph = create_variable(bgraph)
261 |
262 | binput = self.W_i(fbonds)
263 | graph_message = F.relu(binput)
264 |
265 | for i in range(self.depth - 1):
266 | message = torch.cat([tree_message,graph_message], dim=0)
267 | nei_message = index_select_ND(message, 0, bgraph)
268 | nei_message = nei_message.sum(dim=1) #assuming tree_message[0] == vec(0)
269 | nei_message = self.W_h(nei_message)
270 | graph_message = F.relu(binput + nei_message)
271 |
272 | message = torch.cat([tree_message,graph_message], dim=0)
273 | nei_message = index_select_ND(message, 0, agraph)
274 | nei_message = nei_message.sum(dim=1)
275 | ainput = torch.cat([fatoms, nei_message], dim=1)
276 | atom_hiddens = F.relu(self.W_o(ainput))
277 |
278 | mol_vecs = []
279 | for st,le in scope:
280 | mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
281 | mol_vecs.append(mol_vec)
282 |
283 | mol_vecs = torch.stack(mol_vecs, dim=0)
284 | return mol_vecs
285 |
286 | @staticmethod
287 | def tensorize(cand_batch, mess_dict):
288 | fatoms,fbonds = [],[]
289 | in_bonds,all_bonds = [],[]
290 | total_atoms = 0
291 | total_mess = len(mess_dict) + 1 #must include vec(0) padding
292 | scope = []
293 |
294 | for smiles,all_nodes,ctr_node in cand_batch:
295 | mol = Chem.MolFromSmiles(smiles)
296 | Chem.Kekulize(mol) #The original jtnn version kekulizes. Need to revisit why it is necessary
297 | n_atoms = mol.GetNumAtoms()
298 | ctr_bid = ctr_node.idx
299 | for atom in mol.GetAtoms():
300 |
301 | fatoms.append( atom_features(atom) )
302 | in_bonds.append([])
303 |
304 | for bond in mol.GetBonds():
305 |
306 | a1 = bond.GetBeginAtom()
307 | a2 = bond.GetEndAtom()
308 | x = a1.GetIdx() + total_atoms
309 | y = a2.GetIdx() + total_atoms
310 | #Here x_nid,y_nid could be 0
311 | x_nid,y_nid = a1.GetAtomMapNum(),a2.GetAtomMapNum()
312 | x_bid = all_nodes[x_nid - 1].idx if x_nid > 0 else -1
313 | y_bid = all_nodes[y_nid - 1].idx if y_nid > 0 else -1
314 |
315 | bfeature = bond_features(bond)
316 |
317 | b = total_mess + len(all_bonds) #bond idx offseted by total_mess
318 | all_bonds.append((x,y))
319 | fbonds.append( torch.cat([fatoms[x], bfeature], 0) )
320 | in_bonds[y].append(b)
321 |
322 | b = total_mess + len(all_bonds)
323 | all_bonds.append((y,x))
324 | fbonds.append( torch.cat([fatoms[y], bfeature], 0) )
325 | in_bonds[x].append(b)
326 |
327 | if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
328 | if (x_bid,y_bid) in mess_dict:
329 | mess_idx = mess_dict[(x_bid,y_bid)]
330 | in_bonds[y].append(mess_idx)
331 | if (y_bid,x_bid) in mess_dict:
332 | mess_idx = mess_dict[(y_bid,x_bid)]
333 | in_bonds[x].append(mess_idx)
334 |
335 | scope.append((total_atoms,n_atoms))
336 | total_atoms += n_atoms
337 |
338 | total_bonds = len(all_bonds)
339 | fatoms = torch.stack(fatoms, 0)
340 | fbonds = torch.stack(fbonds, 0)
341 | agraph = torch.zeros(total_atoms,MAX_NB).long()
342 | bgraph = torch.zeros(total_bonds,MAX_NB).long()
343 |
344 | for a in range(total_atoms):
345 | for i,b in enumerate(in_bonds[a]):
346 | agraph[a,i] = b
347 |
348 | for b1 in range(total_bonds):
349 | x,y = all_bonds[b1]
350 | for i,b2 in enumerate(in_bonds[x]): #b2 is offseted by total_mess
351 | if b2 < total_mess or all_bonds[b2-total_mess][0] != y:
352 | bgraph[b1,i] = b2
353 |
354 | return (fatoms, fbonds, agraph, bgraph, scope)
355 |
356 |
357 |
358 | def dfs(stack, x, fa_idx):
359 | for y in x.neighbors:
360 | if y.idx == fa_idx: continue
361 | stack.append( (x,y,1) )
362 | dfs(stack, y, x.idx)
363 | stack.append( (y,x,0) )
364 |
365 | def have_slots(fa_slots, ch_slots):
366 | if len(fa_slots) > 2 and len(ch_slots) > 2:
367 | return True
368 | matches = []
369 | for i,s1 in enumerate(fa_slots):
370 | a1,c1,h1 = s1
371 | for j,s2 in enumerate(ch_slots):
372 | a2,c2,h2 = s2
373 | if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4):
374 | matches.append( (i,j) )
375 |
376 | if len(matches) == 0: return False
377 |
378 | fa_match,ch_match = zip(*matches)
379 | if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2: #never remove atom from ring
380 | fa_slots.pop(fa_match[0])
381 | if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2: #never remove atom from ring
382 | ch_slots.pop(ch_match[0])
383 |
384 | return True
385 |
386 | def can_assemble(node_x, node_y):
387 | node_x.nid = 1
388 | node_x.is_leaf = False
389 | set_atommap(node_x.mol, node_x.nid)
390 |
391 | neis = node_x.neighbors + [node_y]
392 | for i,nei in enumerate(neis):
393 | nei.nid = i + 2
394 | nei.is_leaf = (len(nei.neighbors) <= 1)
395 | if nei.is_leaf:
396 | set_atommap(nei.mol, 0)
397 | else:
398 | set_atommap(nei.mol, nei.nid)
399 |
400 | neighbors = [nei for nei in neis if nei.mol.GetNumAtoms() > 1]
401 | neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
402 | singletons = [nei for nei in neis if nei.mol.GetNumAtoms() == 1]
403 | neighbors = singletons + neighbors
404 | cands,aroma_scores = enum_assemble(node_x, neighbors)
405 | return len(cands) > 0# and sum(aroma_scores) >= 0
406 |
407 |
408 |
409 | MAX_NB0 = 15
410 | MAX_DECODE_LEN0 = 100
411 |
412 | class JTNNDecoder(nn.Module):
413 |
414 | def __init__(self, vocab, hidden_size, latent_size, embedding):
415 | super(JTNNDecoder, self).__init__()
416 | self.hidden_size = int(hidden_size)
417 | self.vocab_size = vocab.size()
418 | self.vocab = vocab
419 | self.embedding = embedding
420 | latent_size=int(latent_size)
421 | #GRU Weights
422 | self.W_z = nn.Linear(2 * hidden_size, hidden_size)
423 | self.U_r = nn.Linear(hidden_size, hidden_size, bias=False)
424 | self.W_r = nn.Linear(hidden_size, hidden_size)
425 | self.W_h = nn.Linear(2 * hidden_size, hidden_size)
426 |
427 | #Word Prediction Weights
428 | self.W = nn.Linear(hidden_size + latent_size, hidden_size)
429 |
430 | #Stop Prediction Weights
431 | self.U = nn.Linear(hidden_size + latent_size, hidden_size)
432 | self.U_i = nn.Linear(2 * hidden_size, hidden_size)
433 |
434 | #Output Weights
435 | self.W_o = nn.Linear(hidden_size, self.vocab_size)
436 | self.U_o = nn.Linear(hidden_size, 1)
437 |
438 |
439 | #Loss Functions
440 | self.pred_loss = nn.CrossEntropyLoss(size_average=False)
441 | self.stop_loss = nn.BCEWithLogitsLoss(size_average=False)
442 |
443 | def aggregate(self, hiddens, contexts, x_tree_vecs, mode):
444 | if mode == 'word':
445 | V, V_o = self.W, self.W_o
446 | elif mode == 'stop':
447 | V, V_o = self.U, self.U_o
448 | else:
449 | raise ValueError('aggregate mode is wrong')
450 |
451 | tree_contexts = x_tree_vecs.index_select(0, contexts)
452 | input_vec = torch.cat([hiddens, tree_contexts], dim=-1)
453 | output_vec = F.relu( V(input_vec) )
454 | return V_o(output_vec)
455 |
456 | def forward(self, mol_batch, x_tree_vecs):
457 | pred_hiddens,pred_contexts,pred_targets = [],[],[]
458 | stop_hiddens,stop_contexts,stop_targets = [],[],[]
459 | traces = []
460 | for mol_tree in mol_batch:
461 | s = []
462 | dfs(s, mol_tree.nodes[0], -1)
463 | traces.append(s)
464 | for node in mol_tree.nodes:
465 | node.neighbors = []
466 |
467 | #Predict Root
468 | batch_size = len(mol_batch)
469 | pred_hiddens.append(create_variable(torch.zeros(len(mol_batch),self.hidden_size)))
470 | pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
471 |
472 | pred_contexts.append( create_variable( torch.LongTensor(range(batch_size)) ) )
473 |
474 | max_iter = max([len(tr) for tr in traces])
475 | padding = create_variable(torch.zeros(self.hidden_size), False)
476 | h = {}
477 |
478 | for t in range(max_iter):
479 | prop_list = []
480 | batch_list = []
481 | for i,plist in enumerate(traces):
482 | if t < len(plist):
483 | prop_list.append(plist[t])
484 | batch_list.append(i)
485 |
486 | cur_x = []
487 | cur_h_nei,cur_o_nei = [],[]
488 |
489 | for node_x, real_y, _ in prop_list:
490 | #Neighbors for message passing (target not included)
491 | cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
492 | pad_len = MAX_NB0 - len(cur_nei)
493 | cur_h_nei.extend(cur_nei)
494 | cur_h_nei.extend([padding] * pad_len)
495 |
496 | #Neighbors for stop prediction (all neighbors)
497 | cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors]
498 | pad_len = MAX_NB0 - len(cur_nei)
499 | cur_o_nei.extend(cur_nei)
500 | cur_o_nei.extend([padding] * pad_len)
501 |
502 | #Current clique embedding
503 | cur_x.append(node_x.wid)
504 |
505 |
506 | #Clique embedding
507 | cur_x = create_variable(torch.LongTensor(cur_x))
508 | cur_x = self.embedding(cur_x)
509 |
510 | #Message passing
511 | cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB0,self.hidden_size)
512 | new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
513 |
514 | #Node Aggregate
515 | cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB0,self.hidden_size)
516 | cur_o = cur_o_nei.sum(dim=1)
517 |
518 | #Gather targets
519 | pred_target,pred_list = [],[]
520 | stop_target = []
521 | for i,m in enumerate(prop_list):
522 | node_x,node_y,direction = m
523 | x,y = node_x.idx,node_y.idx
524 | h[(x,y)] = new_h[i]
525 | node_y.neighbors.append(node_x)
526 | if direction == 1:
527 | pred_target.append(node_y.wid)
528 | pred_list.append(i)
529 | stop_target.append(direction)
530 |
531 | #Hidden states for stop prediction
532 | cur_batch = create_variable(torch.LongTensor(batch_list))
533 | stop_hidden = torch.cat([cur_x,cur_o], dim=1)
534 | stop_hiddens.append( stop_hidden )
535 | stop_contexts.append( cur_batch )
536 | stop_targets.extend( stop_target )
537 |
538 | #Hidden states for clique prediction
539 | if len(pred_list) > 0:
540 | batch_list = [batch_list[i] for i in pred_list]
541 | cur_batch = create_variable(torch.LongTensor(batch_list))
542 | pred_contexts.append( cur_batch )
543 |
544 | cur_pred = create_variable(torch.LongTensor(pred_list))
545 | pred_hiddens.append( new_h.index_select(0, cur_pred) )
546 | pred_targets.extend( pred_target )
547 |
548 | #Last stop at root
549 | cur_x,cur_o_nei = [],[]
550 | for mol_tree in mol_batch:
551 | node_x = mol_tree.nodes[0]
552 | cur_x.append(node_x.wid)
553 | cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors]
554 | pad_len = MAX_NB0 - len(cur_nei)
555 | cur_o_nei.extend(cur_nei)
556 | cur_o_nei.extend([padding] * pad_len)
557 |
558 | cur_x = create_variable(torch.LongTensor(cur_x))
559 | cur_x = self.embedding(cur_x)
560 | cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB0,self.hidden_size)
561 | cur_o = cur_o_nei.sum(dim=1)
562 |
563 | stop_hidden = torch.cat([cur_x,cur_o], dim=1)
564 | stop_hiddens.append( stop_hidden )
565 | stop_contexts.append( create_variable( torch.LongTensor(range(batch_size)) ) )
566 | stop_targets.extend( [0] * len(mol_batch) )
567 |
568 | #Predict next clique
569 | pred_contexts = torch.cat(pred_contexts, dim=0)
570 | pred_hiddens = torch.cat(pred_hiddens, dim=0)
571 | pred_scores = self.aggregate(pred_hiddens, pred_contexts, x_tree_vecs, 'word')
572 | pred_targets = create_variable(torch.LongTensor(pred_targets))
573 | pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
574 | _,preds = torch.max(pred_scores, dim=1)
575 | pred_acc = torch.eq(preds, pred_targets).float()
576 | pred_acc = torch.sum(pred_acc) / pred_targets.nelement()
577 |
578 | #Predict stop
579 | stop_contexts = torch.cat(stop_contexts, dim=0)
580 | stop_hiddens = torch.cat(stop_hiddens, dim=0)
581 | stop_hiddens = F.relu( self.U_i(stop_hiddens) )
582 | stop_scores = self.aggregate(stop_hiddens, stop_contexts, x_tree_vecs, 'stop')
583 | stop_scores = stop_scores.squeeze(-1)
584 | stop_targets = create_variable(torch.Tensor(stop_targets))
585 |
586 | stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
587 | stops = torch.ge(stop_scores, 0).float()
588 | stop_acc = torch.eq(stops, stop_targets).float()
589 | stop_acc = torch.sum(stop_acc) / stop_targets.nelement()
590 |
591 | return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
592 |
593 | def decode(self, x_tree_vecs, prob_decode):
594 | assert x_tree_vecs.size(0) == 1
595 |
596 | stack = []
597 | init_hiddens = create_variable( torch.zeros(1, self.hidden_size) )
598 | zero_pad = create_variable(torch.zeros(1,1,self.hidden_size))
599 | contexts = create_variable( torch.LongTensor(1).zero_() )
600 |
601 | #Root Prediction
602 | root_score = self.aggregate(init_hiddens, contexts, x_tree_vecs, 'word')
603 | _,root_wid = torch.max(root_score, dim=1)
604 | root_wid = root_wid.item()
605 |
606 | root = MolTreeNode(self.vocab.get_smiles(root_wid))
607 | root.wid = root_wid
608 | root.idx = 0
609 | stack.append( (root, self.vocab.get_slots(root.wid)) )
610 |
611 | all_nodes = [root]
612 | h = {}
613 | for step in range(MAX_DECODE_LEN0):
614 | node_x,fa_slot = stack[-1]
615 | cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors ]
616 | if len(cur_h_nei) > 0:
617 | cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
618 | else:
619 | cur_h_nei = zero_pad
620 |
621 | cur_x = create_variable(torch.LongTensor([node_x.wid]))
622 | cur_x = self.embedding(cur_x)
623 |
624 | #Predict stop
625 | cur_h = cur_h_nei.sum(dim=1)
626 | stop_hiddens = torch.cat([cur_x,cur_h], dim=1)
627 | stop_hiddens = F.relu( self.U_i(stop_hiddens) )
628 | stop_score = self.aggregate(stop_hiddens, contexts, x_tree_vecs, 'stop')
629 |
630 | if prob_decode:
631 | backtrack = (torch.bernoulli( torch.sigmoid(stop_score) ).item() == 0)
632 | else:
633 | backtrack = (stop_score.item() < 0)
634 |
635 | if not backtrack: #Forward: Predict next clique
636 | new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
637 | pred_score = self.aggregate(new_h, contexts, x_tree_vecs, 'word')
638 |
639 | if prob_decode:
640 | sort_wid = torch.multinomial(F.softmax(pred_score, dim=1).squeeze(), 5)
641 | else:
642 | _,sort_wid = torch.sort(pred_score, dim=1, descending=True)
643 | sort_wid = sort_wid.data.squeeze()
644 |
645 | next_wid = None
646 | for wid in sort_wid[:5]:
647 | slots = self.vocab.get_slots(wid)
648 | node_y = MolTreeNode(self.vocab.get_smiles(wid))
649 | if have_slots(fa_slot, slots) and can_assemble(node_x, node_y):
650 | next_wid = wid
651 | next_slots = slots
652 | break
653 |
654 | if next_wid is None:
655 | backtrack = True #No more children can be added
656 | else:
657 | node_y = MolTreeNode(self.vocab.get_smiles(next_wid))
658 | node_y.wid = next_wid
659 | node_y.idx = len(all_nodes)
660 | node_y.neighbors.append(node_x)
661 | h[(node_x.idx,node_y.idx)] = new_h[0]
662 | stack.append( (node_y,next_slots) )
663 | all_nodes.append(node_y)
664 |
665 | if backtrack: #Backtrack, use if instead of else
666 | if len(stack) == 1:
667 | break #At root, terminate
668 |
669 | node_fa,_ = stack[-2]
670 | cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ]
671 | if len(cur_h_nei) > 0:
672 | cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
673 | else:
674 | cur_h_nei = zero_pad
675 |
676 | new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
677 | h[(node_x.idx,node_fa.idx)] = new_h[0]
678 | node_fa.neighbors.append(node_x)
679 | stack.pop()
680 |
681 | return root, all_nodes
682 |
683 |
684 |
685 |
686 | class JTNNEncoder(nn.Module):
687 |
688 | def __init__(self, hidden_size, depth, embedding):
689 | super(JTNNEncoder, self).__init__()
690 | self.hidden_size = int(hidden_size)
691 | self.depth = depth
692 |
693 | self.embedding = embedding
694 | self.outputNN = nn.Sequential(
695 | nn.Linear(2 * hidden_size, hidden_size),
696 | nn.ReLU()
697 | )
698 | self.GRU = GraphGRU(hidden_size, hidden_size, depth=depth)
699 |
700 | def forward(self, fnode, fmess, node_graph, mess_graph, scope):
701 | fnode = create_variable(fnode)
702 | fmess = create_variable(fmess)
703 | node_graph = create_variable(node_graph)
704 | mess_graph = create_variable(mess_graph)
705 | messages = create_variable(torch.zeros(mess_graph.size(0), self.hidden_size))
706 |
707 | fnode = self.embedding(fnode)
708 | fmess = index_select_ND(fnode, 0, fmess)
709 | messages = self.GRU(messages, fmess, mess_graph)
710 |
711 | mess_nei = index_select_ND(messages, 0, node_graph)
712 | node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
713 | node_vecs = self.outputNN(node_vecs)
714 |
715 | max_len = max([x for _,x in scope])
716 | batch_vecs = []
717 | for st,le in scope:
718 | cur_vecs = node_vecs[st] #Root is the first node
719 | batch_vecs.append( cur_vecs )
720 |
721 | tree_vecs = torch.stack(batch_vecs, dim=0)
722 | return tree_vecs, messages
723 |
724 | @staticmethod
725 | def tensorize(tree_batch):
726 |
727 | node_batch = []
728 | scope = []
729 | for tree in tree_batch:
730 | scope.append( (len(node_batch), len(tree.nodes)) )
731 | node_batch.extend(tree.nodes)
732 |
733 | return JTNNEncoder.tensorize_nodes(node_batch, scope)
734 |
735 | @staticmethod
736 | def tensorize_nodes(node_batch, scope):
737 |
738 | messages,mess_dict = [None],{}
739 | fnode = []
740 | for x in node_batch:
741 | fnode.append(x.wid)
742 | for y in x.neighbors:
743 | mess_dict[(x.idx,y.idx)] = len(messages)
744 | messages.append( (x,y) )
745 |
746 | node_graph = [[] for i in range(len(node_batch))]
747 | mess_graph = [[] for i in range(len(messages))]
748 | fmess = [0] * len(messages)
749 |
750 | for x,y in messages[1:]:
751 | mid1 = mess_dict[(x.idx,y.idx)]
752 | fmess[mid1] = x.idx
753 | node_graph[y.idx].append(mid1)
754 | for z in y.neighbors:
755 | if z.idx == x.idx: continue
756 | mid2 = mess_dict[(y.idx,z.idx)]
757 | mess_graph[mid2].append(mid1)
758 |
759 | max_len = max([len(t) for t in node_graph] + [1])
760 | for t in node_graph:
761 | pad_len = max_len - len(t)
762 | t.extend([0] * pad_len)
763 |
764 | max_len = max([len(t) for t in mess_graph] + [1])
765 | for t in mess_graph:
766 | pad_len = max_len - len(t)
767 | t.extend([0] * pad_len)
768 |
769 | mess_graph = torch.LongTensor(mess_graph)
770 | node_graph = torch.LongTensor(node_graph)
771 | fmess = torch.LongTensor(fmess)
772 | fnode = torch.LongTensor(fnode)
773 | return (fnode, fmess, node_graph, mess_graph, scope), mess_dict
774 |
775 | class GraphGRU(nn.Module):
776 |
777 | def __init__(self, input_size, hidden_size, depth):
778 | super(GraphGRU, self).__init__()
779 | self.hidden_size = int(hidden_size)
780 | self.input_size = input_size
781 | self.depth = depth
782 |
783 | self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
784 | self.W_r = nn.Linear(input_size, hidden_size, bias=False)
785 | self.U_r = nn.Linear(hidden_size, hidden_size)
786 | self.W_h = nn.Linear(input_size + hidden_size, hidden_size)
787 |
788 | def forward(self, h, x, mess_graph):
789 | mask = torch.ones(h.size(0), 1)
790 | mask[0] = 0 #first vector is padding
791 | mask = create_variable(mask)
792 | for it in range(self.depth):
793 | h_nei = index_select_ND(h, 0, mess_graph)
794 | sum_h = h_nei.sum(dim=1)
795 | z_input = torch.cat([x, sum_h], dim=1)
796 | z = torch.sigmoid(self.W_z(z_input))
797 |
798 | r_1 = self.W_r(x).view(-1, 1, self.hidden_size)
799 | r_2 = self.U_r(h_nei)
800 | r = torch.sigmoid(r_1 + r_2)
801 |
802 | gated_h = r * h_nei
803 | sum_gated_h = gated_h.sum(dim=1)
804 | h_input = torch.cat([x, sum_gated_h], dim=1)
805 | pre_h = torch.tanh(self.W_h(h_input))
806 | h = (1.0 - z) * sum_h + z * pre_h
807 | h = h * mask
808 |
809 | return h
810 |
811 |
812 |
813 |
814 |
815 |
816 |
817 | class JTNNVAE(nn.Module):
818 |
819 | def __init__(self, vocab, hidden_size, latent_size, depthT, depthG):
820 | super(JTNNVAE, self).__init__()
821 | self.vocab = vocab
822 |
823 | self.hidden_size = int(hidden_size)
824 | self.latent_size = latent_size = latent_size / 2 #Tree and Mol has two vectors
825 | self.latent_size=int(self.latent_size)
826 | self.jtnn = JTNNEncoder(int(hidden_size),int(depthT), nn.Embedding(780,450))
827 | self.decoder = JTNNDecoder(vocab, int(hidden_size), int(latent_size), nn.Embedding(780,450))
828 |
829 | self.jtmpn = JTMPN(int(hidden_size), int(depthG))
830 | self.mpn = MPN(int(hidden_size), int(depthG))
831 |
832 | self.A_assm = nn.Linear(int(latent_size), int(hidden_size), bias=False)
833 | self.assm_loss = nn.CrossEntropyLoss(size_average=False)
834 |
835 | self.T_mean = nn.Linear(int(hidden_size), int(latent_size))
836 | self.T_var = nn.Linear(int(hidden_size), int(latent_size))
837 | self.G_mean = nn.Linear(int(hidden_size), int(latent_size))
838 | self.G_var = nn.Linear(int(hidden_size), int(latent_size))
839 |
840 | def encode(self, jtenc_holder, mpn_holder):
841 | tree_vecs, tree_mess = self.jtnn(*jtenc_holder)
842 | mol_vecs = self.mpn(*mpn_holder)
843 | return tree_vecs, tree_mess, mol_vecs
844 |
845 | def encode_latent(self, jtenc_holder, mpn_holder):
846 | tree_vecs, _ = self.jtnn(*jtenc_holder)
847 | mol_vecs = self.mpn(*mpn_holder)
848 | tree_mean = self.T_mean(tree_vecs)
849 | mol_mean = self.G_mean(mol_vecs)
850 | tree_var = -torch.abs(self.T_var(tree_vecs))
851 | mol_var = -torch.abs(self.G_var(mol_vecs))
852 | return torch.cat([tree_mean, mol_mean], dim=1), torch.cat([tree_var, mol_var], dim=1)
853 |
854 | def rsample(self, z_vecs, W_mean, W_var):
855 | batch_size = z_vecs.size(0)
856 | z_mean = W_mean(z_vecs)
857 | z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al.
858 | kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
859 | epsilon = create_variable(torch.randn_like(z_mean))
860 | z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon
861 | return z_vecs, kl_loss
862 |
863 | def sample_prior(self, prob_decode=False):
864 | z_tree = torch.randn(1, self.latent_size)
865 | z_mol = torch.randn(1, self.latent_size)
866 | return self.decode(z_tree, z_mol, prob_decode)
867 |
868 | def forward(self, x_batch, beta):
869 | x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder= x_batch
870 | #ncoding the graph and tree
871 | x_tree_vecs, x_tree_mess, x_mol_vecs = self.encode(x_jtenc_holder, x_mpn_holder)
872 |
873 | z_tree_vecs,tree_kl = self.rsample(x_tree_vecs, self.T_mean, self.T_var)
874 |
875 | z_mol_vecs,mol_kl = self.rsample(x_mol_vecs, self.G_mean, self.G_var)
876 |
877 | kl_div = tree_kl + mol_kl
878 | #Decoding the tree
879 | word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs)
880 | #Decoding the graph and assembling the graph
881 | assm_loss, assm_acc = self.assm(x_batch, x_jtmpn_holder, z_mol_vecs, x_tree_mess)
882 |
883 | return word_loss + topo_loss + assm_loss + beta * kl_div, kl_div.item(), word_acc, topo_acc, assm_acc
884 |
885 | def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, x_tree_mess):
886 | jtmpn_holder,batch_idx = jtmpn_holder
887 | fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
888 | batch_idx = create_variable(batch_idx)
889 |
890 | cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, x_tree_mess)
891 |
892 | x_mol_vecs = x_mol_vecs.index_select(0, batch_idx)
893 | x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear
894 | scores = torch.bmm(
895 | x_mol_vecs.unsqueeze(1),
896 | cand_vecs.unsqueeze(-1)
897 | ).squeeze()
898 |
899 | cnt,tot,acc = 0,0,0
900 | all_loss = []
901 | for i,mol_tree in enumerate(mol_batch):
902 | comp_nodes = [node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf]
903 | cnt += len(comp_nodes)
904 | for node in comp_nodes:
905 | label = node.cands.index(node.label)
906 | ncand = len(node.cands)
907 | cur_score = scores.narrow(0, tot, ncand)
908 | tot += ncand
909 |
910 | if cur_score.data[label] >= cur_score.max().item():
911 | acc += 1
912 |
913 | label = create_variable(torch.LongTensor([label]))
914 | all_loss.append( self.assm_loss(cur_score.view(1,-1), label) )
915 |
916 | all_loss = sum(all_loss) / len(mol_batch)
917 | return all_loss, acc * 1.0 / cnt
918 |
919 | def decode(self, x_tree_vecs, x_mol_vecs, prob_decode):
920 |
921 | assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1
922 |
923 | pred_root,pred_nodes = self.decoder.decode(x_tree_vecs, prob_decode)
924 | if len(pred_nodes) == 0: return None
925 | elif len(pred_nodes) == 1: return pred_root.smiles
926 |
927 | #Mark nid & is_leaf & atommap
928 | for i,node in enumerate(pred_nodes):
929 | node.nid = i + 1
930 | node.is_leaf = (len(node.neighbors) == 1)
931 | if len(node.neighbors) > 1:
932 | set_atommap(node.mol, node.nid)
933 |
934 | scope = [(0, len(pred_nodes))]
935 | jtenc_holder,mess_dict = JTNNEncoder.tensorize_nodes(pred_nodes, scope)
936 | _,tree_mess = self.jtnn(*jtenc_holder)
937 | tree_mess = (tree_mess, mess_dict) #Important: tree_mess is a matrix, mess_dict is a python dict
938 |
939 | x_mol_vecs = self.A_assm(x_mol_vecs).squeeze() #bilinear
940 |
941 | cur_mol = copy_edit_mol(pred_root.mol)
942 | global_amap = [{}] + [{} for node in pred_nodes]
943 | global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
944 |
945 | cur_mol,_ = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=True)
946 | if cur_mol is None:
947 | cur_mol = copy_edit_mol(pred_root.mol)
948 | global_amap = [{}] + [{} for node in pred_nodes]
949 | global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
950 | cur_mol,pre_mol = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=False)
951 | if cur_mol is None: cur_mol = pre_mol
952 |
953 | if cur_mol is None:
954 | return None
955 |
956 |
957 |
958 | cur_mol = cur_mol.GetMol()
959 | set_atommap(cur_mol)
960 | cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
961 | return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None
962 |
963 | def dfs_assemble(self, y_tree_mess, x_mol_vecs, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node, prob_decode, check_aroma):
964 | fa_nid = fa_node.nid if fa_node is not None else -1
965 | prev_nodes = [fa_node] if fa_node is not None else []
966 |
967 | children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
968 | neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
969 | neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
970 | singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
971 | neighbors = singletons + neighbors
972 |
973 | cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid]
974 | cands,aroma_score = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
975 | if len(cands) == 0 or (sum(aroma_score) < 0 and check_aroma):
976 | return None, cur_mol
977 |
978 | cand_smiles,cand_amap = zip(*cands)
979 | aroma_score = torch.Tensor(aroma_score)
980 | cands = [(smiles, all_nodes, cur_node) for smiles in cand_smiles]
981 |
982 | if len(cands) > 1:
983 | jtmpn_holder = JTMPN.tensorize(cands, y_tree_mess[1])
984 | fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
985 | cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess[0])
986 | scores = torch.mv(cand_vecs, x_mol_vecs) + aroma_score
987 | else:
988 | scores = torch.Tensor([1.0])
989 |
990 | if prob_decode:
991 | probs = F.softmax(scores.view(1,-1), dim=1).squeeze() + 1e-7 #prevent prob = 0
992 | cand_idx = torch.multinomial(probs, probs.numel())
993 | else:
994 | _,cand_idx = torch.sort(scores, descending=True)
995 |
996 | backup_mol = Chem.RWMol(cur_mol)
997 | pre_mol = cur_mol
998 | for i in range(cand_idx.numel()):
999 | cur_mol = Chem.RWMol(backup_mol)
1000 | pred_amap = cand_amap[cand_idx[i].item()]
1001 | new_global_amap = copy.deepcopy(global_amap)
1002 |
1003 | for nei_id,ctr_atom,nei_atom in pred_amap:
1004 | if nei_id == fa_nid:
1005 | continue
1006 | new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node.nid][ctr_atom]
1007 |
1008 | cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached
1009 | new_mol = cur_mol.GetMol()
1010 | new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
1011 |
1012 | if new_mol is None: continue
1013 |
1014 | has_error = False
1015 | for nei_node in children:
1016 | if nei_node.is_leaf: continue
1017 | tmp_mol, tmp_mol2 = self.dfs_assemble(y_tree_mess, x_mol_vecs, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node, prob_decode, check_aroma)
1018 | if tmp_mol is None:
1019 | has_error = True
1020 | if i == 0: pre_mol = tmp_mol2
1021 | break
1022 | cur_mol = tmp_mol
1023 |
1024 | if not has_error: return cur_mol, cur_mol
1025 |
1026 | return None, pre_mol
1027 |
1028 | #Reading the input
1029 | vocab = [x.strip("\r\n ") for x in open('train.txt')]
1030 | #Building the vocabulary
1031 | vocab = Vocab(vocab,mol_trees)
1032 |
1033 | #Defining the model
1034 | model = JTNNVAE(vocab, int(450), int(56), int(20), int(3))
1035 |
1036 | print("Model")
1037 | print(model)
1038 |
1039 |
1040 |
1041 |
1042 |
--------------------------------------------------------------------------------
/genmol/JTVAE/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math, random, sys
4 | from optparse import OptionParser
5 | import pickle
6 | import rdkit
7 | import json
8 | import rdkit.Chem as Chem
9 | from scipy.sparse import csr_matrix
10 | from scipy.sparse.csgraph import minimum_spanning_tree
11 | from collections import defaultdict
12 | import copy
13 | import torch.optim as optim
14 | import torch.optim.lr_scheduler as lr_scheduler
15 | from torch.utils.data import Dataset, DataLoader
16 | from torch.autograd import Variable
17 | import numpy as np
18 | from collections import deque
19 | import os, random
20 | import torch.nn.functional as F
21 | import pdb
22 |
23 | benzynes_i = ['C1=CC=CC=C1', 'C1=CC=NC=C1', 'C1=CC=NN=C1', 'C1=CN=CC=N1', 'C1=CN=CN=C1', 'C1=CN=NC=N1', 'C1=CN=NN=C1', 'C1=NC=NC=N1', 'C1=NN=CN=N1']
24 | penzynes_i = ['C1=C[NH]C=C1', 'C1=C[NH]C=N1', 'C1=C[NH]N=C1', 'C1=C[NH]N=N1', 'C1=COC=C1', 'C1=COC=N1', 'C1=CON=C1', 'C1=CSC=C1', 'C1=CSC=N1', 'C1=CSN=C1', 'C1=CSN=N1', 'C1=NN=C[NH]1', 'C1=NN=CO1', 'C1=NN=CS1', 'C1=N[NH]C=N1', 'C1=N[NH]N=C1', 'C1=N[NH]N=N1', 'C1=NN=N[NH]1', 'C1=NN=NS1', 'C1=NOC=N1', 'C1=NON=C1', 'C1=NSC=N1', 'C1=NSN=C1']
25 |
26 |
27 | MST_MAX_WEiGHT_10 = 100
28 | MAX_NCAND_10 = 2000
29 |
30 | def set_atommap(mol, num=0):
31 | for atom in mol.GetAtoms():
32 | atom.SetAtomMapNum(num)
33 |
34 | def get_mol(smiles):
35 | mol = Chem.MolFromSmiles(smiles)
36 | if mol is None:
37 | return None
38 | Chem.Kekulize(mol)
39 | return mol
40 |
41 | def get_smiles(mol):
42 | return Chem.MolToSmiles(mol, kekuleSmiles=True)
43 |
44 | def sanitize(mol):
45 | try:
46 | smiles = get_smiles(mol)
47 | mol = get_mol(smiles)
48 | except Exception as e:
49 | return None
50 | return mol
51 |
52 | def copy_atom(atom):
53 | new_atom = Chem.Atom(atom.GetSymbol())
54 | new_atom.SetFormalCharge(atom.GetFormalCharge())
55 | new_atom.SetAtomMapNum(atom.GetAtomMapNum())
56 | return new_atom
57 |
58 | def copy_edit_mol(mol):
59 | new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
60 | for atom in mol.GetAtoms():
61 | new_atom = copy_atom(atom)
62 | new_mol.AddAtom(new_atom)
63 | for bond in mol.GetBonds():
64 | a1 = bond.GetBeginAtom().GetIdx()
65 | a2 = bond.GetEndAtom().GetIdx()
66 | bt = bond.GetBondType()
67 | new_mol.AddBond(a1, a2, bt)
68 | return new_mol
69 |
70 | def get_clique_mol(mol, atoms):
71 | smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
72 | new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
73 | new_mol = copy_edit_mol(new_mol).GetMol()
74 | new_mol = sanitize(new_mol) #We assume this is not None
75 | return new_mol
76 |
77 | def tree_decomp(mol):
78 |
79 | n_atoms = mol.GetNumAtoms()
80 | if n_atoms == 1: #special case
81 | return [[0]], []
82 |
83 | cliques = []
84 | for bond in mol.GetBonds():
85 | a1 = bond.GetBeginAtom().GetIdx()
86 | a2 = bond.GetEndAtom().GetIdx()
87 | if not bond.IsInRing():
88 | cliques.append([a1,a2])
89 |
90 | ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
91 | cliques.extend(ssr)
92 |
93 | nei_list = [[] for i in range(n_atoms)]
94 | for i in range(len(cliques)):
95 | for atom in cliques[i]:
96 | nei_list[atom].append(i)
97 |
98 | #Merge Rings with intersection > 2 atoms
99 | for i in range(len(cliques)):
100 | if len(cliques[i]) <= 2: continue
101 | for atom in cliques[i]:
102 | for j in nei_list[atom]:
103 | if i >= j or len(cliques[j]) <= 2: continue
104 | inter = set(cliques[i]) & set(cliques[j])
105 | if len(inter) > 2:
106 | cliques[i].extend(cliques[j])
107 | cliques[i] = list(set(cliques[i]))
108 | cliques[j] = []
109 |
110 | cliques = [c for c in cliques if len(c) > 0]
111 | nei_list = [[] for i in range(n_atoms)]
112 | for i in range(len(cliques)):
113 | for atom in cliques[i]:
114 | nei_list[atom].append(i)
115 |
116 | #Build edges and add singleton cliques
117 | edges = defaultdict(int)
118 | for atom in range(n_atoms):
119 | if len(nei_list[atom]) <= 1:
120 | continue
121 | cnei = nei_list[atom]
122 | bonds = [c for c in cnei if len(cliques[c]) == 2]
123 | rings = [c for c in cnei if len(cliques[c]) > 4]
124 | if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): #In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
125 | cliques.append([atom])
126 | c2 = len(cliques) - 1
127 | for c1 in cnei:
128 | edges[(c1,c2)] = 1
129 | elif len(rings) > 2: #Multiple (n>2) complex rings
130 | cliques.append([atom])
131 | c2 = len(cliques) - 1
132 | for c1 in cnei:
133 | edges[(c1,c2)] = MST_MAX_WEiGHT_10 - 1
134 | else:
135 | for i in range(len(cnei)):
136 | for j in range(i + 1, len(cnei)):
137 | c1,c2 = cnei[i],cnei[j]
138 | inter = set(cliques[c1]) & set(cliques[c2])
139 | if edges[(c1,c2)] < len(inter):
140 | edges[(c1,c2)] = len(inter) #cnei[i] < cnei[j] by construction
141 |
142 | edges = [u + (MST_MAX_WEiGHT_10-v,) for u,v in edges.items()]
143 | if len(edges) == 0:
144 | return cliques, edges
145 |
146 | #Compute Maximum Spanning Tree
147 | row,col,data = zip(*edges)
148 | n_clique = len(cliques)
149 | clique_graph = csr_matrix( (data,(row,col)), shape=(n_clique,n_clique) )
150 | junc_tree = minimum_spanning_tree(clique_graph)
151 | row,col = junc_tree.nonzero()
152 | edges = [(row[i],col[i]) for i in range(len(row))]
153 | return (cliques, edges)
154 |
155 |
156 |
157 | def atom_equal(a1, a2):
158 | return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
159 |
160 | #Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
161 | def ring_bond_equal(b1, b2, reverse=False):
162 | b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
163 | if reverse:
164 | b2 = (b2.GetEndAtom(), b2.GetBeginAtom())
165 | else:
166 | b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
167 | return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
168 |
169 | def attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap):
170 | prev_nids = [node.nid for node in prev_nodes]
171 | for nei_node in prev_nodes + neighbors:
172 | nei_id,nei_mol = nei_node.nid,nei_node.mol
173 | amap = nei_amap[nei_id]
174 | for atom in nei_mol.GetAtoms():
175 | if atom.GetIdx() not in amap:
176 | new_atom = copy_atom(atom)
177 | amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
178 |
179 | if nei_mol.GetNumBonds() == 0:
180 | nei_atom = nei_mol.GetAtomWithIdx(0)
181 | ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
182 | ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
183 | else:
184 | for bond in nei_mol.GetBonds():
185 | a1 = amap[bond.GetBeginAtom().GetIdx()]
186 | a2 = amap[bond.GetEndAtom().GetIdx()]
187 | if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
188 | ctr_mol.AddBond(a1, a2, bond.GetBondType())
189 | elif nei_id in prev_nids: #father node overrides
190 | ctr_mol.RemoveBond(a1, a2)
191 | ctr_mol.AddBond(a1, a2, bond.GetBondType())
192 | return ctr_mol
193 |
194 | def local_attach(ctr_mol, neighbors, prev_nodes, amap_list):
195 | ctr_mol = copy_edit_mol(ctr_mol)
196 | nei_amap = {nei.nid:{} for nei in prev_nodes + neighbors}
197 |
198 | for nei_id,ctr_atom,nei_atom in amap_list:
199 | nei_amap[nei_id][nei_atom] = ctr_atom
200 |
201 | ctr_mol = attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap)
202 | return ctr_mol.GetMol()
203 |
204 | #This version records idx mapping between ctr_mol and nei_mol
205 | def enum_attach(ctr_mol, nei_node, amap, singletons):
206 | nei_mol,nei_idx = nei_node.mol,nei_node.nid
207 | att_confs = []
208 | black_list = [atom_idx for nei_id,atom_idx,_ in amap if nei_id in singletons]
209 | ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list]
210 | ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
211 |
212 | if nei_mol.GetNumBonds() == 0: #neighbor singleton
213 | nei_atom = nei_mol.GetAtomWithIdx(0)
214 | used_list = [atom_idx for _,atom_idx,_ in amap]
215 | for atom in ctr_atoms:
216 | if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
217 | new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
218 | att_confs.append( new_amap )
219 |
220 | elif nei_mol.GetNumBonds() == 1: #neighbor is a bond
221 | bond = nei_mol.GetBondWithIdx(0)
222 | bond_val = int(bond.GetBondTypeAsDouble())
223 | b1,b2 = bond.GetBeginAtom(), bond.GetEndAtom()
224 |
225 | for atom in ctr_atoms:
226 | #Optimize if atom is carbon (other atoms may change valence)
227 | if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
228 | continue
229 | if atom_equal(atom, b1):
230 | new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
231 | att_confs.append( new_amap )
232 | elif atom_equal(atom, b2):
233 | new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
234 | att_confs.append( new_amap )
235 | else:
236 | #intersection is an atom
237 | for a1 in ctr_atoms:
238 | for a2 in nei_mol.GetAtoms():
239 | if atom_equal(a1, a2):
240 | #Optimize if atom is carbon (other atoms may change valence)
241 | if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
242 | continue
243 | new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
244 | att_confs.append( new_amap )
245 |
246 | #intersection is an bond
247 | if ctr_mol.GetNumBonds() > 1:
248 | for b1 in ctr_bonds:
249 | for b2 in nei_mol.GetBonds():
250 | if ring_bond_equal(b1, b2):
251 | new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
252 | att_confs.append( new_amap )
253 |
254 | if ring_bond_equal(b1, b2, reverse=True):
255 | new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
256 | att_confs.append( new_amap )
257 | return att_confs
258 |
259 | #Try rings first: Speed-Up
260 | def enum_assemble(node, neighbors, prev_nodes=[], prev_amap=[]):
261 | all_attach_confs = []
262 | singletons = [nei_node.nid for nei_node in neighbors + prev_nodes if nei_node.mol.GetNumAtoms() == 1]
263 |
264 | def search(cur_amap, depth):
265 | if len(all_attach_confs) > MAX_NCAND_10:
266 | return
267 | if depth == len(neighbors):
268 | all_attach_confs.append(cur_amap)
269 | return
270 |
271 | nei_node = neighbors[depth]
272 | cand_amap = enum_attach(node.mol, nei_node, cur_amap, singletons)
273 | cand_smiles = set()
274 | candidates = []
275 | for amap in cand_amap:
276 | cand_mol = local_attach(node.mol, neighbors[:depth+1], prev_nodes, amap)
277 | cand_mol = sanitize(cand_mol)
278 | if cand_mol is None:
279 | continue
280 | smiles = get_smiles(cand_mol)
281 | if smiles in cand_smiles:
282 | continue
283 | cand_smiles.add(smiles)
284 | candidates.append(amap)
285 |
286 | if len(candidates) == 0:
287 | return
288 |
289 | for new_amap in candidates:
290 | search(new_amap, depth + 1)
291 |
292 | search(prev_amap, 0)
293 | cand_smiles = set()
294 | candidates = []
295 | aroma_score = []
296 | for amap in all_attach_confs:
297 | cand_mol = local_attach(node.mol, neighbors, prev_nodes, amap)
298 | cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
299 | smiles = Chem.MolToSmiles(cand_mol)
300 | if smiles in cand_smiles or check_singleton(cand_mol, node, neighbors) == False:
301 | continue
302 | cand_smiles.add(smiles)
303 | candidates.append( (smiles,amap) )
304 | aroma_score.append( check_aroma(cand_mol, node, neighbors) )
305 |
306 | return candidates, aroma_score
307 |
308 | def check_singleton(cand_mol, ctr_node, nei_nodes):
309 | rings = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() > 2]
310 | singletons = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() == 1]
311 | if len(singletons) > 0 or len(rings) == 0: return True
312 |
313 | n_leaf2_atoms = 0
314 | for atom in cand_mol.GetAtoms():
315 | nei_leaf_atoms = [a for a in atom.GetNeighbors() if not a.IsInRing()] #a.GetDegree() == 1]
316 | if len(nei_leaf_atoms) > 1:
317 | n_leaf2_atoms += 1
318 |
319 | return n_leaf2_atoms == 0
320 |
321 | def check_aroma(cand_mol, ctr_node, nei_nodes):
322 | rings = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() >= 3]
323 | if len(rings) < 2: return 0 #Only multi-ring system needs to be checked
324 |
325 | get_nid = lambda x: 0 if x.is_leaf else x.nid
326 | benzynes = [get_nid(node) for node in nei_nodes + [ctr_node] if node.smiles in benzynes_i]
327 | penzynes = [get_nid(node) for node in nei_nodes + [ctr_node] if node.smiles in penzynes_i]
328 | if len(benzynes) + len(penzynes) == 0:
329 | return 0 #No specific aromatic rings
330 |
331 | n_aroma_atoms = 0
332 | for atom in cand_mol.GetAtoms():
333 | if atom.GetAtomMapNum() in benzynes+penzynes and atom.GetIsAromatic():
334 | n_aroma_atoms += 1
335 |
336 | if n_aroma_atoms >= len(benzynes) * 4 + len(penzynes) * 3:
337 | return 1000
338 | else:
339 | return -0.001
340 |
341 | #Only used for debugging purpose
342 | def dfs_assemble(cur_mol, global_amap, fa_amap, cur_node, fa_node):
343 | fa_nid = fa_node.nid if fa_node is not None else -1
344 | prev_nodes = [fa_node] if fa_node is not None else []
345 |
346 | children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
347 | neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
348 | neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
349 | singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
350 | neighbors = singletons + neighbors
351 |
352 | cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid]
353 | cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
354 |
355 | cand_smiles,cand_amap = zip(*cands)
356 | label_idx = cand_smiles.index(cur_node.label)
357 | label_amap = cand_amap[label_idx]
358 |
359 | for nei_id,ctr_atom,nei_atom in label_amap:
360 | if nei_id == fa_nid:
361 | continue
362 | global_amap[nei_id][nei_atom] = global_amap[cur_node.nid][ctr_atom]
363 |
364 | cur_mol = attach_mols(cur_mol, children, [], global_amap) #father is already attached
365 | for nei_node in children:
366 | if not nei_node.is_leaf:
367 | dfs_assemble(cur_mol, global_amap, label_amap, nei_node, cur_node)
368 |
369 |
370 |
371 | class MolTreeNode(object):
372 |
373 | def __init__(self, smiles, clique=[]):
374 | self.smiles = smiles
375 | self.mol = get_mol(self.smiles)
376 |
377 | self.clique = [x for x in clique] #copy
378 | self.neighbors = []
379 |
380 | def add_neighbor(self, nei_node):
381 | self.neighbors.append(nei_node)
382 |
383 | def recover(self, original_mol):
384 | clique = []
385 | clique.extend(self.clique)
386 | if not self.is_leaf:
387 | for cidx in self.clique:
388 | original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)
389 |
390 | for nei_node in self.neighbors:
391 | clique.extend(nei_node.clique)
392 | if nei_node.is_leaf: #Leaf node, no need to mark
393 | continue
394 | for cidx in nei_node.clique:
395 | #allow singleton node override the atom mapping
396 | if cidx not in self.clique or len(nei_node.clique) == 1:
397 | atom = original_mol.GetAtomWithIdx(cidx)
398 | atom.SetAtomMapNum(nei_node.nid)
399 |
400 | clique = list(set(clique))
401 | label_mol = get_clique_mol(original_mol, clique)
402 | self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
403 |
404 | for cidx in clique:
405 | original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
406 |
407 | return self.label
408 |
409 | def assemble(self):
410 | neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
411 | neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
412 | singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
413 | neighbors = singletons + neighbors
414 |
415 | cands,aroma = enum_assemble(self, neighbors)
416 | new_cands = [cand for i,cand in enumerate(cands) if aroma[i] >= 0]
417 | if len(new_cands) > 0: cands = new_cands
418 |
419 | if len(cands) > 0:
420 | self.cands, _ = zip(*cands)
421 | self.cands = list(self.cands)
422 | else:
423 | self.cands = []
424 |
425 | class MolTree(object):
426 |
427 | def __init__(self, smiles):
428 | self.smiles = smiles
429 | self.mol = get_mol(smiles)
430 |
431 |
432 | cliques, edges = tree_decomp(self.mol)
433 | self.nodes = []
434 | root = 0
435 | for i,c in enumerate(cliques):
436 | cmol = get_clique_mol(self.mol, c)
437 | node = MolTreeNode(get_smiles(cmol), c)
438 | self.nodes.append(node)
439 | if min(c) == 0: root = i
440 |
441 | for x,y in edges:
442 | self.nodes[x].add_neighbor(self.nodes[y])
443 | self.nodes[y].add_neighbor(self.nodes[x])
444 |
445 | if root > 0:
446 | self.nodes[0],self.nodes[root] = self.nodes[root],self.nodes[0]
447 |
448 | for i,node in enumerate(self.nodes):
449 | node.nid = i + 1
450 | if len(node.neighbors) > 1:
451 | set_atommap(node.mol, node.nid)
452 | node.is_leaf = (len(node.neighbors) == 1)
453 |
454 | def size(self):
455 | return len(self.nodes)
456 |
457 | def recover(self):
458 | for node in self.nodes:
459 | node.recover(self.mol)
460 |
461 | def assemble(self):
462 | for node in self.nodes:
463 | node.assemble()
464 |
465 |
466 | def tensorize_trees(smiles, assm=True):
467 | mol_tree = MolTree(smiles)
468 | mol_tree.recover()
469 | if assm:
470 | mol_tree.assemble()
471 | for node in mol_tree.nodes:
472 | if node.label not in node.cands:
473 | node.cands.append(node.label)
474 |
475 | del mol_tree.mol
476 | for node in mol_tree.nodes:
477 | del node.mol
478 |
479 | return mol_tree
480 |
481 |
482 |
483 | splits = 4
484 |
485 | with open('train.txt') as f:
486 | data = [line.strip("\r\n ").split()[0] for line in f]
487 |
488 | mol_trees=[]
489 | for i in range(0,len(data)):
490 | #Generating the molecular tree for each molecule and appending them to a list
491 | mol_trees.append(tensorize_trees(data[i]))
492 |
493 | print("Molecular trees")
494 | print(mol_trees)
495 |
496 |
497 | trees_data=[]
498 | l = (len(mol_trees) + splits - 1) / splits
499 | #Making the batches of mol trees
500 | for i in range(splits):
501 | s = i * l
502 | sub_data = mol_trees[int(s) : int(s + l)]
503 | trees_data.append(sub_data)
504 |
--------------------------------------------------------------------------------
/genmol/JTVAE/sample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math, random, sys
4 | from optparse import OptionParser
5 | import pickle
6 | import rdkit
7 | import json
8 | import rdkit.Chem as Chem
9 | from scipy.sparse import csr_matrix
10 | from scipy.sparse.csgraph import minimum_spanning_tree
11 | from collections import defaultdict
12 | import copy
13 | import torch.optim as optim
14 | import torch.optim.lr_scheduler as lr_scheduler
15 | from torch.utils.data import Dataset, DataLoader
16 | from torch.autograd import Variable
17 | import numpy as np
18 | from collections import deque
19 | import os, random
20 | import torch.nn.functional as F
21 | import pdb
22 |
23 | from jvae_model import *
24 |
25 | path = "savedmodel.pth"
26 | model=JTNNVAE(vocab, int(450), int(56), int(20), int(3))
27 | model.load_state_dict(torch.load(path))
28 | torch.manual_seed(0)
29 | print("Molecules generated")
30 | for i in range(10):
31 | print(model.sample_prior())
32 |
33 |
34 |
--------------------------------------------------------------------------------
/genmol/JTVAE/savedmodel.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/JTVAE/savedmodel.pth
--------------------------------------------------------------------------------
/genmol/JTVAE/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math, random, sys
4 | from optparse import OptionParser
5 | import pickle
6 | import rdkit
7 | import json
8 | import rdkit.Chem as Chem
9 | from scipy.sparse import csr_matrix
10 | from scipy.sparse.csgraph import minimum_spanning_tree
11 | from collections import defaultdict
12 | import copy
13 | import torch.optim as optim
14 | import torch.optim.lr_scheduler as lr_scheduler
15 | from torch.utils.data import Dataset, DataLoader
16 | from torch.autograd import Variable
17 | import numpy as np
18 | from collections import deque
19 | import os, random
20 | import torch.nn.functional as F
21 | import pdb
22 |
23 |
24 | from jvae_preprocess import *
25 | from jvae_model import *
26 |
27 |
28 | def set_batch_nodeID(mol_batch, vocab):
29 | tot = 0
30 |
31 | for mol_tree in mol_batch:
32 |
33 | for node in mol_tree.nodes:
34 | node.idx = tot
35 |
36 | s_to_m=Chem.MolFromSmiles(node.smiles)
37 | m_to_s=Chem.MolToSmiles(s_to_m,kekuleSmiles=False)
38 | node.wid = vocab.get_index(m_to_s)
39 |
40 | tot += 1
41 |
42 |
43 | def tensorize_x(tree_batch, vocab,assm=True):
44 | set_batch_nodeID(tree_batch, vocab)
45 | smiles_batch = [tree.smiles for tree in tree_batch]
46 | jtenc_holder,mess_dict = JTNNEncoder.tensorize(tree_batch)
47 | jtenc_holder = jtenc_holder
48 | mpn_holder = MPN.tensorize(smiles_batch)
49 |
50 | if assm is False:
51 | return tree_batch, jtenc_holder, mpn_holder
52 |
53 | cands = []
54 | batch_idx = []
55 | for i,mol_tree in enumerate(tree_batch):
56 | for node in mol_tree.nodes:
57 | if node.is_leaf or len(node.cands) == 1: continue
58 | cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] )
59 | batch_idx.extend([i] * len(node.cands))
60 |
61 | jtmpn_holder = JTMPN.tensorize(cands, mess_dict)
62 | batch_idx = torch.LongTensor(batch_idx)
63 |
64 | return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx)
65 |
66 |
67 | class MolTreeDataset(Dataset):
68 |
69 | def __init__(self, data, vocab, assm=True):
70 | self.data = data
71 | self.vocab = vocab
72 | self.assm = assm
73 |
74 | def __len__(self):
75 | return len(self.data)
76 |
77 | def __getitem__(self, idx):
78 | return tensorize_x(self.data[idx], self.vocab,assm=self.assm)
79 |
80 |
81 |
82 | def get_loader(data_1,vocab):
83 |
84 | for i in range(0,len(data_1)):
85 |
86 | if True:
87 | random.shuffle(data_1[i])
88 |
89 | batches=[]
90 | for j in range(0,len(data_1[i])):
91 | batches.append([])
92 |
93 | for j in range(0,len(data_1[i])):
94 |
95 | batches[j].append(data_1[i][j])
96 |
97 | dataset = MolTreeDataset(batches, vocab,True)
98 |
99 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lambda x:x[0])
100 |
101 | for b in dataloader:
102 | yield b
103 |
104 | del batches, dataset, dataloader
105 |
106 |
107 |
108 |
109 | for param in model.parameters():
110 | if param.dim() == 1:
111 | nn.init.constant_(param, 0)
112 | else:
113 | nn.init.xavier_normal_(param)
114 |
115 |
116 | print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
117 |
118 |
119 |
120 | optimizer = optim.Adam(model.parameters(), lr=1e-3)
121 | scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
122 | scheduler.step()
123 |
124 | param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()]))
125 | grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None]))
126 |
127 | total_step = 0
128 | beta = 0.0
129 | meters = np.zeros(4)
130 | path = "savedmodel.pth"
131 | print("Training")
132 | #Training starts here...
133 | for epoch in range(10):
134 |
135 | #Loading the data
136 | loader=get_loader(trees_data,vocab)
137 |
138 | for batch in loader:
139 | total_step += 1
140 | try:
141 | model.zero_grad()
142 | #Send the batch to the model
143 | loss, kl_div, wacc, tacc, sacc = model(batch, beta)
144 | #Backward propagation
145 | loss.backward()
146 | nn.utils.clip_grad_norm_(model.parameters(),50.0)
147 | optimizer.step()
148 | except Exception as e:
149 | print(e)
150 | continue
151 |
152 | meters = meters + np.array([kl_div, wacc * 100, tacc * 100, sacc * 100])
153 |
154 |
155 | torch.save(model.state_dict(), path)
156 |
157 |
158 | scheduler.step()
159 | #print("learning rate: %.6f" % scheduler.get_lr()[0])
160 |
161 | beta = min(1.0, beta + 0.002)
162 |
163 | print("Epoch :" + str(epoch))
164 |
165 |
--------------------------------------------------------------------------------
/genmol/JTVAE/train.txt:
--------------------------------------------------------------------------------
1 | CCC(NC(=O)c1scnc1C1CC1)C(=O)N1CCOCC1
2 | O=C1OCCC1Sc1nnc(-c2c[nH]c3ccccc23)n1C1CC1
3 | CCN(C)S(=O)(=O)N1CCC(Nc2cccc(OC)c2)CC1
4 | CC(=O)Nc1cccc(NC(C)c2ccccn2)c1
5 | Cc1cc(-c2nc3sc(C4CC4)nn3c2C#N)ccc1Cl
6 | CCOCCCNC(=O)c1cc(OC)ccc1Br
7 | Cc1nc(-c2ccncc2)[nH]c(=O)c1CC(=O)NC1CCCC1
8 | C#CCN(CC#C)C(=O)c1cc2ccccc2cc1OC(F)F
9 | CCOc1ccc(CN2c3ccccc3NCC2C)cc1N
10 | NC(=O)C1CCC(CNc2cc(-c3ccccc3)nc3ccnn23)CC1
11 | Cc1csc(Sc2cc(C)nc3ncnn23)n1
12 | COCCN1CCN(C(=O)CCc2nc(C(C)(C)C)no2)CC1=O
13 | CCN(CC(C)C#N)C(=O)CN1CCN(C(=O)CC(F)(F)F)CC1
14 | Cc1ccc(C(=O)Nc2ccc(S(N)(=O)=O)cc2F)cc1C
15 | O=C(c1cccs1)N1CCc2[nH]nc(-c3cccc(F)c3)c2C1
16 | O=C(C1CC1)N1CCc2ccc(NS(=O)(=O)c3ccccc3)cc21
17 | O=C(CCNC(=O)c1ccccc1)Nc1cc(Cl)cc(Cl)c1
18 | CCN1CC(c2nc3ccccc3n2CC(=O)OC(C)C)CC1=O
19 | CCC(NC(=O)COc1ccccc1O)c1ccc(OC)cc1
20 | C=CCn1c(SCCc2c(C)noc2C)n[nH]c1=O
21 | CNC(=O)c1cc(NC(=O)N2C[C@H]3CCC[NH+]3C[C@H]2C)ccc1F
22 | CCCc1ccc([C@@H]([NH3+])C(OC)OC)cc1
23 | [NH3+]C[C@H](c1cc(Cl)cs1)N1CC[C@@H]2CCCC[C@@H]2C1
24 | CC(C)CN(CC(C)C)C(=O)NCC(C)(C)N1CCOCC1
25 | COc1cc(C(=O)N[C@H]2C[C@H]2c2cccc(Cl)c2)cc(OC)c1OC
26 | COCCNC(=O)c1cccc(N2CC[NH2+][C@@H](c3ccccc3)C2)n1
27 | Cc1ccc(-c2cc(OCC(=O)[O-])nc(NCc3ccc4c(c3)OCO4)n2)cc1
28 | COc1cccc(NC2CC[NH+](C[C@@H](O)CN3C[C@H](C)O[C@@H](C)C3)CC2)c1
29 | CC(C)[C@H](NC(=O)Nc1ccccc1)C(=O)NCc1ccco1
30 | O=C(Nc1ccccn1)N1CCC(n2cc[nH+]c2)CC1
31 | Cc1cc(C)cc([C@H]2OCC[C@H]2C[NH2+]C(C)C)c1
32 | C[C@H]1CSC(NC[C@]2(C)CCCO2)=[NH+]1
33 | Cc1ccc(O)c(Cc2ccccc2Cl)c1
34 | O=C1CCCC2=C1[C@H](c1ccc([N+](=O)[O-])o1)n1nnnc1N2
35 | C=CCS/C(N)=C(C#N)/C(C#N)=C(\N)SCC=C
36 | Cc1sc(=O)n(CCC(=O)N(C)[C@@H]2CCCC[C@H]2S(C)(=O)=O)c1C
37 | COc1cc([N+](=O)[O-])cc(/C=N/Nc2ccccc2C)c1O
38 | CCOC(=O)[C@@H]1CCCN(Cc2nn3c(=O)cc(C)nc3s2)C1
39 | Cc1sc2nc(C[NH+](C(C)C)C3CCCC3)nc(N)c2c1C
40 | O=C(NCCCc1nc2ccccc2s1)N1CCc2ccccc2C1
41 | CCOc1ccc(Br)cc1C(=O)Nc1cccc(OC[C@@H]2CCCO2)c1
42 | CCn1nc(C(=O)N2CCN(c3ccccc3)CC2)c2c1CC[NH+](Cc1cc(C)c(OC)cc1C)C2
43 | O=C(N[C@@H]1[C@@H]2CCO[C@@H]2C12CCC2)c1cnc([C@H]2CCCO2)s1
44 | CCOc1ccc(C(=O)Nc2ccc(F)cc2F)cc1
45 | O=C(C1CCC(F)(F)CC1)N1CCN(c2ncccc2F)CC1
46 | O=C(NCc1ccc(F)c(Cl)c1)[C@@H]1CSCN1C(=O)c1c[nH]c2ccccc12
47 | Cc1cc(CCNC(=O)Nc2ccccc2F)on1
48 | COCCOC(=O)C1=C(N)Oc2c(oc(CO)cc2=O)[C@H]1c1cccc(F)c1
49 | CCOC(=O)C(=O)Nc1cc(C)nn1-c1nc([O-])c2c(n1)CCC2
50 | CN(C(=O)[C@H]1CCCN(c2ncnc3onc(-c4ccc(F)cc4)c23)C1)C1CCCCC1
51 |
--------------------------------------------------------------------------------
/genmol/ORGAN/Data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 |
4 | Data = pd.read_csv('C:/Users/haroon_03/Desktop/smiles.csv')
5 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
6 | chars = set()
7 | for string in Data['SMILES']:
8 | chars.update(string)
9 | train_data = Data[Data['SPLIT'] == 'train']
10 | test_data = Data[Data['SPLIT'] == 'test']
11 | test_scaffold = Data[Data['SPLIT'] == 'test_scaffolds']
12 |
13 | all_syms = sorted(list(chars) + ['', '', '', ''])
14 | vocabulary = all_syms
15 |
16 | c2i = {c: i for i, c in enumerate(all_syms)}
17 | i2c = {i: c for i, c in enumerate(all_syms)}
18 |
19 | train_data = (train_data['SMILES'].squeeze()).astype(str).tolist()
20 | test_scaffold = (test_scaffold['SMILES'].squeeze()).astype(str).tolist()
21 | test_data = (test_data['SMILES'].squeeze()).astype(str).tolist()
22 |
--------------------------------------------------------------------------------
/genmol/ORGAN/Metrics_Reward.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import numpy as np
3 | from scipy.spatial.distance import cosine as cos_distance
4 | from fcd_torch import FCD as FCDMetric
5 | from fcd_torch import calculate_frechet_distance
6 | from RewardMetrics import *
7 | from rdkit import rdBase
8 | import random
9 |
10 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
11 |
12 | def get_all_metrics(test, gen, k=None, n_jobs=1, device=device,
13 | batch_size=512, test_scaffolds=None,
14 | ptest=None, ptest_scaffolds=None,
15 | pool=None, gpu=None):
16 | """
17 | Computes all available metrics between test (scaffold test)
18 | and generated sets of SMILES.
19 | Parameters:
20 | test: list of test SMILES
21 | gen: list of generated SMILES
22 | k: list with values for unique@k. Will calculate number of
23 | unique molecules in the first k molecules. Default [1000, 10000]
24 | n_jobs: number of workers for parallel processing
25 | device: 'cpu' or 'cuda:n', where n is GPU device number
26 | batch_size: batch size for FCD metric
27 | test_scaffolds: list of scaffold test SMILES
28 | Will compute only on the general test set if not specified
29 | ptest: dict with precalculated statistics of the test set
30 | ptest_scaffolds: dict with precalculated statistics
31 | of the scaffold test set
32 | pool: optional multiprocessing pool to use for parallelization
33 | gpu: deprecated, use `device`
34 |
35 | Available metrics:
36 | * %valid
37 | * %unique@k
38 | * Frechet ChemNet Distance (FCD)
39 | * Fragment similarity (Frag)
40 | * Scaffold similarity (Scaf)
41 | * Similarity to nearest neighbour (SNN)
42 | * Internal diversity (IntDiv)
43 | * Internal diversity 2: using square root of mean squared
44 | Tanimoto similarity (IntDiv2)
45 | * %passes filters (Filters)
46 | * Distribution difference for logP, SA, QED, NP, weight
47 | """
48 | if k is None:
49 | k = [1000, 10000]
50 | rdBase
51 | metrics = {}
52 | if gpu is not None:
53 | warnings.warn(
54 | "parameter `gpu` is deprecated. Use `device`",
55 | DeprecationWarning
56 | )
57 | if gpu == -1:
58 | device = 'cpu'
59 | else:
60 | device = 'cuda:{}'.format(gpu)
61 | close_pool = False
62 | if pool is None:
63 | if n_jobs != 1:
64 | pool = Pool(n_jobs)
65 | close_pool = True
66 | else:
67 | pool = 1
68 | metrics['valid'] = fraction_valid(gen)
69 | gen = remove_invalid(gen, canonize=True)
70 | if not isinstance(k, (list, tuple)):
71 | k = [k]
72 | for _k in k:
73 | metrics['unique@{}'.format(_k)] = fraction_unique(gen, _k)
74 |
75 | if ptest is None:
76 | ptest = compute_intermediate_statistics(test, n_jobs=n_jobs,
77 | device=device,
78 | batch_size=batch_size,
79 | pool=pool)
80 | if test_scaffolds is not None and ptest_scaffolds is None:
81 | ptest_scaffolds = compute_intermediate_statistics(
82 | test_scaffolds, n_jobs=n_jobs,
83 | device=device, batch_size=batch_size,
84 | pool=pool
85 | )
86 | mols = mapper(pool)(get_mol, gen)
87 | kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size}
88 | kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size}
89 | metrics['FCD/Test'] = FCDMetric(**kwargs_fcd)(gen=gen, pref=ptest['FCD'])
90 | metrics['SNN/Test'] = SNNMetric(**kwargs)(gen=mols, pref=ptest['SNN'])
91 | metrics['Frag/Test'] = FragMetric(**kwargs)(gen=mols, pref=ptest['Frag'])
92 | metrics['Scaf/Test'] = ScafMetric(**kwargs)(gen=mols, pref=ptest['Scaf'])
93 | if ptest_scaffolds is not None:
94 | metrics['FCD/TestSF'] = FCDMetric(**kwargs_fcd)(
95 | gen=gen, pref=ptest_scaffolds['FCD']
96 | )
97 | metrics['SNN/TestSF'] = SNNMetric(**kwargs)(
98 | gen=mols, pref=ptest_scaffolds['SNN']
99 | )
100 | metrics['Frag/TestSF'] = FragMetric(**kwargs)(
101 | gen=mols, pref=ptest_scaffolds['Frag']
102 | )
103 | metrics['Scaf/TestSF'] = ScafMetric(**kwargs)(
104 | gen=mols, pref=ptest_scaffolds['Scaf']
105 | )
106 |
107 | metrics['IntDiv'] = internal_diversity(mols, pool, device=device)
108 | metrics['IntDiv2'] = internal_diversity(mols, pool, device=device, p=2)
109 | metrics['Filters'] = fraction_passes_filters(mols, pool)
110 |
111 | # Properties
112 | for name, func in [('logP', logP), ('SA', SA),
113 | ('QED', QED), ('NP', NP),
114 | ('weight', weight)]:
115 | metrics[name] = FrechetMetric(func, **kwargs)(gen=mols,
116 | pref=ptest[name])
117 | enable_rdkit_log()
118 | if close_pool:
119 | pool.terminate()
120 | return metrics
121 |
122 |
123 | def compute_intermediate_statistics(smiles, n_jobs=1, device='cpu',
124 | batch_size=512, pool=None):
125 | """
126 | The function precomputes statistics such as mean and variance for FCD, etc.
127 | It is useful to compute the statistics for test and scaffold test sets to
128 | speedup metrics calculation.
129 | """
130 | close_pool = False
131 | if pool is None:
132 | if n_jobs != 1:
133 | pool = Pool(n_jobs)
134 | close_pool = True
135 | else:
136 | pool = 1
137 | statistics = {}
138 | mols = mapper(pool)(get_mol, smiles)
139 | kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size}
140 | kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size}
141 | statistics['FCD'] = FCDMetric(**kwargs_fcd).precalc(smiles)
142 | statistics['SNN'] = SNNMetric(**kwargs).precalc(mols)
143 | statistics['Frag'] = FragMetric(**kwargs).precalc(mols)
144 | statistics['Scaf'] = ScafMetric(**kwargs).precalc(mols)
145 | for name, func in [('logP', logP), ('SA', SA),
146 | ('QED', QED), ('NP', NP),
147 | ('weight', weight)]:
148 | statistics[name] = FrechetMetric(func, **kwargs).precalc(mols)
149 | if close_pool:
150 | pool.terminate()
151 | return statistics
152 |
153 |
154 | def fraction_passes_filters(gen, n_jobs=1):
155 | """
156 | Computes the fraction of molecules that pass filters:
157 | * MCF
158 | * PAINS
159 | * Only allowed atoms ('C','N','S','O','F','Cl','Br','H')
160 | * No charges
161 | """
162 | passes = mapper(n_jobs)(mol_passes_filters, gen)
163 | return np.mean(passes)
164 |
165 |
166 | def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
167 | gen_fps=None, p=1):
168 | """
169 | Computes internal diversity as:
170 | 1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
171 | """
172 | if gen_fps is None:
173 | gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
174 | return 1 - (average_agg_tanimoto(gen_fps, gen_fps,
175 | agg='mean', device=device, p=p)).mean()
176 |
177 |
178 | def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
179 | """
180 | Computes a number of unique molecules
181 | Parameters:
182 | gen: list of SMILES
183 | k: compute unique@k
184 | n_jobs: number of threads for calculation
185 | check_validity: raises ValueError if invalid molecules are present
186 | """
187 | if k is not None:
188 | if len(gen) < k:
189 | warnings.warn(
190 | "Can't compute unique@{}.".format(k) +
191 | "gen contains only {} molecules".format(len(gen))
192 | )
193 | gen = gen[:k]
194 | canonic = set(mapper(n_jobs)(canonic_smiles, gen))
195 | if None in canonic and check_validity:
196 | raise ValueError("Invalid molecule passed to unique@k")
197 | return len(canonic) / len(gen)
198 |
199 |
200 | def fraction_valid(gen, n_jobs=1):
201 | """
202 | Computes a number of valid molecules
203 | Parameters:
204 | gen: list of SMILES
205 | n_jobs: number of threads for calculation
206 | """
207 | gen = mapper(n_jobs)(get_mol, gen)
208 | return 1 - gen.count(None) / len(gen)
209 |
210 |
211 | def remove_invalid(gen, canonize=True, n_jobs=1):
212 | """
213 | Removes invalid molecules from the dataset
214 | """
215 | if not canonize:
216 | mols = mapper(n_jobs)(get_mol, gen)
217 | return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
218 | else:
219 | return [x for x in mapper(n_jobs)(canonic_smiles, gen) if x is not None]
220 |
221 |
222 | class Metric:
223 | def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs):
224 | self.n_jobs = n_jobs
225 | self.device = device
226 | self.batch_size = batch_size
227 | for k, v in kwargs.values():
228 | setattr(self, k, v)
229 |
230 | def __call__(self, ref=None, gen=None, pref=None, pgen=None):
231 | assert (ref is None) != (pref is None), "specify ref xor pref"
232 | assert (gen is None) != (pgen is None), "specify gen xor pgen"
233 | if pref is None:
234 | pref = self.precalc(ref)
235 | if pgen is None:
236 | pgen = self.precalc(gen)
237 | return self.metric(pref, pgen)
238 |
239 | def precalc(self, moleclues):
240 | raise NotImplementedError
241 |
242 | def metric(self, pref, pgen):
243 | raise NotImplementedError
244 |
245 |
246 | class SNNMetric(Metric):
247 | """
248 | Computes average max similarities of gen SMILES to ref SMILES
249 | """
250 |
251 | def __init__(self, fp_type='morgan', **kwargs):
252 | self.fp_type = fp_type
253 | super().__init__(**kwargs)
254 |
255 | def precalc(self, mols):
256 | return {'fps': fingerprints(mols, n_jobs=self.n_jobs, fp_type=self.fp_type)}
257 |
258 | def metric(self, pref, pgen):
259 | return average_agg_tanimoto(pref['fps'], pgen['fps'],
260 | device=self.device)
261 |
262 |
263 | def cos_similarity(ref_counts, gen_counts):
264 | """
265 | Computes cosine similarity between
266 | dictionaries of form {name: count}. Non-present
267 | elements are considered zero:
268 |
269 | sim = / ||r|| / ||g||
270 | """
271 | if len(ref_counts) == 0 or len(gen_counts) == 0:
272 | return np.nan
273 | keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys()))
274 | ref_vec = np.array([ref_counts.get(k, 0) for k in keys])
275 | gen_vec = np.array([gen_counts.get(k, 0) for k in keys])
276 | return 1 - cos_distance(ref_vec, gen_vec)
277 |
278 |
279 | class FragMetric(Metric):
280 | def precalc(self, mols):
281 | return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)}
282 |
283 | def metric(self, pref, pgen):
284 | return cos_similarity(pref['frag'], pgen['frag'])
285 |
286 |
287 | class ScafMetric(Metric):
288 | def precalc(self, mols):
289 | return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)}
290 |
291 | def metric(self, pref, pgen):
292 | return cos_similarity(pref['scaf'], pgen['scaf'])
293 |
294 |
295 | class FrechetMetric(Metric):
296 | def __init__(self, func=None, **kwargs):
297 | self.func = func
298 | super().__init__(**kwargs)
299 |
300 | def precalc(self, mols):
301 | if self.func is not None:
302 | values = mapper(self.n_jobs)(self.func, mols)
303 | else:
304 | values = mols
305 | return {'mu': np.mean(values), 'var': np.var(values)}
306 |
307 | def metric(self, pref, pgen):
308 | return calculate_frechet_distance(
309 | pref['mu'], pref['var'], pgen['mu'], pgen['var']
310 | )
311 |
312 |
313 | class MetricsReward:
314 | supported_metrics = ['fcd', 'snn', 'fragments', 'scaffolds',
315 | 'internal_diversity', 'filters',
316 | 'logp', 'sa', 'qed', 'np', 'weight']
317 |
318 | @staticmethod
319 | def _nan2zero(value):
320 | if value == np.nan:
321 | return 0
322 |
323 | return value
324 |
325 | def __init__(self, n_ref_subsample, n_rollouts, n_jobs, metrics=[]):
326 | assert all([m in MetricsReward.supported_metrics for m in metrics])
327 |
328 | self.n_ref_subsample = n_ref_subsample
329 | self.n_rollouts = n_rollouts
330 | # TODO: profile this. Pool works too slow.
331 | n_jobs = n_jobs if False else 1
332 | self.n_jobs = n_jobs
333 | self.metrics = metrics
334 |
335 | def get_reference_data(self, data):
336 | ref_smiles = remove_invalid(data, canonize=True, n_jobs=self.n_jobs)
337 | ref_mols = mapper(self.n_jobs)(get_mol, ref_smiles)
338 | return ref_smiles, ref_mols
339 |
340 | def _get_metrics(self, ref, ref_mols, rollout):
341 | rollout_mols = mapper(self.n_jobs)(get_mol, rollout)
342 | result = [[0 if m is None else 1] for m in rollout_mols]
343 |
344 | if sum([r[0] for r in result], 0) == 0:
345 | return result
346 |
347 | rollout = remove_invalid(rollout, canonize=True, n_jobs=self.n_jobs)
348 | rollout_mols = mapper(self.n_jobs)(get_mol, rollout)
349 | if len(rollout) < 2:
350 | return result
351 |
352 | if len(self.metrics):
353 | for metric_name in self.metrics:
354 | if metric_name == 'fcd':
355 | m = FCDMetric(n_jobs=self.n_jobs)(ref, rollout)
356 | elif metric_name == 'morgan':
357 | m = SNNMetric(n_jobs=self.n_jobs)(ref_mols, rollout_mols)
358 | elif metric_name == 'fragments':
359 | m = FragMetric(n_jobs=self.n_jobs)(ref_mols, rollout_mols)
360 | elif metric_name == 'scaffolds':
361 | m = ScafMetric(n_jobs=self.n_jobs)(ref_mols, rollout_mols)
362 | elif metric_name == 'internal_diversity':
363 | m = internal_diversity(rollout_mols, n_jobs=self.n_jobs)
364 | elif metric_name == 'filters':
365 | m = fraction_passes_filters(
366 | rollout_mols, n_jobs=self.n_jobs
367 | )
368 | elif metric_name == 'logp':
369 | m = -FrechetMetric(func=logP, n_jobs=self.n_jobs)(
370 | ref_mols, rollout_mols
371 | )
372 | elif metric_name == 'sa':
373 | m = -FrechetMetric(func=SA, n_jobs=self.n_jobs)(
374 | ref_mols, rollout_mols
375 | )
376 | elif metric_name == 'qed':
377 | m = -FrechetMetric(func=QED, n_jobs=self.n_jobs)(
378 | ref_mols, rollout_mols
379 | )
380 | elif metric_name == 'np':
381 | m = -FrechetMetric(func=NP, n_jobs=self.n_jobs)(
382 | ref_mols, rollout_mols
383 | )
384 | elif metric_name == 'weight':
385 | m = -FrechetMetric(func=weight, n_jobs=self.n_jobs)(
386 | ref_mols, rollout_mols
387 | )
388 |
389 | m = MetricsReward._nan2zero(m)
390 | for i in range(len(rollout)):
391 | result[i].append(m)
392 |
393 | return result
394 |
395 | def __call__(self, gen, ref, ref_mols):
396 |
397 | idxs = random.sample(range(len(ref)), self.n_ref_subsample)
398 | ref_subsample = [ref[idx] for idx in idxs]
399 | ref_mols_subsample = [ref_mols[idx] for idx in idxs]
400 |
401 | gen_counter = Counter(gen)
402 | gen_counts = [gen_counter[g] for g in gen]
403 |
404 | n = len(gen) // self.n_rollouts
405 | rollouts = [gen[i::n] for i in range(n)]
406 |
407 | metrics_values = [self._get_metrics(
408 | ref_subsample, ref_mols_subsample, rollout
409 | ) for rollout in rollouts]
410 | metrics_values = map(
411 | lambda rm: [
412 | sum(r, 0) / len(r)
413 | for r in rm
414 | ], metrics_values)
415 | reward_values = sum(zip(*metrics_values), ())
416 | reward_values = [v / c for v, c in zip(reward_values, gen_counts)]
417 |
418 | return reward_values
419 |
--------------------------------------------------------------------------------
/genmol/ORGAN/Model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5 | from Data import *
6 | from Metrics_Reward import MetricsReward
7 |
8 | class Generator(nn.Module):
9 | def __init__(self, embedding_layer, hidden_size, num_layers, dropout):
10 | super(Generator, self).__init__()
11 |
12 | self.embedding_layer = embedding_layer
13 | self.lstm_layer = nn.LSTM(embedding_layer.embedding_dim,
14 | hidden_size, num_layers,
15 | batch_first=True, dropout=dropout)
16 | self.linear_layer = nn.Linear(hidden_size,
17 | embedding_layer.num_embeddings)
18 |
19 | def forward(self, x, lengths, states=None):
20 | x = self.embedding_layer(x)
21 | x = pack_padded_sequence(x, lengths, batch_first=True)
22 | x, states = self.lstm_layer(x, states)
23 | x, _ = pad_packed_sequence(x, batch_first=True)
24 | x = self.linear_layer(x)
25 |
26 | return x, lengths, states
27 |
28 |
29 | class Discriminator(nn.Module):
30 | def __init__(self, desc_embedding_layer, convs, dropout=0):
31 | super(Discriminator, self).__init__()
32 |
33 | self.embedding_layer = desc_embedding_layer
34 | self.conv_layers = nn.ModuleList(
35 | [nn.Conv2d(1, f, kernel_size=(
36 | n, self.embedding_layer.embedding_dim)
37 | ) for f, n in convs])
38 | sum_filters = sum([f for f, _ in convs])
39 | self.highway_layer = nn.Linear(sum_filters, sum_filters)
40 | self.dropout_layer = nn.Dropout(p=dropout)
41 | self.output_layer = nn.Linear(sum_filters, 1)
42 |
43 | def forward(self, x):
44 | x = self.embedding_layer(x)
45 | x = x.unsqueeze(1)
46 | convs = [F.elu(conv_layer(x)).squeeze(3)
47 | for conv_layer in self.conv_layers]
48 | x = [F.max_pool1d(c, c.shape[2]).squeeze(2) for c in convs]
49 | x = torch.cat(x, dim=1)
50 |
51 | h = self.highway_layer(x)
52 | t = torch.sigmoid(h)
53 | x = t * F.elu(h) + (1 - t) * x
54 | x = self.dropout_layer(x)
55 | out = self.output_layer(x)
56 |
57 | return out
58 |
59 |
60 | class ORGAN(nn.Module):
61 | def __init__(self):
62 | super(ORGAN, self).__init__()
63 |
64 | self.metrics_reward = MetricsReward(n_ref_subsample=100, n_rollouts=16, n_jobs=1, metrics=[])
65 |
66 | self.reward_weight = 0.7
67 |
68 | self.convs = [(100, 1), (200, 2), (200, 3),
69 | (200, 4), (200, 5), (100, 6),
70 | (100, 7), (100, 8), (100, 9),
71 | (100, 10)]
72 |
73 | self.embedding_layer = nn.Embedding(
74 | len(vocabulary), embedding_dim=32, padding_idx=c2i[''])
75 |
76 | self.desc_embedding_layer = nn.Embedding(
77 | len(vocabulary), embedding_dim=32, padding_idx=c2i[''])
78 |
79 | self.generator = Generator(self.embedding_layer, hidden_size=512, num_layers=2, dropout=0)
80 |
81 | self.discriminator = Discriminator(self.desc_embedding_layer, self.convs, dropout=0)
82 |
83 | def device(self):
84 | return next(self.parameters()).device
85 |
86 | def generator_forward(self, *args, **kwargs):
87 | return self.generator(*args, **kwargs)
88 |
89 | def discriminator_forward(self, *args, **kwargs):
90 | return self.discriminator(*args, **kwargs)
91 |
92 | def forward(self, *args, **kwargs):
93 | return self.sample(*args, **kwargs)
94 |
95 | def char2id(self, c):
96 | if c not in c2i:
97 | return c2i['']
98 |
99 | return c2i[c]
100 |
101 | def id2char(self, id):
102 | if id not in i2c:
103 | return i2c[14]
104 |
105 | return i2c[id]
106 |
107 | def string2id(self, string, add_bos=False, add_eos=False):
108 | ids = [self.char2id(c) for c in string]
109 |
110 | if add_bos:
111 | ids = [c2i['']] + ids
112 |
113 | if add_eos:
114 | ids = ids + [c2i['']]
115 |
116 | return ids
117 |
118 | def ids2string(self, ids, rem_bos=True, rem_eos=True):
119 | if len(ids) == 0:
120 | return ''
121 | if rem_bos and ids[0] == c2i['']:
122 | ids = ids[1:]
123 | if rem_eos and ids[-1] == c2i['']:
124 | ids = ids[:-1]
125 |
126 | string = ''.join([self.id2char(id) for id in ids])
127 |
128 | return string
129 |
130 | def string2tensor(self, string):
131 | ids = self.string2id(string, add_bos=True, add_eos=True)
132 | tensor = torch.tensor(ids, dtype=torch.long, device=device)
133 |
134 | return tensor
135 |
136 | def tensor2string(self, tensor):
137 | ids = tensor.tolist()
138 | string = self.ids2string(ids, rem_bos=True, rem_eos=True)
139 |
140 | return string
141 |
142 | def sample_tensor(self, n, max_length=100):
143 | prevs = torch.empty(n, 1,
144 | dtype=torch.long).fill_(c2i[''])
145 | samples, lengths = self._proceed_sequences(prevs, None, max_length)
146 |
147 | samples = torch.cat([prevs, samples], dim=-1)
148 | lengths += 1
149 |
150 | return samples, lengths
151 |
152 | def sample(self, batch_n=64, max_length=100):
153 | samples, lengths = self.sample_tensor(batch_n, max_length)
154 | samples = [t[:l] for t, l in zip(samples, lengths)]
155 |
156 | return [self.tensor2string(t) for t in samples]
157 |
158 | def _proceed_sequences(self, prevs, states, max_length):
159 | with torch.no_grad():
160 | n_sequences = prevs.shape[0]
161 |
162 | sequences = []
163 | lengths = torch.zeros(n_sequences,
164 | dtype=torch.long, device=device)
165 |
166 | one_lens = torch.ones(n_sequences,
167 | dtype=torch.long, device=device)
168 | is_end = prevs.eq(c2i['']).view(-1)
169 |
170 | for _ in range(max_length):
171 | outputs, _, states = self.generator(prevs, one_lens, states)
172 | probs = F.softmax(outputs, dim=-1).view(n_sequences, -1)
173 | currents = torch.multinomial(probs, 1)
174 |
175 | currents[is_end, :] = c2i['']
176 | sequences.append(currents)
177 | lengths[~is_end] += 1
178 |
179 | is_end[currents.view(-1) == c2i['']] = 1
180 | if is_end.sum() == n_sequences:
181 | break
182 |
183 | prevs = currents
184 |
185 | sequences = torch.cat(sequences, dim=-1)
186 |
187 | return sequences, lengths
188 |
189 | def rollout(self, ref_smiles, ref_mols, n_samples, n_rollouts, max_length=100):
190 | with torch.no_grad():
191 | sequences = []
192 | rewards = []
193 | ref_smiles = ref_smiles
194 | ref_mols = ref_mols
195 | lengths = torch.zeros(n_samples, dtype=torch.long, device=device)
196 |
197 | one_lens = torch.ones(n_samples, dtype=torch.long, )
198 | prevs = torch.empty(n_samples, 1, dtype=torch.long, device=device).fill_(c2i[''])
199 | is_end = torch.zeros(n_samples, dtype=torch.uint8, device=device)
200 | states = None
201 |
202 | sequences.append(prevs)
203 | lengths += 1
204 |
205 | for current_len in range(10):
206 | print(current_len)
207 | outputs, _, states = self.generator(prevs, one_lens, states)
208 |
209 | probs = F.softmax(outputs, dim=-1).view(n_samples, -1)
210 | currents = torch.multinomial(probs, 1)
211 |
212 | currents[is_end, :] = c2i['']
213 | sequences.append(currents)
214 | lengths[~is_end] += 1
215 |
216 | rollout_prevs = currents[~is_end, :].repeat(n_rollouts, 1)
217 | rollout_states = (
218 | states[0][:, ~is_end, :].repeat(1, n_rollouts, 1),
219 | states[1][:, ~is_end, :].repeat(1, n_rollouts, 1)
220 | )
221 | rollout_sequences, rollout_lengths = self._proceed_sequences(
222 | rollout_prevs, rollout_states, max_length - current_len
223 | )
224 |
225 | rollout_sequences = torch.cat(
226 | [s[~is_end, :].repeat(n_rollouts, 1) for s in sequences] + [rollout_sequences], dim=-1)
227 | rollout_lengths += lengths[~is_end].repeat(n_rollouts)
228 |
229 | rollout_rewards = torch.sigmoid(
230 | self.discriminator(rollout_sequences).detach()
231 | )
232 |
233 | if self.metrics_reward is not None and self.reward_weight > 0:
234 | strings = [
235 | self.tensor2string(t[:l])
236 | for t, l in zip(rollout_sequences, rollout_lengths)
237 | ]
238 |
239 | obj_rewards = torch.tensor(
240 | self.metrics_reward(strings, ref_smiles, ref_mols)).view(-1, 1)
241 | rollout_rewards = (rollout_rewards * (1 - self.reward_weight) +
242 | obj_rewards * self.reward_weight
243 | )
244 | print('Metrics Rewards = ', obj_rewards)
245 | current_rewards = torch.zeros(n_samples, device=device)
246 |
247 | current_rewards[~is_end] = rollout_rewards.view(
248 | n_rollouts, -1
249 | ).mean(dim=0)
250 | rewards.append(current_rewards.view(-1, 1))
251 |
252 | is_end[currents.view(-1) == c2i['']] = 1
253 | if is_end.sum() >= 10:
254 | break
255 | prevs = currents
256 |
257 | sequences = torch.cat(sequences, dim=1)
258 | rewards = torch.cat(rewards, dim=1)
259 |
260 | return sequences, rewards, lengths
261 |
--------------------------------------------------------------------------------
/genmol/ORGAN/NP_Score/README:
--------------------------------------------------------------------------------
1 | RDKit-based implementation of the method described in:
2 |
3 | Natural Product-likeness Score and Its Application for Prioritization of Compound Libraries
4 | Peter Ertl, Silvio Roggo, and Ansgar Schuffenhauer
5 | Journal of Chemical Information and Modeling, 48, 68-74 (2008)
6 | http://pubs.acs.org/doi/abs/10.1021/ci700286x
7 |
8 | Contribution from Peter Ertl
9 |
10 |
--------------------------------------------------------------------------------
/genmol/ORGAN/NP_Score/__pycache__/npscorer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/ORGAN/NP_Score/__pycache__/npscorer.cpython-37.pyc
--------------------------------------------------------------------------------
/genmol/ORGAN/NP_Score/npscorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of natural product-likeness as described in:
3 | #
4 | # Natural Product-likeness Score and Its Application for Prioritization of
5 | # Compound Libraries
6 | # Peter Ertl, Silvio Roggo, and Ansgar Schuffenhauer
7 | # Journal of Chemical Information and Modeling, 48, 68-74 (2008)
8 | # http://pubs.acs.org/doi/abs/10.1021/ci700286x
9 | #
10 | # for the training of this model only openly available data have been used
11 | # ~50,000 natural products collected from various open databases
12 | # ~1 million drug-like molecules from ZINC as a "non-NP background"
13 | #
14 | # peter ertl, august 2015
15 | #
16 |
17 | from __future__ import print_function
18 | from rdkit import Chem
19 | from rdkit.Chem import rdMolDescriptors
20 | import sys
21 | import math
22 | import gzip
23 | import pickle
24 | import os.path
25 | from collections import namedtuple
26 |
27 |
28 | _fscores = None
29 |
30 |
31 | def readNPModel(filename=os.path.join(os.path.dirname(__file__),'publicnp.model.gz')):
32 | """Reads and returns the scoring model,
33 | which has to be passed to the scoring functions."""
34 | global _fscores
35 | _fscores = pickle.load(gzip.open(filename))
36 | return _fscores
37 |
38 |
39 | def scoreMolWConfidence(mol, fscore):
40 | """Next to the NP Likeness Score, this function outputs a confidence value
41 | between 0..1 that descibes how many fragments of the tested molecule
42 | were found in the model data set (1: all fragments were found).
43 |
44 | Returns namedtuple NPLikeness(nplikeness, confidence)"""
45 |
46 | if mol is None:
47 | raise ValueError('invalid molecule')
48 | fp = rdMolDescriptors.GetMorganFingerprint(mol, 2)
49 | bits = fp.GetNonzeroElements()
50 |
51 | # calculating the score
52 | score = 0.0
53 | bits_found = 0
54 | for bit in bits:
55 | if bit in fscore:
56 | bits_found += 1
57 | score += fscore[bit]
58 |
59 | score /= float(mol.GetNumAtoms())
60 | confidence = float(bits_found / len(bits))
61 |
62 | # preventing score explosion for exotic molecules
63 | if score > 4:
64 | score = 4. + math.log10(score - 4. + 1.)
65 | elif score < -4:
66 | score = -4. - math.log10(-4. - score + 1.)
67 | NPLikeness = namedtuple("NPLikeness", "nplikeness,confidence")
68 | return NPLikeness(score, confidence)
69 |
70 |
71 | def scoreMol(mol, fscore=None):
72 | """Calculates the Natural Product Likeness of a molecule.
73 |
74 | Returns the score as float in the range -5..5."""
75 | if _fscores is None:
76 | readNPModel()
77 | fscore = fscore or _fscores
78 | return scoreMolWConfidence(mol, fscore).nplikeness
79 |
80 |
81 | def processMols(fscore, suppl):
82 | print("calculating ...", file=sys.stderr)
83 | n = 0
84 | for i, m in enumerate(suppl):
85 | if m is None:
86 | continue
87 |
88 | n += 1
89 | score = "%.3f" % scoreMol(m, fscore)
90 |
91 | smiles = Chem.MolToSmiles(m, True)
92 | name = m.GetProp('_Name')
93 | print(smiles + "\t" + name + "\t" + score)
94 |
95 | print("finished, " + str(n) + " molecules processed", file=sys.stderr)
96 |
97 |
98 | if __name__ == '__main__':
99 | fscore = readNPModel() # fills fscore
100 |
101 | suppl = Chem.SmilesMolSupplier(
102 | sys.argv[1], smilesColumn=0, nameColumn=1, titleLine=False
103 | )
104 | processMols(fscore, suppl)
105 |
106 | #
107 | # Copyright (c) 2015, Novartis Institutes for BioMedical Research Inc.
108 | # All rights reserved.
109 | #
110 | # Redistribution and use in source and binary forms, with or without
111 | # modification, are permitted provided that the following conditions are
112 | # met:
113 | #
114 | # * Redistributions of source code must retain the above copyright
115 | # notice, this list of conditions and the following disclaimer.
116 | # * Redistributions in binary form must reproduce the above
117 | # copyright notice, this list of conditions and the following
118 | # disclaimer in the documentation and/or other materials provided
119 | # with the distribution.
120 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
121 | # nor the names of its contributors may be used to endorse or promote
122 | # products derived from this software without specific prior written
123 | # permission.
124 | #
125 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
126 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
127 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
128 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
129 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
130 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
131 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
132 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
133 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
134 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
135 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
136 | #
137 |
--------------------------------------------------------------------------------
/genmol/ORGAN/NP_Score/publicnp.model.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/ORGAN/NP_Score/publicnp.model.gz
--------------------------------------------------------------------------------
/genmol/ORGAN/RewardMetrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import Counter
3 | from functools import partial
4 | import numpy as np
5 | import pandas as pd
6 | import scipy.sparse
7 | import torch
8 | from rdkit import Chem
9 | from rdkit.Chem import AllChem
10 | from rdkit.Chem import MACCSkeys
11 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan
12 | from rdkit.Chem.QED import qed
13 | from rdkit.Chem.Scaffolds import MurckoScaffold
14 | from rdkit.Chem import Descriptors
15 | from multiprocessing import Pool
16 | from SA_Score import sascorer
17 |
18 | from NP_Score import npscorer
19 | _base_dir = os.path.split(__file__)[0]
20 | _mcf = pd.read_csv(os.path.join(_base_dir, 'mcf.csv'))
21 | _pains = pd.read_csv(os.path.join(_base_dir, 'wehi_pains.csv'),
22 | names=['smarts', 'names'])
23 | _filters = [Chem.MolFromSmarts(x) for x in
24 | _mcf.append(_pains, sort=True)['smarts'].values]
25 |
26 | def mapper(n_jobs):
27 | # n_jobs = 8
28 | '''
29 | Returns function for map call.
30 | If n_jobs == 1, will use standard map
31 | If n_jobs > 1, will use multiprocessing pool
32 | If n_jobs is a pool object, will return its map function
33 | '''
34 | if n_jobs == 1:
35 | def _mapper(*args, **kwargs):
36 | return list(map(*args, **kwargs))
37 |
38 | return _mapper
39 | elif isinstance(n_jobs, int):
40 | pool = Pool(n_jobs)
41 |
42 | def _mapper(*args, **kwargs):
43 | try:
44 | result = pool.map(*args, **kwargs)
45 | finally:
46 | pool.terminate()
47 | return result
48 |
49 | return _mapper
50 | else:
51 | return n_jobs.map
52 |
53 |
54 | def get_mol(smiles_or_mol):
55 | ''''
56 | Loads SMILES/molecule into RDKit's object
57 | '''
58 | if isinstance(smiles_or_mol, str):
59 | if len(smiles_or_mol) == 0:
60 | return None
61 | mol = Chem.MolFromSmiles(smiles_or_mol)
62 | if mol is None:
63 | return None
64 | try:
65 | Chem.SanitizeMol(mol)
66 | except ValueError:
67 | return None
68 | return mol
69 | else:
70 | return smiles_or_mol
71 |
72 |
73 | def canonic_smiles(smiles_or_mol):
74 | mol = get_mol(smiles_or_mol)
75 | if mol is None:
76 | return None
77 | return Chem.MolToSmiles(mol)
78 |
79 |
80 | def logP(mol):
81 | """
82 | Computes RDKit's logP
83 | """
84 | return Chem.Crippen.MolLogP(mol)
85 |
86 |
87 | def SA(mol):
88 | """
89 | Computes RDKit's Synthetic Accessibility score
90 | """
91 | return sascorer.calculateScore(mol)
92 |
93 |
94 | def NP(mol):
95 | """
96 | Computes RDKit's Natural Product-likeness score
97 | """
98 | return npscorer.scoreMol(mol)
99 |
100 |
101 | def QED(mol):
102 | """
103 | Computes RDKit's QED score
104 | """
105 | return qed(mol)
106 |
107 |
108 | def weight(mol):
109 | """
110 | Computes molecular weight for given molecule.
111 | Returns float,
112 | """
113 | return Descriptors.MolWt(mol)
114 |
115 |
116 | def get_n_rings(mol):
117 | """
118 | Computes the number of rings in a molecule
119 | """
120 | return mol.GetRingInfo().NumRings()
121 |
122 |
123 | def fragmenter(mol):
124 | """
125 | fragment mol using BRICS and return smiles list
126 | """
127 | fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol))
128 | fgs_smi = Chem.MolToSmiles(fgs).split(".")
129 | return fgs_smi
130 |
131 |
132 | def compute_fragments(mol_list, n_jobs=1):
133 | """
134 | fragment list of mols using BRICS and return smiles list
135 | """
136 | fragments = Counter()
137 | for mol_frag in mapper(n_jobs)(fragmenter, mol_list):
138 | fragments.update(mol_frag)
139 | return fragments
140 |
141 |
142 | def compute_scaffolds(mol_list, n_jobs=1, min_rings=2):
143 | """
144 | Extracts a scafold from a molecule in a form of a canonic SMILES
145 | """
146 | scaffolds = Counter()
147 | map_ = mapper(n_jobs)
148 | scaffolds = Counter(
149 | map_(partial(compute_scaffold, min_rings=min_rings), mol_list))
150 | if None in scaffolds:
151 | scaffolds.pop(None)
152 | return scaffolds
153 |
154 |
155 | def compute_scaffold(mol, min_rings=2):
156 | mol = get_mol(mol)
157 | scaffold = MurckoScaffold.GetScaffoldForMol(mol)
158 | n_rings = get_n_rings(scaffold)
159 | scaffold_smiles = Chem.MolToSmiles(scaffold)
160 | if scaffold_smiles == '' or n_rings < min_rings:
161 | return None
162 | else:
163 | return scaffold_smiles
164 |
165 |
166 | def average_agg_tanimoto(stock_vecs, gen_vecs,
167 | batch_size=5000, agg='max',
168 | device='cpu', p=1):
169 | """
170 | For each molecule in gen_vecs finds closest molecule in stock_vecs.
171 | Returns average tanimoto score for between these molecules
172 |
173 | Parameters:
174 | stock_vecs: numpy array
175 | gen_vecs: numpy array
176 | agg: max or mean
177 | p: power for averaging: (mean x^p)^(1/p)
178 | """
179 | assert agg in ['max', 'mean'], "Can aggregate only max or mean"
180 | agg_tanimoto = np.zeros(len(gen_vecs))
181 | total = np.zeros(len(gen_vecs))
182 | for j in range(0, stock_vecs.shape[0], batch_size):
183 | x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
184 | for i in range(0, gen_vecs.shape[0], batch_size):
185 | y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
186 | y_gen = y_gen.transpose(0, 1)
187 | tp = torch.mm(x_stock, y_gen)
188 | jac = (tp / (x_stock.sum(1, keepdim=True) +
189 | y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
190 | jac[np.isnan(jac)] = 1
191 | if p != 1:
192 | jac = jac ** p
193 | if agg == 'max':
194 | agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
195 | agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
196 | elif agg == 'mean':
197 | agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
198 | total[i:i + y_gen.shape[1]] += jac.shape[0]
199 | if agg == 'mean':
200 | agg_tanimoto /= total
201 | if p != 1:
202 | agg_tanimoto = (agg_tanimoto) ** (1 / p)
203 | return np.mean(agg_tanimoto)
204 |
205 |
206 | def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2,
207 | morgan__n=1024, *args, **kwargs):
208 | """
209 | Generates fingerprint for SMILES
210 | If smiles is invalid, returns None
211 | Returns numpy array of fingerprint bits
212 |
213 | Parameters:
214 | smiles: SMILES string
215 | type: type of fingerprint: [MACCS|morgan]
216 | dtype: if not None, specifies the dtype of returned array
217 | """
218 | fp_type = fp_type.lower()
219 | molecule = get_mol(smiles_or_mol, *args, **kwargs)
220 | if molecule is None:
221 | return None
222 | if fp_type == 'maccs':
223 | keys = MACCSkeys.GenMACCSKeys(molecule)
224 | keys = np.array(keys.GetOnBits())
225 | fingerprint = np.zeros(166, dtype='uint8')
226 | if len(keys) != 0:
227 | fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero
228 | elif fp_type == 'morgan':
229 | fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n),
230 | dtype='uint8')
231 | else:
232 | raise ValueError("Unknown fingerprint type {}".format(fp_type))
233 | if dtype is not None:
234 | fingerprint = fingerprint.astype(dtype)
235 | return fingerprint
236 |
237 |
238 | def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args,
239 |
240 | **kwargs):
241 | '''
242 | Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers
243 | e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10)
244 | Inserts np.NaN to rows corresponding to incorrect smiles.
245 | IMPORTANT: if there is at least one np.NaN, the dtype would be float
246 | Parameters:
247 | smiles_mols_array: list/array/pd.Series of smiles or already computed
248 | RDKit molecules
249 | n_jobs: number of parralel workers to execute
250 | already_unique: flag for performance reasons, if smiles array is big
251 | and already unique. Its value is set to True if smiles_mols_array
252 | contain RDKit molecules already.
253 | '''
254 | if isinstance(smiles_mols_array, pd.Series):
255 | smiles_mols_array = smiles_mols_array.values
256 | else:
257 | smiles_mols_array = np.asarray(smiles_mols_array)
258 | if not isinstance(smiles_mols_array[0], str):
259 | already_unique = True
260 |
261 | if not already_unique:
262 | smiles_mols_array, inv_index = np.unique(smiles_mols_array,
263 | return_inverse=True)
264 |
265 | fps = mapper(n_jobs)(
266 | partial(fingerprint, *args, **kwargs), smiles_mols_array
267 | )
268 |
269 | length = 1
270 | for fp in fps:
271 | if fp is not None:
272 | length = fp.shape[-1]
273 | first_fp = fp
274 | break
275 | fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :]
276 | for fp in fps]
277 | if scipy.sparse.issparse(first_fp):
278 | fps = scipy.sparse.vstack(fps).tocsr()
279 | else:
280 | fps = np.vstack(fps)
281 | if not already_unique:
282 | return fps[inv_index]
283 | else:
284 | return fps
285 |
286 |
287 | def mol_passes_filters(mol,
288 | allowed=None,
289 | isomericSmiles=False):
290 | """
291 | Checks if mol
292 | * passes MCF and PAINS filters,
293 | * has only allowed atoms
294 | * is not charged
295 | """
296 | allowed = allowed or {'C', 'N', 'S', 'O', 'F', 'Cl', 'Br', 'H'}
297 | mol = get_mol(mol)
298 | if mol is None:
299 | return False
300 | ring_info = mol.GetRingInfo()
301 | if ring_info.NumRings() != 0 and any(
302 | len(x) >= 8 for x in ring_info.AtomRings()
303 | ):
304 | return False
305 | h_mol = Chem.AddHs(mol)
306 | if any(atom.GetFormalCharge() != 0 for atom in mol.GetAtoms()):
307 | return False
308 | if any(atom.GetSymbol() not in allowed for atom in mol.GetAtoms()):
309 | return False
310 | if any(h_mol.HasSubstructMatch(smarts) for smarts in _filters):
311 | return False
312 | smiles = Chem.MolToSmiles(mol, isomericSmiles=isomericSmiles)
313 | if smiles is None or len(smiles) == 0:
314 | return False
315 | if Chem.MolFromSmiles(smiles) is None:
316 | return False
317 | return True
318 |
--------------------------------------------------------------------------------
/genmol/ORGAN/Run.py:
--------------------------------------------------------------------------------
1 | import tqdm as tqdm
2 | from Metrics_Reward import *
3 | from Data import *
4 | from Trainer import fit
5 | from Model import ORGAN
6 |
7 |
8 | def sampler(model):
9 | n_samples = 100
10 | samples = []
11 | with tqdm(total=n_samples, desc='Generating Samples')as T:
12 | while n_samples > 0:
13 | current_samples = model.sample(min(n_samples, 64), max_length=100)
14 | samples.extend(current_samples)
15 | n_samples -= len(current_samples)
16 | T.update(len(current_samples))
17 |
18 | return samples
19 |
20 |
21 | def evaluate(test, samples, test_scaffolds=None, ptest=None, ptest_scaffolds=None):
22 | gen = samples
23 | metrics = get_all_metrics(test, gen, k=[1000, 1000], n_jobs=1,
24 | device=device,
25 | test_scaffolds=test_scaffolds,
26 | ptest=ptest, ptest_scaffolds=ptest_scaffolds)
27 | for name, value in metrics.items():
28 | print('{}, {}'.format(name, value))
29 |
30 |
31 | model = ORGAN()
32 | fit(model, train_data)
33 | samples = sampler(model)
34 | evaluate(test_data, samples, test_scaffold)
35 |
--------------------------------------------------------------------------------
/genmol/ORGAN/SA_Score/README:
--------------------------------------------------------------------------------
1 | RDKit-based implementation of the method described in:
2 |
3 | Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
4 | Peter Ertl and Ansgar Schuffenhauer
5 | Journal of Cheminformatics 1:8 (2009)
6 | http://www.jcheminf.com/content/1/1/8
7 |
8 | Contribution from Peter Ertl and Greg Landrum
9 |
10 |
--------------------------------------------------------------------------------
/genmol/ORGAN/SA_Score/UnitTestSAScore.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import unittest
4 |
5 | import sascorer
6 | from rdkit import Chem
7 |
8 | print(sascorer.__file__)
9 |
10 |
11 | class TestCase(unittest.TestCase):
12 |
13 | def test1(self):
14 | with open('data/zim.100.txt') as f:
15 | testData = [x.strip().split('\t') for x in f]
16 | testData.pop(0)
17 | for row in testData:
18 | smi = row[0]
19 | m = Chem.MolFromSmiles(smi)
20 | tgt = float(row[2])
21 | val = sascorer.calculateScore(m)
22 | self.assertAlmostEqual(tgt, val, 3)
23 |
24 |
25 | if __name__ == '__main__':
26 | import sys
27 | import getopt
28 | import re
29 |
30 | doLong = 0
31 | if len(sys.argv) > 1:
32 | args, extras = getopt.getopt(sys.argv[1:], 'l')
33 | for arg, val in args:
34 | if arg == '-l':
35 | doLong = 1
36 | sys.argv.remove('-l')
37 | if doLong:
38 | for methName in dir(TestCase):
39 | if re.match('_test', methName):
40 | newName = re.sub('_test', 'test', methName)
41 | exec('TestCase.%s = TestCase.%s' % (newName, methName))
42 |
43 | unittest.main()
44 |
--------------------------------------------------------------------------------
/genmol/ORGAN/SA_Score/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/ORGAN/SA_Score/__init__.py
--------------------------------------------------------------------------------
/genmol/ORGAN/SA_Score/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/ORGAN/SA_Score/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/genmol/ORGAN/SA_Score/__pycache__/sascorer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/ORGAN/SA_Score/__pycache__/sascorer.cpython-37.pyc
--------------------------------------------------------------------------------
/genmol/ORGAN/SA_Score/data/zim.100.txt:
--------------------------------------------------------------------------------
1 | smiles Name sa_score
2 | Cc1c(C(=O)NCCO)[n+](=O)c2ccccc2n1[O-] ZINC21984717 3.166
3 | Cn1cc(NC=O)cc1C(=O)Nc1cc(C(=O)Nc2cc(C(=O)NCCC(N)=[NH2+])n(C)c2)n(C)c1 ZINC03872327 3.328
4 | OC(c1ccncc1)c1ccc(OCC[NH+]2CCCC2)cc1 ZINC34421620 3.822
5 | CC(C(=O)[O-])c1ccc(-c2ccccc2)cc1 ZINC00000361 2.462
6 | C[NH+](C)CC(O)Cn1c2ccc(Br)cc2c2cc(Br)ccc21 ZINC00626529 3.577
7 | NC(=[NH2+])NCC1COc2ccccc2O1 ZINC00000357 3.290
8 | CCC(C)(C)[NH2+]CC(O)COc1ccccc1C#N ZINC04214111 3.698
9 | C[NH+](C)CC(O)Cn1c2ccc(Br)cc2c2cc(Br)ccc21 ZINC00626528 3.577
10 | CC12CCC3C(CCC4CC(=O)CCC43C)C1CCC2=O ZINC04081985 3.912
11 | COc1ccc(OC(=O)N(CC(=O)[O-])Cc2ccc(OCCc3nc(-c4ccccc4)oc3C)cc2)cc1 ZINC03935839 2.644
12 | COc1ccccc1OC(=O)c1ccccc1 ZINC00000349 1.342
13 | CC(C)CC[NH2+]CC1COc2ccccc2O1 ZINC04214115 3.701
14 | CN1CCN(C(=O)OC2c3nccnc3C(=O)N2c2ccc(Cl)cn2)CC1 ZINC19632834 3.196
15 | CCC1(c2ccccc2)C(=O)N(COC)C(=O)N(COC)C1=O ZINC02986592 2.759
16 | Nc1ccc(S(=O)(=O)Nc2ccccc2)cc1 ZINC00141883 1.529
17 | O=C([O-])CCCNC(=O)NC1CCCCC1 ZINC08754389 2.493
18 | CCC(C)C(C(=O)OC1CC[N+](C)(C)CC1)c1ccccc1 ZINC00000595 3.399
19 | CCC(C)SSc1ncc[nH]1 ZINC13209429 3.983
20 | CC[N+](C)(CC)CCOC(=O)C(O)(c1cccs1)C1CCCC1 ZINC01690860 3.471
21 | CC12CCC3C(CCC4CC(=O)CCC43C)C1CCC2O ZINC03814360 3.994
22 | CC12CCC3C4CCC(=O)C=C4CCC3C1CCC2O ZINC03814379 4.056
23 | OCC1OC(OC2C(CO)OC(O)C(O)C2O)C(O)C(O)C1O ZINC04095762 4.282
24 | CC(C)CC(CC[NH+](C(C)C)C(C)C)(C(N)=O)c1ccccn1 ZINC02016048 4.092
25 | C=CC1(C)CC(=O)C2(O)C(C)(O1)C(OC(C)=O)C(OC(=O)CC[NH+](C)C)C1C(C)(C)CCC(O)C12C ZINC38595287 5.519
26 | C=CC[NH+]1CCCC1CNC(=O)c1cc(S(N)(=O)=O)cc(OC)c1OC ZINC00601278 4.286
27 | CC(=O)OC1C[NH+]2CCC1CC2 ZINC00492792 5.711
28 | CC12CCC3C(CCC4CC(=O)CCC43C)C1CCC2O ZINC03814418 3.994
29 | CC1(O)CCC2C3CCC4=CC(=O)CCC4(C)C3CCC21C ZINC03814422 4.022
30 | CC(=O)OC1(C(C)=O)CCC2C3C=C(Cl)C4=CC(=O)C5CC5C4(C)C3CCC21C ZINC03814423 4.827
31 | C#CC1(O)CCC2C3CCc4cc(OC)ccc4C3CCC21C ZINC03815424 3.810
32 | C=CC1(C)CC(OC(=O)CSCC[NH+](CC)CC)C2(C)C3C(=O)CCC3(CCC2C)C(C)C1O ZINC25757051 6.200
33 | O=C([O-])C(=O)Nc1nc(-c2ccc3c(c2)OCCO3)cs1 ZINC03623428 2.594
34 | CC[NH+]1CCCC1CNC(=O)C(O)(c1ccccc1)c1ccccc1 ZINC00900569 3.950
35 | CC(C)(OCc1nn(Cc2ccccc2)c2ccccc12)C(=O)[O-] ZINC00004594 2.573
36 | Cc1nnc(C(C)C)n1C1CC2CCC(C1)[NH+]2CCC(NC(=O)C1CCC(F)(F)CC1)c1ccccc1 ZINC03817234 5.316
37 | Nc1ncnc2c1ncn2C1OC(COP(=O)([O-])OP(=O)([O-])OP(=O)([O-])[O-])C(O)C1O ZINC03871612 5.290
38 | O=C([O-])CNC(=O)c1ccccc1 ZINC00097685 2.097
39 | Nc1ncnc2c1ncn2C1OC(COP(=O)([O-])OP(=O)([O-])OP(=O)([O-])[O-])C(O)C1O ZINC03871613 5.290
40 | Nc1ncnc2c1ncn2C1OC(COP(=O)([O-])OP(=O)([O-])OP(=O)([O-])[O-])C(O)C1O ZINC03871614 5.290
41 | c1ccc(OCc2ccc(CCCN3CCOCC3)cc2)cc1 ZINC19865692 1.702
42 | CC=CC1=C(C(=O)[O-])N2C(=O)C(NC(=O)C(N)c3ccc(O)cc3)C2SC1 ZINC20444132 4.042
43 | C[NH+]1CCCC1COc1cccnc1 ZINC03805141 4.510
44 | O=C([O-])C(O)CC(O)C(O)CO ZINC04803503 4.398
45 | O=C([O-])C(O)CC(O)C(O)CO ZINC01696607 4.398
46 | C[NH+]1CCCC1Cc1c[nH]c2ccc(CCS(=O)(=O)c3ccccc3)cc12 ZINC03823475 3.921
47 | C(=Cc1ccccc1)C[NH+]1CCN(C(c2ccccc2)c2ccccc2)CC1 ZINC19632891 2.973
48 | Nc1ncnc2c1ncn2C1OC(COP(=O)([O-])OP(=O)([O-])OP(=O)([O-])[O-])C(O)C1O ZINC03871615 5.290
49 | CC(c1ccccc1)N(C)C=O ZINC06932229 2.562
50 | CC(=O)C1CCC2C3CCC4CC(C)(O)CCC4(C)C3CCC12C ZINC03824281 4.279
51 | O=C([O-])C(O)CC(O)C(O)CO ZINC04803506 4.398
52 | COc1cc(O)c(C(=O)c2ccccc2)c(O)c1 ZINC00000187 1.868
53 | O=C([O-])C(O)CC(O)C(O)CO ZINC04803507 4.398
54 | COc1c2c(cc3c1C(O)N(C)CC3)OCO2 ZINC00000186 3.183
55 | CCC(C(=O)[O-])c1ccc(CC(C)C)cc1 ZINC00015537 2.827
56 | O=C([O-])C1[NH+]=C(c2ccccc2)c2cc(Cl)ccc2NC1(O)O ZINC38611850 4.011
57 | O=C([O-])C1[NH+]=C(c2ccccc2)c2cc(Cl)ccc2NC1(O)O ZINC38611851 4.011
58 | OCC(O)COc1ccc(Cl)cc1 ZINC00000135 2.102
59 | NC(=O)NC(=O)C(Cl)c1ccccc1 ZINC00000134 2.455
60 | OC(c1ccccc1)(c1ccccc1)C1C[NH+]2CCC1CC2 ZINC01298963 4.530
61 | C[NH2+]CC(C)c1ccccc1 ZINC04298801 3.471
62 | Clc1cccc(Cl)c1N=C1NCCO1 ZINC13835972 3.267
63 | [NH3+]C(Cc1ccccc1)C(=O)CCl ZINC02504633 3.251
64 | CC(C)Cn1cnc2c1c1ccccc1nc2N ZINC19632912 2.230
65 | CC(O)CN(C)c1ccc(NN)nn1 ZINC00000624 3.193
66 | CC1(O)CCC2C3CCC4=CC(=O)CCC4=C3C=CC21C ZINC00001727 4.461
67 | CCC(C(=O)[O-])c1ccc(-c2ccccc2)cc1 ZINC00000111 2.505
68 | CC(=O)OCC1OC(n2ncc(=O)[nH]c2=O)C(OC(C)=O)C1OC(C)=O ZINC03830255 3.832
69 | CC(=O)OCC1OC(n2ncc(=O)[nH]c2=O)C(OC(C)=O)C1OC(C)=O ZINC03830256 3.832
70 | Cn1cc(C(=O)c2cccc3ccccc32)cc1C(=O)[O-] ZINC00001783 2.456
71 | CC(=O)OCC1OC(n2ncc(=O)[nH]c2=O)C(OC(C)=O)C1OC(C)=O ZINC03830257 3.832
72 | Cc1cccc(-c2nc3ccccc3c(Nc3ccc4[nH]ncc4c3)n2)n1 ZINC39279791 2.358
73 | O=C([O-])C1CC2CCCCC2[NH2+]1 ZINC04899687 5.422
74 | CC(=O)OCC(=O)C1CCC2C3CC=C4CC(O)CCC4(C)C3CCC12C ZINC00538219 4.187
75 | O=C([O-])C1CC2CCCCC2[NH2+]1 ZINC04899686 5.422
76 | O=C(OCc1ccccc1)C(O)c1ccccc1 ZINC00000078 2.038
77 | CC(=O)OCC(=O)C1(O)CCC2C3CCC4=CC(=O)C=CC4(C)C3C(O)CC21C ZINC00608041 4.394
78 | Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2)cc1 ZINC02570895 2.144
79 | COCc1cccc(CC(O)C=CC2C(O)CC(=O)C2CCSCCCC(=O)OC)c1 ZINC03940680 3.934
80 | CCC(=O)N(c1ccccc1)C1CC[NH+](C(C)Cc2ccccc2)CC1 ZINC01664586 3.582
81 | CCC(=O)N(c1ccccc1)C1CC[NH+](C(C)Cc2ccccc2)CC1 ZINC01664587 3.582
82 | CCOC(=O)Nc1ccc2c(c1)N(C(=O)CCN1CCOCC1)c1ccccc1S2 ZINC19340795 2.446
83 | O=C([O-])Cc1cc(=O)[nH]c(=O)[nH]1 ZINC00403617 3.258
84 | NC(=O)C([NH3+])Cc1c[nH]c2ccccc12 ZINC04899521 3.224
85 | NC(=O)C([NH3+])Cc1ccc(O)cc1 ZINC04899513 3.280
86 | O=C(c1cc2ccccc2o1)N1CCN(Cc2ccccc2)CC1 ZINC19632922 1.799
87 | O=C(CO)C(O)C(O)CO ZINC00902219 3.473
88 | CC(Cc1ccccc1)NC(=O)C([NH3+])CCCC[NH3+] ZINC11680943 3.967
89 | C[NH+]1CCC(c2c(O)cc(=O)c3c(O)cc(-c4ccccc4Cl)oc2-3)C(O)C1 ZINC05966679 4.616
90 | CN(C)c1ccc(O)c2c1CC1CC3C([NH+](C)C)C(=O)C(C(N)=O)=C(O)C3(O)C(=O)C1=C2O ZINC04019704 4.713
91 | Cc1cc2nc3c(=O)[nH]c(=O)nc-3n(CC(O)C(O)C(O)CO)c2cc1C ZINC03650334 3.791
92 | C[NH+]1C2CCC1CC(OC(=O)c1c[nH]c3ccccc13)C2 ZINC18130447 4.892
93 | Cc1ccccc1NC(=O)C(C)[NH+]1CCCC1 ZINC00000051 3.809
94 | O=S(=O)([O-])CCN1CCOCC1 ZINC19419111 2.776
95 | C[NH+]1CCN(CC(=O)N2c3ccccc3C(=O)Nc3cccnc32)CC1 ZINC19632927 3.379
96 | CCCCCC=CCC=CCCCCCCCC(=O)[O-] ZINC03802188 2.805
97 | CC(CC([NH3+])C(=O)[O-])C(=O)[O-] ZINC01747048 5.690
98 | CC1c2cccc(O)c2C(=O)C2=C(O)C3(O)C(O)=C(C(N)=O)C(=O)C([NH+](C)C)C3C(O)C21 ZINC04019706 5.069
99 | Cc1cc2nc3nc([O-])[nH]c(=O)c3nc2cc1C ZINC12446789 3.079
100 | CC1=CC(C)C2(CO)COC(c3ccc(O)cc3)C1C2C ZINC38190856 4.749
101 | CC[NH+]1CCC(=C2c3ccccc3CCc3ccccc32)C1C ZINC02020004 3.925
102 |
--------------------------------------------------------------------------------
/genmol/ORGAN/SA_Score/fpscores.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bayeslabs/genmol/b783aa41f4989bbdbfe2038dd9433dcb49b4a3b3/genmol/ORGAN/SA_Score/fpscores.pkl.gz
--------------------------------------------------------------------------------
/genmol/ORGAN/SA_Score/sascorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of synthetic accessibility score as described in:
3 | #
4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on
5 | # Molecular Complexity and Fragment Contributions
6 | # Peter Ertl and Ansgar Schuffenhauer
7 | # Journal of Cheminformatics 1:8 (2009)
8 | # http://www.jcheminf.com/content/1/1/8
9 | #
10 | # several small modifications to the original paper are included
11 | # particularly slightly different formula for marocyclic penalty
12 | # and taking into account also molecule symmetry (fingerprint density)
13 | #
14 | # for a set of 10k diverse molecules the agreement between the original method
15 | # as implemented in PipelinePilot and this implementation is r2 = 0.97
16 | #
17 | # peter ertl & greg landrum, september 2013
18 | #
19 | from __future__ import print_function
20 |
21 | import math
22 | import os.path as op
23 |
24 | from rdkit import Chem
25 | from rdkit.Chem import rdMolDescriptors
26 | from rdkit.six import iteritems
27 | import pickle
28 |
29 | _fscores = None
30 |
31 |
32 | def readFragmentScores(name='fpscores'):
33 | import gzip
34 | global _fscores
35 | # generate the full path filename:
36 | if name == "fpscores":
37 | name = op.join(op.dirname(__file__), name)
38 | _fscores = pickle.load(gzip.open('%s.pkl.gz' % name))
39 | outDict = {}
40 | for i in _fscores:
41 | for j in range(1, len(i)):
42 | outDict[i[j]] = float(i[0])
43 | _fscores = outDict
44 |
45 |
46 | def numBridgeheadsAndSpiro(mol, ri=None):
47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
49 | return nBridgehead, nSpiro
50 |
51 |
52 | def calculateScore(m):
53 | if _fscores is None:
54 | readFragmentScores()
55 |
56 | # fragment score
57 | fp = rdMolDescriptors.GetMorganFingerprint(
58 | m, 2 # <- 2 is the *radius* of the circular fingerprint
59 | )
60 | fps = fp.GetNonzeroElements()
61 | score1 = 0.
62 | nf = 0
63 | for bitId, v in iteritems(fps):
64 | nf += v
65 | sfp = bitId
66 | score1 += _fscores.get(sfp, -4) * v
67 | score1 /= nf
68 |
69 | # features score
70 | nAtoms = m.GetNumAtoms()
71 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
72 | ri = m.GetRingInfo()
73 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
74 | nMacrocycles = 0
75 | for x in ri.AtomRings():
76 | if len(x) > 8:
77 | nMacrocycles += 1
78 |
79 | sizePenalty = nAtoms ** 1.005 - nAtoms
80 | stereoPenalty = math.log10(nChiralCenters + 1)
81 | spiroPenalty = math.log10(nSpiro + 1)
82 | bridgePenalty = math.log10(nBridgeheads + 1)
83 | macrocyclePenalty = 0.
84 | # ---------------------------------------
85 | # This differs from the paper, which defines:
86 | # macrocyclePenalty = math.log10(nMacrocycles+1)
87 | # This form generates better results when 2 or more macrocycles are present
88 | if nMacrocycles > 0:
89 | macrocyclePenalty = math.log10(2)
90 |
91 | score2 = (0. - sizePenalty - stereoPenalty -
92 | spiroPenalty - bridgePenalty - macrocyclePenalty)
93 |
94 | # correction for the fingerprint density
95 | # not in the original publication, added in version 1.1
96 | # to make highly symmetrical molecules easier to synthetise
97 | score3 = 0.
98 | if nAtoms > len(fps):
99 | score3 = math.log(float(nAtoms) / len(fps)) * .5
100 |
101 | sascore = score1 + score2 + score3
102 |
103 | # need to transform "raw" value into scale between 1 and 10
104 | min = -4.0
105 | max = 2.5
106 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
107 | # smooth the 10-end
108 | if sascore > 8.:
109 | sascore = 8. + math.log(sascore + 1. - 9.)
110 | if sascore > 10.:
111 | sascore = 10.0
112 | elif sascore < 1.:
113 | sascore = 1.0
114 |
115 | return sascore
116 |
117 |
118 | def processMols(mols):
119 | print('smiles\tName\tsa_score')
120 | for i, m in enumerate(mols):
121 | if m is None:
122 | continue
123 |
124 | s = calculateScore(m)
125 |
126 | smiles = Chem.MolToSmiles(m)
127 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
128 |
129 |
130 | if __name__ == '__main__':
131 | import sys
132 | import time
133 |
134 | t1 = time.time()
135 | readFragmentScores("fpscores")
136 | t2 = time.time()
137 |
138 | suppl = Chem.SmilesMolSupplier(sys.argv[1])
139 | t3 = time.time()
140 | processMols(suppl)
141 | t4 = time.time()
142 |
143 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % (
144 | (t2 - t1), (t4 - t3)),
145 | file=sys.stderr)
146 |
147 | #
148 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
149 | # All rights reserved.
150 | #
151 | # Redistribution and use in source and binary forms, with or without
152 | # modification, are permitted provided that the following conditions are
153 | # met:
154 | #
155 | # * Redistributions of source code must retain the above copyright
156 | # notice, this list of conditions and the following disclaimer.
157 | # * Redistributions in binary form must reproduce the above
158 | # copyright notice, this list of conditions and the following
159 | # disclaimer in the documentation and/or other materials provided
160 | # with the distribution.
161 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
162 | # nor the names of its contributors may be used to endorse or promote
163 | # products derived from this software without specific prior written
164 | # permission.
165 | #
166 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
167 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
168 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
169 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
170 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
171 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
172 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
173 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
174 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
175 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
176 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
177 | #
178 |
--------------------------------------------------------------------------------
/genmol/ORGAN/Trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from tqdm import tqdm
5 | from torch.nn.utils.rnn import pad_sequence
6 | from torch.utils.data import DataLoader
7 | from torch.optim import Adam
8 | import random
9 |
10 | from Data import *
11 |
12 | n_batch = 64
13 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
14 | discriminator_pretrain_epochs = 50
15 | discriminator_epochs = 10
16 | generator_pretrain_epochs = 50
17 | max_length = 100
18 | save_frequency = 25
19 | generator_updates = 1
20 | discriminator_updates = 1
21 | n_samples = 64
22 | n_rollouts = 16
23 | pg_iters = 10
24 |
25 | class PolicyGradientLoss(nn.Module):
26 | def forward(self, outputs, targets, rewards, lengths):
27 | log_probs = F.log_softmax(outputs, dim=2)
28 | items = torch.gather(
29 | log_probs, 2, targets.unsqueeze(2)
30 | ) * rewards.unsqueeze(2)
31 | loss = -sum(
32 | [t[:l].sum() for t, l in zip(items, lengths)]
33 | ) / lengths.sum().float()
34 | return loss
35 |
36 |
37 | def generator_collate_fn(model):
38 | def collate(data):
39 | data.sort(key=len, reverse=True)
40 | tensors = [model.string2tensor(string)
41 | for string in data]
42 |
43 | prevs = pad_sequence(
44 | [t[:-1] for t in tensors],
45 | batch_first=True, padding_value=c2i['']
46 | )
47 | nexts = pad_sequence(
48 | [t[1:] for t in tensors],
49 | batch_first=True, padding_value=c2i['']
50 | )
51 | lens = torch.tensor(
52 | [len(t) - 1 for t in tensors],
53 | dtype=torch.long, device=device)
54 | return prevs, nexts, lens
55 |
56 | return collate
57 |
58 |
59 | def get_dataloader(training_data, collate_fn):
60 | return DataLoader(training_data, batch_size=n_batch,
61 | shuffle=True, num_workers=8, collate_fn=collate_fn, worker_init_fn=None)
62 |
63 |
64 | def _pretrain_generator_epoch(model, tqdm_data, criterion, optimizer):
65 | model.discriminator.eval()
66 | if optimizer is None:
67 | model.eval()
68 | else:
69 | model.train()
70 |
71 | postfix = {'loss': 0, 'running_loss': 0}
72 |
73 | for i, batch in enumerate(tqdm_data):
74 | (prevs, nexts, lens) = (data.to(device) for data in batch)
75 | outputs, _, _, = model.generator_forward(prevs, lens)
76 |
77 | loss = criterion(outputs.view(-1, outputs.shape[-1]),
78 | nexts.view(-1))
79 |
80 | if optimizer is not None:
81 | optimizer.zero_grad()
82 | loss.backward()
83 | optimizer.step()
84 |
85 | postfix['loss'] = loss.item()
86 | postfix['running_loss'] += (
87 | loss.item() - postfix['running_loss']
88 | ) / (i + 1)
89 | tqdm_data.set_postfix(postfix)
90 |
91 | postfix['mode'] = ('Pretrain: eval generator'
92 | if optimizer is None
93 | else 'Pretrain: train generator')
94 | return postfix
95 |
96 |
97 | def _pretrain_generator(model, train_loader):
98 | generator = model.generator
99 | criterion = nn.CrossEntropyLoss(ignore_index=c2i[''])
100 | optimizer = torch.optim.Adam(model.generator.parameters(), lr=1e-4)
101 |
102 | model.zero_grad()
103 | for epoch in range(generator_pretrain_epochs):
104 | tqdm_data = tqdm(train_loader, desc='Generator training (epoch #{})'.format(epoch))
105 | postfix = _pretrain_generator_epoch(model, tqdm_data, criterion, optimizer)
106 | if epoch % save_frequency == 0:
107 | generator = generator.to('cpu')
108 | torch.save(generator.state_dict(), 'model.csv'[:-4] +
109 | '_generator_{0:03d}.csv'.format(epoch))
110 | generator = generator.to(device)
111 |
112 |
113 | def discriminator_collate_fn(model):
114 | def collate(data):
115 | data.sort(key=len, reverse=True)
116 | tensors = [model.string2tensor(string) for string in data]
117 | inputs = pad_sequence(tensors, batch_first=True, padding_value=c2i[''])
118 |
119 | return inputs
120 |
121 | return collate
122 |
123 |
124 | def _pretrain_discriminator_epoch(model, tqdm_data,
125 | criterion, optimizer=None):
126 | model.eval()
127 | if optimizer is None:
128 | model.eval()
129 | else:
130 | model.train()
131 |
132 | postfix = {'loss': 0,
133 | 'running_loss': 0}
134 | for i, inputs_from_data in enumerate(tqdm_data):
135 | inputs_from_data = inputs_from_data.to(device)
136 | inputs_from_model, _ = model.sample_tensor(n_batch, 100)
137 |
138 | targets = torch.zeros(n_batch, 1, device=device)
139 | outputs = model.discriminator_forward(inputs_from_model)
140 | loss = criterion(outputs, targets) / 2
141 |
142 | targets = torch.ones(inputs_from_data.shape[0], 1, device=device)
143 | outputs = model.discriminator_forward(inputs_from_data)
144 | loss += criterion(outputs, targets) / 2
145 |
146 | if optimizer is not None:
147 | optimizer.zero_grad()
148 | loss.backward()
149 | optimizer.step()
150 |
151 | postfix['loss'] = loss.item()
152 | postfix['running_loss'] += (loss.item() -
153 | postfix['running_loss']) / (i + 1)
154 | tqdm_data.set_postfix(postfix)
155 |
156 | postfix['mode'] = ('Pretrain: eval discriminator'
157 | if optimizer is None
158 | else 'Pretrain: train discriminator')
159 | return postfix
160 |
161 |
162 | def _pretrain_discriminator(model, train_loader):
163 | discriminator = model.discriminator
164 | criterion = nn.BCEWithLogitsLoss()
165 | optimizer = torch.optim.Adam(model.discriminator.parameters(),
166 | lr=1e-4)
167 |
168 | model.zero_grad()
169 | for epoch in range(discriminator_pretrain_epochs):
170 | tqdm_data = tqdm(
171 | train_loader,
172 | desc='Discriminator training (epoch #{})'.format(epoch)
173 | )
174 | postfix = _pretrain_discriminator_epoch(
175 | model, tqdm_data, criterion, optimizer
176 | )
177 | if epoch % save_frequency == 0:
178 | discriminator = discriminator.to('cpu')
179 | torch.save(discriminator.state_dict(), 'model.csv'[:-4] + '_discriminator_{0:03d}.csv'.format(epoch))
180 | discriminator = discriminator.to(device)
181 |
182 |
183 | def _policy_gradient_iter(model, train_loader, criterion, optimizer, iter_, ref_smiles, ref_mols):
184 | smooth = 0.1
185 |
186 | # Generator
187 | gen_postfix = {'generator_loss': 0,
188 | 'smoothed_reward': 0}
189 |
190 | gen_tqdm = tqdm(range(generator_updates),
191 | desc='PG generator training (iter #{})'.format(iter_))
192 | for _ in gen_tqdm:
193 | model.eval()
194 | sequences, rewards, lengths = model.rollout(ref_smiles, ref_mols, n_samples=n_samples,
195 | n_rollouts=n_rollouts, max_len=max_length)
196 | model.train()
197 |
198 | lengths, indices = torch.sort(lengths, descending=True)
199 | sequences = sequences[indices, ...]
200 | rewards = rewards[indices, ...]
201 |
202 | generator_outputs, lengths, _ = model.generator_forward(
203 | sequences[:, :-1], lengths - 1
204 | )
205 | generator_loss = criterion['generator'](
206 | generator_outputs, sequences[:, 1:], rewards, lengths
207 | )
208 |
209 | optimizer['generator'].zero_grad()
210 | generator_loss.backward()
211 | nn.utils.clip_grad_value_(model.generator.parameters(), clip_value=5)
212 | optimizer['generator'].step()
213 |
214 | gen_postfix['generator_loss'] += (
215 | generator_loss.item() -
216 | gen_postfix['generator_loss']
217 | ) * smooth
218 | mean_episode_reward = torch.cat(
219 | [t[:l] for t, l in zip(rewards, lengths)]
220 | ).mean().item()
221 | gen_postfix['smoothed_reward'] += (
222 | mean_episode_reward - gen_postfix['smoothed_reward']
223 | ) * smooth
224 | gen_tqdm.set_postfix(gen_postfix)
225 |
226 | # Discriminator
227 | discrim_postfix = {'discrim-r_loss': 0}
228 | discrim_tqdm = tqdm(
229 | range(discriminator_updates),
230 | desc='PG discrim-r training (iter #{})'.format(iter_)
231 | )
232 | for _ in discrim_tqdm:
233 | model.generator.eval()
234 | n_batches = (
235 | len(train_loader) + n_batch - 1
236 | ) // n_batch
237 | sampled_batches = [
238 | model.sample_tensor(n_batch,
239 | max_length=max_length)[0]
240 | for _ in range(n_batches)
241 | ]
242 |
243 | for _ in range(discriminator_epochs):
244 | random.shuffle(sampled_batches)
245 |
246 | for inputs_from_model, inputs_from_data in zip(
247 | sampled_batches, train_loader
248 | ):
249 | # print(inputs_from_model)
250 | inputs_from_data = inputs_from_data.to(device)
251 | print(inputs_from_data)
252 |
253 | discrim_outputs = model.discriminator_forward(
254 | inputs_from_model
255 | )
256 | discrim_targets = torch.zeros(len(discrim_outputs),
257 | 1, device=device)
258 | discrim_loss = criterion['discriminator'](
259 | discrim_outputs, discrim_targets
260 | ) / 2
261 |
262 | discrim_outputs = model.discriminator.forward(
263 | inputs_from_data)
264 | discrim_targets = torch.ones(
265 | len(discrim_outputs), 1, device=device)
266 | discrim_loss += criterion['discriminator'](
267 | discrim_outputs, discrim_targets
268 | ) / 2
269 | optimizer['discriminator'].zero_grad()
270 | discrim_loss.backward()
271 | optimizer['discriminator'].step()
272 |
273 | discrim_postfix['discrim-r_loss'] += (
274 | discrim_loss.item() -
275 | discrim_postfix['discrim-r_loss']
276 | ) * smooth
277 |
278 | discrim_tqdm.set_postfix(discrim_postfix)
279 |
280 | postfix = {**gen_postfix, **discrim_postfix}
281 | postfix['mode'] = 'Policy Gradient (iter #{})'.format(iter_)
282 | return postfix
283 |
284 |
285 | def _train_policy_gradient(model, pg_train_loader, ref_smiles, ref_mols):
286 | criterion = {
287 | 'generator': PolicyGradientLoss(),
288 | 'discriminator': nn.BCEWithLogitsLoss(),
289 | }
290 |
291 | optimizer = {
292 | 'generator': torch.optim.Adam(model.generator.parameters(),
293 | lr=1e-4),
294 | 'discriminator': torch.optim.Adam(
295 | model.discriminator.parameters(), lr=1e-4)
296 | }
297 | ref_smiles = ref_smiles
298 | ref_mols = ref_mols
299 | model.zero_grad()
300 | for iter_ in range(pg_iters):
301 | postfix = _policy_gradient_iter(model, pg_train_loader, criterion, optimizer, iter_, ref_smiles, ref_mols)
302 |
303 |
304 | def fit(model, train_data):
305 | # Generator
306 | gen_collate_fn = generator_collate_fn(model)
307 | gen_train_loader = get_dataloader(train_data, gen_collate_fn)
308 | _pretrain_generator(model, gen_train_loader)
309 |
310 | # Discriminator
311 | dsc_collate_fn = discriminator_collate_fn(model)
312 | desc_train_loader = get_dataloader(train_data, dsc_collate_fn)
313 | _pretrain_discriminator(model, desc_train_loader)
314 |
315 | # Policy Gradient
316 | if model.metrics_reward is not None:
317 | (ref_smiles, ref_mols) = model.metrics_reward.get_reference_data(train_data)
318 |
319 | pg_train_loader = desc_train_loader
320 | _train_policy_gradient(model, pg_train_loader, ref_smiles, ref_mols)
321 |
322 | del ref_smiles
323 | del ref_mols
324 | #
325 | return model
326 |
--------------------------------------------------------------------------------
/genmol/ORGAN/mcf.csv:
--------------------------------------------------------------------------------
1 | names,smarts
2 | MCF1,[#6]=&!@[#6]-[#6]#[#7]
3 | MCF2,[#6]=&!@[#6]-[#16](=[#8])=[#8]
4 | MCF3,[#6]=&!@[#6&!H0]-&!@[#6](=[#8])-&!@[#7]
5 | MCF4,"[H]C([H])([#6])[F,Cl,Br,I]"
6 | MCF5,[#6]1-[#8]-[#6]-1
7 | MCF6,[#6]-[#7]=[#6]=[#8]
8 | MCF7,[#6&!H0]=[#8]
9 | MCF8,"[#6](=&!@[#7&!H0])-&!@[#6,#7,#8,#16]"
10 | MCF9,[#6]1-[#7]-[#6]-1
11 | MCF10,[#6]~&!@[#7]~&!@[#7]~&!@[#6]
12 | MCF11,[#7]=&!@[#7]
13 | MCF12,[H][#6]-1=[#6]([H])-[#6]=[#6](-*)-[#8]-1
14 | MCF13,[H][#6]-1=[#6]([H])-[#6]=[#6](-*)-[#16]-1
15 | MCF14,"[#17,#35,#53]-c(:*):[!#1!#6]:*"
16 | MCF15,[H][#7]([H])-[#6]-1=[#6]-[#6]=[#6]-[#6]=[#6]-1
17 | MCF16,[#16]~[#16]
18 | MCF17,[#7]~&!@[#7]~&!@[#7]
19 | MCF18,[#7]-&!@[#6&!H0&!H1]-&!@[#7]
20 | MCF19,[#6&!H0](-&!@[#8])-&!@[#8]
21 | MCF20,[#35].[#35].[#35]
22 | MCF21,[#17].[#17].[#17].[#17]
23 | MCF22,[#9].[#9].[#9].[#9].[#9].[#9].[#9]
24 |
--------------------------------------------------------------------------------
/genmol/ORGAN/test.py:
--------------------------------------------------------------------------------
1 | from Data import *
2 | import unittest
3 | import tqdm as tqdm
4 | from Metrics_Reward import *
5 | from Model import ORGAN
6 |
7 | class test_metrics(unittest.TestCase):
8 | test = ['Oc1ccccc1-c1cccc2cnccc12',
9 | 'COc1cccc(NC(=O)Cc2coc3ccc(OC)cc23)c1']
10 | test_sf = ['COCc1nnc(NC(=O)COc2ccc(C(C)(C)C)cc2)s1',
11 | 'O=C(C1CC2C=CC1C2)N1CCOc2ccccc21',
12 | 'Nc1c(Br)cccc1C(=O)Nc1ccncn1']
13 | gen = ['CNC', 'Oc1ccccc1-c1cccc2cnccc12',
14 | 'INVALID', 'CCCP',
15 | 'Cc1noc(C)c1CN(C)C(=O)Nc1cc(F)cc(F)c1',
16 | 'Cc1nc(NCc2ccccc2)no1-c1ccccc1']
17 | target = {'valid': 2 / 3,
18 | 'unique@3': 1.0,
19 | 'FCD/Test': 52.58371754126664,
20 | 'SNN/Test': 0.3152585653588176,
21 | 'Frag/Test': 0.3,
22 | 'Scaf/Test': 0.5,
23 | 'IntDiv': 0.7189187309761661,
24 | 'Filters': 0.75,
25 | 'logP': 4.9581881764518005,
26 | 'SA': 0.5086898026154574,
27 | 'QED': 0.045033731661603064,
28 | 'NP': 0.2902816615644048,
29 | 'weight': 14761.927533455337}
30 |
31 | def test_get_all_metrics_multiprocess(self):
32 | metrics = get_all_metrics(test_data, samples, k=3)
33 | fail = set()
34 | for metric in self.target:
35 | if not np.allclose(metrics[metric], self.target[metric]):
36 | warnings.warn(
37 | "Metric `{}` value does not match expected "
38 | "value. Got {}, expected {}".format(metric,
39 | metrics[metric],
40 | self.target[metric])
41 | )
42 | fail.add(metric)
43 | assert len(fail) == 0, f"Some metrics didn't pass tests: {fail}"
44 |
45 | def test_get_all_metrics_scaffold(self):
46 | get_all_metrics(self.test, self.gen,
47 | test_scaffolds=self.test_sf,
48 | k=3, n_jobs=2)
49 | mols = ['CCNC', 'CCC', 'INVALID', 'CCC']
50 | assert np.allclose(fraction_valid(mols), 3 / 4), "Failed valid"
51 | assert np.allclose(fraction_unique(mols, check_validity=False),
52 | 3 / 4), "Failed unique"
53 | assert np.allclose(fraction_unique(mols, k=2), 1), "Failed unique"
54 | mols = [Chem.MolFromSmiles(x) for x in mols]
55 | assert np.allclose(fraction_valid(mols), 3 / 4), "Failed valid"
56 | assert np.allclose(fraction_unique(mols, check_validity=False),
57 | 3 / 4), "Failed unique"
58 | assert np.allclose(fraction_unique(mols, k=2), 1), "Failed unique"
59 |
60 | def sampler(model):
61 | n_samples = 100000
62 | samples = []
63 | with tqdm(total=n_samples, desc='Generating Samples')as T:
64 | while n_samples > 0:
65 | current_samples = model.sample(min(n_samples, 64), max_length=100)
66 | samples.extend(current_samples)
67 | n_samples -= len(current_samples)
68 | T.update(len(current_samples))
69 |
70 | return samples
71 |
72 |
73 | def evaluate(test, samples, test_scaffolds=None, ptest=None, ptest_scaffolds=None):
74 | gen = samples
75 | k = [50, 99]
76 | n_jobs = 1
77 | batch_size = 20
78 | ptest = ptest
79 | ptest_scaffolds = 20
80 | pool = None
81 | gpu = None
82 | metrics = get_all_metrics(test, gen, k, n_jobs, device, batch_size, test_scaffolds, ptest, ptest_scaffolds)
83 | for name, value in metrics.items():
84 | print('{}, {}'.format(name, value))
85 |
86 |
87 | model = ORGAN()
88 | samples = sampler(model)
89 | model = model.to(device)
90 | evaluate(test_data, samples)
91 |
--------------------------------------------------------------------------------
/genmol/aae/data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 |
4 | data = pd.read_csv('C:/Users/ASUS\Desktop/intern things/dataset_iso_v1.csv')
5 | train_data1 = data[data['SPLIT'] == 'train']
6 | train_data_smiles2 = (train_data1["SMILES"].squeeze()).astype(str).tolist()
7 | train_data = train_data_smiles2
8 |
9 | chars = set()
10 | for string in train_data:
11 | chars.update(string)
12 | all_sys = sorted(list(chars)) + ['', '', '', '']
13 | vocab = all_sys
14 | c2i = {c: i for i, c in enumerate(all_sys)}
15 | i2c = {i: c for i, c in enumerate(all_sys)}
16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 | vector = torch.eye(len(c2i))
18 |
19 |
20 | def char2id(char):
21 | if char not in c2i:
22 | return c2i['']
23 | else:
24 | return c2i[char]
25 |
26 |
27 | def id2char(id):
28 | if id not in i2c:
29 | return i2c[32]
30 | else:
31 | return i2c[id]
32 |
33 | def string2ids(string,add_bos=False, add_eos=False):
34 | ids = [char2id(c) for c in string]
35 | if add_bos:
36 | ids = [c2i['']] + ids
37 | if add_eos:
38 | ids = ids + [c2i['']]
39 | return ids
40 | def ids2string(ids, rem_bos=True, rem_eos=True):
41 | if len(ids) == 0:
42 | return ''
43 | if rem_bos and ids[0] == c2i['']:
44 | ids = ids[1:]
45 | if rem_eos and ids[-1] == c2i['']:
46 | ids = ids[:-1]
47 | string = ''.join([id2char(id) for id in ids])
48 | return string
49 | def string2tensor(string, device='model'):
50 | ids = string2ids(string, add_bos=True, add_eos=True)
51 | tensor = torch.tensor(ids, dtype=torch.long,device=device if device == 'model' else device)
52 | return tensor
53 |
54 | vector = torch.eye(len(c2i))
--------------------------------------------------------------------------------
/genmol/aae/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import pandas as pd
4 | from torch.nn.utils.rnn import pad_sequence
5 | import torch.nn as nn
6 |
7 |
8 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
9 |
10 |
11 | from data import *
12 |
13 |
14 | emb_dim = 30
15 | hidden_dim = 64
16 | latent_dim = 4
17 | disc_input = 64
18 | disc_output = 84
19 | batch_size = 50
20 |
21 |
22 | class encoder(nn.Module):
23 | def __init__(self, vocab, emb_dim, hidden_dim, latent_dim):
24 | super(encoder, self).__init__()
25 | self.hidden_dim = hidden_dim
26 | self.latent_dim = latent_dim
27 | self.emb_dim = emb_dim
28 | self.vocab = vocab
29 |
30 | self.embeddings_layer = nn.Embedding(len(vocab), emb_dim, padding_idx=c2i[''])
31 |
32 | self.rnn = nn.LSTM(emb_dim, hidden_dim)
33 | self.fc = nn.Linear(hidden_dim, latent_dim)
34 | self.relu = nn.ReLU()
35 | nn.Drop = nn.Dropout(p=0.25)
36 |
37 | def forward(self, x, lengths):
38 | batch_size = x.shape[0]
39 |
40 | x = self.embeddings_layer(x)
41 | x = pack_padded_sequence(x, lengths, batch_first=True)
42 | output, (_, x) = self.rnn(x)
43 |
44 | x = x.permute(1, 2, 0).view(batch_size, -1)
45 | x = self.fc(x)
46 | state = self.relu(x)
47 | return state
48 |
49 |
50 | class decoder(nn.Module):
51 | def __init__(self, vocab, emb_dim, latent_dim, hidden_dim):
52 | super(decoder, self).__init__()
53 | self.latent_dim = latent_dim
54 | self.hidden_dim = hidden_dim
55 | self.emb_dim = emb_dim
56 | self.vocab = vocab
57 |
58 | self.latent = nn.Linear(latent_dim, hidden_dim)
59 | self.embeddings_layer = nn.Embedding(len(vocab), emb_dim, padding_idx=c2i[''])
60 | self.rnn = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
61 | self.fc = nn.Linear(hidden_dim, len(vocab))
62 |
63 | def forward(self, x, lengths, state, is_latent_state=False):
64 | if is_latent_state:
65 | c0 = self.latent(state)
66 |
67 | c0 = c0.unsqueeze(0)
68 | h0 = torch.zeros_like(c0)
69 |
70 | state = (h0, c0)
71 |
72 | x = self.embeddings_layer(x)
73 |
74 | x = pack_padded_sequence(x, lengths, batch_first=True)
75 |
76 | x, state = self.rnn(x, state)
77 |
78 | x, lengths = pad_packed_sequence(x, batch_first=True)
79 | x = self.fc(x)
80 |
81 | return x, lengths, state
82 |
83 |
84 | class Discriminator(nn.Module):
85 | def __init__(self, latent_dim, disc_input, disc_output):
86 | super(Discriminator, self).__init__()
87 | self.latent_dim = latent_dim
88 | self.disc_input = disc_input
89 | self.disc_output = disc_output
90 |
91 | self.lin1 = nn.Linear(latent_dim, disc_input)
92 | self.lin2 = nn.Linear(disc_input, disc_output)
93 | self.lin3 = nn.Linear(disc_output, 1)
94 | self.sig = nn.Sigmoid()
95 |
96 | def forward(self, x):
97 | x = self.lin1(x)
98 | x = self.lin2(x)
99 | x = self.lin3(x)
100 |
101 | x = self.sig(x)
102 | return x
103 | class AAE(nn.Module):
104 | def __init__(self):
105 | super(AAE,self).__init__()
106 | self.encoder = encoder(vocab,emb_dim,hidden_dim,latent_dim)
107 | self.decoder = decoder(vocab,emb_dim,latent_dim,hidden_dim)
108 | self.discriminator = Discriminator(latent_dim,disc_input,disc_output)
109 |
--------------------------------------------------------------------------------
/genmol/aae/run.py:
--------------------------------------------------------------------------------
1 | from data import *
2 | from model import *
3 | from train import *
4 | from sample import *
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 | model = AAE().to(device)
8 | fit(model,train_data)
9 | model.eval()
10 | get_samples(model)
11 |
12 |
--------------------------------------------------------------------------------
/genmol/aae/sample.py:
--------------------------------------------------------------------------------
1 | from model import *
2 | from tqdm import tqdm
3 | from data import *
4 | def sample(model,n_batch, max_len=100):
5 | with torch.no_grad():
6 | samples = []
7 | lengths = torch.zeros(n_batch, dtype=torch.long, device=device)
8 | state = sample_latent(n_batch)
9 | prevs = torch.empty(n_batch, 1, dtype=torch.long, device=device).fill_(c2i[""])
10 | one_lens = torch.ones(n_batch, dtype=torch.long, device=device)
11 | is_end = torch.zeros(n_batch, dtype=torch.uint8, device=device)
12 | for i in range(max_len):
13 | logits, _, state = model.decoder(prevs, one_lens, state, i == 0)
14 | currents = torch.argmax(logits, dim=-1)
15 | is_end[currents.view(-1) == c2i[""]] = 1
16 | if is_end.sum() == max_len:
17 | break
18 |
19 | currents[is_end, :] = c2i[""]
20 | samples.append(currents)
21 | lengths[~is_end] += 1
22 | prevs = currents
23 | if len(samples):
24 | samples = torch.cat(samples, dim=-1)
25 | samples = [tensor2string(t[:l]) for t, l in zip(samples, lengths)]
26 | else:
27 | samples = ['' for _ in range(n_batch)]
28 | return samples
29 |
30 |
31 | def get_samples(model):
32 | samples = []
33 | n = 300
34 | max_len = 100
35 | with tqdm(total=300, desc='Generating samples') as T:
36 | while n > 0:
37 | current_samples = sample(model,min(n, batch_size), max_len)
38 | samples.extend(current_samples)
39 | n -= len(current_samples)
40 | T.update(len(current_samples))
41 | print(samples)
42 |
43 | def tensor2string(tensor):
44 | ids = tensor.tolist()
45 | string = ids2string(ids, rem_bos=True, rem_eos=True)
46 | return string
47 | def sample_latent(n):
48 | return torch.randn(n,latent_dim)
--------------------------------------------------------------------------------
/genmol/aae/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.utils.rnn import pad_sequence
3 | import torch.nn as nn
4 | from torch.utils.data import DataLoader
5 |
6 | import torch.nn.functional as F
7 | from model import *
8 |
9 | def pretrain(model, train_loader):
10 | criterion = nn.CrossEntropyLoss()
11 | optimizer = torch.optim.Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=0.001)
12 | model.zero_grad()
13 | for epoch in range(4):
14 | if optimizer is None:
15 | model.train()
16 | else:
17 | model.eval()
18 | for i, (encoder_inputs, decoder_inputs, decoder_targets) in enumerate(train_loader):
19 | encoder_inputs = (data.to(device) for data in encoder_inputs)
20 | decoder_inputs = (data.to(device) for data in decoder_inputs)
21 | decoder_targets = (data.to(device) for data in decoder_targets)
22 |
23 | latent_code = model.encoder(*encoder_inputs)
24 | decoder_output, decoder_output_lengths, states = model.decoder(*decoder_inputs, latent_code,
25 | is_latent_state=True)
26 |
27 | decoder_outputs = torch.cat([t[:l] for t, l in zip(decoder_output, decoder_output_lengths)], dim=0)
28 | decoder_targets = torch.cat([t[:l] for t, l in zip(*decoder_targets)], dim=0)
29 | loss = criterion(decoder_outputs, decoder_targets)
30 |
31 | if optimizer is not None:
32 | optimizer.zero_grad()
33 | loss.backward()
34 | optimizer.step()
35 |
36 |
37 | def train(model, train_loader):
38 | criterion = {"enc": nn.CrossEntropyLoss(), "gen": lambda t: -torch.mean(F.logsigmoid(t)),"disc": nn.BCEWithLogitsLoss()}
39 |
40 | optimizers = {'auto': torch.optim.Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=0.001),
41 | 'gen': torch.optim.Adam(model.encoder.parameters(), lr=0.001),
42 | 'disc': torch.optim.Adam(model.discriminator.parameters(), lr=0.001)}
43 |
44 | model.zero_grad()
45 | for epoch in range(10):
46 | if optimizers is None:
47 | model.train()
48 | else:
49 | model.eval()
50 |
51 | for i, (encoder_inputs, decoder_inputs, decoder_targets) in enumerate(train_loader):
52 | encoder_inputs = (data.to(device) for data in encoder_inputs)
53 | decoder_inputs = (data.to(device) for data in decoder_inputs)
54 | decoder_targets = (data.to(device) for data in decoder_targets)
55 |
56 | latent_code = model.encoder(*encoder_inputs)
57 | decoder_output, decoder_output_lengths, states = model.decoder(*decoder_inputs, latent_code,
58 | is_latent_state=True)
59 | discriminator_output = model.discriminator(latent_code)
60 |
61 | decoder_outputs = torch.cat([t[:l] for t, l in zip(decoder_output, decoder_output_lengths)], dim=0)
62 | decoder_targets = torch.cat([t[:l] for t, l in zip(*decoder_targets)], dim=0)
63 |
64 | autoencoder_loss = criterion["enc"](decoder_outputs, decoder_targets)
65 | generation_loss = criterion["gen"](discriminator_output)
66 |
67 | if i % 2 == 0:
68 | discriminator_input = torch.randn(batch_size, latent_dim)
69 | discriminator_output = model.discriminator(discriminator_input)
70 | discriminator_targets = torch.ones(batch_size, 1)
71 | else:
72 | discriminator_targets = torch.zeros(batch_size, 1)
73 | discriminator_loss = criterion["disc"](discriminator_output, discriminator_targets)
74 |
75 | if optimizers is not None:
76 | optimizers["auto"].zero_grad()
77 | autoencoder_loss.backward(retain_graph=True)
78 | optimizers["auto"].step()
79 |
80 | optimizers["gen"].zero_grad()
81 | autoencoder_loss.backward(retain_graph=True)
82 | optimizers["gen"].step()
83 |
84 | optimizers["disc"].zero_grad()
85 | autoencoder_loss.backward(retain_graph=True)
86 | optimizers["disc"].step()
87 |
88 | def fit(model,train_data):
89 | train_loader = get_dataloader(model, train_data, collate_fn=None, shuffle=True)
90 | pretrain(model,train_loader)
91 | train(model,train_loader)
92 |
93 | def get_collate_device(model):
94 | return device
95 | def get_dataloader(model, data, collate_fn=None, shuffle=True):
96 | if collate_fn is None:
97 | collate_fn = get_collate_fn(model)
98 | return DataLoader(data, batch_size= batch_size,shuffle=shuffle,collate_fn=collate_fn)
99 |
100 |
101 | def get_collate_fn(model):
102 | device = get_collate_device(model)
103 |
104 | def collate(data):
105 | data.sort(key=lambda x: len(x), reverse=True)
106 |
107 | tensors = [string2tensor(string, device=device) for string in data]
108 | lengths = torch.tensor([len(t) for t in tensors], dtype=torch.long, device=device)
109 |
110 | encoder_inputs = pad_sequence(tensors, batch_first=True, padding_value=c2i[""])
111 | encoder_input_lengths = lengths - 2
112 |
113 | decoder_inputs = pad_sequence([t[:-1] for t in tensors], batch_first=True, padding_value=c2i[""])
114 | decoder_input_lengths = lengths - 1
115 | decoder_targets = pad_sequence([t[1:] for t in tensors], batch_first=True, padding_value=c2i[""])
116 | decoder_target_lengths = lengths - 1
117 | return (encoder_inputs, encoder_input_lengths), (decoder_inputs, decoder_input_lengths), (decoder_targets, decoder_target_lengths)
118 |
119 | return collate
--------------------------------------------------------------------------------
/genmol/models.txt:
--------------------------------------------------------------------------------
1 | Char-Rnn
2 | ORGAN
3 | MOLGAN
4 | VAE
5 | AAE
6 | JTVAE
7 | Release
8 |
--------------------------------------------------------------------------------
/genmol/vae/data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 |
4 |
5 | data = pd.read_csv('C:/Users/ASUS\Desktop/intern things/dataset_iso_v1.csv')
6 | train_data1 = data[data['SPLIT'] == 'train']
7 | train_data_smiles2 = (train_data1["SMILES"].squeeze()).astype(str).tolist()
8 | train_data = train_data_smiles2
9 |
10 | chars = set()
11 | for string in train_data:
12 | chars.update(string)
13 | all_sys = sorted(list(chars)) + ['', '', '', '']
14 | vocab = all_sys
15 | c2i = {c: i for i, c in enumerate(all_sys)}
16 | i2c = {i: c for i, c in enumerate(all_sys)}
17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 | vector = torch.eye(len(c2i))
19 |
20 |
21 | def char2id(char):
22 | if char not in c2i:
23 | return c2i['']
24 | else:
25 | return c2i[char]
26 |
27 |
28 | def id2char(id):
29 | if id not in i2c:
30 | return i2c[32]
31 | else:
32 | return i2c[id]
33 |
34 | def string2ids(string,add_bos=False, add_eos=False):
35 | ids = [char2id(c) for c in string]
36 | if add_bos:
37 | ids = [c2i['']] + ids
38 | if add_eos:
39 | ids = ids + [c2i['']]
40 | return ids
41 | def ids2string(ids, rem_bos=True, rem_eos=True):
42 | if len(ids) == 0:
43 | return ''
44 | if rem_bos and ids[0] == c2i['']:
45 | ids = ids[1:]
46 | if rem_eos and ids[-1] == c2i['']:
47 | ids = ids[:-1]
48 | string = ''.join([id2char(id) for id in ids])
49 | return string
50 | def string2tensor(string, device='model'):
51 | ids = string2ids(string, add_bos=True, add_eos=True)
52 | tensor = torch.tensor(ids, dtype=torch.long,device=device if device == 'model' else device)
53 | return tensor
54 | tensor = [string2tensor(string, device=device) for string in train_data]
55 | vector = torch.eye(len(c2i))
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/genmol/vae/run.py:
--------------------------------------------------------------------------------
1 |
2 | from trainer import *
3 | from vae_model import VAE
4 | from data import *
5 | from samples import *
6 |
7 | model = VAE(vocab,vector).to(device)
8 | fit(model, train_data)
9 | model.eval()
10 | sample = sample.take_samples(model,n_batch)
11 | print(sample)
--------------------------------------------------------------------------------
/genmol/vae/samples.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from tqdm import tqdm
3 | import pandas as pd
4 | n_samples = 3000
5 | n_jobs = 1
6 | max_len = 100
7 |
8 | class sample():
9 | def take_samples(model,n_batch):
10 | n = n_samples
11 | samples = []
12 | with tqdm(total=n_samples, desc='Generating samples') as T:
13 | while n > 0:
14 | current_samples = model.sample(min(n, n_batch), max_len)
15 | samples.extend(current_samples)
16 | n -= len(current_samples)
17 | T.update(len(current_samples))
18 | samples = pd.DataFrame(samples, columns=['SMILES'])
19 | return samples
--------------------------------------------------------------------------------
/genmol/vae/trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from torch.optim.lr_scheduler import _LRScheduler
3 | import torch.optim as optim
4 | from torch.utils.data import DataLoader
5 | from torch.nn.utils import clip_grad_norm_
6 | import math
7 | import numpy as np
8 | from collections import UserList, defaultdict
9 | n_last = 1000
10 | n_batch = 32
11 | kl_start = 0
12 | kl_w_start = 0.0
13 | kl_w_end = 1.0
14 | n_epoch = 50
15 | n_workers = 0
16 |
17 | clip_grad = 50
18 | lr_start = 0.003
19 | lr_n_period = 10
20 | lr_n_mult = 1
21 | lr_end = 3 * 1e-4
22 | lr_n_restarts = 6
23 | from data import *
24 |
25 | def _n_epoch():
26 | return sum(lr_n_period * (lr_n_mult ** i) for i in range(lr_n_restarts))
27 |
28 | def _train_epoch(model, epoch, train_loader, kl_weight, optimizer=None):
29 | if optimizer is None:
30 | model.eval()
31 | else:
32 | model.train()
33 |
34 | kl_loss_values = CircularBuffer(n_last)
35 | recon_loss_values = CircularBuffer(n_last)
36 | loss_values = CircularBuffer(n_last)
37 | for i, input_batch in enumerate(train_loader):
38 | input_batch = tuple(data.to(device) for data in input_batch)
39 |
40 | #forward
41 | kl_loss, recon_loss = model(input_batch)
42 | loss = kl_weight * kl_loss + recon_loss
43 | #backward
44 | if optimizer is not None:
45 | optimizer.zero_grad()
46 | loss.backward()
47 | clip_grad_norm_(get_optim_params(model),clip_grad)
48 | optimizer.step()
49 |
50 | kl_loss_values.add(kl_loss.item())
51 | recon_loss_values.add(recon_loss.item())
52 | loss_values.add(loss.item())
53 | lr = (optimizer.param_groups[0]['lr'] if optimizer is not None else None)
54 |
55 | #update train_loader
56 | kl_loss_value = kl_loss_values.mean()
57 | recon_loss_value = recon_loss_values.mean()
58 | loss_value = loss_values.mean()
59 | postfix = [f'loss={loss_value:.5f}',f'(kl={kl_loss_value:.5f}',f'recon={recon_loss_value:.5f})',f'klw={kl_weight:.5f} lr={lr:.5f}']
60 | postfix = {'epoch': epoch,'kl_weight': kl_weight,'lr': lr,'kl_loss': kl_loss_value,'recon_loss': recon_loss_value,'loss': loss_value,'mode': 'Eval' if optimizer is None else 'Train'}
61 | return postfix
62 |
63 | def _train(model, train_loader, val_loader=None, logger=None):
64 | optimizer = optim.Adam(get_optim_params(model),lr= lr_start)
65 |
66 | lr_annealer = CosineAnnealingLRWithRestart(optimizer)
67 |
68 | model.zero_grad()
69 | for epoch in range(n_epoch):
70 |
71 | kl_annealer = KLAnnealer(n_epoch)
72 | kl_weight = kl_annealer(epoch)
73 | postfix = _train_epoch(model, epoch,train_loader, kl_weight, optimizer)
74 | lr_annealer.step()
75 | def fit(model, train_data, val_data=None):
76 | logger = Logger() if False is not None else None
77 | train_loader = get_dataloader(model,train_data,shuffle=True)
78 |
79 |
80 |
81 | val_loader = None if val_data is None else get_dataloader(model, val_data, shuffle=False)
82 | _train(model, train_loader, val_loader, logger)
83 | return model
84 | def get_collate_device(model):
85 | return model.device
86 | def get_dataloader(model, train_data, collate_fn=None, shuffle=True):
87 | if collate_fn is None:
88 | collate_fn = get_collate_fn(model)
89 | print(collate_fn)
90 | return DataLoader(train_data, batch_size=n_batch, shuffle=shuffle, num_workers=n_workers, collate_fn=collate_fn)
91 |
92 | def get_collate_fn(model):
93 | device = get_collate_device(model)
94 |
95 | def collate(train_data):
96 | train_data.sort(key=len, reverse=True)
97 | tensors = [string2tensor(string, device=device) for string in train_data]
98 | return tensors
99 |
100 | return collate
101 |
102 | def get_optim_params(model):
103 | return (p for p in model.parameters() if p.requires_grad)
104 |
105 | class KLAnnealer:
106 | def __init__(self,n_epoch):
107 | self.i_start = kl_start
108 | self.w_start = kl_w_start
109 | self.w_max = kl_w_end
110 | self.n_epoch = n_epoch
111 |
112 |
113 | self.inc = (self.w_max - self.w_start) / (self.n_epoch - self.i_start)
114 |
115 | def __call__(self, i):
116 | k = (i - self.i_start) if i >= self.i_start else 0
117 | return self.w_start + k * self.inc
118 |
119 |
120 |
121 | class CosineAnnealingLRWithRestart(_LRScheduler):
122 | def __init__(self , optimizer):
123 | self.n_period = lr_n_period
124 | self.n_mult = lr_n_mult
125 | self.lr_end = lr_end
126 |
127 | self.current_epoch = 0
128 | self.t_end = self.n_period
129 |
130 | # Also calls first epoch
131 | super().__init__(optimizer, -1)
132 |
133 | def get_lr(self):
134 | return [self.lr_end + (base_lr - self.lr_end) *
135 | (1 + math.cos(math.pi * self.current_epoch / self.t_end)) / 2
136 | for base_lr in self.base_lrs]
137 |
138 | def step(self, epoch=None):
139 | if epoch is None:
140 | epoch = self.last_epoch + 1
141 | self.last_epoch = epoch
142 | self.current_epoch += 1
143 |
144 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
145 | param_group['lr'] = lr
146 |
147 | if self.current_epoch == self.t_end:
148 | self.current_epoch = 0
149 | self.t_end = self.n_mult * self.t_end
150 |
151 |
152 |
153 |
154 | class CircularBuffer:
155 | def __init__(self, size):
156 | self.max_size = size
157 | self.data = np.zeros(self.max_size)
158 | self.size = 0
159 | self.pointer = -1
160 |
161 | def add(self, element):
162 | self.size = min(self.size + 1, self.max_size)
163 | self.pointer = (self.pointer + 1) % self.max_size
164 | self.data[self.pointer] = element
165 | return element
166 |
167 | def last(self):
168 | assert self.pointer != -1, "Can't get an element from an empty buffer!"
169 | return self.data[self.pointer]
170 |
171 | def mean(self):
172 | return self.data.mean()
173 |
174 |
175 | class Logger(UserList):
176 | def __init__(self, data=None):
177 | super().__init__()
178 | self.sdata = defaultdict(list)
179 | for step in (data or []):
180 | self.append(step)
181 |
182 | def __getitem__(self, key):
183 | if isinstance(key, int):
184 | return self.data[key]
185 | elif isinstance(key, slice):
186 | return Logger(self.data[key])
187 | else:
188 | ldata = self.sdata[key]
189 | if isinstance(ldata[0], dict):
190 | return Logger(ldata)
191 | else:
192 | return ldata
193 |
194 | def append(self, step_dict):
195 | super().append(step_dict)
196 | for k, v in step_dict.items():
197 | self.sdata[k].append(v)
198 |
199 |
200 |
201 |
202 |
--------------------------------------------------------------------------------
/genmol/vae/vae_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | q_bidir = True
7 | q_d_h = 256
8 | q_n_layers = 1
9 | q_dropout = 0.5
10 | d_n_layers = 3
11 | d_dropout = 0
12 | d_z = 128
13 | d_d_h = 512
14 | from data import *
15 | class VAE(nn.Module):
16 | def __init__(self,vocab,vector):
17 | super().__init__()
18 | self.vocabulary = vocab
19 | self.vector = vector
20 |
21 | n_vocab, d_emb = len(vocab), vector.size(1)
22 | self.x_emb = nn.Embedding(n_vocab, d_emb, c2i[''])
23 | self.x_emb.weight.data.copy_(vector)
24 |
25 | #ENCODER
26 |
27 | self.encoder_rnn = nn.GRU(d_emb,q_d_h,num_layers=q_n_layers,batch_first=True,dropout=q_dropout if q_n_layers > 1 else 0,bidirectional=q_bidir)
28 | q_d_last = q_d_h * (2 if q_bidir else 1)
29 | self.q_mu = nn.Linear(q_d_last, d_z)
30 | self.q_logvar = nn.Linear(q_d_last, d_z)
31 |
32 |
33 |
34 | # Decoder
35 | self.decoder_rnn = nn.GRU(d_emb + d_z,d_d_h,num_layers=d_n_layers,batch_first=True,dropout=d_dropout if d_n_layers > 1 else 0)
36 | self.decoder_latent = nn.Linear(d_z, d_d_h)
37 | self.decoder_fullyc = nn.Linear(d_d_h, n_vocab)
38 |
39 |
40 |
41 | # Grouping the model's parameters
42 | self.encoder = nn.ModuleList([self.encoder_rnn,self.q_mu,self.q_logvar])
43 | self.decoder = nn.ModuleList([self.decoder_rnn,self.decoder_latent,self.decoder_fullyc])
44 | self.vae = nn.ModuleList([self.x_emb,self.encoder,self.decoder])
45 |
46 |
47 |
48 | @property
49 | def device(self):
50 | return next(self.parameters()).device
51 |
52 | def string2tensor(self, string, device='model'):
53 | ids = string2ids(string, add_bos=True, add_eos=True)
54 | tensor = torch.tensor(ids, dtype=torch.long,device=self.device if device == 'model' else device)
55 | return tensor
56 |
57 | def tensor2string(self, tensor):
58 | ids = tensor.tolist()
59 | string = ids2string(ids, rem_bos=True, rem_eos=True)
60 | return string
61 |
62 | def forward(self,x):
63 | z, kl_loss = self.forward_encoder(x)
64 | recon_loss = self.forward_decoder(x, z)
65 | print("forward")
66 | return kl_loss, recon_loss
67 |
68 | def forward_encoder(self,x):
69 | x = [self.x_emb(i_x) for i_x in x]
70 | x = nn.utils.rnn.pack_sequence(x)
71 | _, h = self.encoder_rnn(x, None)
72 | h = h[-(1 + int(self.encoder_rnn.bidirectional)):]
73 | h = torch.cat(h.split(1), dim=-1).squeeze(0)
74 | mu, logvar = self.q_mu(h), self.q_logvar(h)
75 | eps = torch.randn_like(mu)
76 | z = mu + (logvar / 2).exp() * eps
77 | kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean()
78 | return z, kl_loss
79 |
80 | def forward_decoder(self,x, z):
81 | lengths = [len(i_x) for i_x in x]
82 | x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value= c2i[''])
83 | x_emb = self.x_emb(x)
84 | z_0 = z.unsqueeze(1).repeat(1, x_emb.size(1), 1)
85 | x_input = torch.cat([x_emb, z_0], dim=-1)
86 | x_input = nn.utils.rnn.pack_padded_sequence(x_input, lengths, batch_first=True)
87 | h_0 = self.decoder_latent(z)
88 | h_0 = h_0.unsqueeze(0).repeat(self.decoder_rnn.num_layers, 1, 1)
89 | output, _ = self.decoder_rnn(x_input, h_0)
90 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
91 | y = self.decoder_fullyc(output)
92 |
93 | recon_loss = F.cross_entropy(y[:, :-1].contiguous().view(-1, y.size(-1)),x[:, 1:].contiguous().view(-1),ignore_index= c2i[''])
94 | return recon_loss
95 |
96 |
97 | def sample_z_prior(self,n_batch):
98 | return torch.randn(n_batch,self.q_mu.out_features,device= self.x_emb.weight.device)
99 | def sample(self,n_batch, max_len=100, z=None, temp=1.0):
100 | with torch.no_grad():
101 | if z is None:
102 | z = self.sample_z_prior(n_batch)
103 | z = z.to(self.device)
104 | z_0 = z.unsqueeze(1)
105 | h = self.decoder_latent(z)
106 | h = h.unsqueeze(0).repeat(self.decoder_rnn.num_layers, 1, 1)
107 | w = torch.tensor(c2i[''], device=self.device).repeat(n_batch)
108 | x = torch.tensor([c2i['']], device=device).repeat(n_batch, max_len)
109 | x[:, 0] = c2i['']
110 | end_pads = torch.tensor([max_len], device=self.device).repeat(n_batch)
111 | eos_mask = torch.zeros(n_batch, dtype=torch.uint8, device=self.device)
112 |
113 |
114 | for i in range(1, max_len):
115 | x_emb = self.x_emb(w).unsqueeze(1)
116 | x_input = torch.cat([x_emb, z_0], dim=-1)
117 |
118 | o, h = self.decoder_rnn(x_input, h)
119 | y = self.decoder_fullyc(o.squeeze(1))
120 | y = F.softmax(y / temp, dim=-1)
121 |
122 | w = torch.multinomial(y, 1)[:, 0]
123 | x[~eos_mask, i] = w[~eos_mask]
124 | i_eos_mask = ~eos_mask & (w == c2i[''])
125 | end_pads[i_eos_mask] = i + 1
126 | eos_mask = eos_mask | i_eos_mask
127 |
128 |
129 | new_x = []
130 | for i in range(x.size(0)):
131 | new_x.append(x[i, :end_pads[i]])
132 |
133 |
134 | return [self.tensor2string(i_x) for i_x in new_x]
--------------------------------------------------------------------------------