├── .fig_intro.jpg ├── .gitattributes ├── .gitignore ├── .model2.jpg ├── LICENSE ├── README.md ├── commons ├── geometry_utils.py ├── logger.py ├── losses.py ├── process_mols.py └── utils.py ├── configs_clean ├── RDKitCoords_flexible_self_docking.yml ├── inference.yml ├── inference_file_for_reproduce.yml └── rigid_self_docking.yml ├── data ├── timesplit_no_lig_or_rec_overlap_train ├── timesplit_no_lig_or_rec_overlap_val ├── timesplit_no_lig_overlap_train ├── timesplit_no_lig_overlap_val └── timesplit_test ├── data_preparation ├── README.md ├── find_disconnected_proteins.py ├── move_valid_files.py ├── openbabel_receptors.py ├── reduce_receptors.py └── select_protein_chains.py ├── datasets ├── custom_collate.py ├── multiple_ligands.py ├── pdbbind.py └── samplers.py ├── environment.yml ├── environment_cpuonly.yml ├── inference.py ├── models ├── README.md ├── __init__.py └── equibind.py ├── multiligand_inference.py ├── runs ├── flexible_self_docking │ ├── best_checkpoint.pt │ └── train_arguments.yaml └── rigid_redocking │ ├── best_checkpoint.pt │ └── train_arguments.yaml ├── train.py └── trainer ├── README.md ├── binding_trainer.py ├── lr_schedulers.py ├── metrics.py └── trainer.py /.fig_intro.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HannesStark/EquiBind/4cb1b4c562dae914780154518a6b915bb4cba658/.fig_intro.jpg -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-vendored=false 2 | *.ipynb linguist-detectable=false 3 | 4 | /jupyter_notebooks linguist-vendored=false 5 | 6 | jupyter_notebooks/** linguist-vendored 7 | 8 | jupyter_notebooks/** linguist-vendored=false 9 | 10 | 11 | jupyter_notebooks/* linguist-vendored 12 | jupyter_notebooks/* linguist-vendored=false -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | renew.sh 2 | tmux_renew.sh 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | .vscode/ 110 | 111 | 112 | *.zip 113 | 114 | .idea/ 115 | 116 | 117 | #################### Project specific 118 | 119 | # this ignores everything in data except for the file 120 | !/data 121 | /data/* 122 | !/data/PDBBind_deepBSP_filtered/pdbbind_ids_without_overlap_with_casf.data 123 | !/data/timesplit_test 124 | !/data/timesplit_no_lig_overlap_train 125 | !/data/timesplit_no_lig_overlap_val 126 | !/data/timesplit_no_lig_or_rec_overlap_train 127 | !/data/timesplit_no_lig_or_rec_overlap_val 128 | 129 | 130 | cache 131 | 132 | logs 133 | 134 | # temporary files 135 | temp/ 136 | bsub* 137 | stderr* 138 | stdout* 139 | 140 | runs2 141 | # this excludes everything in the runs directory except for that specific run 142 | !/runs 143 | /runs/* 144 | !/runs/rigid_redocking 145 | !/runs/flexible_self_docking 146 | -------------------------------------------------------------------------------- /.model2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HannesStark/EquiBind/4cb1b4c562dae914780154518a6b915bb4cba658/.model2.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Hannes Stärk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # EquiBind: Geometric Deep Learning for Drug Binding Structure Prediction 3 | 4 | ### [Paper on arXiv](https://arxiv.org/abs/2202.05146) 5 | 6 | **Before using EquiBind, also consider checking out our new approach called DiffDock which improves over EquiBind in multiple ways. 7 | The DiffDock [GitHub](https://github.com/gcorso/DiffDock) and [paper](https://arxiv.org/abs/2210.01776).** 8 | 9 | EquiBind, is a 10 | SE(3)-equivariant geometric deep learning model 11 | performing direct-shot prediction of both i) the receptor binding location (blind docking) and ii) the 12 | ligand’s bound pose and orientation. EquiBind 13 | achieves significant speed-ups 14 | compared to traditional and recent baselines. 15 | If you have questions, don't hesitate to open an issue or ask me 16 | via [hstark@mit.edu](hstark@mit.edu) 17 | or [social media](https://hannes-stark.com/) or Octavian Ganea via [oct@mit.edu](oct@mit.edu). We are happy to hear from you! 18 | 19 | ![](.fig_intro.jpg) 20 | 21 | ![](.model2.jpg) 22 | 23 | # Dataset 24 | 25 | Our preprocessed data (see dataset section in the paper Appendix) is no longer available on [zenodo](https://zenodo.org/record/6408497). \ 26 | The reason is that the PDBBind license does not allow for redistributing the dataset. 27 | The files in `data` contain the names for the time-based data split. 28 | 29 | This means that now you have to download the data from the PDBBind website (http://www.pdbbind.org.cn/) and place it into `data` such that you have the path `data/PDBBind` 30 | 31 | 32 | # Use provided model weights to predict binding structure of your own protein-ligand pairs: 33 | 34 | ## Step 1: What you need as input 35 | 36 | Ligand files of the formats ``.mol2`` or ``.sdf`` or ``.pdbqt`` or ``.pdb`` whose names contain the string `ligand` (your ligand files should contain **all** hydrogens). \ 37 | Receptor files of the format ``.pdb`` whose names contain the string `protein`. We ran [reduce](https://github.com/rlabduke/reduce) on our training proteins. Maybe you also want to run it on your protein.\ 38 | For each complex you want to predict you need a directory containing the ligand and receptor file. Like this: 39 | ``` 40 | my_data_folder 41 | └───name1 42 | │ name1_protein.pdb 43 | │ name1_ligand.sdf 44 | └───name2 45 | │ name2_protein.pdb 46 | │ name2_ligand.mol2 47 | ... 48 | ``` 49 | 50 | ## Step 2: Setup Environment 51 | 52 | We will set up the environment using [Anaconda](https://docs.anaconda.com/anaconda/install/index.html). Clone the 53 | current repo 54 | 55 | git clone https://github.com/HannesStark/EquiBind 56 | 57 | Create a new environment with all required packages using `environment.yml`. If you have a CUDA GPU run: 58 | 59 | conda env create -f environment.yml 60 | 61 | If you instead only have a CPU run: 62 | 63 | conda env create -f environment_cpuonly.yml 64 | 65 | Activate the environment 66 | 67 | conda activate equibind 68 | 69 | Here are the requirements themselves for the case with a CUDA GPU if you want to install them manually instead of using the `environment.yml`: 70 | ```` 71 | python=3.7 72 | pytorch 1.10 73 | torchvision 74 | cudatoolkit=10.2 75 | torchaudio 76 | dgl-cuda10.2 77 | rdkit 78 | openbabel 79 | biopython 80 | rdkit 81 | biopandas 82 | pot 83 | dgllife 84 | joblib 85 | pyaml 86 | icecream 87 | matplotlib 88 | tensorboard 89 | ```` 90 | 91 | ## Step 3: Predict Binding Structures! 92 | 93 | In the config file `configs_clean/inference.yml` set the path to your input data folder `inference_path: path_to/my_data_folder`. 94 | Then run: 95 | 96 | python inference.py --config=configs_clean/inference.yml 97 | 98 | Done! :tada: \ 99 | Your results are saved as `.sdf` files in the directory specified 100 | in the config file under ``output_directory: 'data/results/output'`` and as tensors at ``runs/flexible_self_docking/predictions_RDKitFalse.pt``! 101 | 102 | # Inference for multiple ligands in the same .sdf file and a single receptor 103 | 104 | 105 | python multiligand_infernce.py -o path/to/output_directory -r path/to/receptor.pdb -l path/to/ligands.sdf 106 | 107 | This runs EquiBind on every ligand in ligands.sdf against the protein in receptor.pdb. The outputs are 3 files in output_directory with the following names and contents: 108 | 109 | failed.txt - contains the index (in the file ligands.sdf) and name of every molecule for which inference failed in a way that was caught and handled.\ 110 | success.txt - contains the index (in the file ligands.sdf) and name of every molecule for which inference succeeded.\ 111 | output.sdf - contains the conformers produced by EquiBind in .sdf format. 112 | 113 | 114 | 115 | # Reproducing paper numbers 116 | Download the data and place it as described in the "Dataset" section above. 117 | ### Using the provided model weights 118 | To predict binding structures using the provided model weights run: 119 | 120 | python inference.py --config=configs_clean/inference_file_for_reproduce.yml 121 | 122 | This will give you the results of *EquiBind-U* and then those of *EquiBind* after running the fast ligand point cloud fitting corrections. \ 123 | The numbers are a bit better than what is reported in the paper. We will put the improved numbers into the next update of the paper. 124 | ### Training a model yourself and using those weights 125 | To train the model yourself, run: 126 | 127 | python train.py --config=configs_clean/RDKitCoords_flexible_self_docking.yml 128 | 129 | The model weights are saved in the `runs` directory.\ 130 | You can also start a tensorboard server ``tensorboard --logdir=runs`` and watch the model train. \ 131 | To evaluate the model on the test set, change the ``run_dirs:`` entry of the config file `inference_file_for_reproduce.yml` to point to the directory produced in `runs`. 132 | Then you can run``python inference.py --config=configs_clean/inference_file_for_reproduce.yml`` as above! 133 | ## Reference 134 | 135 | :page_with_curl: Paper [on arXiv](https://arxiv.org/abs/2202.05146) 136 | ``` 137 | @inproceedings{equibind, 138 | title={Equibind: Geometric deep learning for drug binding structure prediction}, 139 | author={St{\"a}rk, Hannes and Ganea, Octavian and Pattanaik, Lagnajit and Barzilay, Regina and Jaakkola, Tommi}, 140 | booktitle={International Conference on Machine Learning}, 141 | pages={20503--20521}, 142 | year={2022}, 143 | organization={PMLR} 144 | } 145 | ``` 146 | -------------------------------------------------------------------------------- /commons/geometry_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | from rdkit import Chem 7 | from rdkit.Chem import rdMolTransforms 8 | from scipy.spatial.transform import Rotation 9 | 10 | 11 | def random_rotation_translation(translation_distance): 12 | rotation = Rotation.random(num=1) 13 | rotation_matrix = rotation.as_matrix().squeeze() 14 | 15 | t = np.random.randn(1, 3) 16 | t = t / np.sqrt( np.sum(t * t)) 17 | length = np.random.uniform(low=0, high=translation_distance) 18 | t = t * length 19 | return torch.from_numpy(rotation_matrix.astype(np.float32)), torch.from_numpy(t.astype(np.float32)) 20 | 21 | # R = 3x3 rotation matrix 22 | # t = 3x1 column vector 23 | # This already takes residue identity into account. 24 | def rigid_transform_Kabsch_3D(A, B): 25 | assert A.shape[1] == B.shape[1] 26 | num_rows, num_cols = A.shape 27 | if num_rows != 3: 28 | raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") 29 | num_rows, num_cols = B.shape 30 | if num_rows != 3: 31 | raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") 32 | 33 | 34 | # find mean column wise: 3 x 1 35 | centroid_A = np.mean(A, axis=1, keepdims=True) 36 | centroid_B = np.mean(B, axis=1, keepdims=True) 37 | 38 | # subtract mean 39 | Am = A - centroid_A 40 | Bm = B - centroid_B 41 | 42 | H = Am @ Bm.T 43 | 44 | # find rotation 45 | U, S, Vt = np.linalg.svd(H) 46 | 47 | R = Vt.T @ U.T 48 | 49 | # special reflection case 50 | if np.linalg.det(R) < 0: 51 | # print("det(R) < R, reflection detected!, correcting for it ...") 52 | SS = np.diag([1.,1.,-1.]) 53 | R = (Vt.T @ SS) @ U.T 54 | assert math.fabs(np.linalg.det(R) - 1) < 1e-5 55 | 56 | t = -R @ centroid_A + centroid_B 57 | return R, t 58 | 59 | # R = 3x3 rotation matrix 60 | # t = 3x1 column vector 61 | # This already takes residue identity into account. 62 | def rigid_transform_Kabsch_3D_torch(A, B): 63 | assert A.shape[1] == B.shape[1] 64 | num_rows, num_cols = A.shape 65 | if num_rows != 3: 66 | raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") 67 | num_rows, num_cols = B.shape 68 | if num_rows != 3: 69 | raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") 70 | 71 | 72 | # find mean column wise: 3 x 1 73 | centroid_A = torch.mean(A, axis=1, keepdims=True) 74 | centroid_B = torch.mean(B, axis=1, keepdims=True) 75 | 76 | # subtract mean 77 | Am = A - centroid_A 78 | Bm = B - centroid_B 79 | 80 | H = Am @ Bm.T 81 | 82 | # find rotation 83 | U, S, Vt = torch.linalg.svd(H) 84 | 85 | R = Vt.T @ U.T 86 | 87 | # special reflection case 88 | if torch.linalg.det(R) < 0: 89 | # print("det(R) < R, reflection detected!, correcting for it ...") 90 | SS = torch.diag(torch.tensor([1.,1.,-1.], device=A.device)) 91 | R = (Vt.T @ SS) @ U.T 92 | assert math.fabs(torch.linalg.det(R) - 1) < 1e-5 93 | 94 | t = -R @ centroid_A + centroid_B 95 | return R, t 96 | 97 | 98 | def get_torsions(mol_list): 99 | atom_counter = 0 100 | torsionList = [] 101 | dihedralList = [] 102 | for m in mol_list: 103 | torsionSmarts = '[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]' 104 | torsionQuery = Chem.MolFromSmarts(torsionSmarts) 105 | matches = m.GetSubstructMatches(torsionQuery) 106 | conf = m.GetConformer() 107 | for match in matches: 108 | idx2 = match[0] 109 | idx3 = match[1] 110 | bond = m.GetBondBetweenAtoms(idx2, idx3) 111 | jAtom = m.GetAtomWithIdx(idx2) 112 | kAtom = m.GetAtomWithIdx(idx3) 113 | for b1 in jAtom.GetBonds(): 114 | if (b1.GetIdx() == bond.GetIdx()): 115 | continue 116 | idx1 = b1.GetOtherAtomIdx(idx2) 117 | for b2 in kAtom.GetBonds(): 118 | if ((b2.GetIdx() == bond.GetIdx()) 119 | or (b2.GetIdx() == b1.GetIdx())): 120 | continue 121 | idx4 = b2.GetOtherAtomIdx(idx3) 122 | # skip 3-membered rings 123 | if (idx4 == idx1): 124 | continue 125 | # skip torsions that include hydrogens 126 | # if ((m.GetAtomWithIdx(idx1).GetAtomicNum() == 1) 127 | # or (m.GetAtomWithIdx(idx4).GetAtomicNum() == 1)): 128 | # continue 129 | if m.GetAtomWithIdx(idx4).IsInRing(): 130 | torsionList.append( 131 | (idx4 + atom_counter, idx3 + atom_counter, idx2 + atom_counter, idx1 + atom_counter)) 132 | break 133 | else: 134 | torsionList.append( 135 | (idx1 + atom_counter, idx2 + atom_counter, idx3 + atom_counter, idx4 + atom_counter)) 136 | break 137 | break 138 | 139 | atom_counter += m.GetNumAtoms() 140 | return torsionList 141 | 142 | def mol_with_atom_index( mol ): 143 | atoms = mol.GetNumAtoms() 144 | for idx in range( atoms ): 145 | mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) ) 146 | return mol 147 | 148 | 149 | def SetDihedral(conf, atom_idx, new_vale): 150 | rdMolTransforms.SetDihedralDeg(conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3], new_vale) 151 | 152 | 153 | def GetDihedral(conf, atom_idx): 154 | return rdMolTransforms.GetDihedralDeg(conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3]) 155 | 156 | 157 | def GetTransformationMatrix(transformations): 158 | x, y, z, disp_x, disp_y, disp_z = transformations 159 | transMat = np.array([[np.cos(z) * np.cos(y), (np.cos(z) * np.sin(y) * np.sin(x)) - (np.sin(z) * np.cos(x)), 160 | (np.cos(z) * np.sin(y) * np.cos(x)) + (np.sin(z) * np.sin(x)), disp_x], 161 | [np.sin(z) * np.cos(y), (np.sin(z) * np.sin(y) * np.sin(x)) + (np.cos(z) * np.cos(x)), 162 | (np.sin(z) * np.sin(y) * np.cos(x)) - (np.cos(z) * np.sin(x)), disp_y], 163 | [-np.sin(y), np.cos(y) * np.sin(x), np.cos(y) * np.cos(x), disp_z], 164 | [0, 0, 0, 1]], dtype=np.double) 165 | return transMat 166 | 167 | 168 | def apply_changes(mol, values, rotable_bonds): 169 | opt_mol = copy.deepcopy(mol) 170 | # opt_mol = add_rdkit_conformer(opt_mol) 171 | 172 | # apply rotations 173 | [SetDihedral(opt_mol.GetConformer(), rotable_bonds[r], values[r]) for r in range(len(rotable_bonds))] 174 | 175 | # # apply transformation matrix 176 | # rdMolTransforms.TransformConformer(opt_mol.GetConformer(), GetTransformationMatrix(values[:6])) 177 | 178 | return opt_mol 179 | # Clockwise dihedral2 from https://stackoverflow.com/questions/20305272/dihedral-torsion-angle-from-four-points-in-cartesian-coordinates-in-python 180 | def GetDihedralFromPointCloud(Z, atom_idx): 181 | p = Z[list(atom_idx)] 182 | b = p[:-1] - p[1:] 183 | b[0] *= -1 ######################### 184 | v = np.array( [ v - (v.dot(b[1])/b[1].dot(b[1])) * b[1] for v in [b[0], b[2]] ] ) 185 | # Normalize vectors 186 | v /= np.sqrt(np.einsum('...i,...i', v, v)).reshape(-1,1) 187 | b1 = b[1] / np.linalg.norm(b[1]) 188 | x = np.dot(v[0], v[1]) 189 | m = np.cross(v[0], b1) 190 | y = np.dot(m, v[1]) 191 | return np.degrees(np.arctan2( y, x )) 192 | 193 | def A_transpose_matrix(alpha): 194 | return np.array([[np.cos(np.radians(alpha)), np.sin(np.radians(alpha))], 195 | [-np.sin(np.radians(alpha)), np.cos(np.radians(alpha))]], dtype=np.double) 196 | 197 | def S_vec(alpha): 198 | return np.array([[np.cos(np.radians(alpha))], 199 | [np.sin(np.radians(alpha))]], dtype=np.double) 200 | 201 | def get_dihedral_vonMises(mol, conf, atom_idx, Z): 202 | Z = np.array(Z) 203 | v = np.zeros((2,1)) 204 | iAtom = mol.GetAtomWithIdx(atom_idx[1]) 205 | jAtom = mol.GetAtomWithIdx(atom_idx[2]) 206 | k_0 = atom_idx[0] 207 | i = atom_idx[1] 208 | j = atom_idx[2] 209 | l_0 = atom_idx[3] 210 | for b1 in iAtom.GetBonds(): 211 | k = b1.GetOtherAtomIdx(i) 212 | if k == j: 213 | continue 214 | for b2 in jAtom.GetBonds(): 215 | l = b2.GetOtherAtomIdx(j) 216 | if l == i: 217 | continue 218 | assert k != l 219 | s_star = S_vec(GetDihedralFromPointCloud(Z, (k, i, j, l))) 220 | a_mat = A_transpose_matrix(GetDihedral(conf, (k, i, j, k_0)) + GetDihedral(conf, (l_0, i, j, l))) 221 | v = v + np.matmul(a_mat, s_star) 222 | v = v / np.linalg.norm(v) 223 | v = v.reshape(-1) 224 | return np.degrees(np.arctan2(v[1], v[0])) -------------------------------------------------------------------------------- /commons/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from datetime import datetime 3 | 4 | 5 | class Logger(object): 6 | def __init__(self, logpath, syspart=sys.stdout): 7 | self.terminal = syspart 8 | self.log = open(logpath, "a") 9 | 10 | def write(self, message): 11 | 12 | self.terminal.write(message) 13 | self.log.write(message) 14 | self.log.flush() 15 | 16 | def flush(self): 17 | # this flush method is needed for python 3 compatibility. 18 | # this handles the flush command by doing nothing. 19 | # you might want to specify some extra behavior here. 20 | pass 21 | 22 | def log(*args): 23 | print(f'[{datetime.now()}]', *args) -------------------------------------------------------------------------------- /commons/losses.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | 4 | import dgl 5 | import ot 6 | import torch 7 | from torch import Tensor, nn 8 | from torch.distributions import MultivariateNormal 9 | from torch.nn.modules.loss import _Loss, L1Loss, MSELoss, BCEWithLogitsLoss 10 | import numpy as np 11 | import torch.nn.functional as F 12 | 13 | 14 | # Ligand residue locations: a_i in R^3. Receptor: b_j in R^3 15 | # Ligand: G_l(x) = -sigma * ln( \sum_i exp(- ||x - a_i||^2 / sigma) ), same for G_r(x) 16 | # Ligand surface: x such that G_l(x) = surface_ct 17 | # Other properties: G_l(a_i) < 0, G_l(x) = infinity if x is far from all a_i 18 | # Intersection of ligand and receptor: points x such that G_l(x) < surface_ct && G_r(x) < surface_ct 19 | # Intersection loss: IL = \avg_i max(0, surface_ct - G_r(a_i)) + \avg_j max(0, surface_ct - G_l(b_j)) 20 | def G_fn(protein_coords, x, sigma): 21 | # protein_coords: (n,3) , x: (m,3), output: (m,) 22 | e = torch.exp(- torch.sum((protein_coords.view(1, -1, 3) - x.view(-1, 1, 3)) ** 2, dim=2) / float(sigma)) # (m, n) 23 | return - sigma * torch.log(1e-3 + e.sum(dim=1)) 24 | 25 | 26 | def compute_body_intersection_loss(model_ligand_coors_deform, bound_receptor_repres_nodes_loc_array, sigma, surface_ct): 27 | loss = torch.mean( 28 | torch.clamp(surface_ct - G_fn(bound_receptor_repres_nodes_loc_array, model_ligand_coors_deform, sigma), 29 | min=0)) + \ 30 | torch.mean( 31 | torch.clamp(surface_ct - G_fn(model_ligand_coors_deform, bound_receptor_repres_nodes_loc_array, sigma), 32 | min=0)) 33 | return loss 34 | 35 | 36 | def compute_sq_dist_mat(X_1, X_2): 37 | '''Computes the l2 squared cost matrix between two point cloud inputs. 38 | Args: 39 | X_1: [n, #features] point cloud, tensor 40 | X_2: [m, #features] point cloud, tensor 41 | Output: 42 | [n, m] matrix of the l2 distance between point pairs 43 | ''' 44 | n_1, _ = X_1.size() 45 | n_2, _ = X_2.size() 46 | X_1 = X_1.view(n_1, 1, -1) 47 | X_2 = X_2.view(1, n_2, -1) 48 | squared_dist = (X_1 - X_2) ** 2 49 | cost_mat = torch.sum(squared_dist, dim=2) 50 | return cost_mat 51 | 52 | 53 | def compute_ot_emd(cost_mat, device): 54 | cost_mat_detach = cost_mat.detach().cpu().numpy() 55 | a = np.ones([cost_mat.shape[0]]) / cost_mat.shape[0] 56 | b = np.ones([cost_mat.shape[1]]) / cost_mat.shape[1] 57 | ot_mat = ot.emd(a=a, b=b, M=cost_mat_detach, numItermax=10000) 58 | ot_mat_attached = torch.tensor(ot_mat, device=device, requires_grad=False).float() 59 | ot_dist = torch.sum(ot_mat_attached * cost_mat) 60 | return ot_dist, ot_mat_attached 61 | 62 | 63 | def compute_revised_intersection_loss(lig_coords, rec_coords, alpha = 0.2, beta=8, aggression=0): 64 | distances = compute_sq_dist_mat(lig_coords,rec_coords) 65 | if aggression > 0: 66 | aggression_term = torch.clamp(-torch.log(torch.sqrt(distances)/aggression+0.01), min=1) 67 | else: 68 | aggression_term = 1 69 | distance_losses = aggression_term * torch.exp(-alpha*distances * torch.clamp(distances*4-beta, min=1)) 70 | return distance_losses.sum() 71 | 72 | class BindingLoss(_Loss): 73 | def __init__(self, ot_loss_weight=1, intersection_loss_weight=0, intersection_sigma=0, geom_reg_loss_weight=1, loss_rescale=True, 74 | intersection_surface_ct=0, key_point_alignmen_loss_weight=0,revised_intersection_loss_weight=0, centroid_loss_weight=0, kabsch_rmsd_weight=0,translated_lig_kpt_ot_loss=False, revised_intersection_alpha=0.1, revised_intersection_beta=8, aggression=0) -> None: 75 | super(BindingLoss, self).__init__() 76 | self.ot_loss_weight = ot_loss_weight 77 | self.intersection_loss_weight = intersection_loss_weight 78 | self.intersection_sigma = intersection_sigma 79 | self.revised_intersection_loss_weight =revised_intersection_loss_weight 80 | self.intersection_surface_ct = intersection_surface_ct 81 | self.key_point_alignmen_loss_weight = key_point_alignmen_loss_weight 82 | self.centroid_loss_weight = centroid_loss_weight 83 | self.translated_lig_kpt_ot_loss= translated_lig_kpt_ot_loss 84 | self.kabsch_rmsd_weight = kabsch_rmsd_weight 85 | self.revised_intersection_alpha = revised_intersection_alpha 86 | self.revised_intersection_beta = revised_intersection_beta 87 | self.aggression =aggression 88 | self.loss_rescale = loss_rescale 89 | self.geom_reg_loss_weight = geom_reg_loss_weight 90 | self.mse_loss = MSELoss() 91 | 92 | def forward(self, ligs_coords, recs_coords, ligs_coords_pred, ligs_pocket_coords, recs_pocket_coords, ligs_keypts, 93 | recs_keypts, rotations, translations, geom_reg_loss, device, **kwargs): 94 | # Compute MSE loss for each protein individually, then average over the minibatch. 95 | ligs_coords_loss = 0 96 | recs_coords_loss = 0 97 | ot_loss = 0 98 | intersection_loss = 0 99 | intersection_loss_revised = 0 100 | keypts_loss = 0 101 | centroid_loss = 0 102 | kabsch_rmsd_loss = 0 103 | 104 | for i in range(len(ligs_coords_pred)): 105 | ## Compute average MSE loss (which is 3 times smaller than average squared RMSD) 106 | ligs_coords_loss = ligs_coords_loss + self.mse_loss(ligs_coords_pred[i], ligs_coords[i]) 107 | 108 | if self.ot_loss_weight > 0: 109 | # Compute the OT loss for the binding pocket: 110 | ligand_pocket_coors = ligs_pocket_coords[i] ## (N, 3), N = num pocket nodes 111 | receptor_pocket_coors = recs_pocket_coords[i] ## (N, 3), N = num pocket nodes 112 | 113 | ## (N, K) cost matrix 114 | if self.translated_lig_kpt_ot_loss: 115 | cost_mat_ligand = compute_sq_dist_mat(receptor_pocket_coors, (rotations[i] @ ligs_keypts[i].t()).t() + translations[i] ) 116 | else: 117 | cost_mat_ligand = compute_sq_dist_mat(ligand_pocket_coors, ligs_keypts[i]) 118 | cost_mat_receptor = compute_sq_dist_mat(receptor_pocket_coors, recs_keypts[i]) 119 | 120 | ot_dist, _ = compute_ot_emd(cost_mat_ligand + cost_mat_receptor, device) 121 | ot_loss += ot_dist 122 | if self.key_point_alignmen_loss_weight > 0: 123 | keypts_loss += self.mse_loss((rotations[i] @ ligs_keypts[i].t()).t() + translations[i], 124 | recs_keypts[i]) 125 | 126 | if self.intersection_loss_weight > 0: 127 | intersection_loss = intersection_loss + compute_body_intersection_loss(ligs_coords_pred[i], 128 | recs_coords[i], 129 | self.intersection_sigma, 130 | self.intersection_surface_ct) 131 | 132 | if self.revised_intersection_loss_weight > 0: 133 | intersection_loss_revised = intersection_loss_revised + compute_revised_intersection_loss(ligs_coords_pred[i], 134 | recs_coords[i], alpha=self.revised_intersection_alpha, beta=self.revised_intersection_beta, aggression=self.aggression) 135 | 136 | if self.kabsch_rmsd_weight > 0: 137 | lig_coords_pred = ligs_coords_pred[i] 138 | lig_coords = ligs_coords[i] 139 | lig_coords_pred_mean = lig_coords_pred.mean(dim=0, keepdim=True) # (1,3) 140 | lig_coords_mean = lig_coords.mean(dim=0, keepdim=True) # (1,3) 141 | 142 | A = (lig_coords_pred - lig_coords_pred_mean).transpose(0, 1) @ (lig_coords - lig_coords_mean) 143 | 144 | U, S, Vt = torch.linalg.svd(A) 145 | 146 | corr_mat = torch.diag(torch.tensor([1, 1, torch.sign(torch.det(A))], device=lig_coords_pred.device)) 147 | rotation = (U @ corr_mat) @ Vt 148 | translation = lig_coords_pred_mean - torch.t(rotation @ lig_coords_mean.t()) # (1,3) 149 | kabsch_rmsd_loss += self.mse_loss((rotation @ lig_coords.t()).t() + translation, lig_coords_pred) 150 | 151 | centroid_loss += self.mse_loss(ligs_coords_pred[i].mean(dim=0), ligs_coords[i].mean(dim=0)) 152 | 153 | if self.loss_rescale: 154 | ligs_coords_loss = ligs_coords_loss / float(len(ligs_coords_pred)) 155 | ot_loss = ot_loss / float(len(ligs_coords_pred)) 156 | intersection_loss = intersection_loss / float(len(ligs_coords_pred)) 157 | keypts_loss = keypts_loss / float(len(ligs_coords_pred)) 158 | centroid_loss = centroid_loss / float(len(ligs_coords_pred)) 159 | kabsch_rmsd_loss = kabsch_rmsd_loss / float(len(ligs_coords_pred)) 160 | intersection_loss_revised = intersection_loss_revised / float(len(ligs_coords_pred)) 161 | geom_reg_loss = geom_reg_loss / float(len(ligs_coords_pred)) 162 | 163 | loss = ligs_coords_loss + self.ot_loss_weight * ot_loss + self.intersection_loss_weight * intersection_loss + keypts_loss * self.key_point_alignmen_loss_weight + centroid_loss * self.centroid_loss_weight + kabsch_rmsd_loss * self.kabsch_rmsd_weight + intersection_loss_revised *self.revised_intersection_loss_weight + geom_reg_loss*self.geom_reg_loss_weight 164 | return loss, {'ligs_coords_loss': ligs_coords_loss, 'recs_coords_loss': recs_coords_loss, 'ot_loss': ot_loss, 165 | 'intersection_loss': intersection_loss, 'keypts_loss': keypts_loss, 'centroid_loss:': centroid_loss, 'kabsch_rmsd_loss': kabsch_rmsd_loss, 'intersection_loss_revised': intersection_loss_revised, 'geom_reg_loss': geom_reg_loss} 166 | 167 | class TorsionLoss(_Loss): 168 | def __init__(self) -> None: 169 | super(TorsionLoss, self).__init__() 170 | self.mse_loss = MSELoss() 171 | 172 | def forward(self, angles_pred, angles, masks, **kwargs): 173 | return self.mse_loss(angles_pred*masks,angles*masks) -------------------------------------------------------------------------------- /commons/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | from argparse import Namespace 4 | from collections import MutableMapping 5 | from typing import Dict, Any 6 | from joblib import Parallel, delayed, cpu_count 7 | import torch 8 | import numpy as np 9 | import dgl 10 | from torch.utils.tensorboard import SummaryWriter 11 | from tqdm import tqdm 12 | 13 | from commons.logger import log 14 | 15 | 16 | def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, desc=None, **kwargs): 17 | """ 18 | 19 | Parallel map using joblib. 20 | 21 | Parameters 22 | ---------- 23 | pickleable_fn : callable 24 | Function to map over data. 25 | data : iterable 26 | Data over which we want to parallelize the function call. 27 | n_jobs : int, optional 28 | The maximum number of concurrently running jobs. By default, it is one less than 29 | the number of CPUs. 30 | verbose: int, optional 31 | The verbosity level. If nonzero, the function prints the progress messages. 32 | The frequency of the messages increases with the verbosity level. If above 10, 33 | it reports all iterations. If above 50, it sends the output to stdout. 34 | kwargs 35 | Additional arguments for :attr:`pickleable_fn`. 36 | 37 | Returns 38 | ------- 39 | list 40 | The i-th element of the list corresponds to the output of applying 41 | :attr:`pickleable_fn` to :attr:`data[i]`. 42 | """ 43 | if n_jobs is None: 44 | n_jobs = cpu_count() - 1 45 | 46 | results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)( 47 | delayed(pickleable_fn)(*d, **kwargs) for i, d in tqdm(enumerate(data),desc=desc) 48 | ) 49 | 50 | return results 51 | 52 | def seed_all(seed): 53 | if not seed: 54 | seed = 0 55 | 56 | log("[ Using Seed : ", seed, " ]") 57 | 58 | torch.manual_seed(seed) 59 | torch.cuda.manual_seed_all(seed) 60 | torch.cuda.manual_seed(seed) 61 | np.random.seed(seed) 62 | dgl.random.seed(seed) 63 | random.seed(seed) 64 | # torch.backends.cudnn.deterministic = True 65 | # torch.backends.cudnn.benchmark = False 66 | 67 | 68 | def get_random_indices(length, seed=123): 69 | st0 = np.random.get_state() 70 | np.random.seed(seed) 71 | random_indices = np.random.permutation(length) 72 | np.random.set_state(st0) 73 | return random_indices 74 | 75 | edges_dic = {} 76 | def get_adj_matrix(n_nodes, batch_size, device): 77 | if n_nodes in edges_dic: 78 | edges_dic_b = edges_dic[n_nodes] 79 | if batch_size in edges_dic_b: 80 | return edges_dic_b[batch_size] 81 | else: 82 | # get edges for a single sample 83 | rows, cols = [], [] 84 | for batch_idx in range(batch_size): 85 | for i in range(n_nodes): 86 | for j in range(n_nodes): 87 | rows.append(i + batch_idx*n_nodes) 88 | cols.append(j + batch_idx*n_nodes) 89 | 90 | else: 91 | edges_dic[n_nodes] = {} 92 | return get_adj_matrix(n_nodes, batch_size, device) 93 | 94 | edges = [torch.LongTensor(rows).to(device), torch.LongTensor(cols).to(device)] 95 | return edges 96 | 97 | def flatten_dict(params: Dict[Any, Any], delimiter: str = '/') -> Dict[str, Any]: 98 | """ 99 | Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. 100 | 101 | Args: 102 | params: Dictionary containing the hyperparameters 103 | delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``. 104 | 105 | Returns: 106 | Flattened dict. 107 | Examples: 108 | flatten_dict({'a': {'b': 'c'}}) 109 | {'a/b': 'c'} 110 | flatten_dict({'a': {'b': 123}}) 111 | {'a/b': 123} 112 | flatten_dict({5: {'a': 123}}) 113 | {'5/a': 123} 114 | """ 115 | 116 | def _dict_generator(input_dict, prefixes=None): 117 | prefixes = prefixes[:] if prefixes else [] 118 | if isinstance(input_dict, MutableMapping): 119 | for key, value in input_dict.items(): 120 | key = str(key) 121 | if isinstance(value, (MutableMapping, Namespace)): 122 | value = vars(value) if isinstance(value, Namespace) else value 123 | for d in _dict_generator(value, prefixes + [key]): 124 | yield d 125 | else: 126 | yield prefixes + [key, value if value is not None else str(None)] 127 | else: 128 | yield prefixes + [input_dict if input_dict is None else str(input_dict)] 129 | 130 | dictionary = {delimiter.join(keys): val for *keys, val in _dict_generator(params)} 131 | for k in dictionary.keys(): 132 | # convert relevant np scalars to python types first (instead of str) 133 | if isinstance(dictionary[k], (np.bool_, np.integer, np.floating)): 134 | dictionary[k] = dictionary[k].item() 135 | elif type(dictionary[k]) not in [bool, int, float, str, torch.Tensor]: 136 | dictionary[k] = str(dictionary[k]) 137 | return dictionary 138 | 139 | 140 | 141 | 142 | def tensorboard_gradient_magnitude(optimizer: torch.optim.Optimizer, writer: SummaryWriter, step, param_groups=[0]): 143 | for i, param_group in enumerate(optimizer.param_groups): 144 | if i in param_groups: 145 | all_params = [] 146 | for params in param_group['params']: 147 | if params.grad != None: 148 | all_params.append(params.grad.view(-1)) 149 | writer.add_scalar(f'gradient_magnitude_param_group_{i}', torch.cat(all_params).abs().mean(), 150 | global_step=step) 151 | 152 | def move_to_device(element, device): 153 | ''' 154 | takes arbitrarily nested list and moves everything in it to device if it is a dgl graph or a torch tensor 155 | :param element: arbitrarily nested list 156 | :param device: 157 | :return: 158 | ''' 159 | if isinstance(element, list): 160 | return [move_to_device(x, device) for x in element] 161 | else: 162 | return element.to(device) if isinstance(element,(torch.Tensor, dgl.DGLGraph)) else element 163 | 164 | def list_detach(element): 165 | ''' 166 | takes arbitrarily nested list and detaches everyting from computation graph 167 | :param element: arbitrarily nested list 168 | :return: 169 | ''' 170 | if isinstance(element, list): 171 | return [list_detach(x) for x in element] 172 | else: 173 | return element.detach() 174 | 175 | def concat_if_list(tensor_or_tensors): 176 | return torch.cat(tensor_or_tensors) if isinstance(tensor_or_tensors, list) else tensor_or_tensors 177 | 178 | def write_strings_to_txt(strings: list, path): 179 | # every string of the list will be saved in one line 180 | textfile = open(path, "w") 181 | for element in strings: 182 | textfile.write(element + "\n") 183 | textfile.close() 184 | 185 | def read_strings_from_txt(path): 186 | # every line will be one element of the returned list 187 | with open(path) as file: 188 | lines = file.readlines() 189 | return [line.rstrip() for line in lines] -------------------------------------------------------------------------------- /configs_clean/RDKitCoords_flexible_self_docking.yml: -------------------------------------------------------------------------------- 1 | experiment_name: 'YourExperimentName' 2 | seed: 1 3 | data_seed: 1 4 | 5 | trainer: binding 6 | num_epochs: 1000000 7 | batch_size: 8 8 | log_iterations: 100 9 | patience: 150 10 | num_train: # leave empty to use all (in param train_names below) 11 | num_val: # leave empty to use all 12 | collate_function: graph_collate_revised 13 | loss_func: BindingLoss 14 | loss_params: 15 | ot_loss_weight: 1 16 | key_point_alignmen_loss_weight: 0 # this does only work if ot_loss_weight is not 0 17 | centroid_loss_weight: 0 18 | intersection_loss_weight: 1 19 | intersection_sigma: 8 # 8 was determined by gridsearch over data 20 | intersection_surface_ct: 1 # grid search says 2.5 21 | translated_lig_kpt_ot_loss: False 22 | kabsch_rmsd_weight: 1 23 | 24 | 25 | train_names: 'data/timesplit_no_lig_overlap_train' 26 | val_names: 'data/timesplit_no_lig_overlap_val' 27 | test_names: 'data/timesplit_test' 28 | num_workers: 0 29 | dataset_params: 30 | geometry_regularization_ring: True 31 | use_rdkit_coords: True # 32 | bsp_proteins: False # if this is true then the proteins from deepbsp are used, otherwise those from PDBBind. Note that not all proteins are in deepBSP and this only works with e.g. pdbbind_names_without_casf_rec_and_in_bsp as complex_names_path 33 | dataset_size: # mostly for debugging dataset creation. leave empty to use the whole dataset 34 | translation_distance: 5.0 35 | n_jobs: 20 # leave empty to use num_cpu - 1 36 | chain_radius: 10 # only keep chains that have an atom in this radius around the ligand 37 | rec_graph_radius: 30 38 | c_alpha_max_neighbors: 10 # maximum number of neighbors in the receptor graph unless using rec_atoms 39 | lig_graph_radius: 5 40 | lig_max_neighbors: 41 | pocket_cutoff: 4 42 | pocket_mode: match_atoms_to_lig # [match_terminal_atoms, match_atoms, radius_based, lig_atoms] 43 | remove_h: False 44 | only_polar_hydrogens: False 45 | # the following are only relevant if use_rec_atoms is true 46 | use_rec_atoms: False # if this is true then the same parameter also needs to be true for the model parameters 47 | surface_max_neighbors: 5 48 | surface_graph_cutoff: 5 49 | surface_mesh_cutoff: 2 50 | # the following are only relevant if subgraph augmentation is true 51 | subgraph_augmentation: False # using subgraph augmentation increases CPU usage. Consider also using num_workers > 0 52 | min_shell_thickness: 3 53 | # the following is only relevant for rec_subgraph= True 54 | rec_subgraph: False # dont forget to also set use_rec_atoms to True IN THE MODEL PARAMETERS 55 | subgraph_radius: 10 56 | subgraph_max_neigbor: 8 57 | subgraph_cutoff: 4 58 | 59 | 60 | metrics: 61 | - pearsonr 62 | - rsquared 63 | - mean_rmsd 64 | - median_rmsd 65 | - median_centroid_distance 66 | - centroid_distance_less_than_2 67 | - mean_centroid_distance 68 | - kabsch_rmsd 69 | - rmsd_less_than_2 70 | - rmsd_less_than_5 71 | main_metric: rmsd_less_than_2 # used for early stopping etc 72 | main_metric_goal: 'max' 73 | 74 | optimizer: Adam 75 | optimizer_params: 76 | lr: 1.0e-4 77 | weight_decay: 1.0e-4 # 1.0e-5 in good run 78 | clip_grad: 100 # leave empty for no grad clip 79 | 80 | scheduler_step_per_batch: False 81 | lr_scheduler: ReduceLROnPlateau # leave empty to use none 82 | lr_scheduler_params: 83 | factor: 0.6 84 | patience: 60 85 | min_lr: 8.0e-6 86 | mode: 'max' 87 | verbose: True 88 | 89 | 90 | # Model parameters 91 | model_type: 'EquiBind' 92 | model_parameters: 93 | geometry_reg_step_size: 0.001 94 | geometry_regularization: True 95 | use_evolved_lig: True # Whether or not to use the evolved lig as final prediction 96 | standard_norm_order: True 97 | unnormalized_kpt_weights: False # no softmax for the weights that create the keypoints 98 | lig_evolve: True # whether or not the coordinates are changed in the EGNN layers 99 | rec_evolve: True 100 | rec_no_softmax: False 101 | lig_no_softmax: False 102 | centroid_keypts_construction_rec: False 103 | centroid_keypts_construction_lig: False 104 | centroid_keypts_construction: False # this is old. use the two above 105 | move_keypts_back: True # move the constructed keypoints back to the location of the ligand 106 | normalize_Z_rec_directions: False 107 | normalize_Z_lig_directions: False 108 | n_lays: 8 # 5 in good run 109 | debug: False 110 | use_rec_atoms: False 111 | shared_layers: False # False in good run 112 | noise_decay_rate: 0.5 113 | noise_initial: 1 114 | use_edge_features_in_gmn: True 115 | use_mean_node_features: True 116 | residue_emb_dim: 64 117 | iegmn_lay_hid_dim: 64 118 | num_att_heads: 30 # 20 ic good run 119 | dropout: 0.1 120 | nonlin: 'lkyrelu' # ['swish', 'lkyrelu'] 121 | leakyrelu_neg_slope: 1.0e-2 # 1.0e-2 in good run 122 | cross_msgs: True 123 | layer_norm: 'BN' # ['0', 'BN', 'LN'] # BN in good run 124 | layer_norm_coords: '0' # ['0', 'LN'] # 0 in good run 125 | final_h_layer_norm: '0' # ['0', 'GN', 'BN', 'LN'] # 0 in good run 126 | pre_crossmsg_norm_type: '0' # ['0', 'GN', 'BN', 'LN'] 127 | post_crossmsg_norm_type: '0' # ['0', 'GN', 'BN', 'LN'] 128 | use_dist_in_layers: True 129 | skip_weight_h: 0.5 # 0.5 in good run 130 | x_connection_init: 0.25 # 0.25 in good run 131 | random_vec_dim: 0 # set to 0 to have no stochasticity 132 | random_vec_std: 1 133 | use_scalar_features: False # Have a look at lig_feature_dims in process_mols.py to see what features we are talking about. 134 | num_lig_feats: # leave as None to use all ligand features. Have a look at lig_feature_dims in process_mols.py to see what features we are talking about. If this is 1, only the first of those will be used. 135 | normalize_coordinate_update: True 136 | rec_square_distance_scale: 10 # divide square distance by 10 to have a nicer separation instead of many 0.00000 137 | 138 | 139 | eval_on_test: False 140 | # continue training from checkpoint: 141 | #checkpoint: runs/path_to_the_experiment/last_checkpoint.pt -------------------------------------------------------------------------------- /configs_clean/inference.yml: -------------------------------------------------------------------------------- 1 | run_dirs: 2 | - flexible_self_docking # the resulting coordinates will be saved here as tensors in a .pt file (but also as .sdf files if you specify an "output_directory" below) 3 | inference_path: 'data/to_predict' # this should be your input file path as described in the main readme 4 | 5 | test_names: timesplit_test 6 | output_directory: 'data/results/output' # the predicted ligands will be saved as .sdf file here 7 | run_corrections: True 8 | use_rdkit_coords: False # generates the coordinates of the ligand with rdkit instead of using the provided conformer. If you already have a 3D structure that you want to use as initial conformer, then leave this as False 9 | save_trajectories: False 10 | 11 | num_confs: 1 # usually this should be 1 12 | 13 | -------------------------------------------------------------------------------- /configs_clean/inference_file_for_reproduce.yml: -------------------------------------------------------------------------------- 1 | run_dirs: 2 | - flexible_self_docking # the resulting coordinates will be saved here as tensors in a .pt file (but also as .sdf files if you specify an "output_directory" below) 3 | inference_path: # 'data/your_input_path' # this should be your input file path as described in the main readme 4 | 5 | test_names: timesplit_test 6 | output_directory: # the predicted ligands will be saved as .sdf file here 7 | run_corrections: True 8 | use_rdkit_coords: True # generates the coordinates of the ligand with rdkit instead of using the provided conformer. If you already have a 3D structure that you want to use as initial conformer, then leave this as False 9 | save_trajectories: False 10 | 11 | num_confs: 1 # usually this should be 1 12 | 13 | 14 | -------------------------------------------------------------------------------- /configs_clean/rigid_self_docking.yml: -------------------------------------------------------------------------------- 1 | experiment_name: 'YourExperimentName' 2 | seed: 1 3 | data_seed: 1 4 | 5 | trainer: binding 6 | num_epochs: 1000000 7 | batch_size: 8 8 | log_iterations: 100 9 | patience: 150 10 | num_train: # leave empty to use all (in param train_names below) 11 | num_val: # leave empty to use all 12 | collate_function: graph_collate_revised 13 | loss_func: BindingLoss 14 | loss_params: 15 | ot_loss_weight: 1 16 | key_point_alignmen_loss_weight: 0 # this does only work if ot_loss_weight is not 0 17 | centroid_loss_weight: 0 18 | intersection_loss_weight: 1 19 | intersection_sigma: 8 # 8 was determined by gridsearch over data 20 | intersection_surface_ct: 1 # grid search says 2.5 21 | translated_lig_kpt_ot_loss: False 22 | kabsch_rmsd_weight: 1 23 | 24 | 25 | train_names: 'data/timesplit_no_lig_overlap_train' 26 | val_names: 'data/timesplit_no_lig_overlap_val' 27 | test_names: 'data/timesplit_test' 28 | num_workers: 0 29 | dataset_params: 30 | geometry_regularization_ring: False 31 | geometry_regularization: False 32 | use_rdkit_coords: False # 33 | bsp_proteins: False # if this is true then the proteins from deepbsp are used, otherwise those from PDBBind. Note that not all proteins are in deepBSP and this only works with e.g. pdbbind_names_without_casf_rec_and_in_bsp as complex_names_path 34 | dataset_size: # mostly for debugging dataset creation. leave empty to use the whole dataset 35 | translation_distance: 5.0 36 | n_jobs: 20 # leave empty to use num_cpu - 1 37 | chain_radius: 10 # only keep chains that have an atom in this radius around the ligand 38 | rec_graph_radius: 30 39 | c_alpha_max_neighbors: 10 # maximum number of neighbors in the receptor graph unless using rec_atoms 40 | lig_graph_radius: 5 41 | lig_max_neighbors: 42 | pocket_cutoff: 4 43 | pocket_mode: match_atoms_to_lig # [match_terminal_atoms, match_atoms, radius_based, lig_atoms, match_atoms_to_lig] 44 | remove_h: False 45 | only_polar_hydrogens: False 46 | # the following are only relevant if use_rec_atoms is true 47 | use_rec_atoms: False # if this is true then the same parameter also needs to be true for the model parameters 48 | surface_max_neighbors: 5 49 | surface_graph_cutoff: 5 50 | surface_mesh_cutoff: 2 51 | # the following are only relevant if subgraph augmentation is true 52 | subgraph_augmentation: False # using subgraph augmentation increases CPU usage. Consider also using num_workers > 0 53 | min_shell_thickness: 3 54 | # the following is only relevant for rec_subgraph= True 55 | rec_subgraph: False # dont forget to also set use_rec_atoms to True IN THE MODEL PARAMETERS 56 | subgraph_radius: 10 57 | subgraph_max_neigbor: 8 58 | subgraph_cutoff: 4 59 | 60 | 61 | metrics: 62 | - pearsonr 63 | - rsquared 64 | - mean_rmsd 65 | - median_rmsd 66 | - median_centroid_distance 67 | - centroid_distance_less_than_2 68 | - mean_centroid_distance 69 | - kabsch_rmsd 70 | - rmsd_less_than_2 71 | - rmsd_less_than_5 72 | main_metric: rmsd_less_than_2 # used for early stopping etc 73 | main_metric_goal: 'max' 74 | 75 | optimizer: Adam 76 | optimizer_params: 77 | lr: 1.0e-4 78 | weight_decay: 1.0e-4 # 1.0e-5 in good run 79 | clip_grad: 100 # leave empty for no grad clip 80 | 81 | scheduler_step_per_batch: False 82 | lr_scheduler: ReduceLROnPlateau # leave empty to use none 83 | lr_scheduler_params: 84 | factor: 0.6 85 | patience: 60 86 | min_lr: 8.0e-6 87 | mode: 'max' 88 | verbose: True 89 | 90 | 91 | # Model parameters 92 | model_type: 'EquiBind' 93 | model_parameters: 94 | geometry_reg_step_size: 0.001 95 | geometry_regularization: False 96 | fine_tune: False 97 | use_evolved_lig: False # Wether or not to use the evolved lig as final prediction 98 | standard_norm_order: True 99 | unnormalized_kpt_weights: False # no softmax for the weights that create the keypoints (THIS DOES NOT WORK WITH use_evolved_lig = True) (well, it does not really work at all) 100 | lig_evolve: True # whether or not the coordinates are changed in the EGNN layers 101 | rec_evolve: True 102 | rec_no_softmax: False 103 | lig_no_softmax: False 104 | centroid_keypts_construction_rec: False 105 | centroid_keypts_construction_lig: False 106 | centroid_keypts_construction: False # this is old. use the two above 107 | move_keypts_back: True # move the constructed keypoints back to the location of the ligand 108 | normalize_Z_rec_directions: False 109 | normalize_Z_lig_directions: False 110 | n_lays: 8 # 5 in good run 111 | debug: False 112 | use_rec_atoms: False 113 | shared_layers: False # False in good run 114 | noise_decay_rate: 0.5 115 | noise_initial: 1 116 | use_edge_features_in_gmn: True 117 | use_mean_node_features: True 118 | residue_emb_dim: 64 119 | iegmn_lay_hid_dim: 64 120 | num_att_heads: 30 # 20 ic good run 121 | dropout: 0.1 122 | nonlin: 'lkyrelu' # ['swish', 'lkyrelu'] 123 | leakyrelu_neg_slope: 1.0e-2 # 1.0e-2 in good run 124 | cross_msgs: True 125 | layer_norm: 'BN' # ['0', 'BN', 'LN'] # BN in good run 126 | layer_norm_coords: '0' # ['0', 'LN'] # 0 in good run 127 | final_h_layer_norm: '0' # ['0', 'GN', 'BN', 'LN'] # 0 in good run 128 | pre_crossmsg_norm_type: '0' # ['0', 'GN', 'BN', 'LN'] 129 | post_crossmsg_norm_type: '0' # ['0', 'GN', 'BN', 'LN'] 130 | use_dist_in_layers: True 131 | skip_weight_h: 0.5 # 0.5 in good run 132 | x_connection_init: 0.25 # 0.25 in good run 133 | random_vec_dim: 0 # set to 0 to have no stochasticity 134 | random_vec_std: 1 135 | use_scalar_features: False # Have a look at lig_feature_dims in process_mols.py to see what features we are talking about. 136 | num_lig_feats: # leave as None to use all ligand features. Have a look at lig_feature_dims in process_mols.py to see what features we are talking about. If this is 1, only the first of those will be used. 137 | normalize_coordinate_update: True 138 | rec_square_distance_scale: 10 # divide square distance by 10 to have a nicer separation instead of many 0.00000 139 | 140 | 141 | eval_on_test: False 142 | # continue training from checkpoint: 143 | #checkpoint: runs/path_to_the_experiment/last_checkpoint.pt -------------------------------------------------------------------------------- /data/timesplit_no_lig_or_rec_overlap_val: -------------------------------------------------------------------------------- 1 | 4mi6 2 | 5ylv 3 | 4ozo 4 | 6gip 5 | 3std 6 | 3g2n 7 | 6ax1 8 | 6h96 9 | 5q0m 10 | 5hh5 11 | 4idz 12 | 6cec 13 | 5wqa 14 | 3k3e 15 | 1ppk 16 | 4og7 17 | 4b5w 18 | 4bgg 19 | 4jgv 20 | 2g9v 21 | 4og3 22 | 5lz2 23 | 6chq 24 | 5aqv 25 | 4k67 26 | 5j3v 27 | 5lz9 28 | 5t2y 29 | 4el5 30 | 2z60 31 | 4zed 32 | 4pks 33 | 5cuu 34 | 5q0x 35 | 4x5z 36 | 3hs9 37 | 3v7s 38 | 2qyn 39 | 5ehn 40 | 1sz0 41 | 4x6i 42 | 4u82 43 | 2vrj 44 | 6g2l 45 | 2bow 46 | 5o9y 47 | 4mji 48 | 6ccs 49 | 1yvm 50 | 3sym 51 | 4fz3 52 | 5tdi 53 | 2ie4 54 | 1n4m 55 | 1c3i 56 | 5eie 57 | 5ye8 58 | 5in9 59 | 4b7z 60 | 5ncz 61 | 5lz4 62 | 6bw4 63 | 3sl4 64 | 4dmy 65 | 4cmt 66 | 5a9u 67 | 1fki 68 | 4mc1 69 | 5o9o 70 | 4b5t 71 | 4m3g 72 | 1pyg 73 | 3b67 74 | 5l8o 75 | 5mkj 76 | 3oyp 77 | 4anq 78 | 5hog 79 | 4de7 80 | 5tkb 81 | 2vle 82 | 3f7i 83 | 3v51 84 | 3l7d 85 | 2v6n 86 | 4qna 87 | 4cd0 88 | 3iog 89 | 4i8w 90 | 2xup 91 | 3t3i 92 | 1db4 93 | 5es1 94 | 5i9i 95 | 6ccm 96 | 2xui 97 | 1q91 98 | 1bgo 99 | 1akt 100 | 1q84 101 | 1yt7 102 | 2l75 103 | 5aac 104 | 6nao 105 | 5iuh 106 | 3oof 107 | 4ona 108 | 5q0n 109 | 3sfi 110 | 2g9q 111 | 1hlf 112 | 5aqj 113 | 1g9c 114 | 1ayu 115 | 6co4 116 | 6bd1 117 | 4yur 118 | 3vw0 119 | 6cea 120 | 4nyf 121 | 3v43 122 | 2ya8 123 | 1b4d 124 | 2ccb 125 | 5q1f 126 | 1fkb 127 | 3bcs 128 | 1h46 129 | 6dgt 130 | 2ftd 131 | 5t2l 132 | 3i7c 133 | 6ckw 134 | 3csl 135 | 1j07 136 | 3omm 137 | 4g17 138 | 3v49 139 | 4fny 140 | 1fkh 141 | 4u0e 142 | 2g9r 143 | 6hd4 144 | 5oss 145 | 2adm 146 | 4b85 147 | 2fm0 148 | 4hgs 149 | 2qn3 150 | 1ddm 151 | 3fal 152 | 5q0y 153 | 5q1a 154 | 3g2k 155 | 2j7b 156 | 6ee2 157 | 2jjr 158 | 5q1i 159 | 6b7a 160 | 5kh7 161 | 6gqm 162 | 3mta 163 | 3g2l 164 | 4i4f 165 | 5iaw 166 | 5q0i 167 | 3h1z 168 | 4jbl 169 | 5lvx 170 | 6chm 171 | 6fdc 172 | 2ax6 173 | 6cnj 174 | 1pwu 175 | 4fht 176 | 3th9 177 | 5db1 178 | 2hvc 179 | 4pp5 180 | 5q15 181 | 3u8h 182 | 4el0 183 | 4jt8 184 | 5std 185 | 2pwd 186 | 5wh6 187 | 3kf4 188 | 1q83 189 | 2xml 190 | 1c2t 191 | 4in9 192 | 2jt5 193 | 1icj 194 | 3oy3 195 | 3g2h 196 | 3qo2 197 | 3tu9 198 | 3s0j 199 | 1e3g 200 | 6d1x 201 | 6mx8 202 | 3aqt 203 | 6aol 204 | 4m8e 205 | 1ado 206 | 5hh6 207 | 4lkt 208 | 2j78 209 | 5q0o 210 | 1d8m 211 | 3suu 212 | 1gyy 213 | 5aqt 214 | 3oap 215 | 4zs3 216 | 2qn1 217 | 2p98 218 | 4cmu 219 | 4ie0 220 | 1w3j 221 | 5y59 222 | 2mpa 223 | 1akw 224 | 5tyh 225 | 5lb7 226 | 1fkf 227 | 3djv 228 | 5nxq 229 | 1kti 230 | 3mrv 231 | 5t2i 232 | 4uuq 233 | 2gfj 234 | 4poh 235 | 6by8 236 | 6b7e 237 | 4u0f 238 | 5lp1 239 | 4jym 240 | 3suv 241 | 6fse 242 | 5e1s 243 | 2eum 244 | 6hai 245 | 6h7f 246 | 3suw 247 | 2gg7 248 | 3np7 249 | 5l13 250 | 6f8x 251 | 2evc 252 | 1z6q 253 | 5o9r 254 | 4pkt 255 | 1haa 256 | 5jan 257 | 4oiv 258 | 3djq 259 | 3p7i 260 | 2j7h 261 | 3v7c 262 | 2fw6 263 | 3diw 264 | 4dpt 265 | 1tu6 266 | 4pku 267 | 1ggn 268 | 2usn 269 | 6baw 270 | 1zxv 271 | 3bl7 272 | 5vdu 273 | 6fsd 274 | 4u0i 275 | 4j8s 276 | 1g49 277 | 4hni 278 | 6e83 279 | 6f0y 280 | 4mra 281 | 1d5r 282 | 3v3q 283 | 6f8w 284 | 1c7e 285 | 4b5s 286 | 4ara 287 | 4glr 288 | 4cff 289 | 2q93 290 | 5i8p 291 | 1apw 292 | 3sur 293 | 4b84 294 | 4duh 295 | 3l7c 296 | 1j4r 297 | 3mrt 298 | 1l7x 299 | 4dv8 300 | 1iep 301 | 1bsk 302 | 1i8h 303 | 5xih 304 | 5knj 305 | 5ick 306 | 2v3d 307 | 1sln 308 | 4k64 309 | 6cck 310 | 4ajw 311 | 2hb9 312 | 6std 313 | 2w87 314 | 4a23 315 | 1ayv 316 | 4bvb 317 | 1nc1 318 | 5q0p 319 | 3zlv 320 | 3okh 321 | 1pwq 322 | 3pp0 323 | 4pgh 324 | 2xbp 325 | 2qrp 326 | 4mho 327 | 6gpb 328 | 1aku 329 | 6b7f 330 | 3pcu 331 | 1fm9 332 | 5ddc 333 | 5q0u 334 | 1biw 335 | 3ery 336 | 2evo 337 | 4cxx 338 | 4g2l 339 | 5h6v 340 | 5yyz 341 | 4nw2 342 | 1em6 343 | 3shv 344 | 6cz3 345 | 4lh2 346 | 2gfk 347 | 2z78 348 | 4yv8 349 | 2jt6 350 | 3kdt 351 | 6chp 352 | 2xba 353 | 5wbl 354 | 5t2d 355 | 3fqa 356 | 5xii 357 | 2j83 358 | 4x6j 359 | 2q94 360 | 1vsn 361 | 1c8k 362 | 4b81 363 | 5y86 364 | 1caq 365 | 1k3t 366 | 2qrm 367 | 4u68 368 | 1exw 369 | 1nlj 370 | 3kdu 371 | 1axr 372 | 1kkq 373 | 3mrx 374 | 4z2h 375 | 3oki 376 | 6chl 377 | 1ppl 378 | 2q95 379 | 5q13 380 | 2j7g 381 | 2w2u 382 | 5aqq 383 | 1ms0 384 | 2b2v 385 | 6bfa 386 | 5kew 387 | 2cbu 388 | 1nli 389 | 6bw3 390 | 4h4e 391 | 2ha5 392 | 2aq9 393 | 1g98 394 | 2pri 395 | 1apv 396 | 1gar 397 | 3szm 398 | 8gpb 399 | 1noi 400 | 5foo 401 | 2z4w 402 | 3r93 403 | 2z7i 404 | 5vdr 405 | 5ylu 406 | 6f8r 407 | 3upx 408 | 3zyr 409 | 2jnp 410 | 2nm1 411 | 4kab 412 | 4n8r 413 | 5z4h 414 | 4f9v 415 | 3l7a 416 | 2am9 417 | 2wec 418 | 1h6h 419 | 3nfl 420 | 3wd1 421 | 4mdt 422 | 3sxf 423 | 6ftn 424 | 3gt9 425 | 1oim 426 | 1ywh 427 | 1u9w 428 | 1w70 429 | 3u1i 430 | 2v3e 431 | 5jap 432 | 3b68 433 | 2qnb 434 | 5hbs 435 | 2ama 436 | 2web 437 | 6e4t 438 | 4xm6 439 | 6hd6 440 | 4og4 441 | 3il6 442 | 4zs2 443 | 2z50 444 | 1nki 445 | 4my6 446 | 3vvy 447 | 3nc4 448 | 2z4y 449 | 2euk 450 | 3g4k 451 | 2y2i 452 | 1usn 453 | 1y6r 454 | 3bu6 455 | 4u0c 456 | 2ces 457 | 5ye7 458 | 2cbv 459 | 5twg 460 | 1f40 461 | 5m2q 462 | 4hvs 463 | 4z2i 464 | 3g0e 465 | 3dct 466 | 6bnl 467 | 5v8q 468 | 1au2 469 | 3h0a 470 | 2z52 471 | 4ie5 472 | 2auz 473 | 2qn2 474 | 4lh3 475 | 5lyy 476 | 4z2l 477 | 2bmz 478 | 2evm 479 | 5wvd 480 | 3o8h 481 | 6e4w 482 | 1syo 483 | 1yk7 484 | 5q11 485 | 2ha0 486 | 6ccn 487 | 1mh5 488 | 2ai8 489 | 3rde 490 | 5q17 491 | 4zsh 492 | 6aom 493 | 1pot 494 | 5i5x 495 | 4cqe 496 | 3l3x 497 | 4ie6 498 | 6cnk 499 | 1c12 500 | 1gfz 501 | 4jal 502 | 4ie7 503 | 5tvn 504 | 2qoh 505 | 4z2g 506 | 6b7b 507 | 5kxi 508 | 6f8v 509 | 1xor 510 | 1aqi 511 | 6q73 512 | 5t27 513 | 5q0r 514 | 1y2k 515 | 5q1b 516 | 5ko1 517 | 1a8i 518 | 3g2j 519 | 5gwz 520 | 1c8l 521 | 4o42 522 | 3r2a 523 | 1d7i 524 | 3ovz 525 | 3l3z 526 | 4i32 527 | 2ych 528 | 5aqp 529 | 4xkc 530 | 4og8 531 | 3g8i 532 | 2z7h 533 | 3kfa 534 | 3vfa 535 | 5q0v 536 | 3k41 537 | 2pwg 538 | 1au0 539 | 1g9d 540 | 4i8z 541 | 6cef 542 | 4mhs 543 | 4xm7 544 | 1c3e 545 | 3kx1 546 | 2gj4 547 | 5q16 548 | 4pli 549 | 2j7e 550 | 2j7d 551 | 6q6y 552 | 4gq6 553 | 3k5v 554 | 4j09 555 | 3bl9 556 | 2xuf 557 | 1e1y 558 | 6gi6 559 | 5z4o 560 | 5q18 561 | 4foc 562 | 6bcy 563 | 4aaw 564 | 2ha6 565 | 4pl5 566 | 3vw1 567 | 6bq0 568 | 2aux 569 | 1mkd 570 | 2q92 571 | 1xon 572 | 3aig 573 | 3oxz 574 | 2r6n 575 | 4mmp 576 | 5hki 577 | 2gg0 578 | 4qip 579 | 3ms2 580 | 5d1t 581 | 2ot1 582 | 2xpc 583 | 4zcs 584 | 5db3 585 | 3t3h 586 | 6cz4 587 | 4h4d 588 | 3h78 589 | 5q10 590 | 4jsr 591 | 4qll 592 | 4e90 593 | 6dry 594 | 5aqu 595 | 5oei 596 | 5hz5 597 | 3kwz 598 | 4m3d 599 | 5xig 600 | 3u3z 601 | 2y2k 602 | 5lz5 603 | 1o8b 604 | 4hlw 605 | 4tq3 606 | 1fkg 607 | 3pkn 608 | 2xb7 609 | 5ov9 610 | 3ggv 611 | 3cke 612 | 4m84 613 | 6cho 614 | 1m6p 615 | 5jal 616 | 1a0q 617 | 4eky 618 | 3vw2 619 | 2y2n 620 | 5q0q 621 | 6gin 622 | 2nmb 623 | 3rik 624 | 1akv 625 | 4m3e 626 | 4kwg 627 | 5yun 628 | 3mqf 629 | 4pkw 630 | 3k3h 631 | 3t3v 632 | 3t1n 633 | 2bdl 634 | 2prj 635 | 5wh5 636 | 4qhc 637 | 4eoy 638 | 2bv4 639 | 1i7g 640 | 5iui 641 | 4og5 642 | 5evk 643 | 2fsa 644 | 3sdg 645 | 3g2i 646 | 1y2d 647 | 1c7f 648 | 1qkn 649 | 2etm 650 | 3o1g 651 | 3t3d 652 | 1bxo 653 | 2z5o 654 | 3bla 655 | 5o9p 656 | 5g3n 657 | 5v1y 658 | 4gq4 659 | 2vvo 660 | 4u0b 661 | 1opi 662 | 3sut 663 | 3wd2 664 | 4xm8 665 | 4kp4 666 | 1hy7 667 | 1g05 668 | 5aaa 669 | 5wmt 670 | 2fj0 671 | 1bxq 672 | 5t2b 673 | 1o6i 674 | 4xdo 675 | 5ez0 676 | 5wqj 677 | 5t8e 678 | 6g22 679 | 3o0u 680 | 2gfd 681 | 5fpp 682 | 1tuf 683 | 4v0i 684 | 4og6 685 | 3g4g 686 | 2std 687 | 1xnz 688 | 2dw7 689 | 4oue 690 | 6ds0 691 | 5jar 692 | 4ibm 693 | 1d5j 694 | 2hrp 695 | 1koj 696 | 1d7j 697 | 4ryl 698 | 2f6j 699 | 4eke 700 | 4btl 701 | 6b7d 702 | 3bwk 703 | 5aqg 704 | 4i80 705 | 1c3x 706 | 2qrq 707 | 1oif 708 | 2p9a 709 | 5f67 710 | 4mc9 711 | 4dpu 712 | 3il5 713 | 6bnk 714 | 4lh7 715 | 6ccl 716 | 4m3b 717 | 6drz 718 | 4ebw 719 | 6et8 720 | 1g9b 721 | 3vvz 722 | 5q12 723 | 1jys 724 | 1g9a 725 | 5q1c 726 | 4mc6 727 | 2gg9 728 | 5t2m 729 | 3gta 730 | 5q0w 731 | 5oa2 732 | 3mt9 733 | 5iql 734 | 5q0t 735 | 2gkl 736 | 1z95 737 | 6c91 738 | 2z4z 739 | 3syr 740 | 4g16 741 | 3qi3 742 | 1z6p 743 | 3p8o 744 | 1qpl 745 | 2pix 746 | 4crj 747 | 2cet 748 | 4wf6 749 | 4qfr 750 | 1y2c 751 | 4gh6 752 | 1ct8 753 | 3guz 754 | 1oyn 755 | 1d8f 756 | 4x6h 757 | 3gp0 758 | 2srt 759 | 4k63 760 | 1pwp 761 | 4k66 762 | 4ql8 763 | 4ie4 764 | 2fm5 765 | 3g4l 766 | 5ix1 767 | 5d1u 768 | 4y8c 769 | 2evl 770 | 5dde 771 | 5y7w 772 | 6clv 773 | 2fu8 774 | 3hg1 775 | 4xe0 776 | 5k1i 777 | 3c9e 778 | 1gpy 779 | 2gg2 780 | 5vdv 781 | 5eyz 782 | 2wc4 783 | 4qlk 784 | 3t3g 785 | 4xrq 786 | 3v5p 787 | 1exv 788 | 1std 789 | 5jjm 790 | 5cc2 791 | 4f9u 792 | 5jao 793 | 5dda 794 | 3eta 795 | 6f6u 796 | 6cee 797 | 4pl6 798 | 3ms9 799 | 4kwf 800 | 5q0s 801 | 5q1e 802 | 5o83 803 | 5lz7 804 | 5kq5 805 | 5xij 806 | 5kh3 807 | 4qfg 808 | 3ebo 809 | 2zdx 810 | 6q74 811 | 5bpe 812 | 4poj 813 | 4qgi 814 | 2n7b 815 | 1ow7 816 | 3sx9 817 | 2e92 818 | 2amv 819 | 4std 820 | 5ur9 821 | 2jdl 822 | 3ktr 823 | 1ogg 824 | 1onh 825 | 4ad6 826 | 3sl5 827 | 5v8o 828 | 4yrd 829 | 4dce 830 | 3rcd 831 | 3g4i 832 | 3zqt 833 | 3olf 834 | 1j1a 835 | 1aqj 836 | 3fq7 837 | 4cmo 838 | 3b66 839 | 4htp 840 | 5vdw 841 | 3l79 842 | 3usn 843 | 4i4e 844 | 3d27 845 | 2qrh 846 | 2wc3 847 | 4djh 848 | 1jif 849 | 3g58 850 | 3mt7 851 | 4yec 852 | 6b7c 853 | 1y2b 854 | 2v3u 855 | 2qlm 856 | 7std 857 | 2vpe 858 | 2qln 859 | 5wbk 860 | 5ftq 861 | 3fei 862 | 2nsx 863 | 4ebv 864 | 4b82 865 | 4pp3 866 | 5ddd 867 | 1l5r 868 | 4psb 869 | 4cnh 870 | 1azl 871 | 5evb 872 | 5dpx 873 | 4k9y 874 | 1jq3 875 | 2ggb 876 | 2gm9 877 | 4g2j 878 | 5lz8 879 | 2ylq 880 | 5kz0 881 | 5t8j 882 | 4cts 883 | 2j75 884 | 3mt8 885 | 1n3z 886 | 4rme 887 | 2gg8 888 | 4z2k 889 | 2ha7 890 | 4mi3 891 | 4k6i 892 | 3nfk 893 | 4pl4 894 | 6ce8 895 | 6bx6 896 | 1xom 897 | 4u0a 898 | 3djp 899 | 4zeb 900 | 5q0j 901 | 2wos 902 | 5vtb 903 | 1u9x 904 | 1g4k 905 | 1nc3 906 | 4gu9 907 | 6e4u 908 | 4b83 909 | 2p99 910 | 5dtj 911 | 3o8g 912 | 2fwp 913 | 4fod 914 | 3np9 915 | 5o9q 916 | 3v9b 917 | 1i1e 918 | 3jsw 919 | 4gu6 920 | 1snk 921 | 1bm6 922 | 4fnz 923 | 4qfs 924 | 4nj9 925 | 2ha2 926 | 4ej2 927 | 6ced 928 | 3jsi 929 | 2d1o 930 | 4mc2 931 | 6g2m 932 | 6e86 933 | 1l5q 934 | 3g0f 935 | 2jal 936 | 5y1u 937 | 2ya7 938 | 5jau 939 | 4c4n 940 | 3tmk 941 | 1t46 942 | 5ddb 943 | 5n6s 944 | 1ppm 945 | 5tbn 946 | 2wbg 947 | 6d28 948 | 5q14 949 | 3ik3 950 | 5w99 951 | 5q1h 952 | 4joa 953 | 5ha1 954 | 3m3z 955 | 4pzv 956 | 5dd9 957 | 2e91 958 | 1mem 959 | 1rdt 960 | 5vds 961 | 2xwd 962 | 5k32 963 | 3g4f 964 | 4x5y 965 | 3mtb 966 | 2cc7 967 | 4pkr 968 | 1gyx 969 | 5jas 970 | 1xoq 971 | 1u9v 972 | 3mtd 973 | 3kwb 974 | 5aqn 975 | 4ac3 976 | 2ylp 977 | 3p0g 978 | 3bz3 979 | 1xow 980 | 3ew2 981 | 1akq 982 | 5da3 983 | 4lh6 984 | 1db5 985 | 1g27 986 | 2ao6 987 | 5z9e 988 | 5zun 989 | 4cwb 990 | 2ccc 991 | 5tbp 992 | 1nl6 993 | 4pkv 994 | 2ww2 995 | 3upz 996 | 5aab 997 | 2ha4 998 | 3mss 999 | 1zkn 1000 | 4y87 1001 | 2pyi 1002 | 2yhd 1003 | 3rw9 1004 | 3f7h 1005 | 4q9s 1006 | 2g9u 1007 | 4jt9 1008 | 5twh 1009 | 4bj8 1010 | 4pl3 1011 | 2y2h 1012 | 4mi9 1013 | 5cdh 1014 | 4n5g 1015 | 7gpb 1016 | 2wr8 1017 | 3i7b 1018 | 1q9m 1019 | 1p2g 1020 | 4kao 1021 | 5l8n 1022 | 1bl4 1023 | 3iad 1024 | 1q6k 1025 | 4i31 1026 | 4fob 1027 | 5mlj 1028 | 5hm3 1029 | 2oz7 1030 | 5ehq 1031 | 4u0d 1032 | 6b2q 1033 | 4m3f 1034 | 3tcg 1035 | 6ccq 1036 | 4x0u 1037 | 1y6q 1038 | 3iof 1039 | 5db0 1040 | 1n4k 1041 | 4wht 1042 | 4dpy 1043 | 4cli 1044 | 3msc 1045 | 2ylo 1046 | 4x7q 1047 | 1g2a 1048 | 4arb 1049 | 5ncy 1050 | 1zaj 1051 | 3qt6 1052 | 3npa 1053 | 5aqh 1054 | 5oku 1055 | 1yon 1056 | 3ekn 1057 | 2bb7 1058 | 1akr 1059 | 5h2u 1060 | 4cfe 1061 | 4why 1062 | 3ril 1063 | 5q1d 1064 | 5aqo 1065 | 4cxw 1066 | 5osy 1067 | 4m8h 1068 | 1h5u 1069 | 5yea 1070 | 5t2g 1071 | 1c50 1072 | 5l3j 1073 | 4cxy 1074 | 6cco 1075 | 1ow8 1076 | 4k4j 1077 | 5q19 1078 | 5oxg 1079 | 3sus 1080 | 3kw9 1081 | 5wqk 1082 | 6f8u 1083 | 4i33 1084 | 4z2j 1085 | 1y2e 1086 | 4xpj 1087 | 6h0b 1088 | 2wor 1089 | 3ldq 1090 | 3ebp 1091 | 1bqo 1092 | 3ook 1093 | 3l7b 1094 | 1ow6 1095 | 5ye9 1096 | 2off 1097 | 1noj 1098 | 2aig 1099 | 1iup 1100 | 5eou 1101 | 5db2 1102 | 4wcu 1103 | 3ewc 1104 | 6ce6 1105 | 5fto 1106 | 4zei 1107 | 4b80 1108 | 3qi4 1109 | 2xi7 1110 | 2bqv 1111 | 5fkj 1112 | 4wj7 1113 | 6ez6 1114 | 1yhm 1115 | 2z92 1116 | 3sz9 1117 | 5ytu 1118 | 6f8t 1119 | 3amv 1120 | 3eyf 1121 | 5iug 1122 | 5d7a 1123 | 4clj 1124 | 5fum 1125 | 3v5t 1126 | 3ms7 1127 | 1yqy 1128 | 3aox 1129 | 4yjn 1130 | 3o4l 1131 | 2ax9 1132 | 5yto 1133 | 2wed 1134 | 3ozj 1135 | 2whp 1136 | 2qrg 1137 | 2gg5 1138 | 1k08 1139 | 2flh 1140 | 1l5s 1141 | 3n51 1142 | 2vpg 1143 | 5jat 1144 | 6drx 1145 | 4ktc 1146 | 4k8a 1147 | 2zof 1148 | 5aa9 1149 | 1kcs 1150 | 1y4z 1151 | 5oa6 1152 | 4du8 1153 | 2xwe 1154 | 3ms4 1155 | 2y2j 1156 | 6chn 1157 | 5q0l 1158 | 5aa8 1159 | 2qdt 1160 | 4a16 1161 | 3u8d 1162 | 5t28 1163 | 4xkb 1164 | 4hgl 1165 | 4l4v 1166 | 2gg3 1167 | 5ddf 1168 | 4ra1 1169 | 3t3u 1170 | 1ciz 1171 | 2j7x 1172 | 1x8d 1173 | 1kvo 1174 | 1b8y 1175 | 4yik 1176 | 1osv 1177 | 2hdx 1178 | 1k06 1179 | 3g1m 1180 | 5aqf 1181 | 1d7x 1182 | 5yf1 1183 | 3b5r 1184 | 3r0h 1185 | 6b41 1186 | 4mic 1187 | 2rin 1188 | 3bpc 1189 | 2e5y 1190 | 1n5r 1191 | 2j77 1192 | 1gag 1193 | 3djo 1194 | 4zec 1195 | 5xwr 1196 | 5d1s 1197 | 1uz1 1198 | 3sl8 1199 | 2j79 1200 | 3r5m 1201 | 3b65 1202 | 2e95 1203 | 3t3e 1204 | 5cj6 1205 | 1nok 1206 | 5wpb 1207 | 1hfs 1208 | 6e5x 1209 | 5evd 1210 | 5ikb 1211 | 5aqr 1212 | 3p8n 1213 | 5q0z 1214 | 1dg9 1215 | 3qt7 1216 | 5jah 1217 | 5ax9 1218 | 2q96 1219 | 2j7f 1220 | 5q1g 1221 | 2y2p 1222 | 5v84 1223 | 4pji 1224 | -------------------------------------------------------------------------------- /data/timesplit_no_lig_overlap_val: -------------------------------------------------------------------------------- 1 | 4lp9 2 | 1me7 3 | 2zv9 4 | 2qo8 5 | 1cw2 6 | 3k5c 7 | 2o65 8 | 4kqq 9 | 3rdv 10 | 1d4w 11 | 1q4l 12 | 4b5w 13 | 4bgg 14 | 4mm5 15 | 3iej 16 | 3ftu 17 | 830c 18 | 2xye 19 | 1olu 20 | 2wk2 21 | 4pxf 22 | 5o0j 23 | 1my2 24 | 5czm 25 | 4jit 26 | 5mb1 27 | 1sqp 28 | 3zlw 29 | 4xqu 30 | 3hkq 31 | 6fns 32 | 5e0l 33 | 2p8o 34 | 4gzw 35 | 3n87 36 | 1lhc 37 | 4itj 38 | 4m7c 39 | 4olh 40 | 4q1e 41 | 5l7e 42 | 3faa 43 | 5vqx 44 | 3pka 45 | 5x54 46 | 5a9u 47 | 4n9e 48 | 4est 49 | 1il9 50 | 4igr 51 | 3t2t 52 | 6dar 53 | 3gol 54 | 3vbg 55 | 2ydk 56 | 4zpf 57 | 5zo7 58 | 4xnw 59 | 1fpy 60 | 2r1y 61 | 6m8w 62 | 2jds 63 | 5icx 64 | 1hwr 65 | 6bj2 66 | 4b4m 67 | 1zsb 68 | 4do3 69 | 3t3i 70 | 1f8a 71 | 2ke1 72 | 5ezx 73 | 3p78 74 | 4rvm 75 | 3ovn 76 | 5wzv 77 | 4udb 78 | 1okz 79 | 1mpl 80 | 5npc 81 | 5ff6 82 | 1hlf 83 | 1nvq 84 | 4bhf 85 | 4y4g 86 | 5mkz 87 | 2o0u 88 | 3bcs 89 | 1wvc 90 | 4fsl 91 | 3oz1 92 | 6dgt 93 | 1me8 94 | 2puy 95 | 4odp 96 | 1hpx 97 | 4nrq 98 | 1z2b 99 | 3uik 100 | 3mfv 101 | 3vqh 102 | 4w9g 103 | 4xek 104 | 4jok 105 | 2wap 106 | 1g50 107 | 4j0p 108 | 2o9a 109 | 3m94 110 | 4i1c 111 | 5a82 112 | 4i9h 113 | 1k1i 114 | 4uro 115 | 2f7i 116 | 5fpk 117 | 2lgf 118 | 4l7f 119 | 1g3d 120 | 4ir5 121 | 3mta 122 | 3jzg 123 | 5f94 124 | 4nrt 125 | 4yax 126 | 5nhv 127 | 2xtk 128 | 4qh7 129 | 1tok 130 | 4b6p 131 | 3rg2 132 | 3q8d 133 | 3obu 134 | 4awj 135 | 3daj 136 | 2j50 137 | 5l2z 138 | 5bml 139 | 2bba 140 | 5n34 141 | 2xvn 142 | 1dpu 143 | 5fnt 144 | 1jyc 145 | 4zz1 146 | 6hm7 147 | 4rrv 148 | 4rww 149 | 5orv 150 | 3qo2 151 | 3uii 152 | 6d1x 153 | 3juq 154 | 4qk4 155 | 6mr5 156 | 5hjc 157 | 2p4s 158 | 2hnc 159 | 1k4g 160 | 4g0c 161 | 2y5g 162 | 4u3f 163 | 3tv5 164 | 1i3z 165 | 4mw7 166 | 3n2c 167 | 6cvw 168 | 3v66 169 | 3wzp 170 | 3s7m 171 | 5ujv 172 | 1p06 173 | 3ipy 174 | 4wkt 175 | 4ie0 176 | 5fot 177 | 5i59 178 | 5za9 179 | 4gii 180 | 4h2o 181 | 4yrs 182 | 5a6h 183 | 2xo8 184 | 4e3n 185 | 4m5k 186 | 3dga 187 | 6fse 188 | 6ck6 189 | 1sqc 190 | 4x1r 191 | 3dnj 192 | 3rvi 193 | 2a58 194 | 4bf6 195 | 3zlk 196 | 4mbj 197 | 4tpm 198 | 4d8c 199 | 1ejn 200 | 4yt6 201 | 2x7x 202 | 4qp1 203 | 4de3 204 | 5yg4 205 | 1x7b 206 | 5n9s 207 | 2fme 208 | 1ydt 209 | 2bdf 210 | 6baw 211 | 6fsd 212 | 2xn3 213 | 4tk0 214 | 3q4j 215 | 1u9l 216 | 1oqp 217 | 5htz 218 | 4glr 219 | 5kj0 220 | 5ukl 221 | 3fun 222 | 4wk2 223 | 4ht6 224 | 5hv1 225 | 1uze 226 | 4bcc 227 | 3ff6 228 | 5if6 229 | 1tsm 230 | 2r59 231 | 3iqh 232 | 2v7a 233 | 5d10 234 | 5nvh 235 | 3eqr 236 | 1jq9 237 | 1u1b 238 | 6cer 239 | 5uq9 240 | 1u3s 241 | 5icy 242 | 3exh 243 | 2oqs 244 | 1pzp 245 | 1d4i 246 | 4x6p 247 | 4mb9 248 | 5emk 249 | 1iky 250 | 6b7f 251 | 3chq 252 | 3h5s 253 | 5zmq 254 | 4ib5 255 | 2wej 256 | 6fjm 257 | 5ewa 258 | 2igx 259 | 2z78 260 | 5lpm 261 | 4wet 262 | 3lxl 263 | 2xba 264 | 5wbl 265 | 5zla 266 | 2x6x 267 | 4mw9 268 | 5t2d 269 | 4j3m 270 | 4aqh 271 | 3lbk 272 | 4djp 273 | 4odl 274 | 4x6j 275 | 1ero 276 | 5f3t 277 | 4k3q 278 | 5ta4 279 | 1caq 280 | 2eg7 281 | 1f73 282 | 3rxg 283 | 6ezq 284 | 1qkt 285 | 5l3e 286 | 5c28 287 | 4pp9 288 | 4bgk 289 | 3iaf 290 | 5vrp 291 | 5zz4 292 | 5ur5 293 | 3ft2 294 | 5ech 295 | 4jjq 296 | 5iz6 297 | 5dhr 298 | 4l2g 299 | 4r17 300 | 3wk6 301 | 4h1e 302 | 2aq9 303 | 5g1n 304 | 3zm9 305 | 5c4l 306 | 5mfs 307 | 1fzj 308 | 2ltw 309 | 4x7i 310 | 4c94 311 | 2cfg 312 | 2va5 313 | 3vb6 314 | 2hob 315 | 5ah2 316 | 5syn 317 | 3g6g 318 | 3rwj 319 | 5sz4 320 | 4f9v 321 | 5n2d 322 | 3n9r 323 | 5ldo 324 | 3vb7 325 | 1sqo 326 | 3drg 327 | 5j9y 328 | 6b96 329 | 4yz9 330 | 1vcj 331 | 5epr 332 | 4tx6 333 | 3dz6 334 | 3czv 335 | 5v49 336 | 1ahy 337 | 3wzq 338 | 1bq4 339 | 5u8c 340 | 6bj3 341 | 2qnb 342 | 4a9m 343 | 3d4f 344 | 5oui 345 | 5wmg 346 | 6ma4 347 | 4x5q 348 | 5cbr 349 | 6msy 350 | 5avi 351 | 1g3b 352 | 2wi4 353 | 3kjn 354 | 4dhn 355 | 4o7e 356 | 5kit 357 | 5y5t 358 | 3hfj 359 | 2qd8 360 | 5vsj 361 | 2y2i 362 | 5m0m 363 | 3tcp 364 | 4bhz 365 | 1jd6 366 | 5idn 367 | 4zzx 368 | 4kn4 369 | 2a5c 370 | 6hly 371 | 1au2 372 | 4jbo 373 | 5cgj 374 | 3ske 375 | 3lq2 376 | 4pxm 377 | 2wxg 378 | 5tb6 379 | 2vc7 380 | 3iw4 381 | 5hct 382 | 3skf 383 | 5lyy 384 | 3fmz 385 | 4p5z 386 | 5ktw 387 | 6e4w 388 | 1cx9 389 | 6em7 390 | 4mjr 391 | 4u7t 392 | 3rde 393 | 4ux4 394 | 4i6f 395 | 3l3x 396 | 4ie6 397 | 4j70 398 | 1jd0 399 | 4iaw 400 | 1szm 401 | 2afw 402 | 3ess 403 | 3sap 404 | 1olx 405 | 1bzh 406 | 5hfb 407 | 4x3h 408 | 5we9 409 | 3zsw 410 | 5ny6 411 | 1hn2 412 | 3l3z 413 | 4qp2 414 | 1d4p 415 | 4xkc 416 | 2is0 417 | 6c7e 418 | 5zku 419 | 4fai 420 | 6g9a 421 | 4xu3 422 | 5dry 423 | 4d8z 424 | 3zcz 425 | 3kbz 426 | 2y59 427 | 4nal 428 | 4rpv 429 | 4yje 430 | 3vf8 431 | 4bqx 432 | 4z9l 433 | 4ep2 434 | 4ylk 435 | 5mme 436 | 4dht 437 | 2uy4 438 | 6mu3 439 | 3kx1 440 | 5o0s 441 | 4bch 442 | 5c4k 443 | 2br1 444 | 4ddh 445 | 2f9k 446 | 2w2i 447 | 4ogn 448 | 4up5 449 | 5o4y 450 | 5hjd 451 | 2qw1 452 | 5y8z 453 | 4kqr 454 | 1o2t 455 | 6e05 456 | 3u7l 457 | 2mip 458 | 3hvg 459 | 2p59 460 | 4d3h 461 | 4pl5 462 | 3tzd 463 | 2vnp 464 | 4e3m 465 | 3vgc 466 | 5bqi 467 | 1b7h 468 | 1lhu 469 | 3rlr 470 | 3h22 471 | 2wnc 472 | 2wot 473 | 5d1t 474 | 3mo0 475 | 4wn5 476 | 3p3u 477 | 1nfs 478 | 4e90 479 | 5aqu 480 | 1bmq 481 | 3kwz 482 | 6f6n 483 | 4rj5 484 | 4omd 485 | 6min 486 | 1ujj 487 | 4ppa 488 | 4uxl 489 | 5y3n 490 | 6df2 491 | 4wvl 492 | 1xt3 493 | 5oaj 494 | 4a9r 495 | 5mli 496 | 4p4e 497 | 3juo 498 | 1z9g 499 | 2ykc 500 | 5a0e 501 | 3g0w 502 | 5t9w 503 | 1sqa 504 | 3wci 505 | 1fkw 506 | 5u4g 507 | 4mfe 508 | 4kpx 509 | 3nti 510 | 3azb 511 | 2xog 512 | 3c3r 513 | 2buc 514 | 1hyz 515 | 4dcd 516 | 6azl 517 | 3t3d 518 | 3q4l 519 | 4few 520 | 1q95 521 | 4u0b 522 | 3b7u 523 | 4bo4 524 | 4o10 525 | 5wmt 526 | 5v9t 527 | 5aok 528 | 1jtq 529 | 5uit 530 | 2vgc 531 | 2gfd 532 | 3mna 533 | 1aqc 534 | 4xtt 535 | 4z0d 536 | 4ty9 537 | 2yiv 538 | 2hrp 539 | 4zh2 540 | 2z4o 541 | 1qku 542 | 2xdw 543 | 4n7j 544 | 4yp1 545 | 3exf 546 | 4c6z 547 | 6ccu 548 | 2wxn 549 | 1bwb 550 | 2gvf 551 | 1hiy 552 | 5c4t 553 | 2za5 554 | 2xkf 555 | 4q18 556 | 1o2p 557 | 5th2 558 | 4dj7 559 | 3eyd 560 | 4j0r 561 | 2m3o 562 | 2b53 563 | 4m3b 564 | 2izl 565 | 2vtr 566 | 2x6d 567 | 2i0a 568 | 5ehg 569 | 6cw4 570 | 4c37 571 | 3cwj 572 | 1azm 573 | 2qci 574 | 5sz0 575 | 2gkl 576 | 2z4z 577 | 6awo 578 | 1v11 579 | 4l53 580 | 3p55 581 | 2ynn 582 | 2vu3 583 | 4dli 584 | 2bcd 585 | 4l0s 586 | 4uda 587 | 3m37 588 | 5j5t 589 | 2p16 590 | 4gh6 591 | 1mfg 592 | 3s3i 593 | 4j73 594 | 2v5x 595 | 2h4n 596 | 4jsz 597 | 4wk1 598 | 4igt 599 | 4k63 600 | 3qqk 601 | 16pk 602 | 5aom 603 | 1hyv 604 | 5a3w 605 | 3veh 606 | 3g4l 607 | 2ph8 608 | 5mkx 609 | 5c4u 610 | 4gto 611 | 3cj5 612 | 4prj 613 | 2vd7 614 | 5duc 615 | 3odi 616 | 6bg5 617 | 1qwu 618 | 5jn8 619 | 1v1m 620 | 1qpe 621 | 5v3r 622 | 2wc4 623 | 2vte 624 | 1a52 625 | 4dhq 626 | 2qta 627 | 6ccy 628 | 4jog 629 | 4bgy 630 | 5u9i 631 | 3az9 632 | 1gt1 633 | 2jew 634 | 3pdc 635 | 1n3i 636 | 5fyx 637 | 4f49 638 | 4nzn 639 | 6hm2 640 | 4a4l 641 | 5xij 642 | 5vk0 643 | 4xsx 644 | 2aj8 645 | 4odq 646 | 2n7b 647 | 4ygf 648 | 2a4q 649 | 2jc0 650 | 4jsa 651 | 1inq 652 | 3dc3 653 | 5tob 654 | 4urn 655 | 6bik 656 | 4ju4 657 | 5nya 658 | 5oh2 659 | 5znr 660 | 5ct2 661 | 3u4u 662 | 4x7h 663 | 3max 664 | 3rbm 665 | 3krj 666 | 1aj6 667 | 1pmv 668 | 5n0e 669 | 4nhy 670 | 4oem 671 | 6fi4 672 | 4e3j 673 | 1fq4 674 | 5myr 675 | 2hkf 676 | 1os0 677 | 3rqg 678 | 4ivc 679 | 5c7b 680 | 3lq4 681 | 1u6q 682 | 1qxz 683 | 1l5r 684 | 4xxh 685 | 3m40 686 | 5or9 687 | 4okg 688 | 4d89 689 | 2gm9 690 | 5x33 691 | 4de0 692 | 4gr8 693 | 5lz8 694 | 1p93 695 | 2brp 696 | 2gg8 697 | 6fdt 698 | 5cxh 699 | 1jvu 700 | 3wp1 701 | 1fzm 702 | 5cxa 703 | 2gbg 704 | 2g78 705 | 5aml 706 | 2y34 707 | 2qnp 708 | 1v16 709 | 1njj 710 | 2a5u 711 | 4z88 712 | 4wmx 713 | 5vo2 714 | 4fod 715 | 2pou 716 | 3jsw 717 | 2ow2 718 | 5g3m 719 | 3odl 720 | 3o9e 721 | 3eyh 722 | 4ej2 723 | 3c4e 724 | 4b6f 725 | 1pl0 726 | 3pb8 727 | 6fap 728 | 4iax 729 | 2bua 730 | 6fgg 731 | 2o4h 732 | 4uwh 733 | 5wbf 734 | 2yxj 735 | 1ff1 736 | 2giu 737 | 1qbt 738 | 2ovq 739 | 4bak 740 | 2y3p 741 | 2iwu 742 | 3hvi 743 | 2w0x 744 | 3fcl 745 | 1zpa 746 | 5czb 747 | 3t1l 748 | 2cfd 749 | 3k3g 750 | 4cfw 751 | 2e91 752 | 5op8 753 | 3hig 754 | 6h7y 755 | 3mtb 756 | 4eb9 757 | 4lkg 758 | 5ehv 759 | 5ier 760 | 4ode 761 | 1xoq 762 | 5d6p 763 | 3kwa 764 | 5np8 765 | 5v82 766 | 6ma1 767 | 3bz3 768 | 3myq 769 | 4j0s 770 | 4f4p 771 | 4lh6 772 | 1uef 773 | 4j3d 774 | 4yx4 775 | 4amx 776 | 4ptg 777 | 2c97 778 | 4ec4 779 | 4r1v 780 | 1zc9 781 | 4nuf 782 | 3g2u 783 | 6hlx 784 | 5vij 785 | 2x4o 786 | 6hlz 787 | 4lkj 788 | 3s75 789 | 2gz8 790 | 1gvk 791 | 2yhd 792 | 3hqz 793 | 3pb7 794 | 1thr 795 | 4ris 796 | 5twh 797 | 4gql 798 | 3n3l 799 | 3acx 800 | 5yvx 801 | 3gy2 802 | 1xmu 803 | 5l6p 804 | 5l8n 805 | 4msn 806 | 4rz1 807 | 3f66 808 | 3ucj 809 | 5hcl 810 | 1t1r 811 | 3kce 812 | 3u15 813 | 1wbg 814 | 5khi 815 | 3er5 816 | 4qew 817 | 5mft 818 | 6eqp 819 | 5gsw 820 | 2qd7 821 | 4cli 822 | 3f9w 823 | 3msc 824 | 1jgl 825 | 3kid 826 | 1ymx 827 | 1ui0 828 | 3d1f 829 | 1pxl 830 | 5kos 831 | 3vzd 832 | 5fcz 833 | 3ara 834 | 4li6 835 | 5ks7 836 | 4wym 837 | 5j7q 838 | 4qsh 839 | 2ce9 840 | 5vqz 841 | 3o2m 842 | 4bcm 843 | 5orx 844 | 1i41 845 | 3c5u 846 | 4kai 847 | 6gjy 848 | 4tsz 849 | 5o0e 850 | 6drt 851 | 1y57 852 | 3kqb 853 | 3jup 854 | 5ork 855 | 3ikc 856 | 3gwu 857 | 4wke 858 | 4x7l 859 | 3lp1 860 | 5ivy 861 | 3f16 862 | 4c36 863 | 1w2x 864 | 2d06 865 | 1hbj 866 | 1ols 867 | 1iup 868 | 5aix 869 | 1ydd 870 | 5w4r 871 | 3h23 872 | 3rj7 873 | 4ish 874 | 1ebw 875 | 1fcy 876 | 1d09 877 | 5hdv 878 | 4x1n 879 | 5boj 880 | 2xn7 881 | 4b6s 882 | 3f82 883 | 4clj 884 | 4zzz 885 | 5j5d 886 | 2vts 887 | 1k08 888 | 3u3f 889 | 4jk6 890 | 4csy 891 | 6hth 892 | 2mnz 893 | 2vpg 894 | 2qd6 895 | 4jkw 896 | 3ml5 897 | 1ih0 898 | 4at5 899 | 5dgu 900 | 4g31 901 | 5n0d 902 | 5aa9 903 | 4u4s 904 | 5oa6 905 | 2wzm 906 | 4b4q 907 | 6fi1 908 | 6chn 909 | 1z4u 910 | 5aa8 911 | 1lpk 912 | 3cib 913 | 5d75 914 | 5x4o 915 | 1ydb 916 | 5dhq 917 | 5t28 918 | 4zz0 919 | 3evf 920 | 5vyy 921 | 6eip 922 | 1q63 923 | 3ldw 924 | 5tq4 925 | 5uxf 926 | 2j7x 927 | 4kil 928 | 1yda 929 | 3bc4 930 | 2ew5 931 | 6ee3 932 | 4yrr 933 | 3wax 934 | 3bzf 935 | 5ody 936 | 1k06 937 | 4j84 938 | 5l6h 939 | 5eok 940 | 5nne 941 | 5m6m 942 | 2a4r 943 | 3p1d 944 | 2ayp 945 | 3iux 946 | 4b0g 947 | 1jr1 948 | 4qo9 949 | 4bh4 950 | 4xt9 951 | 2ok1 952 | 2r7g 953 | 4uib 954 | 5mmn 955 | 5akj 956 | 3hs4 957 | 5wpb 958 | 6e5x 959 | 5vnd 960 | 5evd 961 | 5wlg 962 | 5l4m 963 | 4kiu 964 | 4own 965 | 5oh9 966 | 6arv 967 | 1xr9 968 | 4hv7 969 | -------------------------------------------------------------------------------- /data/timesplit_test: -------------------------------------------------------------------------------- 1 | 6qqw 2 | 6d08 3 | 6jap 4 | 6np2 5 | 6uvp 6 | 6oxq 7 | 6jsn 8 | 6hzb 9 | 6qrc 10 | 6oio 11 | 6jag 12 | 6moa 13 | 6hld 14 | 6i9a 15 | 6e4c 16 | 6g24 17 | 6jb4 18 | 6s55 19 | 6seo 20 | 6dyz 21 | 5zk5 22 | 6jid 23 | 5ze6 24 | 6qlu 25 | 6a6k 26 | 6qgf 27 | 6e3z 28 | 6te6 29 | 6pka 30 | 6g2o 31 | 6jsf 32 | 5zxk 33 | 6qxd 34 | 6n97 35 | 6jt3 36 | 6qtr 37 | 6oy1 38 | 6n96 39 | 6qzh 40 | 6qqz 41 | 6qmt 42 | 6ibx 43 | 6hmt 44 | 5zk7 45 | 6k3l 46 | 6cjs 47 | 6n9l 48 | 6ibz 49 | 6ott 50 | 6gge 51 | 6hot 52 | 6e3p 53 | 6md6 54 | 6hlb 55 | 6fe5 56 | 6uwp 57 | 6npp 58 | 6g2f 59 | 6mo7 60 | 6bqd 61 | 6nsv 62 | 6i76 63 | 6n53 64 | 6g2c 65 | 6eeb 66 | 6n0m 67 | 6uvy 68 | 6ovz 69 | 6olx 70 | 6v5l 71 | 6hhg 72 | 5zcu 73 | 6dz2 74 | 6mjq 75 | 6efk 76 | 6s9w 77 | 6gdy 78 | 6kqi 79 | 6ueg 80 | 6oxt 81 | 6oy0 82 | 6qr7 83 | 6i41 84 | 6cyg 85 | 6qmr 86 | 6g27 87 | 6ggb 88 | 6g3c 89 | 6n4e 90 | 6fcj 91 | 6quv 92 | 6iql 93 | 6i74 94 | 6qr4 95 | 6rnu 96 | 6jib 97 | 6izq 98 | 6qw8 99 | 6qto 100 | 6qrd 101 | 6hza 102 | 6e5s 103 | 6dz3 104 | 6e6w 105 | 6cyh 106 | 5zlf 107 | 6om4 108 | 6gga 109 | 6pgp 110 | 6qqv 111 | 6qtq 112 | 6gj6 113 | 6os5 114 | 6s07 115 | 6i77 116 | 6hhj 117 | 6ahs 118 | 6oxx 119 | 6mjj 120 | 6hor 121 | 6jb0 122 | 6i68 123 | 6pz4 124 | 6mhb 125 | 6uim 126 | 6jsg 127 | 6i78 128 | 6oxy 129 | 6gbw 130 | 6mo0 131 | 6ggf 132 | 6qge 133 | 6cjr 134 | 6oxp 135 | 6d07 136 | 6i63 137 | 6ten 138 | 6uii 139 | 6qlr 140 | 6sen 141 | 6oxv 142 | 6g2b 143 | 5zr3 144 | 6kjf 145 | 6qr9 146 | 6g9f 147 | 6e6v 148 | 5zk9 149 | 6pnn 150 | 6nri 151 | 6uwv 152 | 6ooz 153 | 6npi 154 | 6oip 155 | 6miv 156 | 6s57 157 | 6p8x 158 | 6hoq 159 | 6qts 160 | 6ggd 161 | 6pnm 162 | 6oy2 163 | 6oi8 164 | 6mhd 165 | 6agt 166 | 6i5p 167 | 6hhr 168 | 6p8z 169 | 6c85 170 | 6g5u 171 | 6j06 172 | 6qsz 173 | 6jbb 174 | 6hhp 175 | 6np5 176 | 6nlj 177 | 6qlp 178 | 6n94 179 | 6e13 180 | 6qls 181 | 6uil 182 | 6st3 183 | 6n92 184 | 6s56 185 | 6hzd 186 | 6uhv 187 | 6k05 188 | 6q36 189 | 6ic0 190 | 6hhi 191 | 6e3m 192 | 6qtx 193 | 6jse 194 | 5zjy 195 | 6o3y 196 | 6rpg 197 | 6rr0 198 | 6gzy 199 | 6qlt 200 | 6ufo 201 | 6o0h 202 | 6o3x 203 | 5zjz 204 | 6i8t 205 | 6ooy 206 | 6oiq 207 | 6od6 208 | 6nrh 209 | 6qra 210 | 6hhh 211 | 6m7h 212 | 6ufn 213 | 6qr0 214 | 6o5u 215 | 6h14 216 | 6jwa 217 | 6ny0 218 | 6jan 219 | 6ftf 220 | 6oxw 221 | 6jon 222 | 6cf7 223 | 6rtn 224 | 6jsz 225 | 6o9c 226 | 6mo8 227 | 6qln 228 | 6qqu 229 | 6i66 230 | 6mja 231 | 6gwe 232 | 6d3z 233 | 6oxr 234 | 6r4k 235 | 6hle 236 | 6h9v 237 | 6hou 238 | 6nv9 239 | 6py0 240 | 6qlq 241 | 6nv7 242 | 6n4b 243 | 6jaq 244 | 6i8m 245 | 6dz0 246 | 6oxs 247 | 6k2n 248 | 6cjj 249 | 6ffg 250 | 6a73 251 | 6qqt 252 | 6a1c 253 | 6oxu 254 | 6qre 255 | 6qtw 256 | 6np4 257 | 6hv2 258 | 6n55 259 | 6e3o 260 | 6kjd 261 | 6sfc 262 | 6qi7 263 | 6hzc 264 | 6k04 265 | 6op0 266 | 6q38 267 | 6n8x 268 | 6np3 269 | 6uvv 270 | 6pgo 271 | 6jbe 272 | 6i75 273 | 6qqq 274 | 6i62 275 | 6j9y 276 | 6g29 277 | 6h7d 278 | 6mo9 279 | 6jao 280 | 6jmf 281 | 6hmy 282 | 6qfe 283 | 5zml 284 | 6i65 285 | 6e7m 286 | 6i61 287 | 6rz6 288 | 6qtm 289 | 6qlo 290 | 6oie 291 | 6miy 292 | 6nrf 293 | 6gj5 294 | 6jad 295 | 6mj4 296 | 6h12 297 | 6d3y 298 | 6qr2 299 | 6qxa 300 | 6o9b 301 | 6ckl 302 | 6oir 303 | 6d40 304 | 6e6j 305 | 6i7a 306 | 6g25 307 | 6oin 308 | 6jam 309 | 6oxz 310 | 6hop 311 | 6rot 312 | 6uhu 313 | 6mji 314 | 6nrj 315 | 6nt2 316 | 6op9 317 | 6pno 318 | 6e4v 319 | 6k1s 320 | 6a87 321 | 6oim 322 | 6cjp 323 | 6pyb 324 | 6h13 325 | 6qrf 326 | 6mhc 327 | 6j9w 328 | 6nrg 329 | 6fff 330 | 6n93 331 | 6jut 332 | 6g2e 333 | 6nd3 334 | 6os6 335 | 6dql 336 | 6inz 337 | 6i67 338 | 6quw 339 | 6qwi 340 | 6npm 341 | 6i64 342 | 6e3n 343 | 6qrg 344 | 6nxz 345 | 6iby 346 | 6gj7 347 | 6qr3 348 | 6qr1 349 | 6s9x 350 | 6q4q 351 | 6hbn 352 | 6nw3 353 | 6tel 354 | 6p8y 355 | 6d5w 356 | 6t6a 357 | 6o5g 358 | 6r7d 359 | 6pya 360 | 6ffe 361 | 6d3x 362 | 6gj8 363 | 6mo2 364 | -------------------------------------------------------------------------------- /data_preparation/README.md: -------------------------------------------------------------------------------- 1 | # Files for the data preprocessing described in the paper -------------------------------------------------------------------------------- /data_preparation/find_disconnected_proteins.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import os 4 | 5 | import networkx as nx 6 | from biopandas.pdb import PandasPdb 7 | 8 | from scipy import spatial 9 | from tqdm import tqdm 10 | import numpy as np 11 | import pandas as pd 12 | 13 | from commons.utils import write_strings_to_txt 14 | 15 | 16 | pdb_path = 'data/PDBBind' 17 | casf_names = os.listdir('data/deepBSP/casf_test') 18 | bsp_names = os.listdir('data/deepBSP/pdbbind_filtered') 19 | pdbbind_names = os.listdir(pdb_path) 20 | 21 | df_pdb_id = pd.read_csv('data/PDBbind_index/INDEX_general_PL_name.2020', sep=" ", comment='#', header=None, names=['complex_name', 'year', 'pdb_id', 'd', 'e','f','g','h','i','j','k','l','m','n','o']) 22 | df_pdb_id = df_pdb_id[['complex_name','year','pdb_id']] 23 | 24 | df_data = pd.read_csv('data/PDBbind_index/INDEX_general_PL_data.2020', sep=" ", comment='#', header=None, names=['complex_name','resolution','year', 'logkd', 'kd', 'reference', 'ligand_name', 'a', 'b', 'c']) 25 | df_data = df_data[['complex_name','resolution','year', 'logkd', 'kd', 'reference', 'ligand_name']] 26 | 27 | cutoff = 5 28 | connected = [] 29 | for name in tqdm(pdbbind_names): 30 | df = PandasPdb().read_pdb(os.path.join(pdb_path, name, f'{name}_protein_obabel_reduce.pdb')).df['ATOM'] 31 | df.rename(columns={'chain_id': 'chain', 'residue_number': 'residue', 'residue_name': 'resname', 32 | 'x_coord': 'x', 'y_coord': 'y', 'z_coord': 'z', 'element_symbol': 'element'}, inplace=True) 33 | df = list(df.groupby(['chain'])) ## Not the same as sequence order ! 34 | 35 | chain_coords_list = [] 36 | for chain in df: 37 | chain_coords_list.append(chain[1][['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32)) 38 | 39 | num_chains = len(chain_coords_list) 40 | distance = np.full((num_chains, num_chains), -np.inf) 41 | for i in range(num_chains - 1): 42 | for j in range((i + 1), num_chains): 43 | pairwise_dis = spatial.distance.cdist(chain_coords_list[i],chain_coords_list[j]) 44 | distance[i, j] = np.min(pairwise_dis) 45 | distance[j, i] = np.min(pairwise_dis) 46 | src_list = [] 47 | dst_list = [] 48 | for i in range(num_chains): 49 | dst = list(np.where(distance[i, :] < cutoff)[0]) 50 | src = [i] * len(dst) 51 | src_list.extend(src) 52 | dst_list.extend(dst) 53 | graph = nx.Graph() 54 | graph.add_edges_from(zip(src_list, dst_list)) 55 | if nx.is_connected(graph): 56 | connected.append(name) 57 | else: 58 | print(f'not connected: {name}') 59 | write_strings_to_txt(connected, f'data/complex_names_connected_by_{cutoff}') 60 | print(len(connected)) -------------------------------------------------------------------------------- /data_preparation/move_valid_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shutil import copyfile 3 | 4 | from tqdm import tqdm 5 | 6 | from commons.utils import read_strings_from_txt 7 | 8 | data_path = '../data/PDBBind' 9 | overwrite = False 10 | names = sorted(os.listdir(data_path)) 11 | invalid_names = read_strings_from_txt('select_chains.log') 12 | 13 | valid_names = list(set(names) - set(invalid_names)) 14 | 15 | if not os.path.exists('../data/PDBBind_processed'): 16 | os.mkdir('../data/PDBBind_processed') 17 | for i, name in tqdm(enumerate(valid_names)): 18 | if not os.path.exists(f'../data/PDBBind_processed/{name}'): 19 | os.mkdir(f'../data/PDBBind_processed/{name}') 20 | rec_path = os.path.join(data_path, name, f'{name}_protein.pdb') 21 | 22 | copyfile(os.path.join(data_path, name, f'{name}_protein_processed.pdb'), f'../data/PDBBind_processed/{name}/{name}_protein_processed.pdb') 23 | copyfile(os.path.join(data_path, name, f'{name}_ligand.mol2'), f'../data/PDBBind_processed/{name}/{name}_ligand.mol2') 24 | copyfile(os.path.join(data_path, name, f'{name}_ligand.sdf'), 25 | f'../data/PDBBind_processed/{name}/{name}_ligand.sdf') 26 | 27 | 28 | -------------------------------------------------------------------------------- /data_preparation/openbabel_receptors.py: -------------------------------------------------------------------------------- 1 | # you need openbabel installed to use this (can be installed with anaconda) 2 | import os 3 | import subprocess 4 | 5 | import time 6 | 7 | from tqdm import tqdm 8 | 9 | start_time = time.time() 10 | data_path = 'data/PDBBind' 11 | overwrite = False 12 | names = sorted(os.listdir(data_path)) 13 | 14 | for i, name in tqdm(enumerate(names)): 15 | rec_path = os.path.join(data_path, name, f'{name}_protein.pdb') 16 | return_code = subprocess.run( 17 | f"obabel {rec_path} -O{os.path.join(data_path, name, f'{name}_protein_obabel.pdb')}", shell=True) 18 | print(return_code) 19 | 20 | 21 | print("--- %s seconds ---" % (time.time() - start_time)) 22 | -------------------------------------------------------------------------------- /data_preparation/reduce_receptors.py: -------------------------------------------------------------------------------- 1 | # you need reduce installed to use this: https://github.com/rlabduke/reduce 2 | import os 3 | import subprocess 4 | 5 | import time 6 | 7 | from tqdm import tqdm 8 | 9 | start_time = time.time() 10 | data_path = 'PDBBind' 11 | overwrite = False 12 | names = sorted(os.listdir(data_path)) 13 | 14 | for i, name in tqdm(enumerate(names)): 15 | rec_path = os.path.join(data_path, name, f'{name}_protein_obabel.pdb') 16 | return_code = subprocess.run( 17 | f"reduce -Trim {rec_path} > {os.path.join(data_path, name, f'{name}_protein_obabel_reduce_tmp.pdb')}", shell=True) 18 | print(return_code) 19 | return_code2 = subprocess.run( 20 | f"reduce -HIS {os.path.join(data_path, name, f'{name}_protein_obabel_reduce_tmp.pdb')} > {os.path.join(data_path, name, f'{name}_protein_obabel_reduce.pdb')}", shell=True) 21 | print(return_code2) 22 | return_code2 = subprocess.run( 23 | f"rm {os.path.join(data_path, name, f'{name}_protein_obabel_reduce_tmp.pdb')}", 24 | shell=True) 25 | print(return_code2) 26 | 27 | 28 | 29 | print("--- %s seconds ---" % (time.time() - start_time)) 30 | -------------------------------------------------------------------------------- /data_preparation/select_protein_chains.py: -------------------------------------------------------------------------------- 1 | # in this file we perform the removal of chains that have none of their atoms withing a 10 A radius of the ligand. 2 | # this additionally needs "conda install prody" 3 | 4 | import os 5 | import warnings 6 | 7 | import numpy as np 8 | import prody 9 | from Bio.PDB import PDBIO, PDBParser 10 | from Bio.PDB.PDBExceptions import PDBConstructionWarning 11 | from scipy import spatial 12 | from tqdm import tqdm 13 | from commons.process_mols import read_molecule 14 | 15 | cutoff = 10 16 | data_dir = '../data/PDBBind' 17 | names = os.listdir(data_dir) 18 | 19 | io = PDBIO() 20 | biopython_parser = PDBParser() 21 | for name in tqdm(names): 22 | rec_path = os.path.join(data_dir, name, f'{name}_protein_obabel_reduce.pdb') 23 | lig = read_molecule(os.path.join(data_dir, name, f'{name}_ligand.sdf'), sanitize=True, remove_hs=False) 24 | if lig == None: 25 | lig = read_molecule(os.path.join(data_dir, name, f'{name}_ligand.mol2'), sanitize=True, remove_hs=False) 26 | if lig == None: 27 | print('ligand was none for ', name) 28 | with open('select_chains.log', 'a') as file: 29 | file.write(f'{name}\n') 30 | continue 31 | conf = lig.GetConformer() 32 | lig_coords = conf.GetPositions() 33 | with warnings.catch_warnings(): 34 | warnings.filterwarnings("ignore", category=PDBConstructionWarning) 35 | structure = biopython_parser.get_structure('random_id', rec_path) 36 | rec = structure[0] 37 | min_distances = [] 38 | coords = [] 39 | valid_chain_ids = [] 40 | lengths = [] 41 | for i, chain in enumerate(rec): 42 | chain_coords = [] # num_residues, num_atoms, 3 43 | chain_is_water = False 44 | count = 0 45 | invalid_res_ids = [] 46 | for res_idx, residue in enumerate(chain): 47 | if residue.get_resname() == 'HOH': 48 | chain_is_water = True 49 | residue_coords = [] 50 | c_alpha, n, c = None, None, None 51 | for atom in residue: 52 | if atom.name == 'CA': 53 | c_alpha = list(atom.get_vector()) 54 | if atom.name == 'N': 55 | n = list(atom.get_vector()) 56 | if atom.name == 'C': 57 | c = list(atom.get_vector()) 58 | residue_coords.append(list(atom.get_vector())) 59 | if c_alpha != None and n != None and c != None: # only append residue if it is an amino acid and not some weired molecule that is part of the complex 60 | chain_coords.append(np.array(residue_coords)) 61 | count += 1 62 | else: 63 | invalid_res_ids.append(residue.get_id()) 64 | for res_id in invalid_res_ids: 65 | chain.detach_child(res_id) 66 | if len(chain_coords) > 0: 67 | all_chain_coords = np.concatenate(chain_coords, axis=0) 68 | distances = spatial.distance.cdist(lig_coords, all_chain_coords) 69 | min_distance = distances.min() 70 | else: 71 | min_distance = np.inf 72 | if chain_is_water: 73 | min_distances.append(np.inf) 74 | else: 75 | min_distances.append(min_distance) 76 | lengths.append(count) 77 | coords.append(chain_coords) 78 | if min_distance < cutoff and not chain_is_water: 79 | valid_chain_ids.append(chain.get_id()) 80 | min_distances = np.array(min_distances) 81 | if len(valid_chain_ids) == 0: 82 | valid_chain_ids.append(np.argmin(min_distances)) 83 | valid_coords = [] 84 | valid_lengths = [] 85 | invalid_chain_ids = [] 86 | for i, chain in enumerate(rec): 87 | if chain.get_id() in valid_chain_ids: 88 | valid_coords.append(coords[i]) 89 | valid_lengths.append(lengths[i]) 90 | else: 91 | invalid_chain_ids.append(chain.get_id()) 92 | 93 | # Many thanks to Professor David Ryan Koes for spotting that the commented code only removes water and other chains while keeping the actual receptor chains. 94 | # While directly modifying the .pdb file as text file is an option, we can again follow Prof. Koes's excellent advice and use the prody library as in the code below. 95 | # io.set_structure(structure) 96 | # io.save(os.path.join(data_dir,name,f'{name}_protein_processed2.pdb')) 97 | prot = prody.parsePDB(rec_path) 98 | sel = prot.select(' or '.join(map(lambda c: f'chain {c}', valid_chain_ids))) 99 | prody.writePDB(os.path.join(data_dir,name,f'{name}_protein_processed2.pdb'),sel) -------------------------------------------------------------------------------- /datasets/custom_collate.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Tuple, List 3 | 4 | import dgl 5 | import torch 6 | 7 | from commons.geometry_utils import random_rotation_translation 8 | from commons.process_mols import lig_rec_graphs_to_complex_graph 9 | 10 | 11 | def graph_collate(batch): 12 | complex_graphs, ligs_coords, recs_coords, pockets_coords_lig, pockets_coords_rec,geometry_graph, complex_names, idx = map(list, zip(*batch)) 13 | geometry_graph = dgl.batch(geometry_graph) if geometry_graph[0] != None else None 14 | return dgl.batch(complex_graphs), ligs_coords, recs_coords, pockets_coords_lig, pockets_coords_rec,geometry_graph, complex_names, idx 15 | 16 | def graph_collate_revised(batch): 17 | lig_graphs, rec_graphs, ligs_coords, recs_coords,all_rec_coords, pockets_coords_lig,geometry_graph, complex_names, idx = map(list, zip(*batch)) 18 | geometry_graph = dgl.batch(geometry_graph) if geometry_graph[0] != None else None 19 | return dgl.batch(lig_graphs), dgl.batch(rec_graphs), ligs_coords, recs_coords, all_rec_coords, pockets_coords_lig,geometry_graph, complex_names, idx 20 | 21 | def torsion_collate(batch): 22 | lig_graphs, rec_graphs, angles, masks, ligs_coords, recs_coords,all_rec_coords, pockets_coords_lig,geometry_graph, complex_names, idx = map(list, zip(*batch)) 23 | geometry_graph = torch.cat(geometry_graph,dim=0) if geometry_graph[0] != None else None 24 | return dgl.batch(lig_graphs), dgl.batch(rec_graphs), torch.cat(angles,dim=0), torch.cat(masks, dim=0), ligs_coords, recs_coords, all_rec_coords, pockets_coords_lig,geometry_graph, complex_names, idx 25 | 26 | 27 | class AtomSubgraphCollate(object): 28 | def __init__(self, random_rec_atom_subgraph_radius=10): 29 | self.random_rec_atom_subgraph_radius = random_rec_atom_subgraph_radius 30 | def __call__(self, batch: List[Tuple]): 31 | lig_graphs, rec_graphs, ligs_coords, recs_coords, all_rec_coords, pockets_coords_lig,geometry_graph, complex_names, idx = map( 32 | list, zip(*batch)) 33 | 34 | rec_subgraphs = [] 35 | for i, (lig_graph, rec_graph) in enumerate(zip(lig_graphs,rec_graphs)): 36 | rot_T, rot_b = random_rotation_translation(translation_distance=2) 37 | translated_lig_coords = ligs_coords[i] + rot_b 38 | min_distances, _ = torch.cdist(rec_graph.ndata['x'], translated_lig_coords.to(rec_graph.ndata['x'].device)).min(dim=1) 39 | rec_subgraphs.append(dgl.node_subgraph(rec_graph, min_distances < self.random_rec_atom_subgraph_radius)) 40 | 41 | geometry_graph = dgl.batch(geometry_graph) if geometry_graph[0] != None else None 42 | 43 | return dgl.batch(lig_graphs), dgl.batch(rec_subgraphs), ligs_coords, recs_coords, all_rec_coords, pockets_coords_lig,geometry_graph, complex_names, idx 44 | 45 | class SubgraphAugmentationCollate(object): 46 | def __init__(self, min_shell_thickness=2): 47 | self.min_shell_thickness = min_shell_thickness 48 | def __call__(self, batch: List[Tuple]): 49 | lig_graphs, rec_graphs, ligs_coords, recs_coords, all_rec_coords, pockets_coords_lig,geometry_graph, complex_names, idx = map( 50 | list, zip(*batch)) 51 | 52 | rec_subgraphs = [] 53 | for lig_graph, rec_graph in zip(lig_graphs,rec_graphs): 54 | lig_centroid = lig_graph.ndata['x'].mean(dim=0) 55 | distances = torch.norm(rec_graph.ndata['x'] - lig_centroid, dim=1) 56 | max_distance = torch.max(distances) 57 | min_distance = torch.min(distances) 58 | radius = min_distance + self.min_shell_thickness + random.random() * (max_distance - min_distance- self.min_shell_thickness).abs() 59 | rec_subgraphs.append(dgl.node_subgraph(rec_graph, distances <= radius)) 60 | geometry_graph = dgl.batch(geometry_graph) if geometry_graph[0] != None else None 61 | 62 | return dgl.batch(lig_graphs), dgl.batch(rec_subgraphs), ligs_coords, recs_coords, all_rec_coords, pockets_coords_lig,geometry_graph, complex_names, idx -------------------------------------------------------------------------------- /datasets/multiple_ligands.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from commons.process_mols import get_geometry_graph, get_lig_graph_revised, get_rdkit_coords 3 | from dgl import batch 4 | from rdkit.Chem import SDMolSupplier, SanitizeMol, SanitizeFlags, PropertyMol, SmilesMolSupplier, AddHs 5 | 6 | 7 | class Ligands(Dataset): 8 | def __init__(self, ligpath, rec_graph, args, lazy = None, slice = None, skips = None, ext = None, addH = None, rdkit_seed = None): 9 | self.ligpath = ligpath 10 | self.rec_graph = rec_graph 11 | self.args = args 12 | self.dp = args.dataset_params 13 | self.use_rdkit_coords = args.use_rdkit_coords 14 | self.device = args.device 15 | self.rdkit_seed = rdkit_seed 16 | 17 | ##Default argument handling 18 | self.skips = skips if skips is not None else set() 19 | 20 | extensions_requiring_conformer_generation = ["smi"] 21 | extensions_defaulting_to_lazy = ["smi"] 22 | 23 | if ext is None: 24 | try: 25 | ext = ligpath.split(".")[-1] 26 | except (AttributeError, KeyError): 27 | ext = "sdf" 28 | 29 | 30 | if lazy is None: 31 | if ext in extensions_defaulting_to_lazy: 32 | self.lazy = True 33 | else: 34 | self.lazy = False 35 | else: 36 | self.lazy = lazy 37 | 38 | if addH is None: 39 | if ext == "smi": 40 | addH = True 41 | else: 42 | addH = False 43 | self.addH = addH 44 | 45 | self.generate_conformer = ext in extensions_requiring_conformer_generation 46 | 47 | suppliers = {"sdf": SDMolSupplier, "smi": SmilesMolSupplier} 48 | supp_kwargs = {"sdf": dict(sanitize = False, removeHs = False), 49 | "smi": dict(sanitize = False)} 50 | self.supplier = suppliers[ext](ligpath, **supp_kwargs[ext]) 51 | 52 | if slice is None: 53 | self.slice = 0, len(self.supplier) 54 | else: 55 | slice = (slice[0] if slice[0] >= 0 else len(self.supplier)+slice[0], slice[1] if slice[1] >= 0 else len(self.supplier)+slice[1]) 56 | self.slice = tuple(slice) 57 | 58 | self.failed_ligs = [] 59 | self.true_idx = [] 60 | 61 | if not self.lazy: 62 | self.ligs = [] 63 | for i in range(*self.slice): 64 | if i in self.skips: 65 | continue 66 | lig = self.supplier[i] 67 | lig, name = self._process(lig) 68 | if lig is not None: 69 | self.ligs.append(PropertyMol.PropertyMol(lig)) 70 | self.true_idx.append(i) 71 | else: 72 | self.failed_ligs.append((i, name)) 73 | 74 | if self.lazy: 75 | self._len = self.slice[1]-self.slice[0] 76 | else: 77 | self._len = len(self.ligs) 78 | 79 | def _process(self, lig): 80 | if lig is None: 81 | return None, None 82 | if self.addH: 83 | lig = AddHs(lig) 84 | if self.generate_conformer: 85 | get_rdkit_coords(lig, self.rdkit_seed) 86 | sanitize_succeded = (SanitizeMol(lig, catchErrors = True) is SanitizeFlags.SANITIZE_NONE) 87 | if sanitize_succeded: 88 | return lig, lig.GetProp("_Name") 89 | else: 90 | return None, lig.GetProp("_Name") 91 | 92 | def __len__(self): 93 | return self._len 94 | 95 | def __getitem__(self, idx): 96 | if self.lazy: 97 | if idx < 0: 98 | nonneg_idx = self._len + idx 99 | else: 100 | nonneg_idx = idx 101 | 102 | if nonneg_idx >= self._len or nonneg_idx < 0: 103 | raise IndexError(f"Index {idx} out of range for Ligands dataset with length {len(self)}") 104 | 105 | 106 | true_index = nonneg_idx + self.slice[0] 107 | if true_index in self.skips: 108 | return true_index, "Skipped" 109 | lig = self.supplier[true_index] 110 | lig, name = self._process(lig) 111 | if lig is not None: 112 | lig = PropertyMol.PropertyMol(lig) 113 | else: 114 | self.failed_ligs.append((true_index, name)) 115 | return true_index, name 116 | elif not self.lazy: 117 | lig = self.ligs[idx] 118 | true_index = self.true_idx[idx] 119 | 120 | 121 | try: 122 | lig_graph = get_lig_graph_revised(lig, lig.GetProp('_Name'), max_neighbors=self.dp['lig_max_neighbors'], 123 | use_rdkit_coords=self.use_rdkit_coords, radius=self.dp['lig_graph_radius']) 124 | except AssertionError: 125 | self.failed_ligs.append((true_index, lig.GetProp("_Name"))) 126 | return true_index, lig.GetProp("_Name") 127 | 128 | geometry_graph = get_geometry_graph(lig) if self.dp['geometry_regularization'] else None 129 | 130 | lig_graph.ndata["new_x"] = lig_graph.ndata["x"] 131 | return lig, lig_graph.ndata["new_x"], lig_graph, self.rec_graph, geometry_graph, true_index 132 | 133 | @staticmethod 134 | def collate(_batch): 135 | sample_succeeded = lambda sample: not isinstance(sample[0], int) 136 | sample_failed = lambda sample: isinstance(sample[0], int) 137 | clean_batch = tuple(filter(sample_succeeded, _batch)) 138 | failed_in_batch = tuple(filter(sample_failed, _batch)) 139 | if len(clean_batch) == 0: 140 | return None, None, None, None, None, None, failed_in_batch 141 | ligs, lig_coords, lig_graphs, rec_graphs, geometry_graphs, true_indices = map(list, zip(*clean_batch)) 142 | output = ( 143 | ligs, 144 | lig_coords, 145 | batch(lig_graphs), 146 | batch(rec_graphs), 147 | batch(geometry_graphs) if geometry_graphs[0] is not None else None, 148 | true_indices, 149 | failed_in_batch 150 | ) 151 | return output 152 | -------------------------------------------------------------------------------- /datasets/pdbbind.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from copy import deepcopy, copy 4 | 5 | from dgl import save_graphs, load_graphs 6 | 7 | from joblib import Parallel, delayed, cpu_count 8 | import torch 9 | import dgl 10 | from biopandas.pdb import PandasPdb 11 | from joblib.externals.loky import get_reusable_executor 12 | 13 | from rdkit import Chem 14 | from rdkit.Chem import MolFromPDBFile 15 | from rdkit.Chem.rdmolops import GetAdjacencyMatrix 16 | from torch.utils.data import Dataset 17 | import numpy as np 18 | import pandas as pd 19 | from tqdm import tqdm 20 | import torch.nn.functional as F 21 | 22 | from commons.geometry_utils import random_rotation_translation, rigid_transform_Kabsch_3D_torch 23 | from commons.process_mols import get_rdkit_coords, get_receptor, get_pocket_coords, \ 24 | read_molecule, get_rec_graph, get_lig_graph_revised, get_receptor_atom_subgraph, get_lig_structure_graph, \ 25 | get_geometry_graph, get_lig_graph_multiple_conformer, get_geometry_graph_ring 26 | from commons.utils import pmap_multi, read_strings_from_txt, log 27 | 28 | 29 | class PDBBind(Dataset): 30 | """""" 31 | 32 | def __init__(self, device='cuda:0', 33 | complex_names_path='data/', 34 | bsp_proteins=False, 35 | bsp_ligands=False, 36 | pocket_cutoff=8.0, 37 | use_rec_atoms=False, 38 | n_jobs=None, 39 | chain_radius=7, 40 | c_alpha_max_neighbors=10, 41 | lig_max_neighbors=20, 42 | translation_distance=5.0, 43 | lig_graph_radius=30, 44 | rec_graph_radius=30, 45 | surface_max_neighbors=5, 46 | surface_graph_cutoff=5, 47 | surface_mesh_cutoff=1.7, 48 | deep_bsp_preprocessing=True, 49 | only_polar_hydrogens=False, 50 | use_rdkit_coords=False, 51 | pocket_mode='match_terminal_atoms', 52 | dataset_size=None, 53 | remove_h=False, 54 | rec_subgraph=False, 55 | is_train_data=False, 56 | min_shell_thickness=2, 57 | subgraph_radius=10, 58 | subgraph_max_neigbor=8, 59 | subgraph_cutoff=4, 60 | lig_structure_graph= False, 61 | random_rec_atom_subgraph= False, 62 | subgraph_augmentation=False, 63 | lig_predictions_name=None, 64 | geometry_regularization= False, 65 | multiple_rdkit_conformers = False, 66 | random_rec_atom_subgraph_radius= 10, 67 | geometry_regularization_ring= False, 68 | num_confs=10, 69 | transform=None, **kwargs): 70 | # subset name is either 'pdbbind_filtered' or 'casf_test' 71 | self.chain_radius = chain_radius 72 | self.pdbbind_dir = 'data/PDBBind' 73 | self.bsp_dir = 'data/deepBSP' 74 | self.only_polar_hydrogens = only_polar_hydrogens 75 | self.complex_names_path = complex_names_path 76 | self.pocket_cutoff = pocket_cutoff 77 | self.use_rec_atoms = use_rec_atoms 78 | self.deep_bsp_preprocessing = deep_bsp_preprocessing 79 | self.device = device 80 | self.lig_graph_radius = lig_graph_radius 81 | self.rec_graph_radius = rec_graph_radius 82 | self.surface_max_neighbors = surface_max_neighbors 83 | self.surface_graph_cutoff = surface_graph_cutoff 84 | self.surface_mesh_cutoff = surface_mesh_cutoff 85 | self.dataset_size = dataset_size 86 | self.c_alpha_max_neighbors = c_alpha_max_neighbors 87 | self.lig_max_neighbors = lig_max_neighbors 88 | self.n_jobs = cpu_count() - 1 if n_jobs == None else n_jobs 89 | self.translation_distance = translation_distance 90 | self.pocket_mode = pocket_mode 91 | self.use_rdkit_coords = use_rdkit_coords 92 | self.bsp_proteins = bsp_proteins 93 | self.bsp_ligands = bsp_ligands 94 | self.remove_h = remove_h 95 | self.is_train_data = is_train_data 96 | self.subgraph_augmentation = subgraph_augmentation 97 | self.min_shell_thickness = min_shell_thickness 98 | self.rec_subgraph = rec_subgraph 99 | self.subgraph_radius = subgraph_radius 100 | self.subgraph_max_neigbor=subgraph_max_neigbor 101 | self.subgraph_cutoff=subgraph_cutoff 102 | self.random_rec_atom_subgraph = random_rec_atom_subgraph 103 | self.lig_structure_graph =lig_structure_graph 104 | self.random_rec_atom_subgraph_radius = random_rec_atom_subgraph_radius 105 | self.lig_predictions_name = lig_predictions_name 106 | self.geometry_regularization = geometry_regularization 107 | self.geometry_regularization_ring = geometry_regularization_ring 108 | self.multiple_rdkit_conformers = multiple_rdkit_conformers 109 | self.num_confs = num_confs 110 | self.conformer_id = 0 111 | if self.lig_predictions_name ==None: 112 | self.rec_subgraph_path = f'rec_subgraphs_cutoff{self.subgraph_cutoff}_radius{self.subgraph_radius}_maxNeigh{self.subgraph_max_neigbor}.pt' 113 | else: 114 | self.rec_subgraph_path = f'rec_subgraphs_cutoff{self.subgraph_cutoff}_radius{self.subgraph_radius}_maxNeigh{self.subgraph_max_neigbor}_{self.lig_predictions_name}' 115 | 116 | self.processed_dir = f'data/processed/size{self.dataset_size}_INDEX{os.path.splitext(os.path.basename(self.complex_names_path))[0]}_Hpolar{int(self.only_polar_hydrogens)}_H{int(not self.remove_h)}_BSPprot{int(self.bsp_proteins)}_BSPlig{int(self.bsp_ligands)}_surface{int(self.use_rec_atoms)}_pocketRad{self.pocket_cutoff}_ligRad{self.lig_graph_radius}_recRad{self.rec_graph_radius}_recMax{self.c_alpha_max_neighbors}_ligMax{self.lig_max_neighbors}_chain{self.chain_radius}_POCKET{self.pocket_mode}' 117 | print(f'using processed directory: {self.processed_dir}') 118 | if self.use_rdkit_coords: 119 | self.lig_graph_path = 'lig_graphs_rdkit_coords.pt' 120 | else: 121 | self.lig_graph_path = 'lig_graphs.pt' 122 | if self.multiple_rdkit_conformers: 123 | self.lig_graph_path = 'lig_graphs_rdkit_multiple_conformers.pt' 124 | if not os.path.exists('data/processed/'): 125 | os.mkdir('data/processed/') 126 | if (not os.path.exists(os.path.join(self.processed_dir, 'geometry_regularization.pt')) and self.geometry_regularization) or (not os.path.exists(os.path.join(self.processed_dir, 'geometry_regularization_ring.pt')) and self.geometry_regularization_ring) or not os.path.exists(os.path.join(self.processed_dir, 'rec_graphs.pt')) or not os.path.exists(os.path.join(self.processed_dir, 'pocket_and_rec_coords.pt')) or not os.path.exists(os.path.join(self.processed_dir, self.lig_graph_path)) or (not os.path.exists(os.path.join(self.processed_dir, self.rec_subgraph_path)) and self.rec_subgraph) or (not os.path.exists(os.path.join(self.processed_dir, 'lig_structure_graphs.pt')) and self.lig_structure_graph): 127 | self.process() 128 | log('loading data into memory') 129 | coords_dict = torch.load(os.path.join(self.processed_dir, 'pocket_and_rec_coords.pt')) 130 | self.pockets_coords = coords_dict['pockets_coords'] 131 | self.lig_graphs, _ = load_graphs(os.path.join(self.processed_dir, self.lig_graph_path)) 132 | if self.multiple_rdkit_conformers: 133 | self.lig_graphs = [self.lig_graphs[i:i + self.num_confs] for i in range(0, len(self.lig_graphs), self.num_confs)] 134 | self.rec_graphs, _ = load_graphs(os.path.join(self.processed_dir, 'rec_graphs.pt')) 135 | if self.rec_subgraph: 136 | self.rec_atom_subgraphs, _ = load_graphs(os.path.join(self.processed_dir, self.rec_subgraph_path)) 137 | if self.lig_structure_graph: 138 | self.lig_structure_graphs, _ = load_graphs(os.path.join(self.processed_dir, 'lig_structure_graphs.pt')) 139 | masks_angles = torch.load(os.path.join(self.processed_dir, 'torsion_masks_and_angles.pt')) 140 | self.angles = masks_angles['angles'] 141 | self.masks = masks_angles['masks'] 142 | if self.geometry_regularization: 143 | print(os.path.join(self.processed_dir, 'geometry_regularization.pt')) 144 | self.geometry_graphs, _ = load_graphs(os.path.join(self.processed_dir, 'geometry_regularization.pt')) 145 | if self.geometry_regularization_ring: 146 | print(os.path.join(self.processed_dir, 'geometry_regularization_ring.pt')) 147 | self.geometry_graphs, _ = load_graphs(os.path.join(self.processed_dir, 'geometry_regularization_ring.pt')) 148 | self.complex_names = coords_dict['complex_names'] 149 | assert len(self.lig_graphs) == len(self.rec_graphs) 150 | log('finish loading data into memory') 151 | self.cache = {} 152 | 153 | 154 | def __len__(self): 155 | return len(self.lig_graphs) 156 | 157 | def __getitem__(self, idx): 158 | pocket_coords = self.pockets_coords[idx] 159 | if self.lig_structure_graph: 160 | lig_graph = deepcopy(self.lig_structure_graphs[idx]) 161 | else: 162 | if self.multiple_rdkit_conformers: 163 | lig_graph = deepcopy(self.lig_graphs[idx][self.conformer_id]) 164 | else: 165 | lig_graph = deepcopy(self.lig_graphs[idx]) 166 | lig_coords = lig_graph.ndata['x'] 167 | rec_graph = self.rec_graphs[idx] 168 | 169 | # Randomly rotate and translate the ligand. 170 | rot_T, rot_b = random_rotation_translation(translation_distance=self.translation_distance) 171 | if self.use_rdkit_coords: 172 | lig_coords_to_move =lig_graph.ndata['new_x'] 173 | else: 174 | lig_coords_to_move = lig_coords 175 | mean_to_remove = lig_coords_to_move.mean(dim=0, keepdims=True) 176 | lig_graph.ndata['new_x'] = (rot_T @ (lig_coords_to_move - mean_to_remove).T).T + rot_b 177 | new_pocket_coords = (rot_T @ (pocket_coords - mean_to_remove).T).T + rot_b 178 | 179 | if self.subgraph_augmentation and self.is_train_data: 180 | with torch.no_grad(): 181 | if idx in self.cache: 182 | max_distance, min_distance, distances = self.cache[idx] 183 | else: 184 | lig_centroid = lig_graph.ndata['x'].mean(dim=0) 185 | distances = torch.norm(rec_graph.ndata['x'] - lig_centroid, dim=1) 186 | max_distance = torch.max(distances) 187 | min_distance = torch.min(distances) 188 | self.cache[idx] = (min_distance.item(), max_distance.item(), distances) 189 | radius = min_distance + self.min_shell_thickness + random.random() * abs(( 190 | max_distance - min_distance - self.min_shell_thickness)) 191 | rec_graph = dgl.node_subgraph(rec_graph, distances <= radius) 192 | assert rec_graph.num_nodes() > 0 193 | if self.rec_subgraph: 194 | rec_graph = self.rec_atom_subgraphs[idx] 195 | if self.random_rec_atom_subgraph: 196 | rot_T, rot_b = random_rotation_translation(translation_distance=2) 197 | translated_lig_coords = lig_coords + rot_b 198 | min_distances, _ = torch.cdist(rec_graph.ndata['x'],translated_lig_coords).min(dim=1) 199 | rec_graph = dgl.node_subgraph(rec_graph, min_distances < self.random_rec_atom_subgraph_radius) 200 | assert rec_graph.num_nodes() > 0 201 | 202 | geometry_graph = self.geometry_graphs[idx] if self.geometry_regularization or self.geometry_regularization_ring else None 203 | if self.lig_structure_graph: 204 | return lig_graph.to(self.device), rec_graph.to(self.device), self.masks[idx], self.angles[idx], lig_coords, rec_graph.ndata['x'], new_pocket_coords, pocket_coords,geometry_graph, self.complex_names[idx], idx 205 | else: 206 | return lig_graph.to(self.device), rec_graph.to(self.device), lig_coords, rec_graph.ndata['x'], new_pocket_coords, pocket_coords, geometry_graph, self.complex_names[idx], idx 207 | 208 | def process(self): 209 | log(f'Processing complexes from [{self.complex_names_path}] and saving it to [{self.processed_dir}]') 210 | 211 | complex_names = read_strings_from_txt(self.complex_names_path) 212 | if self.dataset_size != None: 213 | complex_names = complex_names[:self.dataset_size] 214 | if (self.remove_h or self.only_polar_hydrogens) and '4acu' in complex_names: 215 | complex_names.remove('4acu') # in this complex's ligand the hydrogens cannot be removed 216 | log(f'Loading {len(complex_names)} complexes.') 217 | ligs = [] 218 | to_remove = [] 219 | for name in tqdm(complex_names, desc='loading ligands'): 220 | if self.bsp_ligands: 221 | lig = read_molecule(os.path.join(self.bsp_dir, name, f'Lig_native.pdb'), sanitize=True, remove_hs=self.remove_h) 222 | if lig == None: 223 | to_remove.append(name) 224 | continue 225 | else: 226 | lig = read_molecule(os.path.join(self.pdbbind_dir, name, f'{name}_ligand.sdf'), sanitize=True, 227 | remove_hs=self.remove_h) 228 | if lig == None: # read mol2 file if sdf file cannot be sanitized 229 | lig = read_molecule(os.path.join(self.pdbbind_dir, name, f'{name}_ligand.mol2'), sanitize=True, 230 | remove_hs=self.remove_h) 231 | if self.only_polar_hydrogens: 232 | for atom in lig.GetAtoms(): 233 | if atom.GetAtomicNum() == 1 and [x.GetAtomicNum() for x in atom.GetNeighbors()] == [6]: 234 | atom.SetAtomicNum(0) 235 | lig = Chem.DeleteSubstructs(lig, Chem.MolFromSmarts('[#0]')) 236 | Chem.SanitizeMol(lig) 237 | ligs.append(lig) 238 | for name in to_remove: 239 | complex_names.remove(name) 240 | 241 | if self.bsp_proteins: 242 | rec_paths = [os.path.join(self.bsp_dir, name, f'Rec.pdb') for name in complex_names] 243 | else: 244 | rec_paths = [os.path.join(self.pdbbind_dir, name, f'{name}_protein_processed.pdb') for name in 245 | complex_names] 246 | 247 | if not os.path.exists(self.processed_dir): 248 | os.mkdir(self.processed_dir) 249 | 250 | if not os.path.exists(os.path.join(self.processed_dir, 'rec_graphs.pt')) or not os.path.exists(os.path.join(self.processed_dir, 'pocket_and_rec_coords.pt')) or (not os.path.exists(os.path.join(self.processed_dir, self.rec_subgraph_path)) and self.rec_subgraph): 251 | log('Get receptors, filter chains, and get its coordinates') 252 | receptor_representatives = pmap_multi(get_receptor, zip(rec_paths, ligs), n_jobs=self.n_jobs, cutoff=self.chain_radius, desc='Get receptors') 253 | recs, recs_coords, c_alpha_coords, n_coords, c_coords = map(list, zip(*receptor_representatives)) 254 | # rec coords is a list with n_residues many arrays of shape: [n_atoms_in_residue, 3] 255 | 256 | 257 | if not os.path.exists(os.path.join(self.processed_dir, 'pocket_and_rec_coords.pt')): 258 | log('Get Pocket Coordinates') 259 | pockets_coords = pmap_multi(get_pocket_coords, zip(ligs, recs_coords), n_jobs=self.n_jobs, 260 | cutoff=self.pocket_cutoff, pocket_mode=self.pocket_mode, 261 | desc='Get pocket coords') 262 | recs_coords_concat = [torch.tensor(np.concatenate(rec_coords, axis=0)) for rec_coords in recs_coords] 263 | torch.save({'pockets_coords': pockets_coords, 264 | 'all_rec_coords': recs_coords_concat, 265 | # coords of all atoms and not only those included in graph 266 | 'complex_names': complex_names, 267 | }, os.path.join(self.processed_dir, 'pocket_and_rec_coords.pt')) 268 | else: 269 | log('pocket_and_rec_coords.pt already exists. Using those instead of creating new ones.') 270 | 271 | if not os.path.exists(os.path.join(self.processed_dir, 'rec_graphs.pt')): 272 | log('Get receptor Graphs') 273 | rec_graphs = pmap_multi(get_rec_graph, 274 | zip(recs, recs_coords, c_alpha_coords, n_coords, c_coords), n_jobs=self.n_jobs, 275 | use_rec_atoms=self.use_rec_atoms, rec_radius=self.rec_graph_radius, 276 | surface_max_neighbors=self.surface_max_neighbors, 277 | surface_graph_cutoff=self.surface_graph_cutoff, 278 | surface_mesh_cutoff=self.surface_mesh_cutoff, 279 | c_alpha_max_neighbors=self.c_alpha_max_neighbors, 280 | desc='Convert receptors to graphs') 281 | save_graphs(os.path.join(self.processed_dir, 'rec_graphs.pt'), rec_graphs) 282 | else: 283 | log('rec_graphs.pt already exists. Using those instead of creating new ones.') 284 | log('Done converting to graphs') 285 | 286 | if self.lig_predictions_name != None: 287 | ligs_coords = torch.load(os.path.join('data/processed', self.lig_predictions_name))['predictions'][:len(ligs)] 288 | else: 289 | ligs_coords = [None] * len(ligs) 290 | if self.rec_subgraph and not os.path.exists(os.path.join(self.processed_dir, self.rec_subgraph_path)): 291 | log('Get receptor subgraphs') 292 | rec_subgraphs = pmap_multi(get_receptor_atom_subgraph, 293 | zip(recs, recs_coords, ligs, ligs_coords), n_jobs=self.n_jobs, 294 | max_neighbor=self.subgraph_max_neigbor, subgraph_radius=self.subgraph_radius, 295 | graph_cutoff=self.subgraph_cutoff, 296 | desc='get receptor subgraphs') 297 | save_graphs(os.path.join(self.processed_dir, self.rec_subgraph_path), rec_subgraphs) 298 | else: 299 | log(os.path.join(self.processed_dir, self.rec_subgraph_path), ' already exists. Using those instead of creating new ones.') 300 | log('Done creating receptor subgraphs') 301 | 302 | if not os.path.exists(os.path.join(self.processed_dir, self.lig_graph_path)): 303 | log('Convert ligands to graphs') 304 | if self.multiple_rdkit_conformers: 305 | lig_graphs = pmap_multi(get_lig_graph_multiple_conformer, zip(ligs,complex_names), n_jobs=self.n_jobs, 306 | max_neighbors=self.lig_max_neighbors, use_rdkit_coords=self.use_rdkit_coords, 307 | radius=self.lig_graph_radius, num_confs=self.num_confs, desc='Convert ligands to graphs') 308 | lig_graphs = [item for sublist in lig_graphs for item in sublist] 309 | else: 310 | lig_graphs = pmap_multi(get_lig_graph_revised, zip(ligs,complex_names), n_jobs=self.n_jobs, 311 | max_neighbors=self.lig_max_neighbors, use_rdkit_coords=self.use_rdkit_coords, 312 | radius=self.lig_graph_radius, desc='Convert ligands to graphs') 313 | 314 | save_graphs(os.path.join(self.processed_dir, self.lig_graph_path), lig_graphs) 315 | else: 316 | log('lig_graphs.pt already exists. Using those instead of creating new ones.') 317 | 318 | if not os.path.exists(os.path.join(self.processed_dir, 'lig_structure_graphs.pt')) and self.lig_structure_graph: 319 | log('Convert ligands to graphs') 320 | graphs_masks_angles = pmap_multi(get_lig_structure_graph, zip(ligs), n_jobs=self.n_jobs, desc='Get ligand structure graphs with angle information') 321 | graphs, masks, angles = map(list, zip(*graphs_masks_angles)) 322 | torch.save({'masks': masks, 323 | 'angles': angles, 324 | }, os.path.join(self.processed_dir, 'torsion_masks_and_angles.pt')) 325 | save_graphs(os.path.join(self.processed_dir, 'lig_structure_graphs.pt'), graphs) 326 | else: 327 | log('lig_structure_graphs.pt already exists or is not needed.') 328 | 329 | if not os.path.exists(os.path.join(self.processed_dir, 'geometry_regularization.pt')): 330 | log('Convert ligands to geometry graph') 331 | geometry_graphs = [get_geometry_graph(lig) for lig in ligs] 332 | save_graphs(os.path.join(self.processed_dir, 'geometry_regularization.pt'), geometry_graphs) 333 | else: 334 | log('geometry_regularization.pt already exists or is not needed.') 335 | 336 | if not os.path.exists(os.path.join(self.processed_dir, 'geometry_regularization_ring.pt')): 337 | log('Convert ligands to geometry graph') 338 | geometry_graphs = [get_geometry_graph_ring(lig) for lig in ligs] 339 | save_graphs(os.path.join(self.processed_dir, 'geometry_regularization_ring.pt'), geometry_graphs) 340 | else: 341 | log('geometry_regularization.pt already exists or is not needed.') 342 | 343 | get_reusable_executor().shutdown(wait=True) 344 | -------------------------------------------------------------------------------- /datasets/samplers.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from copy import copy, deepcopy 3 | from typing import List, Optional 4 | 5 | import torch 6 | from torch.distributions import Categorical 7 | from torch.utils.data import Sampler, RandomSampler, Subset, Dataset 8 | from tqdm import tqdm 9 | 10 | 11 | class HardSampler(Sampler[List[int]]): 12 | def __init__(self, data_source: Dataset, batch_size: int, valid_indices=None, replacement: bool = False, num_hard_samples=2, 13 | num_samples: Optional[int] = None, generator=None, drop_last=False) -> None: 14 | super(Sampler, self).__init__() 15 | self.data_source = data_source 16 | self.valid_indices = valid_indices 17 | self.num_hard_samples = num_hard_samples 18 | self.standard_sampler = RandomSampler(data_source=Subset(self.data_source, valid_indices), replacement=replacement, 19 | num_samples=num_samples, 20 | generator=generator) 21 | self.current_hard_indices = range(len(self.data_source)) 22 | self.next_hard_indices = [] 23 | self.batch_size = batch_size 24 | self.drop_last = drop_last 25 | 26 | def __iter__(self): 27 | batch = [] 28 | for idx in self.standard_sampler: 29 | if len(batch)<= self.num_hard_samples and len(self.current_hard_indices) >= 0: 30 | batch.append(self.current_hard_indices[torch.randint(low=0, high=len(self.current_hard_indices),size=(1,))]) 31 | else: 32 | batch.append(idx) 33 | if len(batch) == self.batch_size: 34 | yield batch 35 | batch = [] 36 | if len(batch) > 0 and not self.drop_last: 37 | yield batch 38 | 39 | def add_hard_indices(self, indices): 40 | self.next_hard_indices.extend(indices) 41 | 42 | def set_hard_indices(self): 43 | self.current_hard_indices = deepcopy(self.next_hard_indices) 44 | self.next_hard_indices = [] 45 | 46 | def __len__(self): 47 | # Can only be called if self.standard_sampler has __len__ implemented 48 | # We cannot enforce this condition, so we turn off typechecking for the 49 | # implementation below. 50 | # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] 51 | 52 | return len(self.standard_sampler) # type: ignore 53 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: equibind 2 | channels: 3 | - conda-forge 4 | - defaults 5 | - dglteam 6 | - pytorch 7 | dependencies: 8 | - python=3.7 9 | - pytorch 1.10 10 | - torchvision 11 | - cudatoolkit=10.2 12 | - torchaudio 13 | - dgl-cuda10.2 14 | - rdkit 15 | - openbabel 16 | - biopython 17 | - rdkit 18 | - biopandas 19 | - pot 20 | - dgllife 21 | - joblib 22 | - pyaml 23 | - icecream 24 | - matplotlib 25 | - tensorboard 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /environment_cpuonly.yml: -------------------------------------------------------------------------------- 1 | name: equibind 2 | channels: 3 | - conda-forge 4 | - defaults 5 | - dglteam 6 | - pytorch 7 | dependencies: 8 | - python=3.7 9 | - pytorch 1.10 10 | - torchvision 11 | - cpuonly 12 | - torchaudio 13 | - dgl 14 | - rdkit 15 | - openbabel 16 | - biopython 17 | - rdkit 18 | - biopandas 19 | - pot 20 | - dgllife 21 | - joblib 22 | - pyaml 23 | - icecream 24 | - matplotlib 25 | - tensorboard 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | from copy import deepcopy 5 | 6 | import os 7 | 8 | from dgl import load_graphs 9 | 10 | from rdkit import Chem 11 | from rdkit.Chem import RemoveHs 12 | from rdkit.Geometry import Point3D 13 | from tqdm import tqdm 14 | 15 | from commons.geometry_utils import rigid_transform_Kabsch_3D, get_torsions, get_dihedral_vonMises, apply_changes 16 | from commons.logger import Logger 17 | from commons.process_mols import read_molecule, get_lig_graph_revised, \ 18 | get_rec_graph, get_geometry_graph, get_geometry_graph_ring, \ 19 | get_receptor_inference 20 | 21 | from train import load_model 22 | 23 | from datasets.pdbbind import PDBBind 24 | 25 | from commons.utils import seed_all, read_strings_from_txt 26 | 27 | import yaml 28 | 29 | from datasets.custom_collate import * # do not remove 30 | from models import * # do not remove 31 | from torch.nn import * # do not remove 32 | from torch.optim import * # do not remove 33 | from commons.losses import * # do not remove 34 | from torch.optim.lr_scheduler import * # do not remove 35 | 36 | from torch.utils.data import DataLoader 37 | 38 | from trainer.metrics import Rsquared, MeanPredictorLoss, MAE, PearsonR, RMSD, RMSDfraction, CentroidDist, \ 39 | CentroidDistFraction, RMSDmedian, CentroidDistMedian 40 | 41 | # turn on for debugging C code like Segmentation Faults 42 | import faulthandler 43 | 44 | faulthandler.enable() 45 | 46 | 47 | def parse_arguments(arglist = None): 48 | p = argparse.ArgumentParser() 49 | p.add_argument('--config', type=argparse.FileType(mode='r'), default='configs_clean/inference.yml') 50 | p.add_argument('--checkpoint', type=str, help='path to .pt file in a checkpoint directory') 51 | p.add_argument('--output_directory', type=str, default=None, help='path where to put the predicted results') 52 | p.add_argument('--run_corrections', type=bool, default=False, 53 | help='whether or not to run the fast point cloud ligand fitting') 54 | p.add_argument('--run_dirs', type=list, default=[], help='path directory with saved runs') 55 | p.add_argument('--fine_tune_dirs', type=list, default=[], help='path directory with saved finetuning runs') 56 | p.add_argument('--inference_path', type=str, help='path to some pdb files for which you want to run inference') 57 | p.add_argument('--experiment_name', type=str, help='name that will be added to the runs folder output') 58 | p.add_argument('--logdir', type=str, default='runs', help='tensorboard logdirectory') 59 | p.add_argument('--num_epochs', type=int, default=2500, help='number of times to iterate through all samples') 60 | p.add_argument('--batch_size', type=int, default=1024, help='samples that will be processed in parallel') 61 | p.add_argument('--patience', type=int, default=20, help='stop training after no improvement in this many epochs') 62 | p.add_argument('--minimum_epochs', type=int, default=0, help='minimum numer of epochs to run') 63 | p.add_argument('--dataset_params', type=dict, default={}, 64 | help='parameters with keywords of the dataset') 65 | p.add_argument('--num_train', type=int, default=-1, help='n samples of the model samples to use for train') 66 | p.add_argument('--num_val', type=int, default=None, help='n samples of the model samples to use for validation') 67 | p.add_argument('--seed', type=int, default=1, help='seed for reproducibility') 68 | p.add_argument('--multithreaded_seeds', type=list, default=[], 69 | help='if this is non empty, multiple threads will be started, training the same model but with the different seeds') 70 | p.add_argument('--seed_data', type=int, default=1, help='if you want to use a different seed for the datasplit') 71 | p.add_argument('--loss_func', type=str, default='MSELoss', help='Class name of torch.nn like [MSELoss, L1Loss]') 72 | p.add_argument('--loss_params', type=dict, default={}, help='parameters with keywords of the chosen loss function') 73 | p.add_argument('--optimizer', type=str, default='Adam', help='Class name of torch.optim like [Adam, SGD, AdamW]') 74 | p.add_argument('--optimizer_params', type=dict, help='parameters with keywords of the chosen optimizer like lr') 75 | p.add_argument('--clip_grad', type=float, default=None, help='clip gradients if magnitude is greater') 76 | p.add_argument('--lr_scheduler', type=str, 77 | help='Class name of torch.optim.lr_scheduler like [CosineAnnealingLR, ExponentialLR, LambdaLR]') 78 | p.add_argument('--lr_scheduler_params', type=dict, help='parameters with keywords of the chosen lr_scheduler') 79 | p.add_argument('--scheduler_step_per_batch', default=True, type=bool, 80 | help='step every batch if true step every epoch otherwise') 81 | p.add_argument('--log_iterations', type=int, default=-1, 82 | help='log every log_iterations iterations (-1 for only logging after each epoch)') 83 | p.add_argument('--expensive_log_iterations', type=int, default=100, 84 | help='frequency with which to do expensive logging operations') 85 | p.add_argument('--eval_per_epochs', type=int, default=0, 86 | help='frequency with which to do run the function run_eval_per_epoch that can do some expensive calculations on the val set or sth like that. If this is zero, then the function will never be called') 87 | p.add_argument('--metrics', default=[], help='tensorboard metrics [mae, mae_denormalized, qm9_properties ...]') 88 | p.add_argument('--main_metric', default='loss', help='for early stopping etc.') 89 | p.add_argument('--main_metric_goal', type=str, default='min', help='controls early stopping. [max, min]') 90 | p.add_argument('--val_per_batch', type=bool, default=True, 91 | help='run evaluation every batch and then average over the eval results. When running the molhiv benchmark for example, this needs to be Fale because we need to evaluate on all val data at once since the metric is rocauc') 92 | p.add_argument('--tensorboard_functions', default=[], help='choices of the TENSORBOARD_FUNCTIONS in utils') 93 | p.add_argument('--num_epochs_local_only', type=int, default=1, 94 | help='when training with OptimalTransportTrainer, this specifies for how many epochs only the local predictions will get a loss') 95 | 96 | p.add_argument('--collate_function', default='graph_collate', help='the collate function to use for DataLoader') 97 | p.add_argument('--collate_params', type=dict, default={}, 98 | help='parameters with keywords of the chosen collate function') 99 | p.add_argument('--device', type=str, default='cuda', help='What device to train on: cuda or cpu') 100 | 101 | p.add_argument('--models_to_save', type=list, default=[], 102 | help='specify after which epochs to remember the best model') 103 | 104 | p.add_argument('--model_type', type=str, default='MPNN', help='Classname of one of the models in the models dir') 105 | p.add_argument('--model_parameters', type=dict, help='dictionary of model parameters') 106 | 107 | p.add_argument('--trainer', type=str, default='binding', help='') 108 | p.add_argument('--train_sampler', type=str, default=None, help='any of pytorchs samplers or a custom sampler') 109 | 110 | p.add_argument('--eval_on_test', type=bool, default=True, help='runs evaluation on test set if true') 111 | p.add_argument('--check_se3_invariance', type=bool, default=False, help='check it instead of generating files') 112 | p.add_argument('--num_confs', type=int, default=1, help='num_confs if using rdkit conformers') 113 | p.add_argument('--use_rdkit_coords', action="store_true", 114 | help='override the rkdit usage behavior of the used model') 115 | p.add_argument('--no_use_rdkit_coords', action="store_false", dest = "use_rdkit_coords", 116 | help='override the rkdit usage behavior of the used model') 117 | 118 | cmdline_parser = deepcopy(p) 119 | args = p.parse_args(arglist) 120 | clear_defaults = {key: argparse.SUPPRESS for key in args.__dict__} 121 | cmdline_parser.set_defaults(**clear_defaults) 122 | cmdline_parser._defaults = {} 123 | cmdline_args = cmdline_parser.parse_args(arglist) 124 | 125 | return args, cmdline_args 126 | 127 | 128 | def inference(args, tune_args=None): 129 | sys.stdout = Logger(logpath=os.path.join(os.path.dirname(args.checkpoint), f'inference.log'), syspart=sys.stdout) 130 | sys.stderr = Logger(logpath=os.path.join(os.path.dirname(args.checkpoint), f'inference.log'), syspart=sys.stderr) 131 | seed_all(args.seed) 132 | device = torch.device("cuda:0" if torch.cuda.is_available() and args.device == 'cuda' else "cpu") 133 | 134 | use_rdkit_coords = args.dataset_params[ 135 | 'use_rdkit_coords'] if 'use_rdkit_coords' in args.dataset_params.keys() else False 136 | args.dataset_params['multiple_rdkit_conformers'] = args.num_confs > 1 137 | args.dataset_params['num_confs'] = args.num_confs 138 | data = PDBBind(device=device, complex_names_path=args.test_names, **args.dataset_params) 139 | print('test size: ', len(data)) 140 | model = load_model(args, data_sample=data[0], device=device, save_trajectories=args.save_trajectories) 141 | print('trainable params in model: ', sum(p.numel() for p in model.parameters() if p.requires_grad)) 142 | batch_size = args.batch_size if args.dataset_params['use_rec_atoms'] == False else 2 143 | collate_function = globals()[args.collate_function] if args.collate_params == {} else globals()[ 144 | args.collate_function](**args.collate_params) 145 | loader = DataLoader(data, batch_size=batch_size, collate_fn=collate_function) 146 | 147 | checkpoint = torch.load(args.checkpoint, map_location=device) 148 | 149 | model.load_state_dict({k: v for k, v in checkpoint['model_state_dict'].items() if 'cross_coords' not in k}) 150 | model.load_state_dict(checkpoint['model_state_dict']) 151 | model.to(device) 152 | model.eval() 153 | 154 | for conformer_id in range(args.num_confs): 155 | all_ligs_coords_pred = [] 156 | all_ligs_coords = [] 157 | all_ligs_keypts = [] 158 | all_recs_keypts = [] 159 | all_pocket_coords = [] 160 | all_names = [] 161 | data.conformer_id = conformer_id 162 | for i, batch in tqdm(enumerate(loader)): 163 | with torch.no_grad(): 164 | lig_graphs, rec_graphs, ligs_coords, recs_coords, all_rec_coords, pockets_coords_lig, geometry_graph, names, idx = tuple( 165 | batch) 166 | # if names[0] not in ['2fxs', '2iwx', '2vw5', '2wer', '2yge', ]: continue 167 | ligs_coords_pred, ligs_keypts, recs_keypts, rotations, translations, geom_reg_loss = model(lig_graphs, 168 | rec_graphs, 169 | complex_names=names, 170 | epoch=0, 171 | geometry_graph=geometry_graph.to( 172 | device) if geometry_graph != None else None) 173 | for lig_coords_pred, lig_coords, lig_keypts, rec_keypts, rotation, translation, rec_pocket_coords in zip( 174 | ligs_coords_pred, ligs_coords, ligs_keypts, recs_keypts, rotations, translations, 175 | pockets_coords_lig): 176 | all_ligs_coords_pred.append(lig_coords_pred.detach().cpu()) 177 | all_ligs_coords.append(lig_coords.detach().cpu()) 178 | all_ligs_keypts.append(((rotation @ (lig_keypts).T).T + translation).detach().cpu()) 179 | all_recs_keypts.append(rec_keypts.detach().cpu()) 180 | all_pocket_coords.append(rec_pocket_coords.detach().cpu()) 181 | if translations == []: 182 | for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): 183 | all_ligs_coords_pred.append(lig_coords_pred.detach().cpu()) 184 | all_ligs_coords.append(lig_coords.detach().cpu()) 185 | all_names.extend(names) 186 | 187 | path = os.path.join(os.path.dirname(args.checkpoint), 188 | f'predictions_Tune{tune_args != None}_RDKit{use_rdkit_coords}_confID{conformer_id}.pt') 189 | print(f'Saving predictions to {path}') 190 | results = {'predictions': all_ligs_coords_pred, 'targets': all_ligs_coords, 'lig_keypts': all_ligs_keypts, 191 | 'rec_keypts': all_recs_keypts, 'pocket_coords': all_pocket_coords, 'names': all_names} 192 | torch.save(results, path) 193 | rmsds = [] 194 | centroid_distsH = [] 195 | for i, (prediction, target, lig_keypts, rec_keypts, pocket_coords, name) in tqdm(enumerate( 196 | zip(results['predictions'], results['targets'], results['lig_keypts'], results['rec_keypts'], 197 | results['pocket_coords'], results['names']))): 198 | coords_pred = prediction.numpy() 199 | coords_native = target.numpy() 200 | rmsd = np.sqrt(np.sum((coords_pred - coords_native) ** 2, axis=1).mean()) 201 | 202 | centroid_distance = np.linalg.norm(coords_native.mean(axis=0) - coords_pred.mean(axis=0)) 203 | centroid_distsH.append(centroid_distance) 204 | rmsds.append(rmsd) 205 | rmsds = np.array(rmsds) 206 | centroid_distsH = np.array(centroid_distsH) 207 | 208 | print('EquiBind-U with hydrogens inclduded in the loss') 209 | print('mean rmsd: ', rmsds.mean().__round__(2), ' pm ', rmsds.std().__round__(2)) 210 | print('rmsd precentiles: ', np.percentile(rmsds, [25, 50, 75]).round(2)) 211 | print(f'rmsds below 2: {(100 * (rmsds < 2).sum() / len(rmsds)).__round__(2)}%') 212 | print(f'rmsds below 5: {(100 * (rmsds < 5).sum() / len(rmsds)).__round__(2)}%') 213 | print('mean centroid: ', centroid_distsH.mean().__round__(2), ' pm ', 214 | centroid_distsH.std().__round__(2)) 215 | print('centroid precentiles: ', np.percentile(centroid_distsH, [25, 50, 75]).round(2)) 216 | print(f'centroid_distances below 2: {(100 * (centroid_distsH < 2).sum() / len(centroid_distsH)).__round__(2)}%') 217 | print(f'centroid_distances below 5: {(100 * (centroid_distsH < 5).sum() / len(centroid_distsH)).__round__(2)}%') 218 | 219 | if args.run_corrections: 220 | rdkit_graphs, _ = load_graphs( 221 | f'{data.processed_dir}/lig_graphs_rdkit_coords.pt') 222 | kabsch_rmsds = [] 223 | rmsds = [] 224 | centroid_distances = [] 225 | kabsch_rmsds_optimized = [] 226 | rmsds_optimized = [] 227 | centroid_distances_optimized = [] 228 | for i, (prediction, target, lig_keypts, rec_keypts, name) in tqdm(enumerate( 229 | zip(results['predictions'], results['targets'], results['lig_keypts'], results['rec_keypts'], 230 | results['names']))): 231 | lig = read_molecule(os.path.join('data/PDBBind/', name, f'{name}_ligand.sdf'), sanitize=True) 232 | if lig == None: # read mol2 file if sdf file cannot be sanitized 233 | lig = read_molecule(os.path.join('data/PDBBind/', name, f'{name}_ligand.mol2'), sanitize=True) 234 | 235 | lig_rdkit = deepcopy(lig) 236 | rdkit_coords = rdkit_graphs[i].ndata['new_x'].numpy() 237 | conf = lig_rdkit.GetConformer() 238 | for i in range(lig_rdkit.GetNumAtoms()): 239 | x, y, z = rdkit_coords[i] 240 | conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) 241 | 242 | lig_rdkit = RemoveHs(lig_rdkit) 243 | 244 | lig = RemoveHs(lig) 245 | lig_equibind = deepcopy(lig) 246 | conf = lig_equibind.GetConformer() 247 | for i in range(lig_equibind.GetNumAtoms()): 248 | x, y, z = prediction.numpy()[i] 249 | conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) 250 | 251 | coords_pred = lig_equibind.GetConformer().GetPositions() 252 | coords_native = lig.GetConformer().GetPositions() 253 | rmsdval = np.sqrt(np.sum((coords_pred - coords_native) ** 2, axis=1).mean()) 254 | centroid_distance = np.linalg.norm(coords_native.mean(axis=0) - coords_pred.mean(axis=0)) 255 | R, t = rigid_transform_Kabsch_3D(coords_pred.T, coords_native.T) 256 | moved_coords = (R @ (coords_pred).T).T + t.squeeze() 257 | kabsch_rmsd = np.sqrt(np.sum((moved_coords - coords_native) ** 2, axis=1).mean()) 258 | 259 | Z_pt_cloud = coords_pred 260 | rotable_bonds = get_torsions([lig_rdkit]) 261 | new_dihedrals = np.zeros(len(rotable_bonds)) 262 | for idx, r in enumerate(rotable_bonds): 263 | new_dihedrals[idx] = get_dihedral_vonMises(lig_rdkit, lig_rdkit.GetConformer(), r, Z_pt_cloud) 264 | optimized_mol = apply_changes(lig_rdkit, new_dihedrals, rotable_bonds) 265 | 266 | coords_pred_optimized = optimized_mol.GetConformer().GetPositions() 267 | R, t = rigid_transform_Kabsch_3D(coords_pred_optimized.T, coords_pred.T) 268 | coords_pred_optimized = (R @ (coords_pred_optimized).T).T + t.squeeze() 269 | 270 | rmsdval_optimized = np.sqrt(np.sum((coords_pred_optimized - coords_native) ** 2, axis=1).mean()) 271 | centroid_distance_optimized = np.linalg.norm( 272 | coords_native.mean(axis=0) - coords_pred_optimized.mean(axis=0)) 273 | R, t = rigid_transform_Kabsch_3D(coords_pred_optimized.T, coords_native.T) 274 | moved_coords_optimized = (R @ (coords_pred_optimized).T).T + t.squeeze() 275 | kabsch_rmsd_optimized = np.sqrt(np.sum((moved_coords_optimized - coords_native) ** 2, axis=1).mean()) 276 | kabsch_rmsds.append(kabsch_rmsd) 277 | rmsds.append(rmsdval) 278 | centroid_distances.append(centroid_distance) 279 | kabsch_rmsds_optimized.append(kabsch_rmsd_optimized) 280 | rmsds_optimized.append(rmsdval_optimized) 281 | centroid_distances_optimized.append(centroid_distance_optimized) 282 | kabsch_rmsds = np.array(kabsch_rmsds) 283 | rmsdvals = np.array(rmsds) 284 | centroid_distsU = np.array(centroid_distances) 285 | kabsch_rmsds_optimized = np.array(kabsch_rmsds_optimized) 286 | rmsd_optimized = np.array(rmsds_optimized) 287 | centroid_dists = np.array(centroid_distances_optimized) 288 | print('EquiBind-U') 289 | print('mean rmsdval: ', rmsdvals.mean().__round__(2), ' pm ', rmsdvals.std().__round__(2)) 290 | print('rmsd precentiles: ', np.percentile(rmsdvals, [25, 50, 75]).round(2)) 291 | print(f'rmsdvals below 2: {(100 * (rmsdvals < 2).sum() / len(rmsdvals)).__round__(2)}%') 292 | print(f'rmsdvals below 5: {(100 * (rmsdvals < 5).sum() / len(rmsdvals)).__round__(2)}%') 293 | print('mean centroid: ', centroid_distsU.mean().__round__(2), ' pm ', centroid_distsU.std().__round__(2)) 294 | print('centroid precentiles: ', np.percentile(centroid_distsU, [25, 50, 75]).round(2)) 295 | print(f'centroid dist below 2: {(100 * (centroid_distsU < 2).sum() / len(centroid_distsU)).__round__(2)}%') 296 | print(f'centroid dist below 5: {(100 * (centroid_distsU < 5).sum() / len(centroid_distsU)).__round__(2)}%') 297 | print(f'mean kabsch RMSD: ', kabsch_rmsds.mean().__round__(2), ' pm ', kabsch_rmsds.std().__round__(2)) 298 | print('kabsch RMSD percentiles: ', np.percentile(kabsch_rmsds, [25, 50, 75]).round(2)) 299 | 300 | print('EquiBind') 301 | print('mean rmsdval: ', rmsd_optimized.mean().__round__(2), ' pm ', rmsd_optimized.std().__round__(2)) 302 | print('rmsd precentiles: ', np.percentile(rmsd_optimized, [25, 50, 75]).round(2)) 303 | print(f'rmsdvals below 2: {(100 * (rmsd_optimized < 2).sum() / len(rmsd_optimized)).__round__(2)}%') 304 | print(f'rmsdvals below 5: {(100 * (rmsd_optimized < 5).sum() / len(rmsd_optimized)).__round__(2)}%') 305 | print('mean centroid: ', centroid_dists.mean().__round__(2), ' pm ', centroid_dists.std().__round__(2)) 306 | print('centroid precentiles: ', np.percentile(centroid_dists, [25, 50, 75]).round(2)) 307 | print(f'centroid dist below 2: {(100 * (centroid_dists < 2).sum() / len(centroid_dists)).__round__(2)}%') 308 | print(f'centroid dist below 5: {(100 * (centroid_dists < 5).sum() / len(centroid_dists)).__round__(2)}%') 309 | print(f'mean kabsch RMSD: ', kabsch_rmsds_optimized.mean().__round__(2), ' pm ', 310 | kabsch_rmsds_optimized.std().__round__(2)) 311 | print('kabsch RMSD percentiles: ', np.percentile(kabsch_rmsds_optimized, [25, 50, 75]).round(2)) 312 | 313 | 314 | def inference_from_files(args): 315 | seed_all(args.seed) 316 | device = torch.device("cuda:0" if torch.cuda.is_available() and args.device == 'cuda' else "cpu") 317 | checkpoint = torch.load(args.checkpoint, map_location=device) 318 | model = None 319 | all_ligs_coords_corrected = [] 320 | all_intersection_losses = [] 321 | all_intersection_losses_untuned = [] 322 | all_ligs_coords_pred_untuned = [] 323 | all_ligs_coords = [] 324 | all_ligs_keypts = [] 325 | all_recs_keypts = [] 326 | all_names = [] 327 | dp = args.dataset_params 328 | use_rdkit_coords = args.use_rdkit_coords if args.use_rdkit_coords != None else args.dataset_params[ 329 | 'use_rdkit_coords'] 330 | names = os.listdir(args.inference_path) if args.inference_path != None else tqdm(read_strings_from_txt('data/timesplit_test')) 331 | for idx, name in enumerate(names): 332 | print(f'\nProcessing {name}: complex {idx + 1} of {len(names)}') 333 | file_names = os.listdir(os.path.join(args.inference_path, name)) 334 | rec_name = [i for i in file_names if 'rec.pdb' in i or 'protein' in i][0] 335 | lig_names = [i for i in file_names if 'ligand' in i] 336 | rec_path = os.path.join(args.inference_path, name, rec_name) 337 | for lig_name in lig_names: 338 | if not os.path.exists(os.path.join(args.inference_path, name, lig_name)): 339 | raise ValueError(f'Path does not exist: {os.path.join(args.inference_path, name, lig_name)}') 340 | print(f'Trying to load {os.path.join(args.inference_path, name, lig_name)}') 341 | lig = read_molecule(os.path.join(args.inference_path, name, lig_name), sanitize=True) 342 | if lig != None: # read mol2 file if sdf file cannot be sanitized 343 | used_lig = os.path.join(args.inference_path, name, lig_name) 344 | break 345 | if lig_names == []: raise ValueError(f'No ligand files found. The ligand file has to contain \'ligand\'.') 346 | if lig == None: raise ValueError(f'None of the ligand files could be read: {lig_names}') 347 | print(f'Docking the receptor {os.path.join(args.inference_path, name, rec_name)}\nTo the ligand {used_lig}') 348 | 349 | rec, rec_coords, c_alpha_coords, n_coords, c_coords = get_receptor_inference(rec_path) 350 | rec_graph = get_rec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, 351 | use_rec_atoms=dp['use_rec_atoms'], rec_radius=dp['rec_graph_radius'], 352 | surface_max_neighbors=dp['surface_max_neighbors'], 353 | surface_graph_cutoff=dp['surface_graph_cutoff'], 354 | surface_mesh_cutoff=dp['surface_mesh_cutoff'], 355 | c_alpha_max_neighbors=dp['c_alpha_max_neighbors']) 356 | lig_graph = get_lig_graph_revised(lig, name, max_neighbors=dp['lig_max_neighbors'], 357 | use_rdkit_coords=use_rdkit_coords, radius=dp['lig_graph_radius']) 358 | if 'geometry_regularization' in dp and dp['geometry_regularization']: 359 | geometry_graph = get_geometry_graph(lig) 360 | elif 'geometry_regularization_ring' in dp and dp['geometry_regularization_ring']: 361 | geometry_graph = get_geometry_graph_ring(lig) 362 | else: 363 | geometry_graph = None 364 | 365 | start_lig_coords = lig_graph.ndata['x'] 366 | # Randomly rotate and translate the ligand. 367 | rot_T, rot_b = random_rotation_translation(translation_distance=5) 368 | if (use_rdkit_coords): 369 | lig_coords_to_move = lig_graph.ndata['new_x'] 370 | else: 371 | lig_coords_to_move = lig_graph.ndata['x'] 372 | mean_to_remove = lig_coords_to_move.mean(dim=0, keepdims=True) 373 | input_coords = (rot_T @ (lig_coords_to_move - mean_to_remove).T).T + rot_b 374 | lig_graph.ndata['new_x'] = input_coords 375 | 376 | if model == None: 377 | model = load_model(args, data_sample=(lig_graph, rec_graph), device=device) 378 | model.load_state_dict(checkpoint['model_state_dict']) 379 | model.to(device) 380 | model.eval() 381 | 382 | with torch.no_grad(): 383 | geometry_graph = geometry_graph.to(device) if geometry_graph != None else None 384 | ligs_coords_pred_untuned, ligs_keypts, recs_keypts, rotations, translations, geom_reg_loss = model( 385 | lig_graph.to(device), rec_graph.to(device), geometry_graph, complex_names=[name], epoch=0) 386 | 387 | for lig_coords_pred_untuned, lig_coords, lig_keypts, rec_keypts, rotation, translation in zip( 388 | ligs_coords_pred_untuned, [start_lig_coords], ligs_keypts, recs_keypts, rotations, 389 | translations, ): 390 | all_intersection_losses_untuned.append( 391 | compute_revised_intersection_loss(lig_coords_pred_untuned.detach().cpu(), rec_graph.ndata['x'], 392 | alpha=0.2, beta=8, aggression=0)) 393 | all_ligs_coords_pred_untuned.append(lig_coords_pred_untuned.detach().cpu()) 394 | all_ligs_coords.append(lig_coords.detach().cpu()) 395 | all_ligs_keypts.append(((rotation @ (lig_keypts).T).T + translation).detach().cpu()) 396 | all_recs_keypts.append(rec_keypts.detach().cpu()) 397 | 398 | if args.run_corrections: 399 | prediction = ligs_coords_pred_untuned[0].detach().cpu() 400 | lig_input = deepcopy(lig) 401 | conf = lig_input.GetConformer() 402 | for i in range(lig_input.GetNumAtoms()): 403 | x, y, z = input_coords.numpy()[i] 404 | conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) 405 | 406 | lig_equibind = deepcopy(lig) 407 | conf = lig_equibind.GetConformer() 408 | for i in range(lig_equibind.GetNumAtoms()): 409 | x, y, z = prediction.numpy()[i] 410 | conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) 411 | 412 | coords_pred = lig_equibind.GetConformer().GetPositions() 413 | 414 | Z_pt_cloud = coords_pred 415 | rotable_bonds = get_torsions([lig_input]) 416 | new_dihedrals = np.zeros(len(rotable_bonds)) 417 | for idx, r in enumerate(rotable_bonds): 418 | new_dihedrals[idx] = get_dihedral_vonMises(lig_input, lig_input.GetConformer(), r, Z_pt_cloud) 419 | optimized_mol = apply_changes(lig_input, new_dihedrals, rotable_bonds) 420 | 421 | coords_pred_optimized = optimized_mol.GetConformer().GetPositions() 422 | R, t = rigid_transform_Kabsch_3D(coords_pred_optimized.T, coords_pred.T) 423 | coords_pred_optimized = (R @ (coords_pred_optimized).T).T + t.squeeze() 424 | all_ligs_coords_corrected.append(coords_pred_optimized) 425 | 426 | if args.output_directory: 427 | if not os.path.exists(f'{args.output_directory}/{name}'): 428 | os.makedirs(f'{args.output_directory}/{name}') 429 | conf = optimized_mol.GetConformer() 430 | for i in range(optimized_mol.GetNumAtoms()): 431 | x, y, z = coords_pred_optimized[i] 432 | conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) 433 | block_optimized = Chem.MolToMolBlock(optimized_mol) 434 | print(f'Writing prediction to {args.output_directory}/{name}/lig_equibind_corrected.sdf') 435 | with open(f'{args.output_directory}/{name}/lig_equibind_corrected.sdf', "w") as newfile: 436 | newfile.write(block_optimized) 437 | all_names.append(name) 438 | 439 | path = os.path.join(os.path.dirname(args.checkpoint), f'predictions_RDKit{use_rdkit_coords}.pt') 440 | print(f'Saving predictions to {path}') 441 | results = {'corrected_predictions': all_ligs_coords_corrected, 'initial_predictions': all_ligs_coords_pred_untuned, 442 | 'targets': all_ligs_coords, 'lig_keypts': all_ligs_keypts, 'rec_keypts': all_recs_keypts, 443 | 'names': all_names, 'intersection_losses_untuned': all_intersection_losses_untuned, 444 | 'intersection_losses': all_intersection_losses} 445 | torch.save(results, path) 446 | 447 | 448 | if __name__ == '__main__': 449 | args, cmdline_args = parse_arguments() 450 | 451 | if args.config: 452 | config_dict = yaml.load(args.config, Loader=yaml.FullLoader) 453 | arg_dict = args.__dict__ 454 | for key, value in config_dict.items(): 455 | if isinstance(value, list): 456 | for v in value: 457 | arg_dict[key].append(v) 458 | else: 459 | if key in cmdline_args: 460 | continue 461 | arg_dict[key] = value 462 | args.config = args.config.name 463 | else: 464 | config_dict = {} 465 | 466 | for run_dir in args.run_dirs: 467 | args.checkpoint = f'runs/{run_dir}/best_checkpoint.pt' 468 | config_dict['checkpoint'] = f'runs/{run_dir}/best_checkpoint.pt' 469 | # overwrite args with args from checkpoint except for the args that were contained in the config file 470 | arg_dict = args.__dict__ 471 | with open(os.path.join(os.path.dirname(args.checkpoint), 'train_arguments.yaml'), 'r') as arg_file: 472 | checkpoint_dict = yaml.load(arg_file, Loader=yaml.FullLoader) 473 | for key, value in checkpoint_dict.items(): 474 | if (key not in config_dict.keys()) and (key not in cmdline_args): 475 | if isinstance(value, list): 476 | for v in value: 477 | arg_dict[key].append(v) 478 | else: 479 | arg_dict[key] = value 480 | args.model_parameters['noise_initial'] = 0 481 | if args.inference_path == None: 482 | inference(args) 483 | else: 484 | inference_from_files(args) 485 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # ``equibind.py`` is probably the model/file that you are interested in -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass 2 | from pkgutil import iter_modules 3 | from pathlib import Path 4 | from importlib import import_module 5 | 6 | # iterate through the modules in the current package 7 | package_dir = Path(__file__).resolve().parent 8 | for (_, module_name, _) in iter_modules([package_dir]): 9 | 10 | # import the module and iterate through its attributes 11 | module = import_module(f"{__name__}.{module_name}") 12 | for attribute_name in dir(module): 13 | attribute = getattr(module, attribute_name) 14 | 15 | if isclass(attribute): 16 | # Add the class to this package's variables 17 | globals()[attribute_name] = attribute -------------------------------------------------------------------------------- /multiligand_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import sys 4 | 5 | from copy import deepcopy 6 | 7 | import os 8 | 9 | from rdkit import Chem 10 | from rdkit.Geometry import Point3D 11 | 12 | from commons.geometry_utils import rigid_transform_Kabsch_3D, get_torsions, get_dihedral_vonMises, apply_changes 13 | from commons.process_mols import get_rec_graph, get_receptor_inference 14 | 15 | #from train import load_model 16 | 17 | from commons.utils import seed_all 18 | 19 | import yaml 20 | 21 | from datasets.custom_collate import * # do not remove 22 | from models import * # do not remove 23 | from torch.nn import * # do not remove 24 | from torch.optim import * # do not remove 25 | from commons.losses import * # do not remove 26 | from torch.optim.lr_scheduler import * # do not remove 27 | from torch.utils.data import DataLoader 28 | 29 | 30 | # turn on for debugging C code like Segmentation Faults 31 | import faulthandler 32 | from datasets import multiple_ligands 33 | 34 | faulthandler.enable() 35 | 36 | from models.equibind import EquiBind 37 | 38 | def parse_arguments(arglist = None): 39 | p = argparse.ArgumentParser() 40 | p.add_argument("-l", "--ligands_sdf", type=str, help = "A single sdf file containing all ligands to be screened when running in screening mode") 41 | p.add_argument("-r", "--rec_pdb", type = str, help = "The receptor to dock the ligands in --ligands_sdf against") 42 | p.add_argument('-o', '--output_directory', type=str, default=None, help='path where to put the predicted results') 43 | p.add_argument('--config', type=argparse.FileType(mode='r'), default=None) 44 | p.add_argument('--checkpoint', '--model', dest = "checkpoint", 45 | type=str, help='path to .pt file containing the model used for inference. ' 46 | 'Defaults to runs/flexible_self_docking/best_checkpoint.pt in the same directory as the file being run') 47 | p.add_argument('--train_args', type = str, help = "Path to a yaml file containing the parameters that were used to train the model. " 48 | "If not supplied, it is assumed that a file named 'train_arguments.yaml' is located in the same directory as the model checkpoint") 49 | p.add_argument('--no_skip', dest = "skip_in_output", action = "store_false", help = 'skip input files that already have corresponding folders in the output directory. Used to resume a large interrupted computation') 50 | p.add_argument('--batch_size', type=int, default=8, help='samples that will be processed in parallel') 51 | p.add_argument("--n_workers_data_load", type = int, default = 0, help = "The number of cores used for loading the ligands and generating the graphs used as input to the model. 0 means run in correct process.") 52 | p.add_argument('--use_rdkit_coords', action="store_true", help='override the rkdit usage behavior of the used model') 53 | p.add_argument('--device', type=str, default='cuda', help='What device to train on: cuda or cpu') 54 | p.add_argument('--seed', type=int, default=1, help='seed for reproducibility') 55 | p.add_argument('--num_confs', type=int, default=1, help='num_confs if using rdkit conformers') 56 | p.add_argument("--lig_slice", help = "Run only a slice of the provided ligand file. Like in python, this slice is HALF-OPEN. Should be provided in the format --lig_slice start,end") 57 | p.add_argument("--lazy_dataload", dest = "lazy_dataload", action="store_true", default = None, help = "Turns on lazy dataloading. If on, will postpone rdkit parsing of each ligand until it is requested.") 58 | p.add_argument("--no_lazy_dataload", dest = "lazy_dataload", action="store_false", default = None, help = "Turns off lazy dataloading. If on, will postpone rdkit parsing of each ligand until it is requested.") 59 | p.add_argument("--no_run_corrections", dest = "run_corrections", action = "store_false", help = "possibility of turning off running fast point cloud ligand fitting") 60 | 61 | cmdline_parser = deepcopy(p) 62 | args = p.parse_args(arglist) 63 | clear_defaults = {key: argparse.SUPPRESS for key in args.__dict__} 64 | cmdline_parser.set_defaults(**clear_defaults) 65 | cmdline_parser._defaults = {} 66 | cmdline_args = cmdline_parser.parse_args(arglist) 67 | 68 | return p.parse_args(arglist), set(cmdline_args.__dict__.keys()) 69 | 70 | def get_default_args(args, cmdline_args): 71 | if args.config: 72 | config_dict = yaml.load(args.config, Loader=yaml.FullLoader) 73 | arg_dict = args.__dict__ 74 | for key, value in config_dict.items(): 75 | if isinstance(value, list): 76 | for v in value: 77 | arg_dict[key].append(v) 78 | else: 79 | arg_dict[key] = value 80 | args.config = args.config.name 81 | else: 82 | config_dict = {} 83 | 84 | if args.checkpoint is None: 85 | args.checkpoint = os.path.join(os.path.dirname(__file__), "runs/flexible_self_docking/best_checkpoint.pt") 86 | 87 | config_dict['checkpoint'] = args.checkpoint 88 | # overwrite args with args from checkpoint except for the args that were contained in the config file or provided directly in the commandline 89 | arg_dict = args.__dict__ 90 | 91 | if args.train_args is None: 92 | with open(os.path.join(os.path.dirname(args.checkpoint), 'train_arguments.yaml'), 'r') as arg_file: 93 | checkpoint_dict = yaml.load(arg_file, Loader=yaml.FullLoader) 94 | else: 95 | with open(args.train_args, 'r') as arg_file: 96 | checkpoint_dict = yaml.load(arg_file, Loader=yaml.FullLoader) 97 | 98 | for key, value in checkpoint_dict.items(): 99 | if (key not in config_dict.keys()) and (key not in cmdline_args): 100 | if isinstance(value, list) and key in arg_dict: 101 | for v in value: 102 | arg_dict[key].append(v) 103 | else: 104 | arg_dict[key] = value 105 | args.model_parameters['noise_initial'] = 0 106 | return args 107 | 108 | def load_rec_and_model(args): 109 | device = torch.device("cuda:0" if torch.cuda.is_available() and args.device == 'cuda' else "cpu") 110 | print(f"device = {device}") 111 | # sys.exit() 112 | checkpoint = torch.load(args.checkpoint, map_location=device) 113 | dp = args.dataset_params 114 | 115 | model = EquiBind(device = device, lig_input_edge_feats_dim = 15, rec_input_edge_feats_dim = 27, **args.model_parameters) 116 | model.load_state_dict(checkpoint['model_state_dict']) 117 | model.to(device) 118 | model.eval() 119 | 120 | rec_path = args.rec_pdb 121 | rec, rec_coords, c_alpha_coords, n_coords, c_coords = get_receptor_inference(rec_path) 122 | rec_graph = get_rec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, 123 | use_rec_atoms=dp['use_rec_atoms'], rec_radius=dp['rec_graph_radius'], 124 | surface_max_neighbors=dp['surface_max_neighbors'], 125 | surface_graph_cutoff=dp['surface_graph_cutoff'], 126 | surface_mesh_cutoff=dp['surface_mesh_cutoff'], 127 | c_alpha_max_neighbors=dp['c_alpha_max_neighbors']) 128 | 129 | return rec_graph, model 130 | 131 | def run_batch(model, ligs, lig_coords, lig_graphs, rec_graphs, geometry_graphs, true_indices): 132 | try: 133 | predictions = model(lig_graphs, rec_graphs, geometry_graphs)[0] 134 | out_ligs = ligs 135 | out_lig_coords = lig_coords 136 | names = [lig.GetProp("_Name") for lig in ligs] 137 | successes = list(zip(true_indices, names)) 138 | failures = [] 139 | except AssertionError: 140 | lig_graphs, rec_graphs, geometry_graphs = (dgl.unbatch(lig_graphs), 141 | dgl.unbatch(rec_graphs), dgl.unbatch(geometry_graphs)) 142 | predictions = [] 143 | out_ligs = [] 144 | out_lig_coords = [] 145 | successes = [] 146 | failures = [] 147 | for lig, lig_coord, lig_graph, rec_graph, geometry_graph, true_index in zip(ligs, lig_coords, lig_graphs, rec_graphs, geometry_graphs, true_indices): 148 | try: 149 | output = model(lig_graph, rec_graph, geometry_graph) 150 | except AssertionError as e: 151 | failures.append((true_index, lig.GetProp("_Name"))) 152 | print(f"Failed for {lig.GetProp('_Name')}") 153 | else: 154 | out_ligs.append(lig) 155 | out_lig_coords.append(lig_coord) 156 | predictions.append(output[0][0]) 157 | successes.append((true_index, lig.GetProp("_Name"))) 158 | assert len(predictions) == len(out_ligs) 159 | return out_ligs, out_lig_coords, predictions, successes, failures 160 | 161 | def run_corrections(lig, lig_coord, ligs_coords_pred_untuned): 162 | input_coords = lig_coord.detach().cpu() 163 | prediction = ligs_coords_pred_untuned.detach().cpu() 164 | lig_input = deepcopy(lig) 165 | conf = lig_input.GetConformer() 166 | for i in range(lig_input.GetNumAtoms()): 167 | x, y, z = input_coords.numpy()[i] 168 | conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) 169 | 170 | lig_equibind = deepcopy(lig) 171 | conf = lig_equibind.GetConformer() 172 | for i in range(lig_equibind.GetNumAtoms()): 173 | x, y, z = prediction.numpy()[i] 174 | conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) 175 | 176 | coords_pred = lig_equibind.GetConformer().GetPositions() 177 | 178 | Z_pt_cloud = coords_pred 179 | rotable_bonds = get_torsions([lig_input]) 180 | new_dihedrals = np.zeros(len(rotable_bonds)) 181 | for idx, r in enumerate(rotable_bonds): 182 | new_dihedrals[idx] = get_dihedral_vonMises(lig_input, lig_input.GetConformer(), r, Z_pt_cloud) 183 | optimized_mol = apply_changes(lig_input, new_dihedrals, rotable_bonds) 184 | optimized_conf = optimized_mol.GetConformer() 185 | coords_pred_optimized = optimized_conf.GetPositions() 186 | R, t = rigid_transform_Kabsch_3D(coords_pred_optimized.T, coords_pred.T) 187 | coords_pred_optimized = (R @ (coords_pred_optimized).T).T + t.squeeze() 188 | for i in range(optimized_mol.GetNumAtoms()): 189 | x, y, z = coords_pred_optimized[i] 190 | optimized_conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) 191 | return optimized_mol 192 | 193 | def write_while_inferring(dataloader, model, args): 194 | 195 | full_output_path = os.path.join(args.output_directory, "output.sdf") 196 | full_failed_path = os.path.join(args.output_directory, "failed.txt") 197 | full_success_path = os.path.join(args.output_directory, "success.txt") 198 | 199 | w_or_a = "a" if args.skip_in_output else "w" 200 | with torch.no_grad(), open(full_output_path, w_or_a) as file, open( 201 | full_failed_path, "a") as failed_file, open(full_success_path, w_or_a) as success_file: 202 | with Chem.SDWriter(file) as writer: 203 | i = 0 204 | total_ligs = len(dataloader.dataset) 205 | for batch in dataloader: 206 | i += args.batch_size 207 | print(f"Entering batch ending in index {min(i, total_ligs)}/{len(dataloader.dataset)}") 208 | ligs, lig_coords, lig_graphs, rec_graphs, geometry_graphs, true_indices, failed_in_batch = batch 209 | for failure in failed_in_batch: 210 | if failure[1] == "Skipped": 211 | continue 212 | failed_file.write(f"{failure[0]} {failure[1]}") 213 | failed_file.write("\n") 214 | if ligs is None: 215 | continue 216 | lig_graphs = lig_graphs.to(args.device) 217 | rec_graphs = rec_graphs.to(args.device) 218 | geometry_graphs = geometry_graphs.to(args.device) 219 | 220 | 221 | out_ligs, out_lig_coords, predictions, successes, failures = run_batch(model, ligs, lig_coords, 222 | lig_graphs, rec_graphs, 223 | geometry_graphs, true_indices) 224 | opt_mols = [run_corrections(lig, lig_coord, prediction) for lig, lig_coord, prediction in zip(out_ligs, out_lig_coords, predictions)] 225 | for mol, success in zip(opt_mols, successes): 226 | writer.write(mol) 227 | success_file.write(f"{success[0]} {success[1]}") 228 | success_file.write("\n") 229 | # print(f"written {mol.GetProp('_Name')} to output") 230 | for failure in failures: 231 | failed_file.write(f"{failure[0]} {failure[1]}") 232 | failed_file.write("\n") 233 | 234 | def main(arglist = None): 235 | args, cmdline_args = parse_arguments(arglist) 236 | 237 | args = get_default_args(args, cmdline_args) 238 | assert args.output_directory, "An output directory should be specified" 239 | assert args.ligands_sdf, "No ligand sdf specified" 240 | assert args.rec_pdb, "No protein specified" 241 | seed_all(args.seed) 242 | 243 | os.makedirs(args.output_directory, exist_ok = True) 244 | 245 | success_path = os.path.join(args.output_directory, "success.txt") 246 | failed_path = os.path.join(args.output_directory, "failed.txt") 247 | if os.path.exists(success_path) and os.path.exists(failed_path) and args.skip_in_output: 248 | with open(success_path) as successes, open(failed_path) as failures: 249 | previous_work = successes.readlines() 250 | previous_work += failures.readlines() 251 | previous_work = set(map(lambda tup: int(tup.split(" ")[0]), previous_work)) 252 | print(f"Found {len(previous_work)} previously calculated ligands") 253 | else: 254 | previous_work = None 255 | 256 | 257 | rec_graph, model = load_rec_and_model(args) 258 | if args.lig_slice is not None: 259 | lig_slice = tuple(map(int, args.lig_slice.split(","))) 260 | else: 261 | lig_slice = None 262 | 263 | lig_data = multiple_ligands.Ligands(args.ligands_sdf, rec_graph, args, slice = lig_slice, skips = previous_work, lazy = args.lazy_dataload) 264 | lig_loader = DataLoader(lig_data, batch_size = args.batch_size, collate_fn = lig_data.collate, num_workers = args.n_workers_data_load) 265 | 266 | full_failed_path = os.path.join(args.output_directory, "failed.txt") 267 | with open(full_failed_path, "a" if args.skip_in_output else "w") as failed_file: 268 | for failure in lig_data.failed_ligs: 269 | failed_file.write(f"{failure[0]} {failure[1]}") 270 | failed_file.write("\n") 271 | 272 | write_while_inferring(lig_loader, model, args) 273 | 274 | if __name__ == '__main__': 275 | main() -------------------------------------------------------------------------------- /runs/flexible_self_docking/best_checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HannesStark/EquiBind/4cb1b4c562dae914780154518a6b915bb4cba658/runs/flexible_self_docking/best_checkpoint.pt -------------------------------------------------------------------------------- /runs/flexible_self_docking/train_arguments.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 8 2 | checkpoint: 3 | clip_grad: 100 4 | collate_function: graph_collate_revised 5 | collate_params: {} 6 | config: runs/3_noLRDKitCoordsLigFlexLigAtomPocketLowerLRAndScheduleGeomRegNoIntersectLoss_layers8_bs8_otL1_iL0_dim64_nAttH30_normBN_normc0_normf0_recAtomsFalse_numtrainNone_date26-01_time04-49-06.466894/3.yml 7 | data_seed: 1 8 | 9 | dataset_params: 10 | bsp_proteins: false 11 | c_alpha_max_neighbors: 10 12 | chain_radius: 10 13 | dataset_size: 14 | geometry_regularization: true 15 | lig_graph_radius: 5 16 | lig_max_neighbors: 17 | min_shell_thickness: 3 18 | n_jobs: 20 19 | only_polar_hydrogens: false 20 | pocket_cutoff: 4 21 | pocket_mode: match_atoms_to_lig 22 | rec_graph_radius: 30 23 | rec_subgraph: false 24 | remove_h: false 25 | subgraph_augmentation: false 26 | subgraph_cutoff: 4 27 | subgraph_max_neigbor: 8 28 | subgraph_radius: 10 29 | surface_graph_cutoff: 5 30 | surface_max_neighbors: 5 31 | surface_mesh_cutoff: 2 32 | translation_distance: 5.0 33 | use_rdkit_coords: true 34 | use_rec_atoms: false 35 | device: cuda 36 | eval_on_test: true 37 | eval_per_epochs: 0 38 | expensive_log_iterations: 100 39 | experiment_name: noLRDKitCoordsLigFlexLigAtomPocketLowerLRAndScheduleGeomRegNoIntersectLoss 40 | log_iterations: 100 41 | logdir: runs 42 | loss_func: BindingLoss 43 | loss_params: 44 | centroid_loss_weight: 0 45 | intersection_loss_weight: 0 46 | intersection_sigma: 8 47 | intersection_surface_ct: 1 48 | kabsch_rmsd_weight: 1 49 | key_point_alignmen_loss_weight: 0 50 | ot_loss_weight: 1 51 | translated_lig_kpt_ot_loss: false 52 | lr_scheduler: ReduceLROnPlateau 53 | lr_scheduler_params: 54 | factor: 0.6 55 | min_lr: 8.0e-06 56 | mode: max 57 | patience: 60 58 | verbose: true 59 | main_metric: rmsd_less_than_2 60 | main_metric_goal: max 61 | metrics: 62 | - pearsonr 63 | - rsquared 64 | - mean_rmsd 65 | - median_rmsd 66 | - median_centroid_distance 67 | - centroid_distance_less_than_2 68 | - mean_centroid_distance 69 | - kabsch_rmsd 70 | - rmsd_less_than_2 71 | - rmsd_less_than_5 72 | minimum_epochs: 0 73 | model_parameters: 74 | centroid_keypts_construction: false 75 | centroid_keypts_construction_lig: false 76 | centroid_keypts_construction_rec: false 77 | cross_msgs: true 78 | debug: false 79 | dropout: 0.1 80 | final_h_layer_norm: 0 81 | geometry_reg_step_size: 0.001 82 | geometry_regularization: true 83 | iegmn_lay_hid_dim: 64 84 | layer_norm: BN 85 | layer_norm_coords: 0 86 | leakyrelu_neg_slope: 0.01 87 | lig_evolve: true 88 | lig_no_softmax: false 89 | move_keypts_back: true 90 | n_lays: 8 91 | noise_decay_rate: 0.5 92 | noise_initial: 1 93 | nonlin: lkyrelu 94 | normalize_Z_lig_directions: false 95 | normalize_Z_rec_directions: false 96 | normalize_coordinate_update: true 97 | num_att_heads: 30 98 | num_lig_feats: 99 | post_crossmsg_norm_type: 0 100 | pre_crossmsg_norm_type: 0 101 | random_vec_dim: 0 102 | random_vec_std: 1 103 | rec_evolve: true 104 | rec_no_softmax: false 105 | rec_square_distance_scale: 10 106 | residue_emb_dim: 64 107 | shared_layers: false 108 | skip_weight_h: 0.5 109 | standard_norm_order: true 110 | unnormalized_kpt_weights: false 111 | use_dist_in_layers: true 112 | use_edge_features_in_gmn: true 113 | use_evolved_lig: true 114 | use_mean_node_features: true 115 | use_rec_atoms: false 116 | use_scalar_features: false 117 | x_connection_init: 0.25 118 | model_type: EquiBind 119 | models_to_save: [] 120 | multithreaded_seeds: [] 121 | num_epochs: 1000000 122 | num_train: 123 | num_val: 124 | num_workers: 0 125 | optimizer: Adam 126 | optimizer_params: 127 | lr: 0.0001 128 | weight_decay: 0.0001 129 | patience: 150 130 | pin_memory: true 131 | sampler_parameters: 132 | scheduler_step_per_batch: false 133 | seed: 1 134 | seed_data: 1 135 | tensorboard_functions: [] 136 | test_names: data/new_names 137 | train_names: data/old_no_newL_train 138 | train_predictions_name: 139 | train_sampler: 140 | trainer: revised 141 | val_names: data/old_no_newL_val 142 | val_per_batch: true 143 | val_predictions_name: 144 | -------------------------------------------------------------------------------- /runs/rigid_redocking/best_checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HannesStark/EquiBind/4cb1b4c562dae914780154518a6b915bb4cba658/runs/rigid_redocking/best_checkpoint.pt -------------------------------------------------------------------------------- /runs/rigid_redocking/train_arguments.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 8 2 | checkpoint: runs/23_noLDropoutHighLR_layers8_bs8_otL1_iL1_dim64_nAttH30_normBN_normc0_normf0_recAtomsFalse_numtrainNone_date13-01_time04-28-23.818045/last_checkpoint.pt 3 | clip_grad: 100 4 | collate_function: graph_collate_revised 5 | collate_params: {} 6 | config: runs/23_noLDropoutHighLR_layers8_bs8_otL1_iL1_dim64_nAttH30_normBN_normc0_normf0_recAtomsFalse_numtrainNone_date13-01_time04-28-23.818045/23.yml 7 | data_seed: 1 8 | 9 | dataset_params: 10 | bsp_proteins: false 11 | c_alpha_max_neighbors: 10 12 | chain_radius: 10 13 | dataset_size: 14 | lig_graph_radius: 5 15 | lig_max_neighbors: 16 | min_shell_thickness: 3 17 | n_jobs: 20 18 | only_polar_hydrogens: false 19 | pocket_cutoff: 4 20 | pocket_mode: match_atoms 21 | rec_graph_radius: 30 22 | remove_h: false 23 | subgraph_augmentation: false 24 | surface_graph_cutoff: 5 25 | surface_max_neighbors: 5 26 | surface_mesh_cutoff: 2 27 | translation_distance: 5.0 28 | use_rdkit_coords: false 29 | use_rec_atoms: false 30 | device: cuda 31 | eval_on_test: true 32 | eval_per_epochs: 0 33 | expensive_log_iterations: 100 34 | experiment_name: noLDropoutHighLR 35 | log_iterations: 100 36 | logdir: runs 37 | loss_func: BindingLoss 38 | loss_params: 39 | centroid_loss_weight: 0 40 | intersection_loss_weight: 1 41 | intersection_sigma: 8 42 | intersection_surface_ct: 1 43 | key_point_alignmen_loss_weight: 1 44 | ot_loss_weight: 1 45 | lr_scheduler: 46 | lr_scheduler_params: 47 | factor: 0.6 48 | min_lr: 8.0e-06 49 | mode: max 50 | patience: 10 51 | verbose: true 52 | main_metric: rmsd_less_than_2 53 | main_metric_goal: max 54 | metrics: 55 | - pearsonr 56 | - rsquared 57 | - mean_rmsd 58 | - median_rmsd 59 | - median_centroid_distance 60 | - centroid_distance_less_than_2 61 | - mean_centroid_distance 62 | - rmsd_less_than_2 63 | - rmsd_less_than_5 64 | minimum_epochs: 0 65 | model_parameters: 66 | centroid_keypts_construction: false 67 | centroid_keypts_construction_lig: false 68 | centroid_keypts_construction_rec: false 69 | cross_msgs: true 70 | debug: false 71 | dropout: 0.1 72 | final_h_layer_norm: 0 73 | iegmn_lay_hid_dim: 64 74 | layer_norm: BN 75 | layer_norm_coords: 0 76 | leakyrelu_neg_slope: 0.01 77 | lig_evolve: true 78 | lig_no_softmax: false 79 | move_keypts_back: true 80 | n_lays: 8 81 | noise_decay_rate: 0.5 82 | noise_initial: 1 83 | nonlin: lkyrelu 84 | normalize_Z_lig_directions: false 85 | normalize_Z_rec_directions: false 86 | normalize_coordinate_update: true 87 | num_att_heads: 30 88 | num_lig_feats: 89 | random_vec_dim: 0 90 | random_vec_std: 1 91 | rec_evolve: true 92 | rec_no_softmax: false 93 | rec_square_distance_scale: 10 94 | residue_emb_dim: 64 95 | shared_layers: false 96 | skip_weight_h: 0.5 97 | unnormalized_kpt_weights: false 98 | use_dist_in_layers: true 99 | use_edge_features_in_gmn: true 100 | use_mean_node_features: true 101 | use_rec_atoms: false 102 | use_scalar_features: false 103 | x_connection_init: 0.25 104 | model_type: EquiBind 105 | models_to_save: [] 106 | multithreaded_seeds: [] 107 | num_epochs: 1000000 108 | num_train: 109 | num_val: 110 | num_workers: 0 111 | optimizer: Adam 112 | optimizer_params: 113 | lr: 0.0001 114 | weight_decay: 0.0001 115 | patience: 150 116 | pin_memory: true 117 | sampler_parameters: 118 | scheduler_step_per_batch: false 119 | seed: 1 120 | seed_data: 1 121 | tensorboard_functions: [] 122 | test_names: data/new_names 123 | train_names: data/old_no_newL_train 124 | train_predictions_name: 125 | train_sampler: 126 | trainer: revised 127 | val_names: data/old_no_newL_val 128 | val_per_batch: true 129 | val_predictions_name: 130 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import concurrent.futures 3 | 4 | import os 5 | import sys 6 | import traceback 7 | from collections import defaultdict 8 | from datetime import datetime 9 | 10 | 11 | from commons.logger import Logger 12 | from datasets.samplers import HardSampler 13 | from trainer.binding_trainer import BindingTrainer 14 | 15 | 16 | from datasets.pdbbind import PDBBind 17 | 18 | from commons.utils import seed_all, get_random_indices, log 19 | 20 | import yaml 21 | 22 | from datasets.custom_collate import * # do not remove 23 | from models import * # do not remove 24 | from torch.nn import * # do not remove 25 | from torch.optim import * # do not remove 26 | from commons.losses import * # do not remove 27 | from torch.optim.lr_scheduler import * # do not remove 28 | 29 | from torch.utils.data import DataLoader, Subset 30 | 31 | from trainer.metrics import Rsquared, MeanPredictorLoss, MAE, PearsonR, RMSD, RMSDfraction, CentroidDist, \ 32 | CentroidDistFraction, RMSDmedian, CentroidDistMedian, KabschRMSD 33 | from trainer.trainer import Trainer 34 | 35 | # turn on for debugging for C code like Segmentation Faults 36 | import faulthandler 37 | 38 | faulthandler.enable() 39 | 40 | 41 | def parse_arguments(): 42 | p = argparse.ArgumentParser() 43 | p.add_argument('--config', type=argparse.FileType(mode='r'), default='configs_clean/RDKitCoords_flexible_self_docking.yml') 44 | p.add_argument('--experiment_name', type=str, help='name that will be added to the runs folder output') 45 | p.add_argument('--logdir', type=str, default='runs', help='tensorboard logdirectory') 46 | p.add_argument('--num_epochs', type=int, default=2500, help='number of times to iterate through all samples') 47 | p.add_argument('--batch_size', type=int, default=1024, help='samples that will be processed in parallel') 48 | p.add_argument('--patience', type=int, default=20, help='stop training after no improvement in this many epochs') 49 | p.add_argument('--minimum_epochs', type=int, default=0, help='minimum number of epochs to run') 50 | p.add_argument('--dataset_params', type=dict, default={}, help='parameters with keywords of the dataset') 51 | p.add_argument('--dataset', type=str, default='pdbbind', help='which dataset to use') 52 | p.add_argument('--num_train', type=int, default=-1, help='n samples of the model samples to use for train') 53 | p.add_argument('--num_val', type=int, default=None, help='n samples of the model samples to use for validation') 54 | p.add_argument('--seed', type=int, default=1, help='seed for reproducibility') 55 | p.add_argument('--multithreaded_seeds', type=list, default=[], 56 | help='if this is non empty, multiple threads will be started, training the same model but with the different seeds') 57 | p.add_argument('--seed_data', type=int, default=1, help='if you want to use a different seed for the datasplit') 58 | p.add_argument('--loss_func', type=str, default='MSELoss', help='Class name of torch.nn like [MSELoss, L1Loss]') 59 | p.add_argument('--loss_params', type=dict, default={}, help='parameters with keywords of the chosen loss function') 60 | p.add_argument('--optimizer', type=str, default='Adam', help='Class name of torch.optim like [Adam, SGD, AdamW]') 61 | p.add_argument('--optimizer_params', type=dict, help='parameters with keywords of the chosen optimizer like lr') 62 | p.add_argument('--clip_grad', type=float, default=None, help='clip gradients if magnitude is greater') 63 | p.add_argument('--lr_scheduler', type=str, 64 | help='Class name of torch.optim.lr_scheduler like [CosineAnnealingLR, ExponentialLR, LambdaLR]') 65 | p.add_argument('--lr_scheduler_params', type=dict, help='parameters with keywords of the chosen lr_scheduler') 66 | p.add_argument('--scheduler_step_per_batch', default=True, type=bool, 67 | help='step every batch if true step every epoch otherwise') 68 | p.add_argument('--log_iterations', type=int, default=-1, 69 | help='log every log_iterations iterations (-1 for only logging after each epoch)') 70 | p.add_argument('--expensive_log_iterations', type=int, default=100, 71 | help='frequency with which to do expensive logging operations') 72 | p.add_argument('--eval_per_epochs', type=int, default=0, 73 | help='frequency with which to do run the function run_eval_per_epoch that can do some expensive calculations on the val set or sth like that. If this is zero, then the function will never be called') 74 | p.add_argument('--metrics', default=[], help='tensorboard metrics [mae, mae_denormalized, qm9_properties ...]') 75 | p.add_argument('--main_metric', default='loss', help='for early stopping etc.') 76 | p.add_argument('--main_metric_goal', type=str, default='min', help='controls early stopping. [max, min]') 77 | p.add_argument('--val_per_batch', type=bool, default=True, 78 | help='run evaluation every batch and then average over the eval results. When running the molhiv benchmark for example, this needs to be Fale because we need to evaluate on all val data at once since the metric is rocauc') 79 | p.add_argument('--tensorboard_functions', default=[], help='choices of the TENSORBOARD_FUNCTIONS in utils') 80 | p.add_argument('--checkpoint', type=str, help='path to directory that contains a checkpoint to continue training') 81 | 82 | p.add_argument('--collate_function', default='graph_collate', help='the collate function to use for DataLoader') 83 | p.add_argument('--collate_params', type=dict, default={}, 84 | help='parameters with keywords of the chosen collate function') 85 | p.add_argument('--device', type=str, default='cuda', help='What device to train on: cuda or cpu') 86 | 87 | p.add_argument('--models_to_save', type=list, default=[], 88 | help='specify after which epochs to remember the best model') 89 | 90 | p.add_argument('--model_type', type=str, default='MPNN', help='Classname of one of the models in the models dir') 91 | p.add_argument('--model_parameters', type=dict, help='dictionary of model parameters') 92 | 93 | p.add_argument('--trainer', type=str, default='binding', help='') 94 | p.add_argument('--train_sampler', type=str, default=None, help='any of pytorchs samplers or a custom sampler') 95 | p.add_argument('--train_predictions_name', type=str, default=None, help='') 96 | p.add_argument('--val_predictions_name', type=str, default=None, help='') 97 | p.add_argument('--sampler_parameters', type=dict, help='dictionary of sampler parameters') 98 | 99 | p.add_argument('--eval_on_test', type=bool, default=True, help='runs evaluation on test set if true') 100 | p.add_argument('--pin_memory', type=bool, default=True, help='pin memory argument for pytorch dataloaders') 101 | p.add_argument('--num_workers', type=bool, default=0, help='num workers argument of dataloaders') 102 | 103 | return p.parse_args() 104 | 105 | 106 | def get_trainer(args, model, data, device, metrics, run_dir, sampler=None): 107 | if args.trainer == None: 108 | trainer = Trainer 109 | elif args.trainer == 'binding': 110 | trainer = BindingTrainer 111 | 112 | return trainer(model=model, args=args, metrics=metrics, main_metric=args.main_metric, 113 | main_metric_goal=args.main_metric_goal, optim=globals()[args.optimizer], 114 | loss_func=globals()[args.loss_func](**args.loss_params), device=device, scheduler_step_per_batch=args.scheduler_step_per_batch, 115 | run_dir=run_dir, sampler=sampler) 116 | 117 | 118 | def load_model(args, data_sample, device, **kwargs): 119 | model = globals()[args.model_type](device=device, 120 | lig_input_edge_feats_dim=data_sample[0].edata['feat'].shape[1], 121 | rec_input_edge_feats_dim=data_sample[1].edata['feat'].shape[1], 122 | **args.model_parameters, **kwargs) 123 | return model 124 | 125 | 126 | def train_wrapper(args): 127 | mp = args.model_parameters 128 | lp = args.loss_params 129 | if args.checkpoint: 130 | run_dir = os.path.dirname(args.checkpoint) 131 | else: 132 | if args.trainer == 'torsion': 133 | run_dir= f'{args.logdir}/{os.path.splitext(os.path.basename(args.config))[0]}_{args.experiment_name}_layers{mp["n_lays"]}_bs{args.batch_size}_dim{mp["iegmn_lay_hid_dim"]}_nAttH{mp["num_att_heads"]}_norm{mp["layer_norm"]}_normc{mp["layer_norm_coords"]}_normf{mp["final_h_layer_norm"]}_recAtoms{mp["use_rec_atoms"]}_numtrain{args.num_train}_{start_time}' 134 | else: 135 | run_dir = f'{args.logdir}/{os.path.splitext(os.path.basename(args.config))[0]}_{args.experiment_name}_layers{mp["n_lays"]}_bs{args.batch_size}_otL{lp["ot_loss_weight"]}_iL{lp["intersection_loss_weight"]}_dim{mp["iegmn_lay_hid_dim"]}_nAttH{mp["num_att_heads"]}_norm{mp["layer_norm"]}_normc{mp["layer_norm_coords"]}_normf{mp["final_h_layer_norm"]}_recAtoms{mp["use_rec_atoms"]}_numtrain{args.num_train}_{start_time}' 136 | if not os.path.exists(run_dir): 137 | os.mkdir(run_dir) 138 | 139 | sys.stdout = Logger(logpath=os.path.join(run_dir, f'log.log'), syspart=sys.stdout) 140 | sys.stderr = Logger(logpath=os.path.join(run_dir, f'log.log'), syspart=sys.stderr) 141 | return train(args, run_dir) 142 | 143 | 144 | 145 | def train(args, run_dir): 146 | seed_all(args.seed) 147 | device = torch.device("cuda:0" if torch.cuda.is_available() and args.device == 'cuda' else "cpu") 148 | metrics_dict = {'rsquared': Rsquared(), 149 | 'mean_rmsd': RMSD(), 150 | 'mean_centroid_distance': CentroidDist(), 151 | 'rmsd_less_than_2': RMSDfraction(2), 152 | 'rmsd_less_than_5': RMSDfraction(5), 153 | 'rmsd_less_than_10': RMSDfraction(10), 154 | 'rmsd_less_than_20': RMSDfraction(20), 155 | 'rmsd_less_than_50': RMSDfraction(50), 156 | 'median_rmsd': RMSDmedian(), 157 | 'median_centroid_distance': CentroidDistMedian(), 158 | 'centroid_distance_less_than_2': CentroidDistFraction(2), 159 | 'centroid_distance_less_than_5': CentroidDistFraction(5), 160 | 'centroid_distance_less_than_10': CentroidDistFraction(10), 161 | 'centroid_distance_less_than_20': CentroidDistFraction(20), 162 | 'centroid_distance_less_than_50': CentroidDistFraction(50), 163 | 'kabsch_rmsd': KabschRMSD(), 164 | 'mae': MAE(), 165 | 'pearsonr': PearsonR(), 166 | 'mean_predictor_loss': MeanPredictorLoss(globals()[args.loss_func](**args.loss_params)), 167 | } 168 | 169 | train_data = PDBBind(device=device, complex_names_path=args.train_names,lig_predictions_name=args.train_predictions_name, is_train_data=True, **args.dataset_params) 170 | val_data = PDBBind(device=device, complex_names_path=args.val_names,lig_predictions_name=args.val_predictions_name, **args.dataset_params) 171 | 172 | if args.num_train != None: 173 | train_data = Subset(train_data, get_random_indices(len(train_data))[:args.num_train]) 174 | if args.num_val != None: 175 | val_data = Subset(val_data, get_random_indices(len(val_data))[:args.num_val]) 176 | 177 | log('train size: ', len(train_data)) 178 | log('val size: ', len(val_data)) 179 | 180 | model = load_model(args, data_sample=train_data[0], device=device) 181 | log('trainable params in model: ', sum(p.numel() for p in model.parameters() if p.requires_grad)) 182 | collate_function = globals()[args.collate_function] if args.collate_params == {} else globals()[ 183 | args.collate_function](**args.collate_params) 184 | if args.train_sampler != None: 185 | sampler = globals()[args.train_sampler](data_source=train_data, batch_size=args.batch_size) 186 | train_loader = DataLoader(train_data, batch_sampler=sampler, collate_fn=collate_function, 187 | pin_memory=args.pin_memory, num_workers=args.num_workers) 188 | else: 189 | sampler = None 190 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, collate_fn=collate_function, 191 | pin_memory=args.pin_memory, num_workers=args.num_workers) 192 | val_loader = DataLoader(val_data, batch_size=args.batch_size, collate_fn=collate_function, 193 | pin_memory=args.pin_memory, num_workers=args.num_workers) 194 | 195 | metrics = {metric: metrics_dict[metric] for metric in args.metrics} 196 | trainer = get_trainer(args=args, model=model, data=train_data, device=device, metrics=metrics, run_dir=run_dir, 197 | sampler=sampler) 198 | val_metrics, _, _ = trainer.train(train_loader, val_loader) 199 | if args.eval_on_test: 200 | test_data = PDBBind(device=device, complex_names_path=args.test_names, **args.dataset_params) 201 | test_loader = DataLoader(test_data, batch_size=args.batch_size, collate_fn=collate_function, 202 | pin_memory=args.pin_memory, num_workers=args.num_workers) 203 | log('test size: ', len(test_data)) 204 | test_metrics, _, _ = trainer.evaluation(test_loader, data_split='test') 205 | return val_metrics, test_metrics, trainer.writer.log_dir 206 | return val_metrics 207 | 208 | 209 | def get_arguments(): 210 | args = parse_arguments() 211 | if args.config: 212 | config_dict = yaml.load(args.config, Loader=yaml.FullLoader) 213 | arg_dict = args.__dict__ 214 | for key, value in config_dict.items(): 215 | if isinstance(value, list): 216 | for v in value: 217 | arg_dict[key].append(v) 218 | else: 219 | arg_dict[key] = value 220 | args.config = args.config.name 221 | else: 222 | config_dict = {} 223 | 224 | if args.checkpoint: # overwrite args with args from checkpoint except for the args that were contained in the config file 225 | arg_dict = args.__dict__ 226 | with open(os.path.join(os.path.dirname(args.checkpoint), 'train_arguments.yaml'), 'r') as arg_file: 227 | checkpoint_dict = yaml.load(arg_file, Loader=yaml.FullLoader) 228 | for key, value in checkpoint_dict.items(): 229 | if key not in config_dict.keys(): 230 | if isinstance(value, list): 231 | for v in value: 232 | arg_dict[key].append(v) 233 | else: 234 | arg_dict[key] = value 235 | 236 | return args 237 | 238 | 239 | def main_function(): 240 | args = get_arguments() 241 | 242 | if args.multithreaded_seeds != []: 243 | with concurrent.futures.ThreadPoolExecutor() as executor: 244 | futures = [] 245 | for seed in args.multithreaded_seeds: 246 | args_copy = get_arguments() 247 | args_copy.seed = seed 248 | futures.append(executor.submit(train_wrapper, args_copy)) 249 | results = [f.result() for f in 250 | futures] # list of tuples of dictionaries with the validation results first and the test results second 251 | all_val_metrics = defaultdict(list) 252 | all_test_metrics = defaultdict(list) 253 | log_dirs = [] 254 | for result in results: 255 | val_metrics, test_metrics, log_dir = result 256 | log_dirs.append(log_dir) 257 | for key in val_metrics.keys(): 258 | all_val_metrics[key].append(val_metrics[key]) 259 | all_test_metrics[key].append(test_metrics[key]) 260 | files = [open(os.path.join(dir, 'multiple_seed_validation_statistics.txt'), 'w') for dir in log_dirs] 261 | print('Validation results:') 262 | for key, value in all_val_metrics.items(): 263 | metric = np.array(value) 264 | for file in files: 265 | file.write(f'\n{key:}\n') 266 | file.write(f'mean: {metric.mean()}\n') 267 | file.write(f'stddev: {metric.std()}\n') 268 | file.write(f'stderr: {metric.std() / np.sqrt(len(metric))}\n') 269 | file.write(f'values: {value}\n') 270 | print(f'\n{key}:') 271 | print(f'mean: {metric.mean()}') 272 | print(f'stddev: {metric.std()}') 273 | print(f'stderr: {metric.std() / np.sqrt(len(metric))}') 274 | print(f'values: {value}') 275 | for file in files: 276 | file.close() 277 | files = [open(os.path.join(dir, 'multiple_seed_test_statistics.txt'), 'w') for dir in log_dirs] 278 | print('Test results:') 279 | for key, value in all_test_metrics.items(): 280 | metric = np.array(value) 281 | for file in files: 282 | file.write(f'\n{key:}\n') 283 | file.write(f'mean: {metric.mean()}\n') 284 | file.write(f'stddev: {metric.std()}\n') 285 | file.write(f'stderr: {metric.std() / np.sqrt(len(metric))}\n') 286 | file.write(f'values: {value}\n') 287 | print(f'\n{key}:') 288 | print(f'mean: {metric.mean()}') 289 | print(f'stddev: {metric.std()}') 290 | print(f'stderr: {metric.std() / np.sqrt(len(metric))}') 291 | print(f'values: {value}') 292 | for file in files: 293 | file.close() 294 | else: 295 | train_wrapper(args) 296 | 297 | 298 | if __name__ == '__main__': 299 | start_time = datetime.now().strftime('date%d-%m_time%H-%M-%S.%f') 300 | if not os.path.exists('logs'): 301 | os.mkdir('logs') 302 | with open(os.path.join('logs', f'{start_time}.log'), "w") as file: 303 | try: 304 | main_function() 305 | except Exception as e: 306 | traceback.print_exc(file=file) 307 | raise 308 | -------------------------------------------------------------------------------- /trainer/README.md: -------------------------------------------------------------------------------- 1 | # ``binding_trainer_revised.py`` (inherits from ``Trainer``) is the standard file that is used -------------------------------------------------------------------------------- /trainer/binding_trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from datasets.samplers import HardSampler 4 | from trainer.trainer import Trainer 5 | 6 | 7 | class BindingTrainer(Trainer): 8 | def __init__(self, **kwargs): 9 | super(BindingTrainer, self).__init__(**kwargs) 10 | 11 | def forward_pass(self, batch): 12 | lig_graphs, rec_graphs, ligs_coords, recs_coords, ligs_pocket_coords, recs_pocket_coords, geometry_graphs, complex_names = tuple( 13 | batch) 14 | ligs_coords_pred, ligs_keypts, recs_keypts, rotations, translations, geom_reg_loss = self.model(lig_graphs, rec_graphs, geometry_graphs, 15 | complex_names=complex_names, 16 | epoch=self.epoch) 17 | loss, loss_components = self.loss_func(ligs_coords, recs_coords, ligs_coords_pred, ligs_pocket_coords, 18 | recs_pocket_coords, ligs_keypts, recs_keypts, rotations, translations, geom_reg_loss, 19 | self.device) 20 | return loss, loss_components, ligs_coords_pred, ligs_coords 21 | 22 | def after_batch(self, ligs_coords_pred, ligs_coords, batch_indices): 23 | cutoff = 5 24 | centroid_distances = [] 25 | for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): 26 | centroid_distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0) - lig_coords.mean(dim=0))) 27 | centroid_distances = torch.tensor(centroid_distances) 28 | above_cutoff = torch.tensor(batch_indices)[torch.where(centroid_distances > cutoff)[0]] 29 | if isinstance(self.sampler, HardSampler): 30 | self.sampler.add_hard_indices(above_cutoff.tolist()) 31 | 32 | def after_epoch(self): 33 | if isinstance(self.sampler, HardSampler): 34 | self.sampler.set_hard_indices() 35 | -------------------------------------------------------------------------------- /trainer/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import * 2 | import numpy as np 3 | 4 | 5 | class WarmUpWrapper: 6 | "Optim wrapper that implements lr." 7 | 8 | def __init__(self, optimizer, wrapped_scheduler, warmup_steps, interpolation='linear', 9 | **kwargs): 10 | ''' 11 | 12 | :param optimizer: 13 | :param wrapped_scheduler: 14 | :param warmup_steps: is a list containing how many warmup steps should be done for each param group before updating all parameters 15 | :param interpolation: 16 | :param kwargs: 17 | ''' 18 | self.optim = optimizer 19 | self._step = 0 20 | self.interpolation = interpolation 21 | self.warmup_steps = np.array(warmup_steps) 22 | self.total_warmup_steps = self.warmup_steps.sum() 23 | self.wrapped_scheduler = globals()[wrapped_scheduler](self.optim, **kwargs) 24 | self.start_lrs = [] 25 | self.warmup_phase = 0 26 | for p in self.optim.param_groups: 27 | self.start_lrs.append(p['lr']) 28 | p['lr'] = 0 29 | 30 | def step(self, metrics=None): 31 | "Update parameters and lr" 32 | if self._step < self.total_warmup_steps: 33 | warmup_phase = 0 34 | for steps in self.warmup_steps.cumsum(): 35 | if self._step >= steps: 36 | warmup_phase += 1 37 | for i, p in enumerate(self.optim.param_groups): 38 | # update all parameters if there is only one entry specified for the warmup steps otherwise only update the ones corresponding to the current warmup phase 39 | if i <= warmup_phase or len(self.warmup_steps) == 1: 40 | # interpolate between 0 and the final starting learning rate 41 | interpolation_value = self._step - ([0] + list(self.warmup_steps.cumsum()))[warmup_phase] +1 42 | if self.warmup_steps[warmup_phase] == 0: 43 | p['lr'] = self.start_lrs[i] 44 | else: 45 | if self.interpolation == 'linear': 46 | p['lr'] = self.start_lrs[i] * (interpolation_value / self.warmup_steps[warmup_phase]) 47 | elif self.interpolation == 'cosine': 48 | p['lr'] = self.start_lrs[i] * ( 49 | (-np.cos((np.pi) * (interpolation_value / self.warmup_steps[warmup_phase])) + 1) * 0.5) 50 | else: 51 | raise ValueError('interpolation not implemented:', self.interpolation) 52 | 53 | else: 54 | if metrics != None: 55 | self.wrapped_scheduler.step(metrics=metrics) 56 | else: 57 | self.wrapped_scheduler.step() 58 | self._step += 1 59 | 60 | def state_dict(self): 61 | """Returns the state of the warmup_steps scheduler as a :class:`dict`. 62 | It contains an entry for every variable in self.__dict__ which 63 | is not the optim. 64 | """ 65 | state_dict = {key: value for key, value in self.__dict__.items() if key != 'optim'} 66 | state_dict['wrapped_scheduler'] = self.wrapped_scheduler.state_dict() # overwrite with the state dict 67 | return state_dict 68 | 69 | def load_state_dict(self, state_dict): 70 | """Loads the warmup_steps scheduler's state. 71 | Arguments: 72 | state_dict (dict): warmup_steps scheduler state. Should be an object returned 73 | from a call to :meth:`state_dict`. 74 | """ 75 | wrapped_scheduler_state_dict = state_dict['wrapped_scheduler'] 76 | del state_dict['wrapped_scheduler'] 77 | self.wrapped_scheduler.load_state_dict(wrapped_scheduler_state_dict) 78 | self.__dict__.update(state_dict) 79 | -------------------------------------------------------------------------------- /trainer/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import functional as F 6 | import torch.nn as nn 7 | 8 | from commons.utils import concat_if_list 9 | 10 | 11 | class PearsonR(nn.Module): 12 | """ 13 | Takes a single target property of the QM9 dataset, denormalizes it and turns in into meV from eV if it is an energy 14 | """ 15 | 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def forward(self, preds, targets): 20 | preds, targets = concat_if_list(preds), concat_if_list(targets) # concatenate tensors if list of tensors 21 | shifted_x = preds - torch.mean(preds, dim=0) 22 | shifted_y = targets - torch.mean(targets, dim=0) 23 | sigma_x = torch.sqrt(torch.sum(shifted_x ** 2, dim=0)) 24 | sigma_y = torch.sqrt(torch.sum(shifted_y ** 2, dim=0)) 25 | 26 | pearson = torch.sum(shifted_x * shifted_y, dim=0) / (sigma_x * sigma_y + 1e-8) 27 | pearson = torch.clamp(pearson, min=-1, max=1) 28 | pearson = pearson.mean() 29 | return pearson 30 | 31 | 32 | class MAE(nn.Module): 33 | def __init__(self, ): 34 | super().__init__() 35 | 36 | def forward(self, preds, targets): 37 | loss = F.l1_loss(preds, targets) 38 | return loss 39 | 40 | 41 | class Rsquared(nn.Module): 42 | """ 43 | Coefficient of determination/ R squared measure tells us the goodness of fit of our model. 44 | Rsquared = 1 means that the regression predictions perfectly fit the data. 45 | If Rsquared is less than 0 then our model is worse than the mean predictor. 46 | https://en.wikipedia.org/wiki/Coefficient_of_determination 47 | """ 48 | 49 | def __init__(self): 50 | super().__init__() 51 | 52 | def forward(self, preds, targets): 53 | preds, targets = concat_if_list(preds), concat_if_list(targets) # concatenate tensors if list of tensors 54 | total_SS = ((targets - targets.mean()) ** 2).sum() 55 | residual_SS = ((targets - preds) ** 2).sum() 56 | return 1 - residual_SS / total_SS 57 | 58 | 59 | class RMSD(nn.Module): 60 | def __init__(self) -> None: 61 | super(RMSD, self).__init__() 62 | 63 | def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor: 64 | rmsds = [] 65 | for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): 66 | rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1)))) 67 | return torch.tensor(rmsds).mean() 68 | 69 | class KabschRMSD(nn.Module): 70 | def __init__(self) -> None: 71 | super(KabschRMSD, self).__init__() 72 | 73 | def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor: 74 | rmsds = [] 75 | for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): 76 | lig_coords_pred_mean = lig_coords_pred.mean(dim=0, keepdim=True) # (1,3) 77 | lig_coords_mean = lig_coords.mean(dim=0, keepdim=True) # (1,3) 78 | 79 | A = (lig_coords_pred - lig_coords_pred_mean).transpose(0, 1) @ (lig_coords - lig_coords_mean) 80 | 81 | U, S, Vt = torch.linalg.svd(A) 82 | 83 | corr_mat = torch.diag(torch.tensor([1, 1, torch.sign(torch.det(A))], device=lig_coords_pred.device)) 84 | rotation = (U @ corr_mat) @ Vt 85 | translation = lig_coords_pred_mean - torch.t(rotation @ lig_coords_mean.t()) # (1,3) 86 | 87 | lig_coords = (rotation @ lig_coords.t()).t() + translation 88 | rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1)))) 89 | return torch.tensor(rmsds).mean() 90 | 91 | 92 | class RMSDmedian(nn.Module): 93 | def __init__(self) -> None: 94 | super(RMSDmedian, self).__init__() 95 | 96 | def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor: 97 | rmsds = [] 98 | for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): 99 | rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1)))) 100 | return torch.median(torch.tensor(rmsds)) 101 | 102 | 103 | class RMSDfraction(nn.Module): 104 | def __init__(self, distance) -> None: 105 | super(RMSDfraction, self).__init__() 106 | self.distance = distance 107 | 108 | def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor: 109 | rmsds = [] 110 | for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): 111 | rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1)))) 112 | count = torch.tensor(rmsds) < self.distance 113 | return 100 * count.sum() / len(count) 114 | 115 | 116 | class CentroidDist(nn.Module): 117 | def __init__(self) -> None: 118 | super(CentroidDist, self).__init__() 119 | 120 | def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor: 121 | distances = [] 122 | for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): 123 | distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0)-lig_coords.mean(dim=0))) 124 | return torch.tensor(distances).mean() 125 | 126 | 127 | class CentroidDistMedian(nn.Module): 128 | def __init__(self) -> None: 129 | super(CentroidDistMedian, self).__init__() 130 | 131 | def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor: 132 | distances = [] 133 | for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): 134 | distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0)-lig_coords.mean(dim=0))) 135 | return torch.median(torch.tensor(distances)) 136 | 137 | 138 | class CentroidDistFraction(nn.Module): 139 | def __init__(self, distance) -> None: 140 | super(CentroidDistFraction, self).__init__() 141 | self.distance = distance 142 | 143 | def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor: 144 | distances = [] 145 | for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): 146 | distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0)-lig_coords.mean(dim=0))) 147 | count = torch.tensor(distances) < self.distance 148 | return 100 * count.sum() / len(count) 149 | 150 | 151 | class MeanPredictorLoss(nn.Module): 152 | 153 | def __init__(self, loss_func) -> None: 154 | super(MeanPredictorLoss, self).__init__() 155 | self.loss_func = loss_func 156 | 157 | def forward(self, x1: Tensor, targets: Tensor) -> Tensor: 158 | return self.loss_func(torch.full_like(targets, targets.mean()), targets) 159 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import inspect 3 | import os 4 | import shutil 5 | from typing import Dict, Callable 6 | 7 | import pyaml 8 | import torch 9 | import numpy as np 10 | 11 | from datasets.samplers import HardSampler 12 | from models import * # do not remove 13 | from trainer.lr_schedulers import WarmUpWrapper # do not remove 14 | 15 | from torch.optim.lr_scheduler import * # For loading optimizer specified in config 16 | 17 | from torch.utils.data import DataLoader 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | from commons.utils import flatten_dict, tensorboard_gradient_magnitude, move_to_device, list_detach, concat_if_list, log 21 | 22 | 23 | class Trainer(): 24 | def __init__(self, model, args, metrics: Dict[str, Callable], main_metric: str, device: torch.device, 25 | tensorboard_functions: Dict[str, Callable] = None, optim=None, main_metric_goal: str = 'min', 26 | loss_func=torch.nn.MSELoss(), scheduler_step_per_batch: bool = True, run_dir='', sampler=None): 27 | 28 | self.args = args 29 | self.device = device 30 | self.model = model.to(self.device) 31 | self.loss_func = loss_func 32 | self.tensorboard_functions = tensorboard_functions 33 | self.metrics = metrics 34 | self.sampler = sampler 35 | self.val_per_batch = args.val_per_batch 36 | self.main_metric = type(self.loss_func).__name__ if main_metric == 'loss' else main_metric 37 | self.main_metric_goal = main_metric_goal 38 | self.scheduler_step_per_batch = scheduler_step_per_batch 39 | self.initialize_optimizer(optim) 40 | self.initialize_scheduler() 41 | if args.checkpoint: 42 | checkpoint = torch.load(args.checkpoint, map_location=self.device) 43 | self.writer = SummaryWriter(os.path.dirname(args.checkpoint)) 44 | self.model.load_state_dict(checkpoint['model_state_dict']) 45 | self.optim.load_state_dict(checkpoint['optimizer_state_dict']) 46 | if self.lr_scheduler != None and checkpoint['scheduler_state_dict'] != None: 47 | self.lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 48 | self.start_epoch = checkpoint['epoch'] 49 | self.best_val_score = checkpoint['best_val_score'] 50 | self.optim_steps = checkpoint['optim_steps'] 51 | else: 52 | self.start_epoch = 1 53 | self.optim_steps = 0 54 | self.best_val_score = -np.inf if self.main_metric_goal == 'max' else np.inf # running score to decide whether or not a new model should be saved 55 | self.writer = SummaryWriter(run_dir) 56 | shutil.copyfile(self.args.config, os.path.join(self.writer.log_dir, os.path.basename(self.args.config))) 57 | #for i, param_group in enumerate(self.optim.param_groups): 58 | # param_group['lr'] = 0.0003 59 | self.epoch = self.start_epoch 60 | log(f'Log directory: {self.writer.log_dir}') 61 | self.hparams = copy.copy(args).__dict__ 62 | for key, value in flatten_dict(self.hparams).items(): 63 | log(f'{key}: {value}') 64 | 65 | def run_per_epoch_evaluations(self, loader): 66 | pass 67 | 68 | def train(self, train_loader: DataLoader, val_loader: DataLoader): 69 | epochs_no_improve = 0 # counts every epoch that the validation accuracy did not improve for early stopping 70 | for epoch in range(self.start_epoch, self.args.num_epochs + 1): # loop over the dataset multiple times 71 | self.epoch = epoch 72 | self.model.train() 73 | self.predict(train_loader, optim=self.optim) 74 | 75 | self.model.eval() 76 | with torch.no_grad(): 77 | metrics, _, _ = self.predict(val_loader) 78 | val_score = metrics[self.main_metric] 79 | 80 | if self.lr_scheduler != None and not self.scheduler_step_per_batch: 81 | self.step_schedulers(metrics=val_score) 82 | 83 | if self.args.eval_per_epochs > 0 and epoch % self.args.eval_per_epochs == 0: 84 | self.run_per_epoch_evaluations(val_loader) 85 | 86 | self.tensorboard_log(metrics, data_split='val', log_hparam=True, step=self.optim_steps) 87 | val_loss = metrics[type(self.loss_func).__name__] 88 | log('[Epoch %d] %s: %.6f val loss: %.6f' % (epoch, self.main_metric, val_score, val_loss)) 89 | # save the model with the best main_metric depending on wether we want to maximize or minimize the main metric 90 | if val_score >= self.best_val_score and self.main_metric_goal == 'max' or val_score <= self.best_val_score and self.main_metric_goal == 'min': 91 | epochs_no_improve = 0 92 | self.best_val_score = val_score 93 | self.save_checkpoint(epoch, checkpoint_name='best_checkpoint.pt') 94 | else: 95 | epochs_no_improve += 1 96 | self.save_checkpoint(epoch, checkpoint_name='last_checkpoint.pt') 97 | log('Epochs with no improvement: [', epochs_no_improve, '] and the best ', self.main_metric, 98 | ' was in ', epoch - epochs_no_improve) 99 | if epochs_no_improve >= self.args.patience and epoch >= self.args.minimum_epochs: # stopping criterion 100 | log(f'Early stopping criterion based on -{self.main_metric}- that should be {self.main_metric_goal}-imized reached after {epoch} epochs. Best model checkpoint was in epoch {epoch - epochs_no_improve}.') 101 | break 102 | if epoch in self.args.models_to_save: 103 | shutil.copyfile(os.path.join(self.writer.log_dir, 'best_checkpoint.pt'), 104 | os.path.join(self.writer.log_dir, f'best_checkpoint_{epoch}epochs.pt')) 105 | self.after_epoch() 106 | #if val_loss > 10000: 107 | # raise Exception 108 | 109 | # evaluate on best checkpoint 110 | checkpoint = torch.load(os.path.join(self.writer.log_dir, 'best_checkpoint.pt'), map_location=self.device) 111 | self.model.load_state_dict(checkpoint['model_state_dict']) 112 | return self.evaluation(val_loader, data_split='val_best_checkpoint') 113 | 114 | def forward_pass(self, batch): 115 | targets = batch[-1] # the last entry of the batch tuple is always the targets 116 | predictions = self.model(*batch[0]) # foward the rest of the batch to the model 117 | loss, *loss_components = self.loss_func(predictions, targets) 118 | # if loss_func does not return any loss_components, we turn the empty list into None 119 | return loss, (loss_components if loss_components != [] else None), predictions, targets 120 | 121 | def process_batch(self, batch, optim): 122 | loss, loss_components, predictions, targets = self.forward_pass(batch) 123 | if optim != None: # run backpropagation if an optimizer is provided 124 | loss.backward() 125 | if self.args.clip_grad != None: 126 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.args.clip_grad, norm_type=2) 127 | self.optim.step() 128 | self.after_optim_step() # overwrite this function to do stuff before zeroing out grads 129 | self.optim.zero_grad() 130 | self.optim_steps += 1 131 | return loss, loss_components, list_detach(predictions), list_detach(targets) 132 | 133 | def predict(self, data_loader: DataLoader, optim: torch.optim.Optimizer = None, return_pred=False): 134 | total_metrics = {k: 0 for k in 135 | list(self.metrics.keys()) + [type(self.loss_func).__name__, 'mean_pred', 'std_pred', 136 | 'mean_targets', 'std_targets']} 137 | epoch_targets = [] 138 | epoch_predictions = [] 139 | epoch_loss = 0 140 | for i, batch in enumerate(data_loader): 141 | *batch, batch_indices = move_to_device(list(batch), self.device) 142 | # loss components is either none, or a dict with the components of the loss function 143 | loss, loss_components, predictions, targets = self.process_batch(batch, optim) 144 | with torch.no_grad(): 145 | if loss_components != None and i == 0: # add loss_component keys to total_metrics 146 | total_metrics.update({k: 0 for k in loss_components.keys()}) 147 | if self.optim_steps % self.args.log_iterations == 0 and optim != None: 148 | metrics = self.evaluate_metrics(predictions, targets) 149 | metrics[type(self.loss_func).__name__] = loss.item() 150 | metrics.update(loss_components) 151 | self.tensorboard_log(metrics, data_split='train', step=self.optim_steps) 152 | log('[Epoch %d; Iter %5d/%5d] %s: loss: %.7f' % ( 153 | self.epoch, i + 1, len(data_loader), 'train', loss.item())) 154 | if optim == None and self.val_per_batch: # during validation or testing when we want to average metrics over all the data in that dataloader 155 | metrics = self.evaluate_metrics(predictions, targets, val=True) 156 | metrics[type(self.loss_func).__name__] = loss.item() 157 | metrics.update(loss_components) 158 | for key, value in metrics.items(): 159 | total_metrics[key] += value 160 | if optim == None and not self.val_per_batch or return_pred: 161 | epoch_loss += loss.item() 162 | epoch_targets.extend(targets if isinstance(targets, list) else [targets]) 163 | epoch_predictions.extend(predictions if isinstance(predictions, list) else [predictions]) 164 | self.after_batch(predictions, targets, batch_indices) 165 | if optim == None: 166 | loader_len = len(data_loader) if len(data_loader) != 0 else 1 167 | if self.val_per_batch: 168 | total_metrics = {k: v / loader_len for k, v in total_metrics.items()} 169 | else: 170 | total_metrics = self.evaluate_metrics(epoch_predictions, epoch_targets, val=True) 171 | total_metrics[type(self.loss_func).__name__] = epoch_loss / loader_len 172 | if return_pred: 173 | return total_metrics, list_detach(epoch_predictions), list_detach(epoch_targets) 174 | else: 175 | return total_metrics, None, None 176 | 177 | def after_batch(self, predictions, targets, batch_indices): 178 | pass 179 | 180 | def after_epoch(self): 181 | pass 182 | 183 | def after_optim_step(self): 184 | if self.optim_steps % self.args.log_iterations == 0: 185 | tensorboard_gradient_magnitude(self.optim, self.writer, self.optim_steps) 186 | if self.lr_scheduler != None and (self.scheduler_step_per_batch or (isinstance(self.lr_scheduler, 187 | WarmUpWrapper) and self.lr_scheduler.total_warmup_steps > self.lr_scheduler._step)): # step per batch if that is what we want to do or if we are using a warmup schedule and are still in the warmup period 188 | self.step_schedulers() 189 | 190 | def evaluate_metrics(self, predictions, targets, batch=None, val=False) -> Dict[str, float]: 191 | metrics = {} 192 | metrics[f'mean_pred'] = torch.mean(concat_if_list(predictions)).item() 193 | metrics[f'std_pred'] = torch.std(concat_if_list(predictions)).item() 194 | metrics[f'mean_targets'] = torch.mean(concat_if_list(targets)).item() 195 | metrics[f'std_targets'] = torch.std(concat_if_list(targets)).item() 196 | for key, metric in self.metrics.items(): 197 | if not hasattr(metric, 'val_only') or val: 198 | metrics[key] = metric(predictions, targets).item() 199 | return metrics 200 | 201 | def tensorboard_log(self, metrics, data_split: str, step: int, log_hparam: bool = False): 202 | metrics['epoch'] = self.epoch 203 | for i, param_group in enumerate(self.optim.param_groups): 204 | metrics[f'lr_param_group_{i}'] = param_group['lr'] 205 | logs = {} 206 | for key, metric in metrics.items(): 207 | metric_name = f'{key}/{data_split}' 208 | logs[metric_name] = metric 209 | self.writer.add_scalar(metric_name, metric, step) 210 | 211 | def evaluation(self, data_loader: DataLoader, data_split: str = '', return_pred=False): 212 | self.model.eval() 213 | metrics, predictions, targets = self.predict(data_loader, return_pred=return_pred) 214 | 215 | with open(os.path.join(self.writer.log_dir, 'evaluation_' + data_split + '.txt'), 'w') as file: 216 | log('Statistics on ', data_split) 217 | for key, value in metrics.items(): 218 | file.write(f'{key}: {value}\n') 219 | log(f'{key}: {value}') 220 | return metrics, predictions, targets 221 | 222 | def initialize_optimizer(self, optim): 223 | self.optim = optim(self.model.parameters(), **self.args.optimizer_params) 224 | 225 | def step_schedulers(self, metrics=None): 226 | try: 227 | self.lr_scheduler.step(metrics=metrics) 228 | except: 229 | self.lr_scheduler.step() 230 | 231 | def initialize_scheduler(self): 232 | if self.args.lr_scheduler: # Needs "from torch.optim.lr_scheduler import *" to work 233 | self.lr_scheduler = globals()[self.args.lr_scheduler](self.optim, **self.args.lr_scheduler_params) 234 | else: 235 | self.lr_scheduler = None 236 | 237 | def save_checkpoint(self, epoch: int, checkpoint_name: str): 238 | """ 239 | Saves checkpoint of model in the logdir of the summarywriter in the used rundi 240 | """ 241 | run_dir = self.writer.log_dir 242 | self.save_model_state(epoch, checkpoint_name) 243 | train_args = copy.copy(self.args) 244 | train_args.config = os.path.join(run_dir, os.path.basename(self.args.config)) 245 | with open(os.path.join(run_dir, 'train_arguments.yaml'), 'w') as yaml_path: 246 | pyaml.dump(train_args.__dict__, yaml_path) 247 | 248 | # Get the class of the used model (works because of the "from models import *" calling the init.py in the models dir) 249 | model_class = globals()[type(self.model).__name__] 250 | source_code = inspect.getsource(model_class) # Get the sourcecode of the class of the model. 251 | file_name = os.path.basename(inspect.getfile(model_class)) 252 | with open(os.path.join(run_dir, file_name), "w") as f: 253 | f.write(source_code) 254 | 255 | def save_model_state(self, epoch: int, checkpoint_name: str): 256 | torch.save({ 257 | 'epoch': epoch, 258 | 'best_val_score': self.best_val_score, 259 | 'optim_steps': self.optim_steps, 260 | 'model_state_dict': self.model.state_dict(), 261 | 'optimizer_state_dict': self.optim.state_dict(), 262 | 'scheduler_state_dict': None if self.lr_scheduler == None else self.lr_scheduler.state_dict() 263 | }, os.path.join(self.writer.log_dir, checkpoint_name)) 264 | --------------------------------------------------------------------------------