├── GraphBP ├── config.py ├── count_atom_dist.py ├── data │ └── crossdock2020 │ │ └── it2_tt_0_lowrmsd_mols_train0_fixed.types ├── dataset.py ├── dataset_from_scratch.py ├── download_data.sh ├── environment.yml ├── main.py ├── main_eval.py ├── main_gen.py ├── model │ ├── __init__.py │ ├── features.py │ ├── geometric_computing.py │ ├── graphbp.py │ ├── net_utils.py │ └── schnet.py ├── runner.py ├── scripts │ └── split_sdf.py ├── trained_model │ └── model_33.pth └── utils │ ├── __init__.py │ └── bond_adding.py ├── LICENSE ├── README.md └── assets └── GraphBP.png /GraphBP/config.py: -------------------------------------------------------------------------------- 1 | conf = {} 2 | 3 | 4 | 5 | 6 | # ## skip600000_hidden32_numinter6_lr1e-4_wd1e-5_wbindingsite_cutoff10_bs16 7 | conf_model = {} 8 | conf_model['cutoff'] = 10.0 9 | conf_model['num_node_types'] = 46 # lig_types + rec_types 10 | conf_model['num_lig_node_types'] = 27 # lig_types 11 | conf_model['num_interactions'] = 6 12 | conf_model['num_filters'] = 32 13 | conf_model['num_gaussians'] = 50 14 | conf_model['hidden_channels'] = 32 15 | conf_model['basis_emb_size'] = 32 16 | conf_model['num_spherical'] = 7 17 | conf_model['num_radial'] = 6 18 | conf_model['num_flow_layers'] = 6 19 | conf_model['deq_coeff'] = 0.9 20 | conf_model['use_gpu'] = True 21 | 22 | conf_optim = {'lr': 0.0001, 'weight_decay': 0.00001} 23 | 24 | conf['model'] = conf_model 25 | conf['optim'] = conf_optim 26 | conf['verbose'] = 100 27 | conf['batch_size'] = 16 28 | conf['epochs'] = 100 29 | conf['chunk_size'] = 20 30 | conf['num_workers'] = 4 31 | 32 | conf['gen_model'] = 'GraphBP' 33 | -------------------------------------------------------------------------------- /GraphBP/count_atom_dist.py: -------------------------------------------------------------------------------- 1 | from dataset import CrossDocked2020_SBDD, collate_mols 2 | from torch.utils.data import DataLoader 3 | 4 | from rdkit import Chem 5 | dataset = CrossDocked2020_SBDD() 6 | atomic_num_to_type = {} 7 | 8 | atomic_element_to_type = {} 9 | 10 | for i in range(len(dataset)-1, -1, -1): 11 | if i%1000 == 0: 12 | print(i) 13 | print(atomic_num_to_type) 14 | print(atomic_element_to_type) 15 | print('=======') 16 | rec_structure, lig_supplier, rec_src, lig_src = dataset[i] 17 | for atom in rec_structure.get_atoms(): 18 | if atom.element!='H': 19 | if atom.element not in atomic_element_to_type: 20 | atomic_element_to_type[atom.element] = 0 21 | else: 22 | atomic_element_to_type[atom.element] += 1 23 | else: 24 | pass 25 | lig_mol = Chem.rdmolops.RemoveAllHs(lig_supplier[0], sanitize=False) 26 | for atom in lig_mol.GetAtoms(): 27 | atom_num = atom.GetAtomicNum() 28 | if atom_num not in atomic_num_to_type: 29 | atomic_num_to_type[atom_num] = 0 30 | else: 31 | atomic_num_to_type[atom_num] += 1 32 | del lig_supplier 33 | del rec_structure 34 | print(atomic_num_to_type) 35 | print(atomic_element_to_type) 36 | 37 | 38 | ### results 39 | 40 | # atomic_num_to_type = {6: 7896717, 7: 1467608, 8: 1827282, 16: 132998, 17: 80952, 15: 122108, 9: 136520, 23: 145, 12: 63, 42: 1002, 33: 45, 35: 21643, 53: 4453, 5: 3295, 14: 55, 34: 312, 30: 18, 21: 4, 26: 223, 44: 214, 45: 19, 74: 47, 39: 1, 29: 32, 79: 6, 51: 3, 13: 29} 41 | 42 | # atomic_element_to_type = {'N': 280071932, 'C': 1049822037, 'O': 306441297, 'S': 8411128, 'MG': 101404, 'CA': 99148, 'CL': 62654, 'SE': 21025, 'NA': 46982, 'CD': 7655, 'ZN': 100731, 'P': 47588, 'K': 14544, 'CU': 513, 'CO': 8088, 'CS': 41, 'I': 28335, 'HG': 1941, 'MN': 25} -------------------------------------------------------------------------------- /GraphBP/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | import os 5 | import pandas as pd 6 | 7 | import networkx as nx 8 | from networkx.algorithms import tree 9 | from math import pi 10 | from rdkit import Chem 11 | from rdkit.Chem.rdchem import BondType 12 | from Bio.PDB import PDBParser 13 | import warnings 14 | from Bio.PDB.PDBExceptions import PDBConstructionWarning 15 | 16 | # lig_elem_range = [ 17 | # #B, C, N, O, F, Mg, Al, Si, P, S, Cl, Sc, V, Fe, Cu, Zn, As, Se, Br, Y, Mo, Ru, Rh, Sb, I, W, Au 18 | # 5, 6, 7, 8, 9, 12, 13, 14, 15, 16, 17, 21, 23, 26, 29, 30, 33, 34, 35, 39, 42, 44, 45, 51, 53, 74, 79 19 | # ] 20 | # rec_elem_range = [ 21 | # #C, N, O, Na, Mg, P, S, Cl, K, Ca, Mn, Co, Cu, Zn, Se, Cd, I, Cs, Hg 22 | # 6, 7, 8, 11, 12, 15, 16, 17, 19, 20, 25, 27, 29, 30, 34, 48, 53, 55, 80 23 | # ] 24 | 25 | 26 | def collate_mols(mol_dicts): 27 | data_batch = {} 28 | 29 | for key in ['atom_type', 'position', 'rec_mask', 'cannot_contact', 'new_atom_type', 'new_dist', 'new_angle', 'new_torsion', 'cannot_focus']: 30 | data_batch[key] = torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0) 31 | 32 | num_steps_list = torch.tensor([0]+[len(mol_dicts[i]['new_atom_type']) for i in range(len(mol_dicts)-1)]) 33 | batch_idx_offsets = torch.cumsum(num_steps_list, dim=0) 34 | repeats = torch.tensor([len(mol_dict['batch']) for mol_dict in mol_dicts]) 35 | batch_idx_repeated_offsets = torch.repeat_interleave(batch_idx_offsets, repeats) 36 | batch_offseted = torch.cat([mol_dict['batch'] for mol_dict in mol_dicts], dim=0) + batch_idx_repeated_offsets 37 | data_batch['batch'] = batch_offseted 38 | 39 | num_atoms_list = torch.tensor([0]+[len(mol_dicts[i]['atom_type']) for i in range(len(mol_dicts)-1)]) 40 | atom_idx_offsets = torch.cumsum(num_atoms_list, dim=0) 41 | for key in ['focus', 'c1_focus', 'c2_c1_focus', 'contact_y_or_n']: 42 | repeats = torch.tensor([len(mol_dict[key]) for mol_dict in mol_dicts]) 43 | atom_idx_repeated_offsets = torch.repeat_interleave(atom_idx_offsets, repeats) 44 | if key == 'contact_y_or_n': 45 | atom_offseted = torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0) + atom_idx_repeated_offsets 46 | else: 47 | atom_offseted = torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0) + atom_idx_repeated_offsets[:,None] 48 | data_batch[key] = atom_offseted 49 | 50 | return data_batch -------------------------------------------------------------------------------- /GraphBP/dataset_from_scratch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | import os 5 | import pandas as pd 6 | 7 | import networkx as nx 8 | from networkx.algorithms import tree 9 | from math import pi 10 | from rdkit import Chem 11 | from rdkit.Chem.rdchem import BondType 12 | from Bio.PDB import PDBParser 13 | import warnings 14 | from Bio.PDB.PDBExceptions import PDBConstructionWarning 15 | 16 | # lig_elem_range = [ 17 | # #B, C, N, O, F, Mg, Al, Si, P, S, Cl, Sc, V, Fe, Cu, Zn, As, Se, Br, Y, Mo, Ru, Rh, Sb, I, W, Au 18 | # 5, 6, 7, 8, 9, 12, 13, 14, 15, 16, 17, 21, 23, 26, 29, 30, 33, 34, 35, 39, 42, 44, 45, 51, 53, 74, 79 19 | # ] 20 | # rec_elem_range = [ 21 | # #C, N, O, Na, Mg, P, S, Cl, K, Ca, Mn, Co, Cu, Zn, Se, Cd, I, Cs, Hg 22 | # 6, 7, 8, 11, 12, 15, 16, 17, 19, 20, 25, 27, 29, 30, 34, 48, 53, 55, 80 23 | # ] 24 | 25 | 26 | 27 | atomic_num_to_type = {5:0, 6:1, 7:2, 8:3, 9:4, 12:5, 13:6, 14:7, 15:8, 16:9, 17:10, 21:11, 23:12, 26:13, 29:14, 30:15, 33:16, 34:17, 35:18, 39:19, 42:20, 44:21, 45:22, 51:23, 53:24, 74:25, 79:26} 28 | 29 | atomic_element_to_type = {'C':27, 'N':28, 'O':29, 'NA':30, 'MG':31, 'P':32, 'S':33, 'CL':34, 'K':35, 'CA':36, 'MN':37, 'CO':38, 'CU':39, 'ZN':40, 'SE':41, 'CD':42, 'I':43, 'CS':44, 'HG':45} 30 | 31 | def collate_mols(mol_dicts): 32 | # mol_dicts = filter(lambda x:x is not None, mol_dicts) 33 | # print(mol_dicts) 34 | data_batch = {} 35 | 36 | for key in ['atom_type', 'position', 'rec_mask', 'cannot_contact', 'new_atom_type', 'new_dist', 'new_angle', 'new_torsion', 'cannot_focus']: 37 | data_batch[key] = torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0) 38 | 39 | num_steps_list = torch.tensor([0]+[len(mol_dicts[i]['new_atom_type']) for i in range(len(mol_dicts)-1)]) 40 | batch_idx_offsets = torch.cumsum(num_steps_list, dim=0) 41 | repeats = torch.tensor([len(mol_dict['batch']) for mol_dict in mol_dicts]) 42 | batch_idx_repeated_offsets = torch.repeat_interleave(batch_idx_offsets, repeats) 43 | batch_offseted = torch.cat([mol_dict['batch'] for mol_dict in mol_dicts], dim=0) + batch_idx_repeated_offsets 44 | data_batch['batch'] = batch_offseted 45 | 46 | num_atoms_list = torch.tensor([0]+[len(mol_dicts[i]['atom_type']) for i in range(len(mol_dicts)-1)]) 47 | atom_idx_offsets = torch.cumsum(num_atoms_list, dim=0) 48 | for key in ['focus', 'c1_focus', 'c2_c1_focus', 'contact_y_or_n']: 49 | repeats = torch.tensor([len(mol_dict[key]) for mol_dict in mol_dicts]) 50 | atom_idx_repeated_offsets = torch.repeat_interleave(atom_idx_offsets, repeats) 51 | if key == 'contact_y_or_n': 52 | atom_offseted = torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0) + atom_idx_repeated_offsets 53 | else: 54 | atom_offseted = torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0) + atom_idx_repeated_offsets[:,None] 55 | data_batch[key] = atom_offseted 56 | 57 | return data_batch 58 | 59 | 60 | class CrossDocked2020_SBDD(Dataset): 61 | def __init__(self, data_root='./data/crossdock2020', data_file='./data/crossdock2020/it2_tt_0_lowrmsd_mols_train0_fixed.types', atomic_num_to_type = atomic_num_to_type, atomic_element_to_type = atomic_element_to_type, binding_site_range=15.0): 62 | super().__init__() 63 | data_cols = [ 64 | 'low_rmsd', 65 | 'true_aff', 66 | 'xtal_rmsd', 67 | 'rec_src', 68 | 'lig_src', 69 | 'vina_aff' 70 | ] 71 | self.data_lines = pd.read_csv( 72 | data_file, sep=' ', names=data_cols, index_col=False 73 | ) 74 | self.data_root = data_root 75 | 76 | self.atomic_num_to_type = atomic_num_to_type 77 | self.atomic_element_to_type = atomic_element_to_type 78 | self.bond_to_type = {BondType.SINGLE: 1, BondType.DOUBLE: 2, BondType.TRIPLE: 3} 79 | self.binding_site_range = binding_site_range 80 | self.pdb_parser = PDBParser() 81 | 82 | 83 | 84 | 85 | 86 | def __len__(self): 87 | return len(self.data_lines) 88 | 89 | 90 | def read_rec_mol(self, mol_src): 91 | ''' 92 | mol_src: the path of a .pdb file 93 | return: biopython 94 | ''' 95 | with warnings.catch_warnings(): 96 | warnings.simplefilter('ignore', PDBConstructionWarning) 97 | structure = self.pdb_parser.get_structure('', os.path.join(self.data_root, mol_src)) 98 | return structure 99 | 100 | 101 | def read_lig_mol(self, mol_src): 102 | ''' 103 | mol_src: the path of a .sdf file 104 | return: rdkit.Chem.rdmolfiles.SDMolSupplier 105 | ''' 106 | supp = Chem.SDMolSupplier() 107 | sdf_file = os.path.join(self.data_root, mol_src) 108 | supp.SetData(open(sdf_file).read(), removeHs=False, sanitize=False) 109 | return supp 110 | 111 | 112 | 113 | def get_rec_mol(self, mol_src): 114 | return self.read_rec_mol(mol_src) 115 | 116 | def get_lig_mol(self, mol_src): 117 | return self.read_lig_mol(mol_src) 118 | 119 | 120 | def __getitem__(self, index): 121 | ''' 122 | Note that H atoms are not considered in both lig and rec. 123 | ''' 124 | example = self.data_lines.iloc[index] 125 | rec_structure = self.get_rec_mol(example.rec_src) 126 | lig_supplier = self.get_lig_mol(example.lig_src.rsplit('.', 1)[0]) # Read .sdf file instead of .sdf.gz file; Why? (https://github.com/rdkit/rdkit/issues/1938) 127 | 128 | rec_atom_type = [self.atomic_element_to_type[atom.element] for atom in rec_structure.get_atoms() if atom.element!='H'] 129 | rec_position = np.stack([atom.coord for atom in rec_structure.get_atoms() if atom.element!='H'], axis=0) 130 | rec_atom_type = torch.tensor(rec_atom_type) #[rec_n_atoms] 131 | rec_position = torch.tensor(rec_position) #[rec_n_atoms, 3] 132 | 133 | del rec_structure 134 | 135 | # lig_mol = lig_supplier[0] 136 | lig_mol = Chem.rdmolops.RemoveAllHs(lig_supplier[0], sanitize=False) 137 | lig_n_atoms = lig_mol.GetNumAtoms() 138 | lig_pos = lig_supplier.GetItemText(0).split('\n')[4:4+lig_n_atoms] 139 | lig_position = np.array([[float(x) for x in line.split()[:3]] for line in lig_pos], dtype=np.float32) 140 | lig_atom_type = np.array([self.atomic_num_to_type[atom.GetAtomicNum()] for atom in lig_mol.GetAtoms()]) 141 | lig_con_mat = np.zeros([lig_n_atoms, lig_n_atoms], dtype=int) 142 | for bond in lig_mol.GetBonds(): 143 | start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() 144 | bond_type = self.bond_to_type[bond.GetBondType()] 145 | lig_con_mat[start, end] = bond_type 146 | lig_con_mat[end, start] = bond_type 147 | lig_atom_type = torch.tensor(lig_atom_type) #[lig_n_atoms] 148 | lig_position = torch.tensor(lig_position) #[lig_n_atoms, 3] 149 | lig_atom_bond_valency = torch.tensor(np.sum(lig_con_mat, axis=1)) #[lig_n_atoms] 150 | lig_con_mat = torch.tensor(lig_con_mat) #[lig_n_atoms, lig_n_atoms] 151 | del lig_supplier 152 | 153 | 154 | # Get binding site 155 | lig_center = torch.mean(lig_position, dim=0) 156 | rec_atom_dist_to_lig_center = torch.sqrt(torch.sum(torch.square(rec_position - lig_center), dim=-1)) 157 | selected_mask = rec_atom_dist_to_lig_center <= self.binding_site_range 158 | try: 159 | assert torch.sum(selected_mask) >= 3 # Ensure that there are at least 3 selected atoms in rec 160 | except: 161 | print('One sample does not bind tightly. We can ignore it!') 162 | index = index - 1 if index > 0 else index + 1 163 | return self.__getitem__(index) 164 | rec_atom_type = rec_atom_type[selected_mask] 165 | rec_position = rec_position[selected_mask] 166 | 167 | 168 | rec_n_atoms = len(rec_atom_type) 169 | 170 | 171 | lig_rec_squared_dist = torch.sum(torch.square(lig_position[:,None,:] - rec_position[None,:,:]), dim=-1) #[lig_n_atoms, rec_n_atoms] 172 | lig_internal_squared_dist = torch.sum(torch.square(lig_position[:,None,:] - lig_position[None,:,:]), dim=-1) #[lig_n_atoms, lig_n_atoms] 173 | 174 | 175 | # To find contact nodes and node in rec that are furthest from lig 176 | min_index = torch.argmin(lig_rec_squared_dist) 177 | lig_contact_id = min_index // rec_n_atoms 178 | rec_contact_id = min_index % rec_n_atoms 179 | rec_n_contact_id = torch.argmax(torch.sum(lig_rec_squared_dist, dim=0)) 180 | 181 | # Start from the contact node in the lig 182 | perm = torch.arange(0, lig_n_atoms, dtype=int) 183 | perm[0] = lig_contact_id 184 | perm[lig_contact_id] = 0 185 | lig_atom_type, lig_position, lig_atom_bond_valency, lig_rec_squared_dist = lig_atom_type[perm], lig_position[perm], lig_atom_bond_valency[perm], lig_rec_squared_dist[perm] 186 | lig_con_mat, lig_internal_squared_dist = lig_con_mat[perm][:, perm], lig_internal_squared_dist[perm][:, perm] 187 | 188 | # Decide the order among lig nodes 189 | nx_graph = nx.from_numpy_matrix(lig_internal_squared_dist.numpy()) 190 | edges = list(tree.minimum_spanning_edges(nx_graph, algorithm='prim', data=False)) # return edges starts from the 0-th node (i.e., the contact node here) by default 191 | focus_node_id, target_node_id = zip(*edges) 192 | node_perm = torch.cat((torch.tensor([0]), torch.tensor(target_node_id))) 193 | lig_atom_type, lig_position, lig_atom_bond_valency, lig_rec_squared_dist = lig_atom_type[node_perm], lig_position[node_perm], lig_atom_bond_valency[node_perm], lig_rec_squared_dist[node_perm] 194 | lig_con_mat, lig_internal_squared_dist = lig_con_mat[node_perm][:, node_perm], lig_internal_squared_dist[node_perm][:, node_perm] 195 | 196 | 197 | 198 | # Prepare training data for sequential generation 199 | focus_node_id = torch.tensor(focus_node_id) 200 | focus_ids = torch.nonzero(focus_node_id[:,None] == node_perm[None,:])[:,1] # focus_ids denotes the focus atom IDs that are indiced according to the order given by node_perm 201 | 202 | steps_cannot_focus = torch.empty([0,1], dtype=torch.float) 203 | idx_offsets = torch.cumsum(torch.arange(lig_n_atoms), dim=0) #[M] 204 | idx_offsets_brought_by_rec = rec_n_atoms*torch.arange(1, lig_n_atoms) #[M-1] 205 | 206 | 207 | 208 | 209 | for i in range(lig_n_atoms): 210 | if i==0: 211 | # In the 1st step, all we have is the rec. Note that contact classifier should be only applied for the 1st step in which we don't have any lig atoms 212 | steps_atom_type = rec_atom_type 213 | steps_rec_mask = torch.ones([rec_n_atoms], dtype=torch.bool) 214 | contact_y_or_n = torch.tensor([rec_contact_id, rec_n_contact_id], dtype=int) # The atom IDs of contact node and the node that are furthest from lig. 215 | cannot_contact = torch.tensor([0,1], dtype=torch.float) # The groundtruth for contact_y_or_n 216 | steps_position = rec_position 217 | steps_batch = torch.tensor([i]).repeat(rec_n_atoms) 218 | steps_focus = torch.tensor([rec_contact_id], dtype=int) 219 | 220 | dist_to_focus = torch.sum(torch.square(rec_position[rec_contact_id] - rec_position), dim=-1) 221 | _, indices = torch.topk(dist_to_focus, 3, largest=False) 222 | one_step_c1, one_step_c2 = indices[1], indices[2] 223 | assert indices[0] == rec_contact_id 224 | steps_c1_focus = torch.tensor([one_step_c1, rec_contact_id], dtype=int).view(1,2) 225 | steps_c2_c1_focus = torch.tensor([one_step_c2, one_step_c1, rec_contact_id], dtype=int).view(1,3) 226 | 227 | focus_pos, new_pos = rec_position[rec_contact_id], lig_position[i] 228 | one_step_dis = torch.norm(new_pos - focus_pos) 229 | steps_dist = one_step_dis.view(1,1) 230 | 231 | c1_pos = rec_position[one_step_c1] 232 | a = ((c1_pos - focus_pos) * (new_pos - focus_pos)).sum(dim=-1) 233 | b = torch.cross(c1_pos - focus_pos, new_pos - focus_pos).norm(dim=-1) 234 | one_step_angle = torch.atan2(b,a) 235 | steps_angle = one_step_angle.view(1,1) 236 | 237 | c2_pos = rec_position[one_step_c2] 238 | plane1 = torch.cross(focus_pos - c1_pos, new_pos - c1_pos) 239 | plane2 = torch.cross(focus_pos - c1_pos, c2_pos - c1_pos) 240 | a = (plane1 * plane2).sum(dim=-1) 241 | b = (torch.cross(plane1, plane2) * (focus_pos - c1_pos)).sum(dim=-1) / torch.norm(focus_pos - c1_pos) 242 | one_step_torsion = torch.atan2(b, a) 243 | steps_torsion = one_step_torsion.view(1,1) 244 | 245 | 246 | else: 247 | one_step_atom_type = torch.cat((lig_atom_type[:i], rec_atom_type), dim=0) 248 | steps_atom_type = torch.cat((steps_atom_type, one_step_atom_type)) 249 | one_step_rec_mask = torch.cat((torch.zeros([i], dtype=torch.bool), torch.ones([rec_n_atoms], dtype=torch.bool)), dim=0) 250 | steps_rec_mask = torch.cat((steps_rec_mask, one_step_rec_mask)) 251 | one_step_position = torch.cat((lig_position[:i], rec_position), dim=0) 252 | steps_position = torch.cat((steps_position, one_step_position)) 253 | steps_batch = torch.cat((steps_batch, torch.tensor([i]).repeat(i + rec_n_atoms))) 254 | 255 | partial_lig_con_mat = lig_con_mat[:i, :i] 256 | bond_sum = partial_lig_con_mat.sum(dim=1, keepdim=True) 257 | steps_cannot_focus = torch.cat((steps_cannot_focus, (bond_sum == lig_atom_bond_valency[:i, None]).float())) 258 | 259 | focus_id = focus_ids[i-1] 260 | if i == 1: # c1, c2 must be in rec 261 | dist_to_focus = lig_rec_squared_dist[focus_id] 262 | _, indices = torch.topk(dist_to_focus, 2, largest=False) 263 | one_step_c1, one_step_c2 = indices[0], indices[1] 264 | one_step_c1_focus = torch.tensor([one_step_c1+idx_offsets[i]+idx_offsets_brought_by_rec[i-1], focus_id+idx_offsets_brought_by_rec[i-1]], dtype=int).view(1,2) 265 | steps_c1_focus = torch.cat((steps_c1_focus, one_step_c1_focus), dim=0) 266 | one_step_c2_c1_focus = torch.tensor([one_step_c2+idx_offsets[i]+idx_offsets_brought_by_rec[i-1],one_step_c1+idx_offsets[i]+idx_offsets_brought_by_rec[i-1], focus_id+idx_offsets_brought_by_rec[i-1]], dtype=int).view(1,3) 267 | steps_c2_c1_focus = torch.cat((steps_c2_c1_focus, one_step_c2_c1_focus), dim=0) 268 | 269 | focus_pos, new_pos = lig_position[focus_id], lig_position[i] 270 | one_step_dis = torch.norm(new_pos - focus_pos).view(1,1) 271 | steps_dist = torch.cat((steps_dist, one_step_dis), dim=0) 272 | 273 | c1_pos = rec_position[one_step_c1] 274 | a = ((c1_pos - focus_pos) * (new_pos - focus_pos)).sum(dim=-1) 275 | b = torch.cross(c1_pos - focus_pos, new_pos - focus_pos).norm(dim=-1) 276 | one_step_angle = torch.atan2(b,a).view(1,1) 277 | steps_angle = torch.cat((steps_angle, one_step_angle), dim=0) 278 | 279 | c2_pos = rec_position[one_step_c2] 280 | plane1 = torch.cross(focus_pos - c1_pos, new_pos - c1_pos) 281 | plane2 = torch.cross(focus_pos - c1_pos, c2_pos - c1_pos) 282 | a = (plane1 * plane2).sum(dim=-1) 283 | b = (torch.cross(plane1, plane2) * (focus_pos - c1_pos)).sum(dim=-1) / torch.norm(focus_pos - c1_pos) 284 | one_step_torsion = torch.atan2(b, a).view(1,1) 285 | steps_torsion = torch.cat((steps_torsion, one_step_torsion), dim=0) 286 | 287 | else: #c1, c2 could be in both lig and rec 288 | dist_to_focus = torch.cat((lig_internal_squared_dist[focus_id, :i],lig_rec_squared_dist[focus_id])) 289 | _, indices = torch.topk(dist_to_focus, 3, largest=False) 290 | one_step_c1, one_step_c2 = indices[1], indices[2] 291 | 292 | one_step_c1_focus = torch.tensor([one_step_c1+idx_offsets[i-1]+idx_offsets_brought_by_rec[i-1], focus_id+idx_offsets[i-1]+idx_offsets_brought_by_rec[i-1]], dtype=int).view(1,2) 293 | one_step_c2_c1_focus = torch.tensor([one_step_c2+idx_offsets[i-1]+idx_offsets_brought_by_rec[i-1], one_step_c1+idx_offsets[i-1]+idx_offsets_brought_by_rec[i-1], focus_id+idx_offsets[i-1]+idx_offsets_brought_by_rec[i-1]], dtype=int).view(1,3) 294 | if one_step_c1 < i: # c1 in lig 295 | c1_pos = lig_position[one_step_c1] 296 | if one_step_c2 < i: # c2 in lig 297 | c2_pos = lig_position[one_step_c2] 298 | else: 299 | c2_pos = rec_position[one_step_c2-i] 300 | else: 301 | c1_pos = rec_position[one_step_c1-i] 302 | if one_step_c2 < i: # c2 in lig 303 | c2_pos = lig_position[one_step_c2] 304 | else: # c2 in rec 305 | c2_pos = rec_position[one_step_c2-i] 306 | steps_c1_focus = torch.cat((steps_c1_focus, one_step_c1_focus), dim=0) 307 | steps_c2_c1_focus = torch.cat((steps_c2_c1_focus, one_step_c2_c1_focus), dim=0) 308 | 309 | focus_pos, new_pos = lig_position[focus_id], lig_position[i] 310 | # if i==3 or i==4: # Use for debug. We have verified the id offset is correct. 311 | # print(new_pos) 312 | one_step_dis = torch.norm(new_pos - focus_pos).view(1,1) 313 | steps_dist = torch.cat((steps_dist, one_step_dis), dim=0) 314 | 315 | a = ((c1_pos - focus_pos) * (new_pos - focus_pos)).sum(dim=-1) 316 | b = torch.cross(c1_pos - focus_pos, new_pos - focus_pos).norm(dim=-1) 317 | one_step_angle = torch.atan2(b,a).view(1,1) 318 | steps_angle = torch.cat((steps_angle, one_step_angle), dim=0) 319 | 320 | plane1 = torch.cross(focus_pos - c1_pos, new_pos - c1_pos) 321 | plane2 = torch.cross(focus_pos - c1_pos, c2_pos - c1_pos) 322 | a = (plane1 * plane2).sum(dim=-1) 323 | b = (torch.cross(plane1, plane2) * (focus_pos - c1_pos)).sum(dim=-1) / torch.norm(focus_pos - c1_pos) 324 | one_step_torsion = torch.atan2(b, a).view(1,1) 325 | steps_torsion = torch.cat((steps_torsion, one_step_torsion), dim=0) 326 | 327 | 328 | steps_focus = torch.cat((steps_focus, focus_ids+idx_offsets[:-1]+idx_offsets_brought_by_rec), dim=0) 329 | steps_new_atom_type = lig_atom_type 330 | 331 | # For example, for a rec-lig pair, rec has N atoms and lig has M atoms 332 | data_batch = {} 333 | data_batch['atom_type'] = steps_atom_type # [N+(1+N)+(2+N)+...+(M-1+N)], which correspond to M steps 334 | data_batch['position'] = steps_position # [N+(1+N)+(2+N)+...+(M-1+N), 3] 335 | data_batch['rec_mask'] = steps_rec_mask # [N+(1+N)+(2+N)+...+(M-1+N)] 336 | data_batch['batch'] = steps_batch # [N+(1+N)+(2+N)+...+(M-1+N)] 337 | data_batch['contact_y_or_n'] = contact_y_or_n # [2] 338 | data_batch['cannot_contact'] = cannot_contact # [2] 339 | data_batch['new_atom_type'] = steps_new_atom_type # [M] 340 | 341 | data_batch['focus'] = steps_focus[:,None] # [M, 1] 342 | data_batch['c1_focus'] = steps_c1_focus # [M, 2] 343 | data_batch['c2_c1_focus'] = steps_c2_c1_focus # [M, 3] 344 | 345 | data_batch['new_dist'] = steps_dist # [M, 1] 346 | data_batch['new_angle'] = steps_angle # [M, 1] 347 | data_batch['new_torsion'] = steps_torsion # [M, 1] 348 | 349 | data_batch['cannot_focus'] = steps_cannot_focus.view(-1) # [1+2+...+(M-1)] 350 | 351 | 352 | 353 | return data_batch 354 | 355 | 356 | -------------------------------------------------------------------------------- /GraphBP/download_data.sh: -------------------------------------------------------------------------------- 1 | ### Unzipping on SSD is much faster!!! 2 | 3 | #!/bin/bash 4 | 5 | # download and extract the CrossDocked2020 molecular structures 6 | wget https://bits.csb.pitt.edu/files/crossdock2020/CrossDocked2020_v1.1.tgz -P data/crossdock2020/ 7 | tar -C data/crossdock2020/ -xzf data/crossdock2020/CrossDocked2020_v1.1.tgz 8 | 9 | # download the train and test sets 10 | wget https://bits.csb.pitt.edu/files/it2_tt_0_lowrmsd_mols_train0_fixed.types -P data/crossdock2020/ 11 | wget https://bits.csb.pitt.edu/files/it2_tt_0_lowrmsd_mols_test0_fixed.types -P data/crossdock2020/ 12 | 13 | # split multi-pose files into single-pose files 14 | python scripts/split_sdf.py data/crossdock2020/it2_tt_0_lowrmsd_mols_train0_fixed.types data/crossdock2020 15 | python scripts/split_sdf.py data/crossdock2020/it2_tt_0_lowrmsd_mols_test0_fixed.types data/crossdock2020 16 | -------------------------------------------------------------------------------- /GraphBP/environment.yml: -------------------------------------------------------------------------------- 1 | name: gen 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=4.5=1_gnu 8 | - boost-cpp=1.74.0=hc6e9bd1_3 9 | - bzip2=1.0.8=h7f98852_4 10 | - ca-certificates=2021.10.8=ha878542_0 11 | - cairo=1.16.0=hf32fb01_1 12 | - certifi=2021.10.8=py38h578d9bd_1 13 | - fontconfig=2.13.1=hba837de_1005 14 | - freetype=2.10.4=h0708190_1 15 | - glib=2.69.1=h5202010_0 16 | - icu=68.1=h2531618_0 17 | - ld_impl_linux-64=2.35.1=h7274673_9 18 | - libffi=3.3=he6710b0_2 19 | - libgcc-ng=9.3.0=h5101ec6_17 20 | - libgomp=9.3.0=h5101ec6_17 21 | - libiconv=1.16=h516909a_0 22 | - libpng=1.6.37=h21135ba_2 23 | - libstdcxx-ng=9.3.0=hd4cf53a_17 24 | - libuuid=2.32.1=h7f98852_1000 25 | - libuv=1.42.0=h7f98852_0 26 | - libxcb=1.14=h7b6447c_0 27 | - libxml2=2.9.12=h72842e0_0 28 | - lz4-c=1.9.3=h9c3ff4c_1 29 | - ncurses=6.3=h7f8727e_2 30 | - nodejs=14.17.4=h92b4a50_0 31 | - openbabel=3.1.1=py38hf4b5c11_1 32 | - openssl=1.1.1m=h7f8727e_0 33 | - pcre=8.45=h9c3ff4c_0 34 | - pip=21.2.4=py38h06a4308_0 35 | - pixman=0.40.0=h36c2ea0_0 36 | - python=3.8.12=h12debd9_0 37 | - python_abi=3.8=2_cp38 38 | - readline=8.1=h27cfd23_0 39 | - setuptools=58.0.4=py38h06a4308_0 40 | - sqlite=3.37.0=hc218d9a_0 41 | - tk=8.6.11=h1ccaba5_0 42 | - wheel=0.37.0=pyhd3eb1b0_1 43 | - xz=5.2.5=h7b6447c_0 44 | - zlib=1.2.11=h7f8727e_4 45 | - zstd=1.4.9=ha95c52a_0 46 | - pip: 47 | - anyio==3.4.0 48 | - argon2-cffi==21.3.0 49 | - argon2-cffi-bindings==21.2.0 50 | - attrs==21.4.0 51 | - babel==2.9.1 52 | - backcall==0.2.0 53 | - biopython==1.79 54 | - bleach==4.1.0 55 | - cffi==1.15.0 56 | - charset-normalizer==2.0.9 57 | - cycler==0.11.0 58 | - debugpy==1.5.1 59 | - decorator==5.1.0 60 | - defusedxml==0.7.1 61 | - entrypoints==0.3 62 | - fonttools==4.28.5 63 | - future==0.18.2 64 | - idna==3.3 65 | - importlib-resources==5.4.0 66 | - ipykernel==6.6.1 67 | - ipython==7.30.1 68 | - ipython-genutils==0.2.0 69 | - jedi==0.18.1 70 | - jinja2==3.0.3 71 | - joblib==1.1.0 72 | - json5==0.9.6 73 | - jsonschema==4.3.3 74 | - jupyter-client==7.1.0 75 | - jupyter-core==4.9.1 76 | - jupyter-server==1.13.1 77 | - jupyterlab==3.2.5 78 | - jupyterlab-pygments==0.1.2 79 | - jupyterlab-server==2.10.2 80 | - kiwisolver==1.3.2 81 | - markupsafe==2.0.1 82 | - matplotlib==3.5.1 83 | - matplotlib-inline==0.1.3 84 | - mistune==0.8.4 85 | - mpmath==1.2.1 86 | - nbclassic==0.3.4 87 | - nbclient==0.5.9 88 | - nbconvert==6.4.0 89 | - nbformat==5.1.3 90 | - nest-asyncio==1.5.4 91 | - networkx==2.6.3 92 | - nonechucks==0.4.2 93 | - notebook==6.4.6 94 | - numpy==1.22.0 95 | - packaging==21.3 96 | - pandas==1.3.5 97 | - pandocfilters==1.5.0 98 | - parso==0.8.3 99 | - pexpect==4.8.0 100 | - pickleshare==0.7.5 101 | - pillow==9.0.0 102 | - prometheus-client==0.12.0 103 | - prompt-toolkit==3.0.24 104 | - ptyprocess==0.7.0 105 | - py3dmol==1.7.0 106 | - pycparser==2.21 107 | - pygments==2.11.1 108 | - pyparsing==3.0.6 109 | - pyrsistent==0.18.0 110 | - python-dateutil==2.8.2 111 | - pytz==2021.3 112 | - pyzmq==22.3.0 113 | - rdkit-pypi==2021.9.3 114 | - requests==2.27.0 115 | - scikit-learn==1.0.2 116 | - scipy==1.7.3 117 | - seaborn==0.11.2 118 | - send2trash==1.8.0 119 | - six==1.16.0 120 | - sniffio==1.2.0 121 | - sym==0.3.5 122 | - sympy==1.9 123 | - terminado==0.12.1 124 | - testpath==0.5.0 125 | - threadpoolctl==3.0.0 126 | - torch==1.9.0+cu111 127 | - torch-cluster==1.5.9 128 | - torch-geometric==1.7.2 129 | - torch-scatter==2.0.9 130 | - torch-sparse==0.6.12 131 | - torch-spline-conv==1.2.1 132 | - torchaudio==0.9.0 133 | - torchvision==0.10.0+cu111 134 | - tornado==6.1 135 | - tqdm==4.62.3 136 | - traitlets==5.1.1 137 | - typing-extensions==4.0.1 138 | - urllib3==1.26.7 139 | - wcwidth==0.2.5 140 | - webencodings==0.5.1 141 | - websocket-client==1.2.3 142 | - zipp==3.7.0 143 | -------------------------------------------------------------------------------- /GraphBP/main.py: -------------------------------------------------------------------------------- 1 | from config import conf 2 | from runner import Runner 3 | import os 4 | 5 | 6 | binding_site_range = 15.0 7 | 8 | 9 | out_path = 'trained_model' 10 | if not os.path.isdir(out_path): 11 | os.mkdir(out_path) 12 | 13 | runner = Runner(conf, out_path=out_path) 14 | runner.train(binding_site_range) -------------------------------------------------------------------------------- /GraphBP/main_eval.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | from rdkit.Chem import Draw 4 | 5 | from utils import BondAdder 6 | from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule 7 | from rdkit import Chem 8 | from rdkit.Chem import AllChem 9 | 10 | import numpy as np 11 | import scipy as sp 12 | 13 | 14 | 15 | 16 | 17 | ### config 18 | save_mol = True 19 | uff = True 20 | uff_w_rec = False # UFF in the context of binding site 21 | save_sdf_before_uff = False 22 | save_sdf = True 23 | data_root='./data/crossdock2020' 24 | 25 | path = './trained_model' 26 | epoch = 33 27 | 28 | all_mols_dict_path = os.path.join(path, '{}_mols.mol_dict'.format(epoch)) 29 | 30 | 31 | def check_chemical_validity(mol): 32 | """ 33 | Checks the chemical validity of the mol object. Existing mol object is 34 | not modified. Radicals pass this test. 35 | 36 | Args: 37 | mol: Rdkit mol object 38 | 39 | :rtype: 40 | :class:`bool`, True if chemically valid, False otherwise 41 | """ 42 | 43 | s = Chem.MolToSmiles(mol, isomericSmiles=True) 44 | m = Chem.MolFromSmiles(s) # implicitly performs sanitization 45 | if m: 46 | return True 47 | else: 48 | return False 49 | 50 | def rd_mol_to_sdf(rd_mol, sdf_file, kekulize=False, name=''): 51 | writer = Chem.SDWriter(sdf_file) 52 | writer.SetKekulize(kekulize) 53 | if name: 54 | rd_mol.SetProp('_Name', name) 55 | writer.write(rd_mol) 56 | writer.close() 57 | 58 | def get_rd_atom_res_id(rd_atom): 59 | ''' 60 | Return an object that uniquely 61 | identifies the residue that the 62 | atom belongs to in a given PDB. 63 | ''' 64 | res_info = rd_atom.GetPDBResidueInfo() 65 | return ( 66 | res_info.GetChainId(), 67 | res_info.GetResidueNumber() 68 | ) 69 | 70 | def get_pocket(lig_mol, rec_mol, max_dist=8): 71 | lig_coords = lig_mol.GetConformer().GetPositions() 72 | rec_coords = rec_mol.GetConformer().GetPositions() 73 | dist = sp.spatial.distance.cdist(lig_coords, rec_coords) 74 | 75 | # indexes of atoms in rec_mol that are 76 | # within max_dist of an atom in lig_mol 77 | pocket_atom_idxs = set(np.nonzero((dist < max_dist))[1]) 78 | 79 | # determine pocket residues 80 | pocket_res_ids = set() 81 | for i in pocket_atom_idxs: 82 | atom = rec_mol.GetAtomWithIdx(int(i)) 83 | res_id = get_rd_atom_res_id(atom) 84 | pocket_res_ids.add(res_id) 85 | 86 | # copy mol and delete atoms 87 | pkt_mol = rec_mol 88 | pkt_mol = Chem.RWMol(pkt_mol) 89 | for atom in list(pkt_mol.GetAtoms()): 90 | res_id = get_rd_atom_res_id(atom) 91 | if res_id not in pocket_res_ids: 92 | pkt_mol.RemoveAtom(atom.GetIdx()) 93 | 94 | Chem.SanitizeMol(pkt_mol) 95 | return pkt_mol 96 | 97 | 98 | 99 | with open(all_mols_dict_path, 'rb') as f: 100 | all_mols_dict = pickle.load(f) 101 | 102 | bond_adder = BondAdder() 103 | 104 | 105 | all_results_dict = {} 106 | os.makedirs(os.path.join(path, 'gen_mols' + '_epoch_' + str(epoch) + '/') , exist_ok=True) 107 | 108 | global_index = 0 109 | global_index_to_rec_src = {} 110 | global_index_to_ref_lig_src = {} 111 | num_valid = 0 112 | for index in all_mols_dict: 113 | # print(index) 114 | mol_dicts = all_mols_dict[index] 115 | for num_atom in mol_dicts: 116 | if type(num_atom) is int: 117 | mol_dicts_w_num_atom = mol_dicts[num_atom] 118 | num_mol_w_num_atom = len(mol_dicts_w_num_atom['_atomic_numbers']) 119 | for j in range(num_mol_w_num_atom): 120 | global_index += 1 121 | 122 | 123 | ### Add bonds 124 | atomic_numbers = mol_dicts_w_num_atom['_atomic_numbers'][j] 125 | positions = mol_dicts_w_num_atom['_positions'][j] 126 | rd_mol, ob_mol = bond_adder.make_mol(atomic_numbers, positions) 127 | 128 | ### check validity 129 | if check_chemical_validity(rd_mol): 130 | num_valid += 1 131 | print('Valid molecules:', num_valid) 132 | 133 | rd_mol = Chem.AddHs(rd_mol, explicitOnly=True, addCoords=True) 134 | if save_sdf_before_uff: 135 | sdf_file = os.path.join(path, 'gen_mols' + '_epoch_' + str(epoch) + '/' + str(global_index) + '_beforeuff.sdf') 136 | rd_mol_to_sdf(rd_mol, sdf_file) 137 | print('Saving' + str(sdf_file)) 138 | 139 | 140 | # ### UFF minimization 141 | if uff: 142 | try: 143 | # print(rd_mol.GetConformer().GetPositions()) 144 | UFFOptimizeMolecule(rd_mol) 145 | print("Performing UFF...") 146 | # print(rd_mol.GetConformer().GetPositions()) 147 | except: 148 | print('Skip UFF...') 149 | # pass 150 | 151 | if uff_w_rec: 152 | # try: 153 | # print(rd_mol.GetConformer().GetPositions()) 154 | # print(rd_mol.GetConformer().GetPositions().shape) 155 | rd_mol = Chem.RWMol(rd_mol) 156 | rec_mol = Chem.MolFromPDBFile(os.path.join(data_root, mol_dicts['rec_src']), sanitize=True) 157 | rec_mol = get_pocket(rd_mol, rec_mol) 158 | 159 | 160 | uff_mol = Chem.CombineMols(rec_mol, rd_mol) 161 | 162 | # print(uff_mol.GetConformer().GetPositions()[:-rd_mol.GetNumAtoms()]) 163 | # print(uff_mol.GetConformer().GetPositions()[:-rd_mol.GetNumAtoms()].shape) 164 | 165 | try: 166 | Chem.SanitizeMol(uff_mol) 167 | except Chem.AtomValenceException: 168 | print('Invalid valence') 169 | except (Chem.AtomKekulizeException, Chem.KekulizeException): 170 | print('Failed to kekulize') 171 | try: 172 | # UFFOptimizeMolecule(uff_mol) 173 | uff = AllChem.UFFGetMoleculeForceField( 174 | uff_mol, confId=0, ignoreInterfragInteractions=False 175 | ) 176 | uff.Initialize() 177 | # E_init = uff.CalcEnergy() 178 | for i in range(rec_mol.GetNumAtoms()): # Fix the rec atoms 179 | uff.AddFixedPoint(i) 180 | converged = False 181 | n_iters=200 182 | n_tries=2 183 | while n_tries > 0 and not converged: 184 | print('.', end='', flush=True) 185 | converged = not uff.Minimize(maxIts=n_iters) 186 | n_tries -= 1 187 | print(flush=True) 188 | # E_final = uff.CalcEnergy() 189 | print("Performed UFF with binding site...") 190 | except: 191 | print('Skip UFF...') 192 | coords = uff_mol.GetConformer().GetPositions() 193 | rd_conf = rd_mol.GetConformer() 194 | for i, xyz in enumerate(coords[-rd_mol.GetNumAtoms():]): 195 | rd_conf.SetAtomPosition(i, xyz) 196 | # print(rd_mol.GetConformer().GetPositions()) 197 | # print(rd_mol.GetConformer().GetPositions().shape) 198 | # print(uff_mol.GetConformer().GetPositions()[:-rd_mol.GetNumAtoms()]) 199 | # print(uff_mol.GetConformer().GetPositions()[:-rd_mol.GetNumAtoms()].shape) 200 | # print(E_init, E_final) 201 | 202 | if save_sdf: 203 | 204 | ### 205 | try: 206 | rd_mol = Chem.RemoveHs(rd_mol) 207 | print("Remove H atoms before saving mol...") 208 | except: 209 | print("Cannot remove H atoms...") 210 | 211 | sdf_file = os.path.join(path, 'gen_mols' + '_epoch_' + str(epoch) + '/' + str(global_index) + '.sdf') 212 | rd_mol_to_sdf(rd_mol, sdf_file) 213 | print('Saving' + str(sdf_file)) 214 | global_index_to_rec_src[global_index] = mol_dicts['rec_src'] 215 | global_index_to_ref_lig_src[global_index] = mol_dicts['lig_src'] 216 | 217 | if save_mol: 218 | try: 219 | img_path = os.path.join(path, 'gen_mols' + '_epoch_' + str(epoch) + '/' + str(global_index) + '.png') 220 | img = Draw.MolsToGridImage([rd_mol]) 221 | img.save(img_path) 222 | print('Saving' + str(img_path)) 223 | except: 224 | pass 225 | print('------------------------------------------------') 226 | else: 227 | continue 228 | 229 | if save_sdf: 230 | print('Saving dicts...') 231 | with open(os.path.join(path, 'gen_mols_epoch_{}/global_index_to_rec_src.dict').format(epoch),'wb') as f: 232 | pickle.dump(global_index_to_rec_src, f) 233 | with open(os.path.join(path, 'gen_mols_epoch_{}/global_index_to_ref_lig_src.dict').format(epoch),'wb') as f: 234 | pickle.dump(global_index_to_ref_lig_src, f) 235 | 236 | print('Done!!!') 237 | 238 | -------------------------------------------------------------------------------- /GraphBP/main_gen.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from config import conf 3 | from runner import Runner 4 | import torch 5 | 6 | runner = Runner(conf) 7 | 8 | known_binding_site = True 9 | 10 | 11 | node_temp = 0.5 12 | dist_temp = 0.3 13 | angle_temp = 0.4 14 | torsion_temp = 1.0 15 | 16 | min_atoms = 10 17 | max_atoms = 45 18 | focus_th = 0.5 19 | contact_th = 0.5 20 | num_gen = 100 # number generate for each reference rec-lig pair 21 | 22 | trained_model_path = 'trained_model' 23 | epochs = [33] 24 | 25 | 26 | 27 | for epoch in epochs: 28 | print('Epoch:', epoch) 29 | runner.model.load_state_dict(torch.load('{}/model_{}.pth'.format(trained_model_path, epoch))) 30 | all_mol_dicts = runner.generate(num_gen, temperature=[node_temp, dist_temp, angle_temp, torsion_temp], max_atoms=max_atoms, min_atoms=min_atoms, focus_th=focus_th, contact_th=contact_th, add_final=True, known_binding_site=known_binding_site) 31 | 32 | with open('{}/{}_mols.mol_dict'.format(trained_model_path, epoch),'wb') as f: 33 | pickle.dump(all_mol_dicts, f) 34 | 35 | -------------------------------------------------------------------------------- /GraphBP/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .geometric_computing import xyztodat, xyztoda, dattoxyz 2 | from .graphbp import GraphBP 3 | from .features import dist_emb, angle_emb, torsion_emb -------------------------------------------------------------------------------- /GraphBP/model/features.py: -------------------------------------------------------------------------------- 1 | # Based on the code from: https://github.com/klicperajo/dimenet, 2 | # https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet_utils.py 3 | 4 | import numpy as np 5 | from scipy.optimize import brentq 6 | from scipy import special as sp 7 | import torch 8 | from math import sqrt, pi as PI 9 | 10 | try: 11 | import sympy as sym 12 | except ImportError: 13 | sym = None 14 | 15 | # import sympy as sym 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | def Jn(r, n): 20 | return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) 21 | 22 | 23 | def Jn_zeros(n, k): 24 | zerosj = np.zeros((n, k), dtype='float32') 25 | zerosj[0] = np.arange(1, k + 1) * np.pi 26 | points = np.arange(1, k + n) * np.pi 27 | racines = np.zeros(k + n - 1, dtype='float32') 28 | for i in range(1, n): 29 | for j in range(k + n - 1 - i): 30 | foo = brentq(Jn, points[j], points[j + 1], (i, )) 31 | racines[j] = foo 32 | points = racines 33 | zerosj[i][:k] = racines[:k] 34 | 35 | return zerosj 36 | 37 | 38 | def spherical_bessel_formulas(n): 39 | x = sym.symbols('x') 40 | 41 | f = [sym.sin(x) / x] 42 | a = sym.sin(x) / x 43 | for i in range(1, n): 44 | b = sym.diff(a, x) / x 45 | f += [sym.simplify(b * (-x)**i)] 46 | a = sym.simplify(b) 47 | return f 48 | 49 | 50 | def bessel_basis(n, k): 51 | zeros = Jn_zeros(n, k) 52 | normalizer = [] 53 | for order in range(n): 54 | normalizer_tmp = [] 55 | for i in range(k): 56 | normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1)**2] 57 | normalizer_tmp = 1 / np.array(normalizer_tmp)**0.5 58 | normalizer += [normalizer_tmp] 59 | 60 | f = spherical_bessel_formulas(n) 61 | x = sym.symbols('x') 62 | bess_basis = [] 63 | for order in range(n): 64 | bess_basis_tmp = [] 65 | for i in range(k): 66 | bess_basis_tmp += [ 67 | sym.simplify(normalizer[order][i] * 68 | f[order].subs(x, zeros[order, i] * x)) 69 | ] 70 | bess_basis += [bess_basis_tmp] 71 | return bess_basis 72 | 73 | 74 | def sph_harm_prefactor(k, m): 75 | return ((2 * k + 1) * np.math.factorial(k - abs(m)) / 76 | (4 * np.pi * np.math.factorial(k + abs(m))))**0.5 77 | 78 | 79 | def associated_legendre_polynomials(k, zero_m_only=True): 80 | z = sym.symbols('z') 81 | P_l_m = [[0] * (j + 1) for j in range(k)] 82 | 83 | P_l_m[0][0] = 1 84 | if k > 0: 85 | P_l_m[1][0] = z 86 | 87 | for j in range(2, k): 88 | P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] - 89 | (j - 1) * P_l_m[j - 2][0]) / j) 90 | if not zero_m_only: 91 | for i in range(1, k): 92 | P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1]) 93 | if i + 1 < k: 94 | P_l_m[i + 1][i] = sym.simplify( 95 | (2 * i + 1) * z * P_l_m[i][i]) 96 | for j in range(i + 2, k): 97 | P_l_m[j][i] = sym.simplify( 98 | ((2 * j - 1) * z * P_l_m[j - 1][i] - 99 | (i + j - 1) * P_l_m[j - 2][i]) / (j - i)) 100 | 101 | return P_l_m 102 | 103 | 104 | def real_sph_harm(l, zero_m_only=False, spherical_coordinates=True): 105 | """ 106 | Computes formula strings of the the real part of the spherical harmonics up to order l (excluded). 107 | Variables are either cartesian coordinates x,y,z on the unit sphere or spherical coordinates phi and theta. 108 | """ 109 | if not zero_m_only: 110 | x = sym.symbols('x') 111 | y = sym.symbols('y') 112 | S_m = [x*0] 113 | C_m = [1+0*x] 114 | # S_m = [0] 115 | # C_m = [1] 116 | for i in range(1, l): 117 | x = sym.symbols('x') 118 | y = sym.symbols('y') 119 | S_m += [x*S_m[i-1] + y*C_m[i-1]] 120 | C_m += [x*C_m[i-1] - y*S_m[i-1]] 121 | 122 | P_l_m = associated_legendre_polynomials(l, zero_m_only) 123 | if spherical_coordinates: 124 | theta = sym.symbols('theta') 125 | z = sym.symbols('z') 126 | for i in range(len(P_l_m)): 127 | for j in range(len(P_l_m[i])): 128 | if type(P_l_m[i][j]) != int: 129 | P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) 130 | if not zero_m_only: 131 | phi = sym.symbols('phi') 132 | for i in range(len(S_m)): 133 | S_m[i] = S_m[i].subs(x, sym.sin( 134 | theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) 135 | for i in range(len(C_m)): 136 | C_m[i] = C_m[i].subs(x, sym.sin( 137 | theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) 138 | 139 | Y_func_l_m = [['0']*(2*j + 1) for j in range(l)] 140 | for i in range(l): 141 | Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) 142 | 143 | if not zero_m_only: 144 | for i in range(1, l): 145 | for j in range(1, i + 1): 146 | Y_func_l_m[i][j] = sym.simplify( 147 | 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) 148 | for i in range(1, l): 149 | for j in range(1, i + 1): 150 | Y_func_l_m[i][-j] = sym.simplify( 151 | 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) 152 | 153 | return Y_func_l_m 154 | 155 | 156 | class Envelope(torch.nn.Module): 157 | def __init__(self, exponent): 158 | super(Envelope, self).__init__() 159 | self.p = exponent + 1 160 | self.a = -(self.p + 1) * (self.p + 2) / 2 161 | self.b = self.p * (self.p + 2) 162 | self.c = -self.p * (self.p + 1) / 2 163 | 164 | def forward(self, x): 165 | p, a, b, c = self.p, self.a, self.b, self.c 166 | x_pow_p0 = x.pow(p - 1) 167 | x_pow_p1 = x_pow_p0 * x 168 | x_pow_p2 = x_pow_p1 * x 169 | return 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2 170 | 171 | 172 | class dist_emb(torch.nn.Module): 173 | def __init__(self, num_radial, cutoff=5.0, envelope_exponent=5): 174 | super(dist_emb, self).__init__() 175 | self.cutoff = cutoff 176 | self.envelope = Envelope(envelope_exponent) 177 | 178 | self.freq = torch.nn.Parameter(torch.Tensor(num_radial), requires_grad=False) 179 | 180 | self.reset_parameters() 181 | 182 | def reset_parameters(self): 183 | torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI) 184 | 185 | def forward(self, dist): 186 | dist = dist.unsqueeze(-1) / self.cutoff 187 | return self.envelope(dist) * (self.freq * dist).sin() 188 | 189 | 190 | class angle_emb(torch.nn.Module): 191 | def __init__(self, num_spherical, num_radial, cutoff=5.0, 192 | envelope_exponent=5): 193 | super(angle_emb, self).__init__() 194 | assert num_radial <= 64 195 | self.num_spherical = num_spherical 196 | self.num_radial = num_radial 197 | self.cutoff = cutoff 198 | # self.envelope = Envelope(envelope_exponent) 199 | 200 | bessel_forms = bessel_basis(num_spherical, num_radial) 201 | sph_harm_forms = real_sph_harm(num_spherical) 202 | self.sph_funcs = [] 203 | self.bessel_funcs = [] 204 | 205 | x, theta = sym.symbols('x theta') 206 | modules = {'sin': torch.sin, 'cos': torch.cos} 207 | for i in range(num_spherical): 208 | if i == 0: 209 | sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0) 210 | self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1) 211 | else: 212 | sph = sym.lambdify([theta], sph_harm_forms[i][0], modules) 213 | self.sph_funcs.append(sph) 214 | for j in range(num_radial): 215 | bessel = sym.lambdify([x], bessel_forms[i][j], modules) 216 | self.bessel_funcs.append(bessel) 217 | 218 | def forward(self, dist, angle, idx_kj=None): 219 | dist = dist / self.cutoff 220 | 221 | rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1) 222 | # rbf = self.envelope(dist).unsqueeze(-1) * rbf 223 | 224 | 225 | cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1) 226 | 227 | 228 | n, k = self.num_spherical, self.num_radial 229 | if idx_kj is None: # Use for encoding in generative modeling 230 | out = (rbf.view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k) 231 | else: # Use for SphereNet physical representation 232 | out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k) 233 | return out 234 | 235 | 236 | class torsion_emb(torch.nn.Module): 237 | def __init__(self, num_spherical, num_radial, cutoff=5.0, 238 | envelope_exponent=5): 239 | super(torsion_emb, self).__init__() 240 | assert num_radial <= 64 241 | self.num_spherical = num_spherical # 242 | self.num_radial = num_radial 243 | self.cutoff = cutoff 244 | # self.envelope = Envelope(envelope_exponent) 245 | 246 | bessel_forms = bessel_basis(num_spherical, num_radial) 247 | sph_harm_forms = real_sph_harm(num_spherical, zero_m_only=False) 248 | self.sph_funcs = [] 249 | self.bessel_funcs = [] 250 | 251 | x = sym.symbols('x') 252 | theta = sym.symbols('theta') 253 | phi = sym.symbols('phi') 254 | modules = {'sin': torch.sin, 'cos': torch.cos} 255 | for i in range(self.num_spherical): 256 | if i == 0: 257 | sph1 = sym.lambdify([theta, phi], sph_harm_forms[i][0], modules) 258 | self.sph_funcs.append(lambda x, y: torch.zeros_like(x) + torch.zeros_like(y) + sph1(0,0)) #torch.zeros_like(x) + torch.zeros_like(y) 259 | else: 260 | for k in range(-i, i + 1): 261 | sph = sym.lambdify([theta, phi], sph_harm_forms[i][k+i], modules) 262 | self.sph_funcs.append(sph) 263 | for j in range(self.num_radial): 264 | bessel = sym.lambdify([x], bessel_forms[i][j], modules) 265 | self.bessel_funcs.append(bessel) 266 | 267 | def forward(self, dist, angle, phi, idx_kj): 268 | dist = dist / self.cutoff 269 | rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1) 270 | cbf = torch.stack([f(angle, phi) for f in self.sph_funcs], dim=1) 271 | 272 | n, k = self.num_spherical, self.num_radial 273 | out = (rbf[idx_kj].view(-1, 1, n, k) * cbf.view(-1, n, n, 1)).view(-1, n * n * k) 274 | return out 275 | 276 | -------------------------------------------------------------------------------- /GraphBP/model/geometric_computing.py: -------------------------------------------------------------------------------- 1 | # Based on the code from: https://github.com/klicperajo/dimenet, 2 | # https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet.py 3 | 4 | import torch 5 | from torch_scatter import scatter 6 | from torch_sparse import SparseTensor 7 | from math import sqrt, pi as PI 8 | from torch_geometric.nn import knn_graph 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | 13 | def get_nearst_node(pos, batch): 14 | j, i = knn_graph(pos, 1, batch) 15 | adj_nearest = SparseTensor(row=i, col=j, value=torch.ones_like(j, device=j.device), sparse_sizes=(batch.size(0), batch.size(0))) 16 | j2, i2 = knn_graph(pos, 2, batch) 17 | adj_nearest2 = SparseTensor(row=i2, col=j2, value=torch.ones_like(j2, device=j2.device), sparse_sizes=(batch.size(0), batch.size(0))) 18 | 19 | return adj_nearest, SparseTensor.from_dense(adj_nearest2.to_dense() - adj_nearest.to_dense()) 20 | 21 | 22 | def xyztoda(pos, edge_index, num_nodes): 23 | j, i = edge_index # j->i 24 | 25 | # Calculate distances. # number of edges 26 | dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() 27 | 28 | value = torch.arange(j.size(0), device=j.device) 29 | adj_t = SparseTensor(row=i, col=j, value=value, sparse_sizes=(num_nodes, num_nodes)) 30 | adj_t_row = adj_t[j] 31 | num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) 32 | 33 | # Node indices (k->j->i) for triplets. 34 | idx_i = i.repeat_interleave(num_triplets) 35 | idx_j = j.repeat_interleave(num_triplets) 36 | idx_k = adj_t_row.storage.col() 37 | mask = idx_i != idx_k 38 | idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] 39 | 40 | # Edge indices (k-j, j->i) for triplets. 41 | idx_kj = adj_t_row.storage.value()[mask] 42 | idx_ji = adj_t_row.storage.row()[mask] 43 | 44 | # Calculate angles. 0 to pi 45 | pos_ji = pos[idx_i] - pos[idx_j] 46 | pos_jk = pos[idx_k] - pos[idx_j] 47 | a = (pos_ji * pos_jk).sum(dim=-1) # cos_angle * |pos_ji| * |pos_jk| 48 | b = torch.cross(pos_ji, pos_jk).norm(dim=-1) # sin_angle * |pos_ji| * |pos_jk| 49 | angle = torch.atan2(b, a) 50 | 51 | return dist, angle, i, j, idx_kj, idx_ji 52 | 53 | 54 | def xyztodat(pos, edge_index, num_nodes, batch): 55 | j, i = edge_index # j->i 56 | 57 | # Calculate distances. # number of edges 58 | dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() 59 | 60 | value = torch.arange(j.size(0), device=j.device) 61 | adj_t = SparseTensor(row=i, col=j, value=value, sparse_sizes=(num_nodes, num_nodes)) 62 | adj_t_row = adj_t[j] 63 | num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) 64 | 65 | # Node indices (k->j->i) for triplets. 66 | idx_i = i.repeat_interleave(num_triplets) 67 | idx_j = j.repeat_interleave(num_triplets) 68 | idx_k = adj_t_row.storage.col() 69 | mask = idx_i != idx_k 70 | idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] 71 | 72 | # Edge indices (k-j, j->i) for triplets. 73 | idx_kj = adj_t_row.storage.value()[mask] 74 | idx_ji = adj_t_row.storage.row()[mask] 75 | 76 | # Calculate angles. 0 to pi 77 | pos_ji = pos[idx_i] - pos[idx_j] 78 | pos_jk = pos[idx_k] - pos[idx_j] 79 | a = (pos_ji * pos_jk).sum(dim=-1) # cos_angle * |pos_ji| * |pos_jk| 80 | b = torch.cross(pos_ji, pos_jk).norm(dim=-1) # |sin_angle| * |pos_ji| * |pos_jk| 81 | angle = torch.atan2(b, a) 82 | 83 | print(pos.shape) 84 | print(batch.shape) 85 | 86 | adj_nearest, adj_nearest2 = get_nearst_node(pos, batch) 87 | 88 | adj_nearest_row = adj_nearest[idx_j] 89 | adj_nearest2_row = adj_nearest2[idx_j] 90 | idx_k_n = adj_nearest_row.storage.col() 91 | idx_k_n2 = adj_nearest2_row.storage.col() 92 | mask = idx_k_n == idx_i 93 | idx_k_n[mask] = idx_k_n2[mask] 94 | 95 | # Calculate torsions. 96 | pos_j0 = pos[idx_k] - pos[idx_j] 97 | pos_ji = pos[idx_i] - pos[idx_j] 98 | pos_jk = pos[idx_k_n] - pos[idx_j] 99 | dist_ji = pos_ji.pow(2).sum(dim=-1).sqrt() 100 | plane1 = torch.cross(pos_ji, pos_j0) 101 | plane2 = torch.cross(pos_ji, pos_jk) 102 | a = (plane1 * plane2).sum(dim=-1) # cos_angle * |plane1| * |plane2| 103 | b = (torch.cross(plane1, plane2) * pos_ji).sum(dim=-1) / dist_ji 104 | torsion = torch.atan2(b, a) # -pi to pi 105 | torsion[torsion<=0]+=2*PI # 0 to 2pi 106 | 107 | return dist, angle, torsion, i, j, idx_kj, idx_ji 108 | 109 | 110 | def dattoxyz(f, c1, c2, d, angle, torsion): 111 | c1c2 = c2 - c1 112 | c1f = f - c1 113 | c1c3 = c1f * torch.sum(c1c2 * c1f, dim=-1, keepdim=True) / torch.sum(c1f * c1f, dim=-1, keepdim=True) 114 | c3 = c1c3 + c1 115 | 116 | c3c2 = c2 - c3 117 | c3c4_1 = c3c2 * torch.cos(torsion[:, :, None]) 118 | c3c4_2 = torch.cross(c3c2, c1f) / torch.norm(c1f, dim=-1, keepdim=True) * torch.sin(torsion[:, :, None]) 119 | c3c4 = c3c4_1 + c3c4_2 120 | 121 | new_pos = -c1f / torch.norm(c1f, dim=-1, keepdim=True) * d[:, :, None] * torch.cos(angle[:, :, None]) 122 | new_pos += c3c4 / torch.norm(c3c4, dim=-1, keepdim=True) * d[:, :, None] * torch.sin(angle[:, :, None]) 123 | new_pos += f 124 | 125 | return new_pos 126 | -------------------------------------------------------------------------------- /GraphBP/model/graphbp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .net_utils import * 5 | from .geometric_computing import * 6 | from .features import dist_emb, angle_emb, torsion_emb 7 | from .schnet import SchNet 8 | 9 | 10 | class GraphBP(nn.Module): 11 | def __init__(self, cutoff, num_node_types, num_lig_node_types, num_interactions, num_filters, num_gaussians, 12 | hidden_channels, basis_emb_size, num_spherical, num_radial, num_flow_layers, deq_coeff=0.9, use_gpu=True): 13 | 14 | super(GraphBP, self).__init__() 15 | self.use_gpu = use_gpu 16 | self.num_node_types = num_node_types 17 | self.num_lig_node_types = num_lig_node_types 18 | 19 | self.feat_net = SchNet(num_node_types, hidden_channels, num_filters, num_interactions, num_gaussians, cutoff) 20 | 21 | node_feat_dim, dist_feat_dim, angle_feat_dim, torsion_feat_dim = hidden_channels, hidden_channels, hidden_channels * 2, hidden_channels * 3 22 | 23 | self.node_flow_layers = nn.ModuleList([ST_Net_Exp(node_feat_dim, num_lig_node_types, hid_dim=hidden_channels, bias=True) for _ in range(num_flow_layers)]) 24 | self.dist_flow_layers = nn.ModuleList([ST_Net_Exp(dist_feat_dim, 1, hid_dim=hidden_channels, bias=True) for _ in range(num_flow_layers)]) 25 | self.angle_flow_layers = nn.ModuleList([ST_Net_Exp(angle_feat_dim, 1, hid_dim=hidden_channels, bias=True) for _ in range(num_flow_layers)]) 26 | self.torsion_flow_layers = nn.ModuleList([ST_Net_Exp(torsion_feat_dim, 1, hid_dim=hidden_channels, bias=True) for _ in range(num_flow_layers)]) 27 | self.focus_mlp = MLP(hidden_channels) 28 | self.contact_mlp = MLP(hidden_channels) 29 | self.deq_coeff = deq_coeff 30 | 31 | 32 | 33 | self.dist_emb = dist_emb(num_radial, cutoff, envelope_exponent=5) 34 | self.angle_emb = angle_emb(num_spherical, num_radial, cutoff, envelope_exponent=5) 35 | 36 | self.dist_lb2 = LB2(num_radial, basis_emb_size, hidden_channels) 37 | self.angle_lb2 = LB2(num_spherical * num_radial, basis_emb_size, hidden_channels) 38 | 39 | if use_gpu: 40 | self.feat_net = self.feat_net.to('cuda') 41 | self.node_flow_layers = self.node_flow_layers.to('cuda') 42 | self.dist_flow_layers = self.dist_flow_layers.to('cuda') 43 | self.angle_flow_layers = self.angle_flow_layers.to('cuda') 44 | self.torsion_flow_layers = self.torsion_flow_layers.to('cuda') 45 | self.focus_mlp = self.focus_mlp.to('cuda') 46 | self.contact_mlp = self.contact_mlp.to('cuda') 47 | self.dist_lb2 = self.dist_lb2.to('cuda') 48 | self.angle_lb2 = self.angle_lb2.to('cuda') 49 | self.dist_emb = self.dist_emb.to('cuda') 50 | self.angle_emb = self.angle_emb.to('cuda') 51 | 52 | 53 | def forward(self, data_batch): 54 | z, pos, batch = data_batch['atom_type'], data_batch['position'], data_batch['batch'] 55 | node_feat = self.feat_net(z, pos, batch) 56 | focus_score = self.focus_mlp(node_feat[~data_batch['rec_mask']]) 57 | contact_score = self.contact_mlp(node_feat[data_batch['contact_y_or_n']]) 58 | 59 | new_atom_type, focus = data_batch['new_atom_type'], data_batch['focus'] 60 | x_z = F.one_hot(new_atom_type, num_classes=self.num_lig_node_types).float() 61 | x_z += self.deq_coeff * torch.rand(x_z.size(), device=x_z.device) 62 | 63 | local_node_type_feat = node_feat[focus[:,0]] 64 | node_latent, node_log_jacob = flow_forward(self.node_flow_layers, x_z, local_node_type_feat) 65 | node_type_emb_block = self.feat_net.embedding 66 | node_type_emb = node_type_emb_block(new_atom_type) 67 | node_emb = node_feat * node_type_emb[batch] 68 | 69 | c1_focus, c2_c1_focus = data_batch['c1_focus'], data_batch['c2_c1_focus'] 70 | dist, angle, torsion = data_batch['new_dist'], data_batch['new_angle'], data_batch['new_torsion'] 71 | 72 | local_dist_feat = node_emb[focus[:,0]] 73 | dist_latent, dist_log_jacob = flow_forward(self.dist_flow_layers, dist, local_dist_feat) 74 | 75 | ### d --> theta 76 | 77 | dist_emb = self.dist_lb2(self.dist_emb(dist.squeeze()[batch].to(torch.float))) 78 | node_emb = node_emb * dist_emb # [N, hidden] * [N, hidden]. N is the total number of steps for all molecules in the batch 79 | 80 | 81 | node_emb_clone = node_emb.clone() # Avoid changing node_emb in-place --> cannot comput gradient otherwise 82 | local_angle_feat = torch.cat((node_emb_clone[c1_focus[:,1]], node_emb_clone[c1_focus[:,0]]), dim=1) 83 | angle_latent, angle_log_jacob = flow_forward(self.angle_flow_layers, angle, local_angle_feat) 84 | 85 | 86 | 87 | ### d, theta --> phi 88 | dist_angle_emd = self.angle_lb2(self.angle_emb(dist.squeeze()[batch].to(torch.float), angle.squeeze()[batch].to(torch.float))) 89 | 90 | node_emb = node_emb * dist_angle_emd 91 | 92 | local_torsion_feat = torch.cat((node_emb[c2_c1_focus[:,2]], node_emb[c2_c1_focus[:,1]], node_emb[c2_c1_focus[:,0]]), dim=1) 93 | torsion_latent, torsion_log_jacob = flow_forward(self.torsion_flow_layers, torsion, local_torsion_feat) 94 | 95 | return (node_latent, node_log_jacob), focus_score, contact_score, (dist_latent, dist_log_jacob), (angle_latent, angle_log_jacob), (torsion_latent, torsion_log_jacob) 96 | 97 | 98 | def generate(self, type_to_atomic_number, rec_atom_type, rec_position, num_gen=100, temperature=[1.0, 1.0, 1.0, 1.0], min_atoms=2, max_atoms=35, focus_th=0.5, contact_th=0.5, add_final=False, contact_prob=False): 99 | with torch.no_grad(): 100 | if self.use_gpu: 101 | prior_node = torch.distributions.normal.Normal(torch.zeros([self.num_lig_node_types]).cuda(), temperature[0] * torch.ones([self.num_lig_node_types]).cuda()) 102 | prior_dist = torch.distributions.normal.Normal(torch.zeros([1]).cuda(), temperature[1] * torch.ones([1]).cuda()) 103 | prior_angle = torch.distributions.normal.Normal(torch.zeros([1]).cuda(), temperature[2] * torch.ones([1]).cuda()) 104 | prior_torsion = torch.distributions.normal.Normal(torch.zeros([1]).cuda(), temperature[3] * torch.ones([1]).cuda()) 105 | else: 106 | prior_node = torch.distributions.normal.Normal(torch.zeros([self.num_lig_node_types]), temperature[0] * torch.ones([self.num_lig_node_types])) 107 | prior_dist = torch.distributions.normal.Normal(torch.zeros([1]), temperature[1] * torch.ones([1])) 108 | prior_angle = torch.distributions.normal.Normal(torch.zeros([1]), temperature[2] * torch.ones([1])) 109 | prior_torsion = torch.distributions.normal.Normal(torch.zeros([1]), temperature[3] * torch.ones([1])) 110 | 111 | rec_n_atoms = len(rec_atom_type) 112 | node_type_emb_block = self.feat_net.embedding 113 | z_lig = torch.empty([num_gen, 0], dtype=int) 114 | pos_lig = torch.empty([num_gen, 0, 3], dtype=torch.float32) 115 | focuses = torch.empty([num_gen, 0], dtype=int) # Note that the 1st focus ID is the contact ID in rec 116 | if self.use_gpu: 117 | z_lig, pos_lig, focuses = z_lig.cuda(), pos_lig.cuda(), focuses.cuda() 118 | rec_atom_type, rec_position = rec_atom_type.cuda(), rec_position.cuda() 119 | out_dict = {} 120 | 121 | feat_index = lambda node_id, f: f[torch.arange(num_gen), node_id] 122 | pos_index = lambda node_id, p: p[torch.arange(num_gen), node_id].view(num_gen,1,3) 123 | 124 | for i in range(max_atoms): 125 | # print(i) 126 | batch = torch.arange(num_gen, device=z_lig.device).view(num_gen, 1).repeat(1, i+rec_n_atoms) 127 | z = torch.cat((z_lig, rec_atom_type.repeat(num_gen, 1)), dim=1) 128 | pos = torch.cat((pos_lig, rec_position.repeat(num_gen, 1, 1)), dim=1) 129 | node_feat = self.feat_net(z.view(-1), pos.view(-1,3), batch.view(-1)) 130 | 131 | if i == 0: 132 | contact_score = self.contact_mlp(node_feat).view(num_gen, rec_n_atoms) 133 | if contact_prob: # The prob of selecting a atom is propotional to the predicted prob 134 | contact_mask = contact_score > contact_th 135 | can_contact = contact_score 136 | can_contact[contact_mask] = 0 137 | else: # Contact atom is selected randomly from nodes with predicted score < contact_th 138 | can_contact = contact_score < contact_th 139 | focus_node_id = torch.multinomial(can_contact.float(), 1).view(num_gen) 140 | 141 | node_feat = node_feat.view(num_gen, rec_n_atoms, -1) 142 | 143 | else: 144 | rec_mask = torch.cat((torch.zeros([i], dtype=torch.bool), torch.ones([rec_n_atoms], dtype=torch.bool))).repeat(num_gen) 145 | focus_score = self.focus_mlp(node_feat[~rec_mask]).view(num_gen, i) 146 | can_focus = (focus_score < focus_th) 147 | complete_mask = (can_focus.sum(dim=-1) == 0) 148 | if i > max(0, min_atoms-1) and torch.sum(complete_mask) > 0: 149 | out_dict[i] = {} 150 | out_node_types = z_lig[complete_mask].view(-1, i).cpu().numpy() 151 | out_dict[i]['_atomic_numbers'] = type_to_atomic_number[out_node_types] 152 | out_dict[i]['_positions'] = pos_lig[complete_mask].view(-1, i, 3).cpu().numpy() 153 | out_dict[i]['_focus'] = focuses[complete_mask].view(-1, i).cpu().numpy() 154 | 155 | continue_mask = torch.logical_not(complete_mask) 156 | dirty_mask = torch.nonzero(torch.isnan(focus_score).sum(dim=-1))[:,0] 157 | if len(dirty_mask) > 0: 158 | continue_mask[dirty_mask] = False 159 | dirty_mask = torch.nonzero(torch.isinf(focus_score).sum(dim=-1))[:,0] 160 | if len(dirty_mask) > 0: 161 | continue_mask[dirty_mask] = False 162 | 163 | if torch.sum(continue_mask) == 0: 164 | break 165 | 166 | node_feat = node_feat.view(num_gen, i+rec_n_atoms, -1) 167 | num_gen = torch.sum(continue_mask).cpu().item() 168 | z, pos, can_focus, focuses = z[continue_mask], pos[continue_mask], can_focus[continue_mask], focuses[continue_mask] 169 | z_lig, pos_lig = z_lig[continue_mask], pos_lig[continue_mask] 170 | focus_node_id = torch.multinomial(can_focus.float(), 1).view(num_gen) 171 | node_feat = node_feat[continue_mask] 172 | 173 | latent_node = prior_node.sample([num_gen]) 174 | 175 | local_node_type_feat = feat_index(focus_node_id, node_feat) 176 | 177 | latent_node = flow_reverse(self.node_flow_layers, latent_node, local_node_type_feat) 178 | node_type_id = torch.argmax(latent_node, dim=1) 179 | node_type_emb = node_type_emb_block(node_type_id) 180 | node_emb = node_feat * node_type_emb.view(num_gen, 1, -1) 181 | 182 | latent_dist = prior_dist.sample([num_gen]) 183 | 184 | local_dist_feat = feat_index(focus_node_id, node_emb) 185 | 186 | dist = flow_reverse(self.dist_flow_layers, latent_dist, local_dist_feat) 187 | 188 | dist_emb = self.dist_lb2(self.dist_emb(dist.to(torch.float))) 189 | node_emb = node_emb * dist_emb.view(num_gen, 1, -1) 190 | 191 | # print(pos.shape) 192 | dist_to_focus = torch.sum(torch.square(pos - pos_index(focus_node_id, pos)), dim=-1) 193 | _, indices = torch.topk(dist_to_focus, 3, largest=False) 194 | c1_node_id, c2_node_id = indices[:,1], indices[:,2] 195 | 196 | 197 | latent_angle = prior_angle.sample([num_gen]) 198 | local_angle_feat = torch.cat((feat_index(focus_node_id, node_emb), feat_index(c1_node_id, node_emb)), dim=1) 199 | 200 | angle = flow_reverse(self.angle_flow_layers, latent_angle, local_angle_feat) 201 | 202 | 203 | dist_angle_emd = self.angle_lb2(self.angle_emb(dist.to(torch.float), angle.to(torch.float))) 204 | node_emb = node_emb * dist_angle_emd.view(num_gen, 1, -1) 205 | 206 | 207 | latent_torsion = prior_torsion.sample([num_gen]) 208 | 209 | local_torsion_feat = torch.cat((feat_index(focus_node_id, node_emb), feat_index(c1_node_id, node_emb), feat_index(c2_node_id, node_emb)), dim=1) 210 | 211 | torsion = flow_reverse(self.torsion_flow_layers, latent_torsion, local_torsion_feat) 212 | new_pos = dattoxyz(pos_index(focus_node_id, pos), pos_index(c1_node_id, pos), pos_index(c2_node_id, pos), dist, angle, torsion) 213 | 214 | 215 | # print(z_lig.shape) 216 | # print(node_type_id.shape) 217 | z_lig = torch.cat((z_lig, node_type_id[:, None]), dim=1) 218 | pos_lig = torch.cat((pos_lig, new_pos.view(num_gen, 1, 3)), dim=1) 219 | focuses = torch.cat((focuses, focus_node_id[:,None]), dim=1) 220 | 221 | if add_final and torch.sum(continue_mask) > 0: 222 | out_dict[i+1] = {} 223 | out_node_types = z_lig.view(-1,i+1).cpu().numpy() 224 | out_dict[i+1]['_atomic_numbers'] = type_to_atomic_number[out_node_types] 225 | out_dict[i+1]['_positions'] = pos_lig.view(-1, i+1, 3).cpu().numpy() 226 | out_dict[i+1]['_focus'] = focuses.view(-1, i+1).cpu().numpy() 227 | 228 | return out_dict -------------------------------------------------------------------------------- /GraphBP/model/net_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch_geometric.nn.acts import swish 5 | from torch_geometric.nn.inits import glorot_orthogonal 6 | 7 | class ST_Net_Exp(nn.Module): 8 | def __init__(self, input_dim, output_dim, hid_dim=64, num_layers=2, bias=True): 9 | super(ST_Net_Exp, self).__init__() 10 | self.num_layers = num_layers # unused 11 | self.input_dim = input_dim 12 | self.hid_dim = hid_dim 13 | self.output_dim = output_dim 14 | self.bias = bias 15 | 16 | self.linear1 = nn.Linear(input_dim, hid_dim, bias=bias) 17 | self.linear2 = nn.Linear(hid_dim, output_dim*2, bias=bias) 18 | self.rescale1 = Rescale() 19 | self.tanh = nn.Tanh() 20 | 21 | self.reset_parameters() 22 | 23 | def reset_parameters(self): 24 | nn.init.xavier_uniform_(self.linear1.weight) 25 | nn.init.constant_(self.linear2.weight, 1e-10) 26 | if self.bias: 27 | nn.init.constant_(self.linear1.bias, 0.) 28 | nn.init.constant_(self.linear2.bias, 0.) 29 | 30 | def forward(self, x): 31 | ''' 32 | :param x: (batch * repeat_num for node/edge, emb) 33 | :return: w and b for affine operation 34 | ''' 35 | x = self.linear2(self.tanh(self.linear1(x))) 36 | s = x[:, :self.output_dim] 37 | t = x[:, self.output_dim:] 38 | s = self.rescale1(torch.tanh(s)) 39 | return s, t 40 | 41 | 42 | class Rescale(nn.Module): 43 | def __init__(self): 44 | super(Rescale, self).__init__() 45 | self.weight = nn.Parameter(torch.zeros([1])) 46 | 47 | def forward(self, x): 48 | if torch.isnan(torch.exp(self.weight)).any(): 49 | print(self.weight) 50 | raise RuntimeError('Rescale factor has NaN entries') 51 | 52 | x = torch.exp(self.weight) * x 53 | return x 54 | 55 | 56 | def init_layer(layer: torch.nn.Linear, w_scale=1.0) -> torch.nn.Linear: 57 | torch.nn.init.orthogonal_(layer.weight.data) 58 | layer.weight.data.mul_(w_scale) # type: ignore 59 | torch.nn.init.constant_(layer.bias.data, 0) 60 | return layer 61 | 62 | 63 | class MLP(nn.Module): 64 | def __init__(self, input_dim, hidden_units=128): 65 | super(MLP, self).__init__() 66 | self.layers = nn.Sequential( 67 | init_layer(nn.Linear(input_dim, hidden_units)), 68 | nn.ReLU(), 69 | init_layer(nn.Linear(hidden_units, 1)), 70 | nn.Sigmoid() 71 | ) 72 | 73 | def forward(self, x): 74 | return self.layers(x).view(-1) 75 | 76 | class LB2(nn.Module): 77 | def __init__(self, input_dim, hidden_units, output_dim, bias=False): 78 | super(LB2, self).__init__() 79 | self.lin1 = nn.Linear(input_dim, hidden_units, bias=bias) 80 | self.lin2 = nn.Linear(hidden_units, output_dim, bias=bias) 81 | 82 | def forward(self, x): 83 | return self.lin2(self.lin1(x)) 84 | 85 | 86 | def flow_reverse(flow_layers, latent, feat): 87 | for i in reversed(range(len(flow_layers))): 88 | s, t = flow_layers[i](feat) 89 | s = s.exp() 90 | latent = (latent / s) - t 91 | return latent 92 | 93 | 94 | def flow_forward(flow_layers, x, feat): 95 | for i in range(len(flow_layers)): 96 | s, t = flow_layers[i](feat) 97 | s = s.exp() 98 | x = (x + t) * s 99 | 100 | if i == 0: 101 | x_log_jacob = (torch.abs(s) + 1e-20).log() 102 | else: 103 | x_log_jacob += (torch.abs(s) + 1e-20).log() 104 | return x, x_log_jacob -------------------------------------------------------------------------------- /GraphBP/model/schnet.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/schnet.html#SchNet 2 | 3 | import torch 4 | from torch.nn import Embedding, ModuleList, Sequential, Linear 5 | import torch.nn.functional as F 6 | from torch_geometric.nn import radius_graph, MessagePassing 7 | from math import pi as PI 8 | 9 | 10 | class GaussianSmearing(torch.nn.Module): 11 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50): 12 | super(GaussianSmearing, self).__init__() 13 | offset = torch.linspace(start, stop, num_gaussians) 14 | self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 15 | self.register_buffer('offset', offset) 16 | 17 | def forward(self, dist): 18 | dist = dist.view(-1, 1) - self.offset.view(1, -1) 19 | return torch.exp(self.coeff * torch.pow(dist, 2)) 20 | 21 | 22 | class ShiftedSoftplus(torch.nn.Module): 23 | def __init__(self): 24 | super(ShiftedSoftplus, self).__init__() 25 | self.shift = torch.log(torch.tensor(2.0)).item() 26 | 27 | def forward(self, x): 28 | return F.softplus(x) - self.shift 29 | 30 | 31 | class CFConv(MessagePassing): 32 | def __init__(self, in_channels, out_channels, num_filters, nn, cutoff): 33 | super(CFConv, self).__init__(aggr='add') 34 | self.lin1 = Linear(in_channels, num_filters, bias=False) 35 | self.lin2 = Linear(num_filters, out_channels) 36 | self.nn = nn 37 | self.cutoff = cutoff 38 | 39 | self.reset_parameters() 40 | 41 | def reset_parameters(self): 42 | torch.nn.init.xavier_uniform_(self.lin1.weight) 43 | torch.nn.init.xavier_uniform_(self.lin2.weight) 44 | self.lin2.bias.data.fill_(0) 45 | 46 | def forward(self, x, edge_index, edge_weight, edge_attr): 47 | C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0) 48 | W = self.nn(edge_attr) * C.view(-1, 1) 49 | 50 | x = self.lin1(x) 51 | x = self.propagate(edge_index, x=x, W=W) 52 | x = self.lin2(x) 53 | return x 54 | 55 | def message(self, x_j, W): 56 | return x_j * W 57 | 58 | 59 | class InteractionBlock(torch.nn.Module): 60 | def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff): 61 | super(InteractionBlock, self).__init__() 62 | self.mlp = Sequential( 63 | Linear(num_gaussians, num_filters), 64 | ShiftedSoftplus(), 65 | Linear(num_filters, num_filters), 66 | ) 67 | self.conv = CFConv(hidden_channels, hidden_channels, num_filters, self.mlp, cutoff) 68 | self.act = ShiftedSoftplus() 69 | self.lin = Linear(hidden_channels, hidden_channels) 70 | 71 | self.reset_parameters() 72 | 73 | def reset_parameters(self): 74 | torch.nn.init.xavier_uniform_(self.mlp[0].weight) 75 | self.mlp[0].bias.data.fill_(0) 76 | torch.nn.init.xavier_uniform_(self.mlp[2].weight) 77 | self.mlp[0].bias.data.fill_(0) 78 | self.conv.reset_parameters() 79 | torch.nn.init.xavier_uniform_(self.lin.weight) 80 | self.lin.bias.data.fill_(0) 81 | 82 | def forward(self, x, edge_index, edge_weight, edge_attr): 83 | x = self.conv(x, edge_index, edge_weight, edge_attr) 84 | x = self.act(x) 85 | x = self.lin(x) 86 | return x 87 | 88 | 89 | class SchNet(torch.nn.Module): 90 | def __init__(self, num_node_types, hidden_channels=128, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0): 91 | super(SchNet, self).__init__() 92 | self.hidden_channels = hidden_channels 93 | self.num_filters = num_filters 94 | self.num_interactions = num_interactions 95 | self.num_gaussians = num_gaussians 96 | self.cutoff = cutoff 97 | 98 | self.embedding = Embedding(num_node_types, hidden_channels) 99 | self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians) 100 | self.interactions = ModuleList() 101 | for _ in range(num_interactions): 102 | block = InteractionBlock(hidden_channels, num_gaussians, 103 | num_filters, cutoff) 104 | self.interactions.append(block) 105 | 106 | # self.lin1 = Linear(hidden_channels, hidden_channels // 2) 107 | # self.act = ShiftedSoftplus() 108 | # self.lin2 = Linear(hidden_channels // 2, 1) 109 | self.reset_parameters() 110 | 111 | 112 | def reset_parameters(self): 113 | self.embedding.reset_parameters() 114 | for interaction in self.interactions: 115 | interaction.reset_parameters() 116 | # torch.nn.init.xavier_uniform_(self.lin1.weight) 117 | # self.lin1.bias.data.fill_(0) 118 | # torch.nn.init.xavier_uniform_(self.lin2.weight) 119 | # self.lin2.bias.data.fill_(0) 120 | 121 | 122 | def forward(self, z, pos, batch): 123 | h = self.embedding(z) 124 | edge_index = radius_graph(pos, r=self.cutoff, batch=batch) 125 | row, col = edge_index 126 | edge_weight = (pos[row] - pos[col]).norm(dim=-1) 127 | edge_attr = self.distance_expansion(edge_weight) 128 | 129 | for interaction in self.interactions: 130 | h = h + interaction(h, edge_index, edge_weight, edge_attr) 131 | 132 | # h = self.lin1(h) 133 | # h = self.act(h) 134 | # h = self.lin2(h) 135 | 136 | return h 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /GraphBP/runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, dataset 3 | import os 4 | import time 5 | import pandas as pd 6 | import numpy as np 7 | from model import GraphBP 8 | from dataset_from_scratch import CrossDocked2020_SBDD, collate_mols 9 | import torch.optim as optim 10 | from torch_scatter import scatter 11 | 12 | from Bio.PDB import PDBParser 13 | import warnings 14 | from Bio.PDB.PDBExceptions import PDBConstructionWarning 15 | from rdkit import Chem 16 | 17 | atomic_num_to_type = {5:0, 6:1, 7:2, 8:3, 9:4, 12:5, 13:6, 14:7, 15:8, 16:9, 17:10, 21:11, 23:12, 26:13, 29:14, 30:15, 33:16, 34:17, 35:18, 39:19, 42:20, 44:21, 45:22, 51:23, 53:24, 74:25, 79:26} 18 | 19 | atomic_element_to_type = {'C':27, 'N':28, 'O':29, 'NA':30, 'MG':31, 'P':32, 'S':33, 'CL':34, 'K':35, 'CA':36, 'MN':37, 'CO':38, 'CU':39, 'ZN':40, 'SE':41, 'CD':42, 'I':43, 'CS':44, 'HG':45} 20 | 21 | class Runner(): 22 | def __init__(self, conf, out_path=None): 23 | self.conf = conf 24 | if conf['gen_model'] == 'GraphBP': 25 | self.model = GraphBP(**conf['model']) 26 | else: 27 | print('Please give correct gen_model name!') 28 | self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), **conf['optim']) 29 | self.focus_ce = torch.nn.BCELoss() 30 | self.contact_ce = torch.nn.BCELoss() 31 | self.out_path = out_path 32 | 33 | 34 | def _train_epoch(self, loader): 35 | self.model.train() 36 | total_ll_node, total_ll_dist, total_ll_angle, total_ll_torsion, total_focus_ce, total_contact_ce = 0, 0, 0, 0, 0, 0 37 | skip_batch_num = 0 38 | for iter_num, data_batch in enumerate(loader): 39 | 40 | if torch.cuda.is_available(): 41 | torch.cuda.synchronize() 42 | t_start = time.perf_counter() 43 | 44 | if self.conf['model']['use_gpu']: 45 | for key in data_batch: 46 | data_batch[key] = data_batch[key].to('cuda') 47 | 48 | if data_batch['atom_type'].size(0) > 600000: 49 | skip_batch_num += 1 50 | print("Skip batch to avoid OOM!") 51 | continue 52 | node_out, focus_score, contact_score, dist_out, angle_out, torsion_out = self.model(data_batch) 53 | cannot_focus = data_batch['cannot_focus'] 54 | cannot_contact = data_batch['cannot_contact'] 55 | 56 | ll_node = torch.mean(1/2 * (node_out[0] ** 2) - node_out[1]) 57 | ll_dist = torch.mean(1/2 * (dist_out[0] ** 2) - dist_out[1]) 58 | ll_angle = torch.mean(1/2 * (angle_out[0] ** 2) - angle_out[1]) 59 | ll_torsion = torch.mean(1/2 * (torsion_out[0] ** 2) - torsion_out[1]) 60 | focus_ce = self.focus_ce(focus_score, cannot_focus) 61 | contact_ce = self.contact_ce(contact_score, cannot_contact) 62 | 63 | loss = ll_node + ll_dist + ll_angle + ll_torsion + focus_ce + contact_ce 64 | 65 | self.optimizer.zero_grad() 66 | loss.backward() 67 | self.optimizer.step() 68 | 69 | total_ll_node += ll_node.to('cpu').item() 70 | total_ll_dist += ll_dist.to('cpu').item() 71 | total_ll_angle += ll_angle.to('cpu').item() 72 | total_ll_torsion += ll_torsion.to('cpu').item() 73 | total_focus_ce += focus_ce.to('cpu').item() 74 | total_contact_ce += contact_ce.to('cpu').item() 75 | 76 | if torch.cuda.is_available(): 77 | torch.cuda.synchronize() 78 | t_end = time.perf_counter() 79 | 80 | duration = t_end - t_start 81 | 82 | if iter_num % self.conf['verbose'] == 0: 83 | print('Training iteration {} | loss node {:.6f} dist {:.6f} angle {:.6f} torsion {:.6f} focus {:.6f} contact {:.6f} duration {:.6f}'.format(iter_num, ll_node.to('cpu').item(), 84 | ll_dist.to('cpu').item(), ll_angle.to('cpu').item(), ll_torsion.to('cpu').item(), focus_ce.to('cpu').item(), contact_ce.to('cpu').item(), duration)) 85 | 86 | iter_num += 1 87 | iter_num -= skip_batch_num 88 | return total_ll_node / iter_num, total_ll_dist / iter_num, total_ll_angle / iter_num, total_ll_torsion / iter_num, total_focus_ce / iter_num, total_contact_ce / iter_num, skip_batch_num 89 | 90 | 91 | def train(self, binding_site_range): 92 | dataset = CrossDocked2020_SBDD(binding_site_range=binding_site_range) 93 | loader = DataLoader(dataset, batch_size=self.conf['batch_size'], shuffle=True, collate_fn=collate_mols, num_workers=self.conf['num_workers']) 94 | 95 | 96 | epochs = self.conf['epochs'] 97 | for epoch in range(epochs): 98 | avg_ll_node, avg_ll_dist, avg_ll_angle, avg_ll_torsion, avg_focus_ce, avg_contact_ce, skip_batch_num = self._train_epoch(loader) 99 | print('=============================================') 100 | print('Training | Average loss node {:.6f} dist {:.6f} angle {:.6f} torsion {:.6f} focus {:.6f} contact {:.6f}'.format(avg_ll_node, avg_ll_dist, avg_ll_angle, avg_ll_torsion, avg_focus_ce, avg_contact_ce)) 101 | print('Skip batch nums:', skip_batch_num) 102 | print('=============================================') 103 | if self.out_path is not None: 104 | torch.save(self.model.state_dict(), os.path.join(self.out_path, 'model_{}.pth'.format(epoch))) 105 | file_obj = open(os.path.join(self.out_path, 'record.txt'), 'a') 106 | file_obj.write('Training | Average loss node {:.6f} dist {:.6f} angle {:.6f} torsion {:.6f} focus {:.6f} contact {:.6f}\n'.format(avg_ll_node, avg_ll_dist, avg_ll_angle, avg_ll_torsion, avg_focus_ce, avg_contact_ce)) 107 | file_obj.close() 108 | 109 | 110 | 111 | def generate(self, num_gen, temperature=[1.0, 1.0, 1.0, 1.0], min_atoms=2, max_atoms=35, focus_th=0.5, contact_th=0.5, add_final=False, contact_prob=False, data_root='./data/crossdock2020', data_file='./data/crossdock2020/selected_test_targets.types', atomic_num_to_type=atomic_num_to_type, atomic_element_to_type = atomic_element_to_type, known_binding_site=False, binding_site_range=15.0): 112 | data_cols = [ 113 | 'low_rmsd', 114 | 'true_aff', 115 | 'xtal_rmsd', 116 | 'rec_src', 117 | 'lig_src', 118 | 'vina_aff' 119 | ] 120 | data_lines = pd.read_csv( 121 | data_file, sep=' ', names=data_cols, index_col=False 122 | ) 123 | pdb_parser = PDBParser() 124 | 125 | 126 | all_mol_dicts = {} 127 | 128 | for index in range(len(data_lines)): 129 | example = data_lines.iloc[index] 130 | rec_src = example.rec_src 131 | lig_src = example.lig_src.rsplit('.', 1)[0] 132 | print(rec_src) 133 | print(lig_src) 134 | print("=============") 135 | 136 | with warnings.catch_warnings(): 137 | warnings.simplefilter('ignore', PDBConstructionWarning) 138 | rec_structure = pdb_parser.get_structure('', os.path.join(data_root, rec_src)) 139 | 140 | rec_atom_type = [atomic_element_to_type[atom.element] for atom in rec_structure.get_atoms() if atom.element!='H'] 141 | rec_position = np.stack([atom.coord for atom in rec_structure.get_atoms() if atom.element!='H'], axis=0) 142 | rec_atom_type = torch.tensor(rec_atom_type) 143 | rec_position = torch.tensor(rec_position) 144 | # print(rec_atom_type.shape) 145 | # print(rec_position.shape) 146 | 147 | if known_binding_site: 148 | supp = Chem.SDMolSupplier() 149 | print("Generate molecules with given binding site infomation...") 150 | sdf_file = os.path.join(data_root, lig_src) 151 | supp.SetData(open(sdf_file).read(), removeHs=False, sanitize=False) 152 | lig_mol = Chem.rdmolops.RemoveAllHs(supp[0], sanitize=False) 153 | lig_n_atoms = lig_mol.GetNumAtoms() 154 | lig_pos = supp.GetItemText(0).split('\n')[4:4+lig_n_atoms] 155 | lig_position = np.array([[float(x) for x in line.split()[:3]] for line in lig_pos], dtype=np.float32) 156 | lig_position = torch.tensor(lig_position) 157 | lig_center = torch.mean(lig_position, dim=0) 158 | rec_atom_dist_to_lig_center = torch.sqrt(torch.sum(torch.square(rec_position - lig_center), dim=-1)) 159 | # print(lig_position) 160 | # print(rec_position) 161 | # print(rec_atom_dist_to_lig_center) 162 | selected_mask = rec_atom_dist_to_lig_center <= binding_site_range 163 | assert torch.sum(selected_mask) > 0 164 | rec_atom_type = rec_atom_type[selected_mask] 165 | rec_position = rec_position[selected_mask] 166 | # print(rec_atom_type.shape) 167 | # print(rec_position.shape) 168 | # print(lig_position) 169 | # print(rec_position) 170 | del supp 171 | 172 | 173 | num_remain = num_gen 174 | one_time_gen = self.conf['chunk_size'] 175 | type_to_atomic_number_dict = {atomic_num_to_type[k]:k for k in atomic_num_to_type} 176 | type_to_atomic_number = np.zeros([max(type_to_atomic_number_dict.keys())+1], dtype=int) 177 | for k in type_to_atomic_number_dict: 178 | type_to_atomic_number[k] = type_to_atomic_number_dict[k] 179 | mol_dicts = {} 180 | 181 | self.model.eval() 182 | while num_remain > 0: 183 | if num_remain > one_time_gen: 184 | mols = self.model.generate(type_to_atomic_number, rec_atom_type, rec_position, one_time_gen, temperature, min_atoms, max_atoms, focus_th, contact_th, add_final, contact_prob) 185 | else: 186 | mols = self.model.generate(type_to_atomic_number, rec_atom_type, rec_position, num_remain, temperature, min_atoms, max_atoms, focus_th, contact_th, add_final, contact_prob) 187 | 188 | for num_atom in mols: 189 | if not num_atom in mol_dicts.keys(): 190 | mol_dicts[num_atom] = mols[num_atom] 191 | else: 192 | mol_dicts[num_atom]['_atomic_numbers'] = np.concatenate((mol_dicts[num_atom]['_atomic_numbers'], mols[num_atom]['_atomic_numbers']), axis=0) 193 | mol_dicts[num_atom]['_positions'] = np.concatenate((mol_dicts[num_atom]['_positions'], mols[num_atom]['_positions']), axis=0) 194 | mol_dicts[num_atom]['_focus'] = np.concatenate((mol_dicts[num_atom]['_focus'], mols[num_atom]['_focus']), axis=0) 195 | num_mol = len(mols[num_atom]['_atomic_numbers']) 196 | num_remain -= num_mol 197 | 198 | print('{} molecules are generated!'.format(num_gen-num_remain)) 199 | mol_dicts['rec_src'] = rec_src 200 | mol_dicts['lig_src'] = lig_src 201 | all_mol_dicts[index] = mol_dicts 202 | 203 | return all_mol_dicts 204 | 205 | -------------------------------------------------------------------------------- /GraphBP/scripts/split_sdf.py: -------------------------------------------------------------------------------- 1 | import sys, os, gzip 2 | from collections import defaultdict 3 | 4 | from rdkit import Chem 5 | 6 | sys.path.append('..') 7 | 8 | 9 | def read_rd_mols_from_sdf_file(sdf_file, removeHs=False, sanitize=False): 10 | if sdf_file.endswith('.gz'): 11 | with gzip.open(sdf_file) as f: 12 | suppl = Chem.ForwardSDMolSupplier(f, removeHs=removeHs, sanitize=sanitize) 13 | return [mol for mol in suppl] 14 | else: 15 | suppl = Chem.SDMolSupplier(sdf_file, removeHs=removeHs, sanitize=sanitize) 16 | return [mol for mol in suppl] 17 | 18 | 19 | def write_rd_mol_to_sdf_file(sdf_file, mol, name='', kekulize=True): 20 | return write_rd_mols_to_sdf_file(sdf_file, [mol], name, kekulize) 21 | 22 | def write_rd_mols_to_sdf_file(sdf_file, mols, name='', kekulize=True): 23 | ''' 24 | Write a list of rdkit molecules to a file 25 | or io stream in sdf format. 26 | ''' 27 | use_gzip = ( 28 | isinstance(sdf_file, str) and sdf_file.endswith('.gz') 29 | ) 30 | if use_gzip: 31 | sdf_file = gzip.open(sdf_file, 'wt') 32 | writer = Chem.SDWriter(sdf_file) 33 | writer.SetKekulize(kekulize) 34 | for mol in mols: 35 | if name: 36 | mol.SetProp('_Name', name) 37 | writer.write(mol) 38 | writer.close() 39 | if use_gzip: 40 | sdf_file.close() 41 | 42 | 43 | 44 | def split_sdf(sdf_file): 45 | ''' 46 | Split an sdf file into several files 47 | that each contain one molecular pose. 48 | ''' 49 | assert os.path.isfile(sdf_file), sdf_file + ' does not exist' 50 | print('Splitting', sdf_file) 51 | in_dir, in_base = os.path.split(sdf_file) 52 | mol_name = in_base.split('.', 1)[0] 53 | pose_count = defaultdict(int) 54 | for mol in read_rd_mols_from_sdf_file(sdf_file): 55 | #mol_name = mol.GetProp('_Name') 56 | pose_index = pose_count[mol_name] 57 | # out_base = '{}_{}.sdf.gz'.format(mol_name, pose_index) 58 | out_base = '{}_{}.sdf'.format(mol_name, pose_index) 59 | out_file = os.path.join(in_dir, out_base) 60 | write_rd_mol_to_sdf_file(out_file, mol, name=mol_name, kekulize=True) 61 | print('\tWriting', out_file) 62 | pose_count[mol_name] += 1 63 | 64 | 65 | def find_and_split_sdf(sdf_file): 66 | ''' 67 | Given the name of a single-pose sdf file, 68 | find and split the multi-pose sdf file. 69 | ''' 70 | if os.path.isfile(sdf_file): 71 | print('Found', sdf_file) 72 | return 73 | # need to find and split multi-pose file 74 | in_prefix = sdf_file.split('.', 1)[0] # strip file extension 75 | in_prefix = in_prefix.rsplit('_', 1)[0] # strip pose index 76 | multi_sdf_file = in_prefix + '.sdf.gz' 77 | split_sdf(multi_sdf_file) 78 | assert os.path.isfile(sdf_file), sdf_file + ' was not created' 79 | 80 | 81 | if __name__ == '__main__': 82 | _, data_file, data_root = sys.argv 83 | with open(data_file) as f: 84 | lines = f.readlines() 85 | n_lines = len(lines) 86 | for i, line in enumerate(lines): 87 | pct = 100*(i+1)/n_lines 88 | print(f'[{pct:.2f}%] ', end='') 89 | sdf_file = os.path.join(data_root, line.split()[4].rsplit('.', 1)[0]) # Use .sdf file instead of .sdf.gz file 90 | find_and_split_sdf(sdf_file) 91 | print('[100.00%] Done') 92 | -------------------------------------------------------------------------------- /GraphBP/trained_model/model_33.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/GraphBP/c5a2a2a1a5781d471754ed091ce4f79d15c5e341/GraphBP/trained_model/model_33.pth -------------------------------------------------------------------------------- /GraphBP/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .bond_adding import BondAdder 2 | -------------------------------------------------------------------------------- /GraphBP/utils/bond_adding.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://github.com/luost26/3D-Generative-SBDD/tree/6f9c7d92784e58474b9c22a74c8113f0344ca795 and https://github.com/mattragoza/liGAN 2 | 3 | import numpy as np 4 | 5 | from openbabel import openbabel as ob 6 | from rdkit.Chem import AllChem as Chem 7 | from rdkit import Geometry 8 | 9 | from scipy.spatial.distance import pdist 10 | from scipy.spatial.distance import squareform 11 | 12 | class MolReconsError(Exception): 13 | pass 14 | 15 | class BondAdder(): 16 | ''' 17 | An algorithm for constructing a valid molecule 18 | from a structure of atomic coordinates and types. 19 | 20 | First, it converts the struture to OBAtoms and 21 | tries to maintain as many of the atomic proper- 22 | ties defined by the atom types as possible. 23 | 24 | Next, it add bonds to the atoms, using the atom 25 | properties and coordinates as constraints. 26 | ''' 27 | def __init__( 28 | self, 29 | min_bond_len=0.01, 30 | max_bond_len=4.0, 31 | max_bond_stretch=0.45, 32 | min_bond_angle=45 33 | ): 34 | self.min_bond_len = min_bond_len 35 | self.max_bond_len = max_bond_len 36 | 37 | self.max_bond_stretch = max_bond_stretch 38 | self.min_bond_angle = min_bond_angle 39 | self.UPGRADE_BOND_ORDER = {Chem.BondType.SINGLE:Chem.BondType.DOUBLE, Chem.BondType.DOUBLE:Chem.BondType.TRIPLE} 40 | 41 | def to_ob_mol(self, xyz, atomic_nums): 42 | ''' 43 | Convert numpy arrays to ob_mol 44 | ''' 45 | mol = ob.OBMol() 46 | mol.BeginModify() 47 | atoms = [] 48 | for xyz,t in zip(xyz, atomic_nums): 49 | x,y,z = xyz 50 | atom = mol.NewAtom() 51 | atom.SetAtomicNum(t) 52 | atom.SetVector(x,y,z) 53 | atoms.append(atom) 54 | return mol, atoms 55 | 56 | 57 | def fixup(self, mol): 58 | mol.SetAromaticPerceived(True) #avoid perception 59 | for atom in ob.OBMolAtomIter(mol): 60 | 61 | 62 | if atom.IsAromatic(): 63 | atom.SetHyb(2) 64 | 65 | if (atom.GetAtomicNum() in (7, 8)) and atom.IsInRing(): # Nitrogen, Oxygen 66 | #this is a little iffy, ommitting until there is more evidence it is a net positive 67 | #we don't have aromatic types for nitrogen, but if it 68 | #is in a ring with aromatic carbon mark it aromatic as well 69 | acnt = 0 70 | for nbr in ob.OBAtomAtomIter(atom): 71 | if nbr.IsAromatic(): 72 | acnt += 1 73 | if acnt > 1: 74 | atom.SetAromatic(True) 75 | 76 | def connect_the_dots(self, mol, atoms): 77 | ''' 78 | Add bonds based on distance 79 | ''' 80 | pt = Chem.GetPeriodicTable() 81 | 82 | mol.BeginModify() 83 | 84 | #just going to to do n^2 comparisons, can worry about efficiency later 85 | coords = np.array([(a.GetX(),a.GetY(),a.GetZ()) for a in atoms]) 86 | dists = squareform(pdist(coords)) 87 | 88 | for (i,a) in enumerate(atoms): 89 | for (j,b) in enumerate(atoms): 90 | if i == j: # Note that this differs from https://github.com/luost26/3D-Generative-SBDD/blob/6f9c7d92784e58474b9c22a74c8113f0344ca795/utils/reconstruct.py#L93 91 | break 92 | if self.min_bond_len < dists[i,j] < self.max_bond_len: 93 | flag = 0 94 | ### Aromatic 95 | if a.IsAromatic() and b.IsAromatic(): 96 | flag = ob.OB_AROMATIC_BOND 97 | mol.AddBond(a.GetIdx(),b.GetIdx(),1,flag) 98 | 99 | atom_maxb = {} 100 | for (i,a) in enumerate(atoms): 101 | #set max valance to the smallest max allowed by openbabel or rdkit 102 | #since we want the molecule to be valid for both (rdkit is usually lower) 103 | maxb = ob.GetMaxBonds(a.GetAtomicNum()) 104 | maxb = min(maxb,pt.GetDefaultValence(a.GetAtomicNum())) 105 | if a.GetAtomicNum() == 16: # sulfone check 106 | if self.count_nbrs_of_elem(a, 8) >= 2: 107 | maxb = 6 108 | atom_maxb[a.GetIdx()] = maxb 109 | 110 | #remove any impossible bonds between halogens 111 | for bond in ob.OBMolBondIter(mol): 112 | a1 = bond.GetBeginAtom() 113 | a2 = bond.GetEndAtom() 114 | if atom_maxb[a1.GetIdx()] == 1 and atom_maxb[a2.GetIdx()] == 1: 115 | mol.DeleteBond(bond) 116 | 117 | def get_bond_info(biter): 118 | '''Return bonds sorted by their distortion''' 119 | bonds = [b for b in biter] 120 | binfo = [] 121 | for bond in bonds: 122 | bdist = bond.GetLength() 123 | #compute how far away from optimal we are 124 | a1 = bond.GetBeginAtom() 125 | a2 = bond.GetEndAtom() 126 | ideal = ob.GetCovalentRad(a1.GetAtomicNum()) + ob.GetCovalentRad(a2.GetAtomicNum()) 127 | stretch = bdist-ideal 128 | binfo.append((stretch,bdist,bond)) 129 | binfo.sort(reverse=True, key=lambda t: t[:2]) #most stretched bonds first 130 | return binfo 131 | 132 | #prioritize removing hypervalency causing bonds, do more valent constrained atoms first since their bonds introduce the most problems with reachability (e.g. oxygen) 133 | hypers = sorted([(atom_maxb[a.GetIdx()],a.GetExplicitValence() - atom_maxb[a.GetIdx()], a) for a in atoms],key=lambda aa: (aa[0],-aa[1])) 134 | for mb,diff,a in hypers: 135 | if a.GetExplicitValence() <= atom_maxb[a.GetIdx()]: 136 | continue 137 | binfo = get_bond_info(ob.OBAtomBondIter(a)) 138 | for stretch,bdist,bond in binfo: 139 | #can we remove this bond without disconnecting the molecule? 140 | a1 = bond.GetBeginAtom() 141 | a2 = bond.GetEndAtom() 142 | 143 | #get right valence 144 | if a1.GetExplicitValence() > atom_maxb[a1.GetIdx()] or \ 145 | a2.GetExplicitValence() > atom_maxb[a2.GetIdx()]: 146 | #don't fragment the molecule 147 | # if not self.reachable(a1,a2): 148 | # continue 149 | mol.DeleteBond(bond) 150 | if a.GetExplicitValence() <= atom_maxb[a.GetIdx()]: 151 | break #let nbr atoms choose what bonds to throw out 152 | 153 | 154 | binfo = get_bond_info(ob.OBMolBondIter(mol)) 155 | #now eliminate geometrically poor bonds 156 | for stretch,bdist,bond in binfo: 157 | #can we remove this bond without disconnecting the molecule? 158 | a1 = bond.GetBeginAtom() 159 | a2 = bond.GetEndAtom() 160 | 161 | #as long as we aren't disconnecting, let's remove things 162 | #that are excessively far away (0.45 from ConnectTheDots) 163 | #get bonds to be less than max allowed 164 | #also remove tight angles, because that is what ConnectTheDots does 165 | if stretch > self.max_bond_stretch or self.forms_small_angle(a1,a2) or self.forms_small_angle(a2,a1): 166 | #don't fragment the molecule 167 | if not self.reachable(a1,a2): 168 | continue 169 | mol.DeleteBond(bond) 170 | 171 | mol.EndModify() 172 | ### Use the largest fragment if the mol is seperated 173 | if len(mol.Separate()) > 1: 174 | sep_mols = sorted([sep_mol for sep_mol in mol.Separate()], key=lambda x: x.NumAtoms()) 175 | print("Using LCC with num_atoms: ", sep_mols[-1].NumAtoms()) 176 | return sep_mols[-1] 177 | 178 | return mol 179 | 180 | def forms_small_angle(self, a, b): 181 | '''Return true if bond between a and b is part of a small angle 182 | with a neighbor of a only.''' 183 | 184 | cutoff=self.min_bond_angle 185 | for nbr in ob.OBAtomAtomIter(a): 186 | if nbr != b: 187 | degrees = b.GetAngle(a,nbr) 188 | if degrees < cutoff: 189 | return True 190 | return False 191 | 192 | 193 | def reachable(self, a, b): 194 | '''Return true if atom b is reachable from a without using the bond between them.''' 195 | if a.GetExplicitDegree() == 1 or b.GetExplicitDegree() == 1: 196 | return False #this is the _only_ bond for one atom 197 | #otherwise do recursive traversal 198 | seenbonds = set([a.GetBond(b).GetIdx()]) 199 | return self.reachable_r(a,b,seenbonds) 200 | 201 | def reachable_r(self, a, b, seenbonds): 202 | '''Recursive helper.''' 203 | 204 | for nbr in ob.OBAtomAtomIter(a): 205 | bond = a.GetBond(nbr).GetIdx() 206 | if bond not in seenbonds: 207 | seenbonds.add(bond) 208 | if nbr == b: 209 | return True 210 | elif self.reachable_r(nbr,b,seenbonds): 211 | return True 212 | return False 213 | 214 | 215 | def count_nbrs_of_elem(self, atom, atomic_num): 216 | ''' 217 | Count the number of neighbors atoms 218 | of atom with the given atomic_num. 219 | ''' 220 | count = 0 221 | for nbr in ob.OBAtomAtomIter(atom): 222 | if nbr.GetAtomicNum() == atomic_num: 223 | count += 1 224 | return count 225 | 226 | 227 | def calc_valence(self, rdatom): 228 | '''Can call GetExplicitValence before sanitize, but need to 229 | know this to fix up the molecule to prevent sanitization failures''' 230 | cnt = 0.0 231 | for bond in rdatom.GetBonds(): 232 | cnt += bond.GetBondTypeAsDouble() 233 | return cnt 234 | 235 | 236 | def convert_ob_mol_to_rd_mol(self, ob_mol): 237 | ''' 238 | Convert ob_mol to rd_mol 239 | ''' 240 | ob_mol.DeleteHydrogens() 241 | n_atoms = ob_mol.NumAtoms() 242 | rd_mol = Chem.RWMol() 243 | rd_conf = Chem.Conformer(n_atoms) 244 | 245 | for ob_atom in ob.OBMolAtomIter(ob_mol): 246 | rd_atom = Chem.Atom(ob_atom.GetAtomicNum()) 247 | #TODO copy format charge 248 | if ob_atom.IsAromatic() and ob_atom.IsInRing() and ob_atom.MemberOfRingSize() <= 6: 249 | #don't commit to being aromatic unless rdkit will be okay with the ring status 250 | #(this can happen if the atoms aren't fit well enough) 251 | rd_atom.SetIsAromatic(True) 252 | i = rd_mol.AddAtom(rd_atom) 253 | ob_coords = ob_atom.GetVector() 254 | x = ob_coords.GetX() 255 | y = ob_coords.GetY() 256 | z = ob_coords.GetZ() 257 | rd_coords = Geometry.Point3D(x, y, z) 258 | rd_conf.SetAtomPosition(i, rd_coords) 259 | 260 | rd_mol.AddConformer(rd_conf) 261 | 262 | for ob_bond in ob.OBMolBondIter(ob_mol): 263 | i = ob_bond.GetBeginAtomIdx()-1 264 | j = ob_bond.GetEndAtomIdx()-1 265 | bond_order = ob_bond.GetBondOrder() 266 | if bond_order == 1: 267 | rd_mol.AddBond(i, j, Chem.BondType.SINGLE) 268 | elif bond_order == 2: 269 | rd_mol.AddBond(i, j, Chem.BondType.DOUBLE) 270 | elif bond_order == 3: 271 | rd_mol.AddBond(i, j, Chem.BondType.TRIPLE) 272 | else: 273 | raise Exception('unknown bond order {}'.format(bond_order)) 274 | 275 | if ob_bond.IsAromatic(): 276 | bond = rd_mol.GetBondBetweenAtoms (i,j) 277 | bond.SetIsAromatic(True) 278 | 279 | 280 | rd_mol = Chem.RemoveHs(rd_mol, sanitize=False) 281 | 282 | pt = Chem.GetPeriodicTable() 283 | 284 | positions = rd_mol.GetConformer().GetPositions() 285 | nonsingles = [] 286 | for bond in rd_mol.GetBonds(): 287 | if bond.GetBondType() == Chem.BondType.DOUBLE or bond.GetBondType() == Chem.BondType.TRIPLE: 288 | i = bond.GetBeginAtomIdx() 289 | j = bond.GetEndAtomIdx() 290 | dist = np.linalg.norm(positions[i]-positions[j]) 291 | nonsingles.append((dist,bond)) 292 | nonsingles.sort(reverse=True, key=lambda t: t[0]) 293 | 294 | for (d,bond) in nonsingles: 295 | a1 = bond.GetBeginAtom() 296 | a2 = bond.GetEndAtom() 297 | 298 | if self.calc_valence(a1) > pt.GetDefaultValence(a1.GetAtomicNum()) or \ 299 | self.calc_valence(a2) > pt.GetDefaultValence(a2.GetAtomicNum()): 300 | btype = Chem.BondType.SINGLE 301 | if bond.GetBondType() == Chem.BondType.TRIPLE: 302 | btype = Chem.BondType.DOUBLE 303 | bond.SetBondType(btype) 304 | 305 | 306 | for atom in rd_mol.GetAtoms(): 307 | #set nitrogens with 4 neighbors to have a charge 308 | if atom.GetAtomicNum() == 7 and atom.GetDegree() == 4: 309 | atom.SetFormalCharge(1) 310 | 311 | 312 | rd_mol = Chem.AddHs(rd_mol,addCoords=True) 313 | 314 | positions = rd_mol.GetConformer().GetPositions() 315 | center = np.mean(positions[np.all(np.isfinite(positions),axis=1)],axis=0) 316 | for atom in rd_mol.GetAtoms(): 317 | i = atom.GetIdx() 318 | pos = positions[i] 319 | if not np.all(np.isfinite(pos)): 320 | #hydrogens on C fragment get set to nan (shouldn't, but they do) 321 | rd_mol.GetConformer().SetAtomPosition(i,center) 322 | 323 | 324 | try: 325 | Chem.SanitizeMol(rd_mol,Chem.SANITIZE_ALL^Chem.SANITIZE_KEKULIZE) 326 | except: 327 | raise MolReconsError() 328 | 329 | #but at some point stop trying to enforce our aromaticity - 330 | #openbabel and rdkit have different aromaticity models so they 331 | #won't always agree. Remove any aromatic bonds to non-aromatic atoms 332 | for bond in rd_mol.GetBonds(): 333 | a1 = bond.GetBeginAtom() 334 | a2 = bond.GetEndAtom() 335 | if bond.GetIsAromatic(): 336 | if not a1.GetIsAromatic() or not a2.GetIsAromatic(): 337 | bond.SetIsAromatic(False) 338 | elif a1.GetIsAromatic() and a2.GetIsAromatic(): 339 | bond.SetIsAromatic(True) 340 | 341 | 342 | return rd_mol 343 | 344 | 345 | 346 | def postprocess_rd_mol_1(self, rdmol): 347 | 348 | rdmol = Chem.RemoveHs(rdmol) 349 | 350 | # Construct bond nbh list 351 | nbh_list = {} 352 | for bond in rdmol.GetBonds(): 353 | begin, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() 354 | if begin not in nbh_list: nbh_list[begin] = [end] 355 | else: nbh_list[begin].append(end) 356 | 357 | if end not in nbh_list: nbh_list[end] = [begin] 358 | else: nbh_list[end].append(begin) 359 | 360 | # Fix missing bond-order 361 | for atom in rdmol.GetAtoms(): 362 | idx = atom.GetIdx() 363 | num_radical = atom.GetNumRadicalElectrons() 364 | if num_radical > 0: 365 | for j in nbh_list[idx]: 366 | if j <= idx: continue 367 | nb_atom = rdmol.GetAtomWithIdx(j) 368 | nb_radical = nb_atom.GetNumRadicalElectrons() 369 | if nb_radical > 0: 370 | bond = rdmol.GetBondBetweenAtoms(idx, j) 371 | bond.SetBondType(self.UPGRADE_BOND_ORDER[bond.GetBondType()]) 372 | nb_atom.SetNumRadicalElectrons(nb_radical - 1) 373 | num_radical -= 1 374 | if num_radical > 0: 375 | atom.SetNumRadicalElectrons(num_radical) 376 | 377 | num_radical = atom.GetNumRadicalElectrons() 378 | if num_radical > 0: 379 | atom.SetNumRadicalElectrons(0) 380 | num_hs = atom.GetNumExplicitHs() 381 | atom.SetNumExplicitHs(num_hs + num_radical) 382 | 383 | return rdmol 384 | 385 | 386 | def postprocess_rd_mol_2(self, rdmol): 387 | rdmol_edit = Chem.RWMol(rdmol) 388 | 389 | ring_info = rdmol.GetRingInfo() 390 | ring_info.AtomRings() 391 | rings = [set(r) for r in ring_info.AtomRings()] 392 | for i, ring_a in enumerate(rings): 393 | if len(ring_a) == 3: 394 | non_carbon = [] 395 | atom_by_symb = {} 396 | for atom_idx in ring_a: 397 | symb = rdmol.GetAtomWithIdx(atom_idx).GetSymbol() 398 | if symb != 'C': 399 | non_carbon.append(atom_idx) 400 | if symb not in atom_by_symb: 401 | atom_by_symb[symb] = [atom_idx] 402 | else: 403 | atom_by_symb[symb].append(atom_idx) 404 | if len(non_carbon) == 2: 405 | rdmol_edit.RemoveBond(*non_carbon) 406 | if 'O' in atom_by_symb and len(atom_by_symb['O']) == 2: 407 | rdmol_edit.RemoveBond(*atom_by_symb['O']) 408 | rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][0]).SetNumExplicitHs( 409 | rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][0]).GetNumExplicitHs() + 1 410 | ) 411 | rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][1]).SetNumExplicitHs( 412 | rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][1]).GetNumExplicitHs() + 1 413 | ) 414 | rdmol = rdmol_edit.GetMol() 415 | 416 | for atom in rdmol.GetAtoms(): 417 | if atom.GetFormalCharge() > 0: 418 | atom.SetFormalCharge(0) 419 | 420 | return rdmol 421 | 422 | 423 | def make_mol(self, atomic_numbers, positions): 424 | ''' 425 | Creat molecules with added bonds 426 | atomic_numbers: [N] 427 | positions: [N, 3] 428 | ''' 429 | xyz = positions.tolist() 430 | atomic_nums = atomic_numbers.tolist() 431 | 432 | mol, atoms = self.to_ob_mol(xyz, atomic_nums) 433 | self.fixup(mol) 434 | 435 | ob_mol = self.connect_the_dots(mol, atoms) 436 | self.fixup(ob_mol) 437 | mol.EndModify() 438 | 439 | 440 | self.fixup(ob_mol) 441 | 442 | ob_mol.AddPolarHydrogens() 443 | ob_mol.PerceiveBondOrders() 444 | self.fixup(ob_mol) 445 | 446 | 447 | for (i,a) in enumerate(atoms): 448 | ob.OBAtomAssignTypicalImplicitHydrogens(a) 449 | self.fixup(ob_mol) 450 | 451 | 452 | ob_mol.AddHydrogens() 453 | self.fixup(ob_mol) 454 | 455 | #make rings all aromatic if majority of carbons are aromatic 456 | for ring in ob.OBMolRingIter(ob_mol): 457 | if 5 <= ring.Size() <= 6: 458 | carbon_cnt = 0 459 | aromatic_ccnt = 0 460 | for ai in ring._path: 461 | a = ob_mol.GetAtom(ai) 462 | if a.GetAtomicNum() == 6: 463 | carbon_cnt += 1 464 | if a.IsAromatic(): 465 | aromatic_ccnt += 1 466 | if aromatic_ccnt >= carbon_cnt/2 and aromatic_ccnt != ring.Size(): 467 | #set all ring atoms to be aromatic 468 | for ai in ring._path: 469 | a = ob_mol.GetAtom(ai) 470 | a.SetAromatic(True) 471 | 472 | 473 | #bonds must be marked aromatic for smiles to match 474 | for bond in ob.OBMolBondIter(ob_mol): 475 | a1 = bond.GetBeginAtom() 476 | a2 = bond.GetEndAtom() 477 | if a1.IsAromatic() and a2.IsAromatic(): 478 | bond.SetAromatic(True) 479 | 480 | ob_mol.PerceiveBondOrders() 481 | 482 | 483 | rd_mol = self.convert_ob_mol_to_rd_mol(ob_mol) 484 | 485 | # Post-processing 486 | rd_mol = self.postprocess_rd_mol_1(rd_mol) 487 | rd_mol = self.postprocess_rd_mol_2(rd_mol) 488 | 489 | # rd_mol = Chem.RemoveHs(rd_mol) 490 | 491 | 492 | return rd_mol, ob_mol 493 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Update: Our code has been moved into [AIRS](https://github.com/divelab/AIRS/tree/main/OpenMI/GraphBP). Please refer to AIRS for any future updates. This repo is no longer maintained. 2 | 3 | 4 | # Generating 3D Molecules for Target Protein Binding 5 | This is the official implementation of the **GraphBP** method proposed in the following paper. 6 | 7 | Meng Liu, Youzhi Luo, Kanji Uchino, Koji Maruhashi, and Shuiwang Ji. "[Generating 3D Molecules for Target Protein Binding](https://arxiv.org/abs/2204.09410)". [ICML 2022 **Long Presentation**] 8 | 9 | ![](https://github.com/divelab/GraphBP/blob/main/assets/GraphBP.png) 10 | 11 | 12 | ## Requirements 13 | We include key dependencies below. The versions we used are in the parentheses. Our detailed environmental setup is available in [environment.yml](https://github.com/divelab/GraphBP/blob/main/GraphBP/environment.yml). 14 | * PyTorch (1.9.0) 15 | * PyTorch Geometric (1.7.2) 16 | * rdkit-pypi (2021.9.3) 17 | * biopython (1.79) 18 | * openbabel (3.3.1) 19 | 20 | 21 | ## Preparing Data 22 | * Download and extract the CrossDocked2020 dataset: 23 | ```linux 24 | wget https://bits.csb.pitt.edu/files/crossdock2020/v1.1/CrossDocked2020_v1.1.tgz -P data/crossdock2020/ 25 | tar -C data/crossdock2020/ -xzf data/crossdock2020/CrossDocked2020_v1.1.tgz 26 | wget https://bits.csb.pitt.edu/files/it2_tt_0_lowrmsd_mols_train0_fixed.types -P data/crossdock2020/ 27 | wget https://bits.csb.pitt.edu/files/it2_tt_0_lowrmsd_mols_test0_fixed.types -P data/crossdock2020/ 28 | ``` 29 | **Note**: (1) The unzipping process could take a lot of time. Unzipping on SSD is much faster!!! (2) Several samples in the training set cannot be processed by our code. Hence, we recommend replacing the `it2_tt_0_lowrmsd_mols_train0_fixed.types` 30 | file with a new one, where these samples are deleted. The new one is available [here](https://github.com/divelab/GraphBP/blob/main/GraphBP/data/crossdock2020/it2_tt_0_lowrmsd_mols_train0_fixed.types). 31 | 32 | * Split data files: 33 | ```linux 34 | python scripts/split_sdf.py data/crossdock2020/it2_tt_0_lowrmsd_mols_train0_fixed.types data/crossdock2020 35 | python scripts/split_sdf.py data/crossdock2020/it2_tt_0_lowrmsd_mols_test0_fixed.types data/crossdock2020 36 | ``` 37 | 38 | ## Run 39 | * Train GraphBP from scratch: 40 | ```linux 41 | CUDA_VISIBLE_DEVICES=${you_gpu_id} python main.py 42 | ``` 43 | **Note**: GraphBP can be trained on a `48GB GPU` with `batchsize=16`. Our trained model is available [here](https://github.com/divelab/GraphBP/blob/main/GraphBP/trained_model/model_33.pth). 44 | 45 | * Generate atoms in the 3D space with the trained model: 46 | ```linux 47 | CUDA_VISIBLE_DEVICES=${you_gpu_id} python main_gen.py 48 | ``` 49 | 50 | * Postprocess and then save the generated molecules: 51 | ```linux 52 | CUDA_VISIBLE_DEVICES=${you_gpu_id} python main_eval.py 53 | ``` 54 | 55 | 56 | 57 | ## Reference 58 | ``` 59 | @inproceedings{liu2022graphbp, 60 | title={Generating 3D Molecules for Target Protein Binding}, 61 | author={Meng Liu and Youzhi Luo and Kanji Uchino and Koji Maruhashi and Shuiwang Ji}, 62 | booktitle={International Conference on Machine Learning}, 63 | year={2022} 64 | } 65 | ``` 66 | 67 | ## Acknowledgments 68 | This work was supported in part by National Science Foundation grants IIS-2006861 and IIS-1908220. 69 | -------------------------------------------------------------------------------- /assets/GraphBP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/GraphBP/c5a2a2a1a5781d471754ed091ce4f79d15c5e341/assets/GraphBP.png --------------------------------------------------------------------------------