├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── analysis ├── SA_Score │ ├── README.md │ ├── __pycache__ │ │ ├── sascorer.cpython-310.pyc │ │ └── sascorer.cpython-38.pyc │ └── sascorer.py ├── all_pdbs_to_pdbqts.py ├── bond_angle_config.py ├── bond_length_config.py ├── docking.py ├── docking_py27.py ├── eval_bond_angles.py ├── eval_bond_length.py ├── get_atom_types_dist.py ├── get_empirical_angles.py ├── get_empirical_dists.py ├── get_volume.py ├── metrics.py ├── qvina_docking.py ├── reconstruct_mol.py ├── scoring_func.py ├── similarity.py ├── utils.py └── vina_docking.py ├── analyze_generated_pocket_mols.py ├── analyze_scaffolds_generated.py ├── assets ├── .DS_Store ├── movie.gif └── scaffold_optim.png ├── data ├── CROSSDOCK │ ├── __init__.py │ ├── fragment_hierarchy.py │ ├── prepare_fragments.py │ ├── process_crossdock.py │ ├── process_ligands.py │ ├── process_pockets.py │ └── sascorer.py ├── __init__.py └── sascorer.py ├── extend_scaffold_crossdock.py ├── fpscores.pkl.gz ├── generate_pocket_molecules.py ├── notebooks ├── 2z3h.pdb ├── 2z3h_H.pdb ├── 2z3h_out │ ├── 2z3h.pml │ ├── 2z3h.tcl │ ├── 2z3h_PYMOL.sh │ ├── 2z3h_VMD.sh │ ├── 2z3h_info.txt │ ├── 2z3h_out.pdb │ ├── 2z3h_pockets.pqr │ └── pockets │ │ ├── pocket1_atm.pdb │ │ ├── pocket1_vert.pqr │ │ ├── pocket2_atm.pdb │ │ └── pocket2_vert.pqr ├── __init__.py └── sample_for_pocket.ipynb ├── sample_crossdock_mols.py ├── sample_from_pocket.py ├── sampling ├── rejection_sampling.py ├── sample_mols.py └── scaffold_extension.py ├── src ├── __init__.py ├── anchor_gnn.py ├── const.py ├── conv_layer.py ├── datasets.py ├── dropout.py ├── dynamics_gvp.py ├── edm.py ├── egnn.py ├── extension_size.py ├── fragment_size_gnn.py ├── gvp.py ├── gvp_model.py ├── layer_norm.py ├── lightning.py ├── lightning_anchor_gnn.py ├── noise.py └── utils.py ├── train_anchor_predictor.py ├── train_frag_diffuser.py └── utils ├── sample_frag_size.py ├── templates.py ├── visuals.py └── volume_sampling.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 keiserlab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoFragDiff 2 | 3 | This repository is the official implementation of Autoregressive fragment based diffusion model for target-aware ligand design 4 | 5 | 6 | 7 | 8 | # Dependencies 9 | - RDKit 10 | - openbabel 11 | - PyTorch 12 | - biopython 13 | - biopandas 14 | - networkx 15 | - py3dmol 16 | - scikit-learn 17 | - tensorboard 18 | - wandb 19 | - pytorch-lightning 20 | 21 | ## Create conda environment 22 | ``` 23 | conda create -n autofragdiff 24 | pip install rdkit 25 | conda install -c conda-forge openbabel 26 | pip3 install torch torchvision torchaudio 27 | pip install biopython 28 | pip install biopandas 29 | pip install networkx 30 | pip install py3dmol 31 | pip install scikit-learn 32 | pip install tensorboard 33 | pip install wandb 34 | pip install tqdm 35 | pip install pytorch-lightning==1.6.0 36 | ``` 37 | 38 | The model has been tested with the following software versions: 39 | 40 | | Software | Version | 41 | | --------------- | ----------- | 42 | | rdkit | 2023.3.1 | 43 | | openbabel | 3.1.1 | 44 | | pytorch | 2.0.1 | 45 | | biopython | 1.81 | 46 | | biopandas | 0.4.1 | 47 | | networkx | 3.1 | 48 | | py3dmol | 2.0.1. | 49 | | scikit-learn | 1.2.2 | 50 | | tensorboard | 2.13.0 | 51 | | wandb | 0.15.2 | 52 | | pytorch-lightning | 1.6.0 | 53 | 54 | 55 | ## QucikVina2 56 | For Docking with qvina install QuickVina2: 57 | ``` 58 | wget https://github.com/QVina/qvina/raw/master/bin/qvina2.1 59 | chmod +x qvina2.1 60 | ``` 61 | We also need MGLTools for preparing the receptor for docking (pdb->pdbqt) but it can mess up the conda environment, so make a new one. 62 | ``` 63 | conda create -n mgltools -c bioconda mgltools 64 | ``` 65 | 66 | # Data Preparation 67 | 68 | ## CrossDock 69 | Download and extract the dataset as described by the authors of Pocket2Mol: https://github.com/pengxingang/Pocket2Mol/tree/main/data 70 | 71 | process the molecule fragments using a custom fragmentation. 72 | ``` 73 | python process_crossdock.py --rootdir $CROSSDOCK_PATH --outdir $OUT_DIR \ 74 | --dist_cutoff 7. --max-num-frags 8 --split test --max-atoms-single-fragment 22 \ 75 | --add-Vina-score --add-QED-score --add-SA-score --n-cores 16 76 | ``` 77 | - For adding Vina you also need to generate pdbqt files for each receptor and crystallographic ligand. 78 | 79 | # Training 80 | 81 | ## Training AutoFragdiff. 82 | ``` 83 | python train_frag_diffuser.py --data $CROSSDOCK_DIR --exp_name CROSSDOCK_model_1 \ 84 | --lr 0.0001 --n_layers 6 --nf 128 --diffusoin_steps 500 \ 85 | --diffusion_loss_type l2 --n_epochs 1000 --batch_size 4 86 | ``` 87 | 88 | ## Training anchor predictor 89 | ``` 90 | python train_anchor_predictor --data $CROSSDOCK_DIR --exp_name CROSDOCK_anchor_model_1 \ 91 | --n_layers 4 --inv_sublayers 2 --nf 128 --dataset-type CrossDock 92 | ``` 93 | 94 | 95 | # Sampling: 96 | 97 | Firt download the trained models from the google drive in the following link 98 | 99 | https://drive.google.com/drive/folders/1DQwIfibHIoFPGJP6aHBGiYRp87bCZFA0?usp=share_link 100 | 101 | ## CrossDock pocket-based molecule generation: 102 | 103 | To generate molecules from trained pocket-based model, also use anchor-predictor model. fragment sizes are sampled from the data distribution. 104 | 105 | ## CrossDock pocket-based molecule generation (with guidance): 106 | 107 | To generate molecules for crossdock test set: 108 | ``` 109 | python sample_crossdock_mols.py --results-path results/ --data-path $(path-to-crossdock-dataset) --use-anchor-model --anchor-model anchor-model.ckpt --n-samples 20 --exp-name test-crossdock --diff-model pocket-gvp.ckpt --device cuda:0 110 | ``` 111 | 112 | To sample molecules from a pdb file: 113 | first run fpocket and identify the correct pocket using: 114 | ``` 115 | fpocket -f $pdb.pdb 116 | ``` 117 | fpocket gives multiple pockets, you can visualize the identify the right pocket and run sampling 118 | 119 | ``` 120 | python sample_from_pocket.py --result-path results --pdb $pdbname --anchor-model anchor-model.ckpt --n-samples 10 --device cuda:0 --pocket-number 1 121 | ``` 122 | 123 | ## Scaffold-based molecule property optimization 124 | 125 | For scaffold-based optimization you need the pdb file of the pocket and the sdf file of the scaffold molecule (and the original molecule). 126 | 127 | Scaffold-extension for crossdock test set 128 | ``` 129 | python extend_scaffold_crossdock.py --data-path $(path-to-crossdock) --results-path scaffold-gen --anchor-model anchor-model.ckpt --n-samples 20 --exp-name scaffold-gen --diff-model pocket-gvp.ckpt --device cuda:0 130 | ``` 131 | 132 | - In order to select the anchor you can add the `--custom-anchors` argument and provide the ids of custom anchors (starts from 0 and based on atomic ids in the scaffold molecule). 133 |
134 | 135 |
136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/__init__.py -------------------------------------------------------------------------------- /analysis/SA_Score/README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | Files taken from [rdkit/rdkit](https://github.com/rdkit/rdkit/tree/master/Contrib/SA_Score) repository on GitHub. 4 | -------------------------------------------------------------------------------- /analysis/SA_Score/__pycache__/sascorer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/analysis/SA_Score/__pycache__/sascorer.cpython-310.pyc -------------------------------------------------------------------------------- /analysis/SA_Score/__pycache__/sascorer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/analysis/SA_Score/__pycache__/sascorer.cpython-38.pyc -------------------------------------------------------------------------------- /analysis/SA_Score/sascorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 5 | # Peter Ertl and Ansgar Schuffenhauer 6 | # Journal of Cheminformatics 1:8 (2009) 7 | # http://www.jcheminf.com/content/1/1/8 8 | # 9 | # several small modifications to the original paper are included 10 | # particularly slightly different formula for marocyclic penalty 11 | # and taking into account also molecule symmetry (fingerprint density) 12 | # 13 | # for a set of 10k diverse molecules the agreement between the original method 14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 15 | # 16 | # peter ertl & greg landrum, september 2013 17 | # 18 | 19 | 20 | from rdkit import Chem 21 | from rdkit.Chem import rdMolDescriptors 22 | import pickle 23 | 24 | import math 25 | from collections import defaultdict 26 | 27 | import os.path as op 28 | 29 | _fscores = None 30 | 31 | 32 | def readFragmentScores(name='fpscores'): 33 | import gzip 34 | global _fscores 35 | # generate the full path filename: 36 | #if name == "fpscores": 37 | # name = op.join(op.dirname(__file__), name) 38 | data = pickle.load(gzip.open('fpscores.pkl.gz')) 39 | outDict = {} 40 | for i in data: 41 | for j in range(1, len(i)): 42 | outDict[i[j]] = float(i[0]) 43 | _fscores = outDict 44 | 45 | 46 | def numBridgeheadsAndSpiro(mol, ri=None): 47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 49 | return nBridgehead, nSpiro 50 | 51 | 52 | def calculateScore(m): 53 | if _fscores is None: 54 | readFragmentScores() 55 | 56 | # fragment score 57 | fp = rdMolDescriptors.GetMorganFingerprint(m, 58 | 2) # <- 2 is the *radius* of the circular fingerprint 59 | fps = fp.GetNonzeroElements() 60 | score1 = 0. 61 | nf = 0 62 | for bitId, v in fps.items(): 63 | nf += v 64 | sfp = bitId 65 | score1 += _fscores.get(sfp, -4) * v 66 | score1 /= nf 67 | 68 | # features score 69 | nAtoms = m.GetNumAtoms() 70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 71 | ri = m.GetRingInfo() 72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 73 | nMacrocycles = 0 74 | for x in ri.AtomRings(): 75 | if len(x) > 8: 76 | nMacrocycles += 1 77 | 78 | sizePenalty = nAtoms**1.005 - nAtoms 79 | stereoPenalty = math.log10(nChiralCenters + 1) 80 | spiroPenalty = math.log10(nSpiro + 1) 81 | bridgePenalty = math.log10(nBridgeheads + 1) 82 | macrocyclePenalty = 0. 83 | # --------------------------------------- 84 | # This differs from the paper, which defines: 85 | # macrocyclePenalty = math.log10(nMacrocycles+1) 86 | # This form generates better results when 2 or more macrocycles are present 87 | if nMacrocycles > 0: 88 | macrocyclePenalty = math.log10(2) 89 | 90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 91 | 92 | # correction for the fingerprint density 93 | # not in the original publication, added in version 1.1 94 | # to make highly symmetrical molecules easier to synthetise 95 | score3 = 0. 96 | if nAtoms > len(fps): 97 | score3 = math.log(float(nAtoms) / len(fps)) * .5 98 | 99 | sascore = score1 + score2 + score3 100 | 101 | # need to transform "raw" value into scale between 1 and 10 102 | min = -4.0 103 | max = 2.5 104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 105 | # smooth the 10-end 106 | if sascore > 8.: 107 | sascore = 8. + math.log(sascore + 1. - 9.) 108 | if sascore > 10.: 109 | sascore = 10.0 110 | elif sascore < 1.: 111 | sascore = 1.0 112 | 113 | return sascore 114 | 115 | 116 | def processMols(mols): 117 | print('smiles\tName\tsa_score') 118 | for i, m in enumerate(mols): 119 | if m is None: 120 | continue 121 | 122 | s = calculateScore(m) 123 | 124 | smiles = Chem.MolToSmiles(m) 125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 126 | 127 | 128 | if __name__ == '__main__': 129 | import sys 130 | import time 131 | 132 | t1 = time.time() 133 | readFragmentScores("fpscores") 134 | t2 = time.time() 135 | 136 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 137 | t3 = time.time() 138 | processMols(suppl) 139 | t4 = time.time() 140 | 141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), 142 | file=sys.stderr) 143 | 144 | # 145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 146 | # All rights reserved. 147 | # 148 | # Redistribution and use in source and binary forms, with or without 149 | # modification, are permitted provided that the following conditions are 150 | # met: 151 | # 152 | # * Redistributions of source code must retain the above copyright 153 | # notice, this list of conditions and the following disclaimer. 154 | # * Redistributions in binary form must reproduce the above 155 | # copyright notice, this list of conditions and the following 156 | # disclaimer in the documentation and/or other materials provided 157 | # with the distribution. 158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 159 | # nor the names of its contributors may be used to endorse or promote 160 | # products derived from this software without specific prior written permission. 161 | # 162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 173 | # 174 | 175 | def compute_sa_score(rdmol): 176 | rdmol = Chem.MolFromSmiles(Chem.MolToSmiles(rdmol)) 177 | sa = calculateScore(rdmol) 178 | sa_norm = round((10 - sa) / 9, 2) 179 | return sa_norm -------------------------------------------------------------------------------- /analysis/all_pdbs_to_pdbqts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | if __name__ == '__main__': 6 | with open('/srv/home/mahdi.ghorbani/FragDiff/pdb_paths.txt', 'r') as f: 7 | all_files = [line.strip() for line in f.readlines()] 8 | root_dir = '/srv/home/mahdi.ghorbani/FragDiff/crossdock/crossdocked_pocket10/' 9 | for i, file in enumerate(all_files): 10 | if i % 100 == 0: 11 | print(i) 12 | prot_name = root_dir + file 13 | pdbqt_name = prot_name[:-3] + 'pdbqt' 14 | if not os.path.exists(pdbqt_name): 15 | os.system('prepare_receptor4.py -r {} -o {}'.format(prot_name, pdbqt_name)) 16 | -------------------------------------------------------------------------------- /analysis/docking.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | from pathlib import Path 5 | import argparse 6 | 7 | import pandas as pd 8 | from rdkit import Chem 9 | from tqdm import tqdm 10 | 11 | affinity_pattern = r"Affinity:\s+(-?\d+\.\d+)\s+\(kcal/mol\)" 12 | def calculate_smina_score(pdb_file, sdf_file): 13 | # add '-o _smina.sdf' if you want to see the output 14 | out = os.popen(f'smina.static -l {sdf_file} -r {pdb_file} ' 15 | f'--score_only').read() 16 | matches = re.findall( 17 | r"Affinity:[ ]+([+-]?[0-9]*[.]?[0-9]+)[ ]+\(kcal/mol\)", out) 18 | return [float(x) for x in matches] 19 | 20 | def sdf_to_pdbqt(sdf_file, pdbqt_outfile, mol_id): 21 | os.popen(f'obabel {sdf_file} -O {pdbqt_outfile} -f {mol_id + 1} -l {mol_id + 1}').read() 22 | return pdbqt_outfile 23 | 24 | def calculate_qvina2_score(receptor_file, sdf_file, out_dir, size=20, 25 | exhaustiveness=16, return_rdmol=False, score_only=False): 26 | """ 27 | receptor_file: pdbqt file for receptor 28 | sdf_file: sdf file for ligand 29 | out_dir: output directory 30 | 31 | returns: 32 | scores: list of scores for each ligand 33 | rdmols: list of qvina docked ligands 34 | """ 35 | 36 | receptor_pdbqt_file = Path(receptor_file) 37 | sdf_file = Path(sdf_file) 38 | 39 | scores = [] 40 | rdmols = [] # for if return rdmols 41 | suppl = Chem.SDMolSupplier(str(sdf_file), sanitize=False) 42 | 43 | for i, mol in enumerate(suppl): # sdf file may contain several ligands 44 | ligand_name = f'{sdf_file.stem}_{i}' 45 | # prepare ligand 46 | ligand_pdbqt_file = Path(out_dir, ligand_name + '.pdbqt') 47 | out_sdf_file = Path(out_dir, ligand_name + '_out.sdf') 48 | 49 | if out_sdf_file.exists(): 50 | with open(out_sdf_file, 'r') as f: 51 | scores.append(min([float(x.split()[2]) for x in f.readlines() 52 | if x.startswith(' VINA RESULT:')])) 53 | else: 54 | sdf_to_pdbqt(sdf_file, ligand_pdbqt_file, i) 55 | 56 | # center box at ligand's center of mass 57 | cx, cy, cz = mol.GetConformer().GetPositions().mean(0) 58 | 59 | # run QuckVina2 60 | # run QuickVina 2 61 | if not score_only: 62 | 63 | out = os.popen( 64 | f'qvina2.1 --receptor {receptor_pdbqt_file} ' 65 | f'--ligand {ligand_pdbqt_file} ' 66 | f'--center_x {cx:.4f} --center_y {cy:.4f} --center_z {cz:.4f} ' 67 | f'--size_x {size} --size_y {size} --size_z {size} ' 68 | f'--exhaustiveness {exhaustiveness}' 69 | ).read() 70 | out_split = out.splitlines() 71 | best_ids = out_split.index('-----+------------+----------+----------') + 1 72 | best_line = out_split[best_ids].split() 73 | assert best_line[0] == '1' 74 | scores.append(float(best_line[1])) 75 | 76 | out_pdbqt_file = Path(out_dir, ligand_name + '_out.pdbqt') 77 | if out_pdbqt_file.exists(): 78 | os.popen(f'obabel {out_pdbqt_file} -O {out_sdf_file}').read() 79 | 80 | if return_rdmol: 81 | rdmol = Chem.SDMolSupplier(str(out_sdf_file))[0] 82 | rdmols.append(rdmol) 83 | 84 | else: 85 | out = os.popen( 86 | f'qvina2.1 --score_only --receptor {receptor_pdbqt_file} ' 87 | f'--ligand {ligand_pdbqt_file} ' 88 | f'--center_x {cx:.4f} --center_y {cy:.4f} --center_z {cz:.4f} ' 89 | f'--size_x {size} --size_y {size} --size_z {size} ' 90 | ).read() 91 | match = re.search(affinity_pattern, out) 92 | scores = float(match.group(1)) 93 | 94 | if return_rdmol: 95 | return scores, rdmols 96 | else: 97 | return scores 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser('QuickVina evaulation') 101 | parser.add_argument('--pdbqt_dir', type=Path, 102 | help='Receptor files in pdbqt format') 103 | parser.add_argument('--sdf_dir', type=Path, default=None, 104 | help='Ligand files in sdf format') 105 | parser.add_argument('--out_dir', type=Path) 106 | parser.add_argument('--write_csv', action='store_true') 107 | parser.add_argument('--write_dict', action='store_true') 108 | parser.add_argument('--dataset', type=str, default='CROSSDOCK') 109 | args = parser.parse_args() 110 | 111 | assert (args.sdf_dir is not None) 112 | 113 | results = {'receptor': [], 'ligand': [], 'scores':[]} 114 | results_dict = {} 115 | 116 | sdf_files = list(os.listdir(args.sdf_dir)) 117 | pbar = tqdm(sdf_files) 118 | 119 | for sdf_file in pbar: 120 | pbar.set_description(f'Processing {sdf_file}') 121 | 122 | if args.dataset == 'CROSSDOCK': 123 | receptor_name = sdf_file.split('_')[0] + '_pocket' 124 | receptor_file = Path(args.pdbqt_dir, receptor_name + '.pdbqt') 125 | 126 | sdf_path = Path(str(args.sdf_dir) + '/' + sdf_file) 127 | try: 128 | scores, rdmols = calculate_qvina2_score(receptor_file, sdf_path, args.out_dir, return_rdmol=True) 129 | except (ValueError, AttributeError) as e: 130 | print(e) 131 | continue 132 | results['receptor'].append(str(receptor_file)) 133 | results['ligand'].append(str(sdf_file)) 134 | results['scores'].append(scores) 135 | 136 | if args.write_dict: 137 | results_dict[receptor_name] = [scores, rdmols] 138 | 139 | if args.write_csv: 140 | df = pd.DataFrame.from_dict(results) 141 | df.to_csv(Path(args.out_dir, 'qvina2_scores.csv')) 142 | 143 | if args.write_dict: 144 | torch.save(results_dict, Path(args.out_dir, 'qvina2_scores.pt')) -------------------------------------------------------------------------------- /analysis/docking_py27.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | 5 | def pdbs_to_pdbqts(pdb_dir, pdbqt_dir, dataset): 6 | for file in glob.glob(os.path.join(pdb_dir, '*.pdb')): 7 | name = os.path.splitext(os.path.basename(file))[0] 8 | outfile = os.path.join(pdbqt_dir, name + '.pdbqt') 9 | pdb_to_pdbqt(file, outfile, dataset) 10 | print('Wrote converted file to {}'.format(outfile)) 11 | 12 | def pdb_to_pdbqt(pdb_file, pdbqt_file, dataset): 13 | if dataset == 'CROSSDOCK': 14 | os.system('prepare_receptor4.py -r {} -o {}'.format(pdb_file, pdbqt_file)) 15 | 16 | else: 17 | raise NotImplementedError 18 | 19 | return pdbqt_file 20 | 21 | if __name__ == '__main__': 22 | pdbs_to_pdbqts(sys.argv[1], sys.argv[2], sys.argv[3]) 23 | -------------------------------------------------------------------------------- /analysis/eval_bond_angles.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import collections 4 | from typing import Tuple, Sequence, Dict, Optional 5 | 6 | from scipy import spatial as sci_spatial 7 | from .bond_angle_config import ANGLE_DIST_CROSSDOCK, DIHED_DIST_CROSSDOCK 8 | from rdkit.Chem.rdMolTransforms import GetAngleRad, GetDihedralRad, GetAngleDeg, GetDihedralDeg 9 | 10 | def get_distribution(angles, bins): 11 | 12 | bin_counts = collections.Counter(np.searchsorted(bins, angles)) 13 | bin_counts = [bin_counts[i] if i in bin_counts else 0 for i in range(len(bins))] 14 | bin_counts = np.array(bin_counts) / np.sum(bin_counts) 15 | return bin_counts 16 | 17 | def eval_angle_dist_profile(bond_angle_profile, dihedral_angle_profile, frag): 18 | 19 | # frag is the smiles of fragment 20 | # bond_angle_profile -> a dictionary with keys the smiles of fragmenst and values the distribution of angles/dihedrals 21 | metrics = {} 22 | gt_distribution = ANGLE_DIST_CROSSDOCK[frag] 23 | metrics[f'Angle-JSD_{frag}'] = sci_spatial.distance.jensenshannon(gt_distribution, 24 | bond_angle_profile) 25 | 26 | gt_distribution = DIHED_DIST_CROSSDOCK[frag] 27 | metrics[f'Dihedral-JSD_{frag}'] = sci_spatial.distance.jensenshannon(gt_distribution, 28 | dihedral_angle_profile) 29 | return metrics 30 | 31 | 32 | def find_angle_dist(mol, frag): 33 | all_frag_angles = [] 34 | all_frag_dihedrals = [] 35 | 36 | conf = mol.GetConformer() 37 | 38 | matches_frag = mol.GetSubstructMatches(frag) 39 | for match in matches_frag: 40 | match_angles = [] 41 | match_dih = [] 42 | match_set = set(match) 43 | 44 | for atom_index in match: 45 | atom = mol.GetAtomWithIdx(atom_index) 46 | neighbors = [neighbor.GetIdx() for neighbor in atom.GetNeighbors() if neighbor.GetIdx() in match_set] 47 | for i in range(len(neighbors)-1): 48 | for j in range(i+1, len(neighbors)): 49 | angle_deg = GetAngleDeg(conf, neighbors[i], atom_index, neighbors[j]) 50 | 51 | if angle_deg < 0: 52 | angle_deg += 360 53 | match_angles.append(angle_deg) 54 | 55 | for neighbor in neighbors: 56 | next_neighbors = [next_neighbor.GetIdx() for next_neighbor in mol.GetAtomWithIdx(neighbor).GetNeighbors() if next_neighbor.GetIdx() in match_set] 57 | for next_neighbor in next_neighbors: 58 | if next_neighbor != atom_index: # don't want to go to original atom 59 | # calculate and print dihedral angle 60 | dihedral_deg = GetDihedralDeg(conf, neighbor, atom_index, next_neighbor, neighbors[(neighbors.index(neighbor)+1) % len(neighbors)]) 61 | if dihedral_deg < 0: 62 | dihedral_deg += 360 63 | match_dih.append(dihedral_deg) 64 | 65 | all_frag_angles += match_angles 66 | all_frag_dihedrals += match_dih 67 | 68 | return all_frag_angles, all_frag_dihedrals -------------------------------------------------------------------------------- /analysis/eval_bond_length.py: -------------------------------------------------------------------------------- 1 | 2 | # taken from https://github.com/guanjq/targetdiff/blob/main/utils/evaluation/eval_bond_length.py 3 | 4 | import collections 5 | from typing import Tuple, Sequence, Dict, Optional 6 | 7 | import numpy as np 8 | from scipy import spatial as sci_spatial 9 | import matplotlib.pyplot as plt 10 | 11 | from analysis import bond_length_config 12 | from analysis import utils 13 | 14 | BondType = Tuple[int, int, int] # (atomic_num, atomic_num, bond_type) 15 | BondLengthData = Tuple[BondType, float] # (bond_type, bond_length) 16 | BondLengthProfile = Dict[BondType, np.ndarray] # bond_type -> empirical distribution 17 | 18 | def get_distribution(distances: Sequence[float], bins=bond_length_config.DISTANCE_BINS) -> np.ndarray: 19 | """ Get teh distribution of distances. 20 | 21 | Args: 22 | distances: (list) List of distances 23 | bins (list): bins of distances 24 | Returns: 25 | np.array: empirical distribution of distances with length equal to DISTANCE_BINS 26 | """ 27 | bin_counts = collections.Counter(np.searchsorted(bins, distances)) 28 | bin_counts = [bin_counts[i] if i in bin_counts else 0 for i in range(len(bins) + 1)] 29 | bin_counts = np.array(bin_counts) / np.sum(bin_counts) 30 | return bin_counts 31 | 32 | def _format_bond_type(bond_type: BondType) -> BondType: 33 | atom1, atom2, bond_category = bond_type 34 | if atom1 > atom2: 35 | atom1, atom2 = atom2, atom1 36 | return atom1, atom2, bond_category 37 | 38 | def get_bond_length_profile(bond_lengths: Sequence[BondLengthData]) -> BondLengthProfile: 39 | bond_length_profile = collections.defaultdict(list) 40 | for bond_type, bond_length in bond_lengths: 41 | bond_type = _format_bond_type(bond_type) 42 | bond_length_profile[bond_type].append(bond_length) 43 | bond_length_profile = {k: get_distribution(v) for k, v in bond_length_profile.items()} 44 | return bond_length_profile 45 | 46 | def _bond_type_str(bond_type: BondType) -> str: 47 | atom1, atom2, bond_category = bond_type 48 | return f'{atom1}-{atom2}|{bond_category}' 49 | 50 | def eval_bond_length_profile(bond_length_profile: BondLengthProfile) -> Dict[str, Optional[float]]: 51 | # gives the JS divergence of bond distances (different C-(C,O,N) bonds) 52 | metrics = {} 53 | for bond_type, gt_distribution in bond_length_config.EMPIRICAL_DISTRIBUTIONS.items(): 54 | if bond_type not in bond_length_profile: 55 | metrics[f'JSD_{_bond_type_str(bond_type)}'] = None 56 | else: 57 | metrics[f'JSD_{_bond_type_str(bond_type)}'] = sci_spatial.distance.jensenshannon(gt_distribution, 58 | bond_length_profile[bond_type]) 59 | return metrics 60 | 61 | def get_pair_length_profile(pair_lengths): 62 | cc_dist = [d[1] for d in pair_lengths if d[0] == (6,6) and d[1] < 2] 63 | all_dist = [d[1] for d in pair_lengths if d[1] < 12] 64 | pair_length_profile = { 65 | 'CC_2A': get_distribution(cc_dist, bins=np.linspace(0, 2, 100)), # distances of C-C bonds less than 2 A 66 | 'All_12A': get_distribution(all_dist, bins=np.linspace(0, 12, 100)) # all distances less than 12 A 67 | } 68 | return pair_length_profile 69 | 70 | def eval_pair_length_profile(pair_length_profile): 71 | metrics = {} 72 | for k, gt_distribution in bond_length_config.PAIR_EMPIRICAL_DISTRIBUTIONS.items(): 73 | if k not in pair_length_profile: 74 | metrics[f'JSD_{k}'] = None 75 | else: 76 | metrics[f'JSD_{k}'] = sci_spatial.distance.jensenshannon(gt_distribution, pair_length_profile[k]) 77 | return metrics 78 | 79 | def plot_distance_hist(pair_length_profile, metrics=None, save_path=None): 80 | 81 | gt_profile = bond_length_config.PAIR_EMPIRICAL_DISTRIBUTIONS 82 | plt.figure(figsize=(6*len(gt_profile), 4)) 83 | for idx, (k, gt_distribution) in enumerate(bond_length_config.PAIR_EMPIRICAL_DISTRIBUTIONS.items()): 84 | plt.subplot(1, len(gt_profile), idx+1) 85 | x = bond_length_config.PAIR_EMPIRICAL_BINS[k] 86 | plt.step(x, gt_profile[k][1:]) 87 | plt.step(x, pair_length_profile[k][1:]) 88 | plt.legend(['True', 'Learned']) 89 | if metrics is not None: 90 | plt.title(f'{k} JS div: {metrics["JSD_" + k]:.4f}') 91 | else: 92 | plt.title(k) 93 | 94 | if save_path is not None: 95 | plt.savefig(save_path) 96 | else: 97 | plt.show() 98 | plt.close() 99 | 100 | def pair_distance_from_pos_v(pos, elements): 101 | pdist = pos[None, :] - pos[:,None] 102 | pdist = np.sqrt(np.sum(pdist ** 2, axis=-1)) 103 | dist_list = [] 104 | for s in range(len(pos)): 105 | for e in range(s+1, len(pos)): 106 | s_sym = elements[s] 107 | e_sym = elements[e] 108 | d = pdist[s, e] 109 | dist_list.append(((s_sym, e_sym), d)) 110 | return dist_list 111 | 112 | def bond_distance_from_mol(mol): 113 | pos = mol.GetConformer().GetPositions() 114 | pdist = pos[None, :] - pos[:, None] 115 | pdist = np.sqrt(np.sum(pdist ** 2, axis=-1)) 116 | all_distances = [] 117 | for bond in mol.GetBonds(): 118 | s_sym = bond.GetBeginAtom().GetAtomicNum() 119 | e_sym = bond.GetEndAtom().GetAtomicNum() 120 | s_idx, e_idx = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() 121 | bond_type = utils.BOND_TYPES[bond.GetBondType()] 122 | distance = pdist[s_idx, e_idx] 123 | all_distances.append(((s_sym, e_sym, bond_type), distance)) 124 | return all_distances -------------------------------------------------------------------------------- /analysis/get_atom_types_dist.py: -------------------------------------------------------------------------------- 1 | from scipy import spatial as sci_spatial 2 | import numpy as np 3 | from collections import Counter 4 | 5 | 6 | CROSSDOCK_atom_charges = {'C':6, 'N': 7, 'O': 8, 'S': 16, 'B': 5, 'Br': 35, 'Cl': 17, 'P': 15, 'I':53 ,'F':9} 7 | 8 | def get_atom_charges(mol, charge_dict): 9 | atomic_nums = [] 10 | for atom in mol.GetAtoms(): 11 | atomic_nums.append(charge_dict[atom.GetSymbol()]) 12 | 13 | atomic_nums = np.array(atomic_nums) 14 | return atomic_nums 15 | 16 | ATOM_TYPE_DISTRIBUTION = { # atom type distributions in CrossDock 17 | 6: 0.6715020339893559, 18 | 7: 0.11703509510732567, 19 | 8: 0.16956379168491933, 20 | 9: 0.01307879304486639, 21 | 15: 0.01113716146426898, 22 | 16: 0.01123926340861198, 23 | 17: 0.006443861300651673, 24 | } 25 | 26 | ATOM_TYPE_DISTRIBUTION_GEOM = { # atom type distributions in CrossDock 27 | 6: 0.7266496963585743, 28 | 7: 0.11690156566351215, 29 | 8: 0.11619156632264795, 30 | 9: 0.008849559988534103, 31 | 15: 0.0001854777473386173, 32 | 16: 0.022003011957949646, 33 | 17: 0.007286864677748788, 34 | 35: 0.001897001182960629, 35 | } 36 | 37 | def eval_atom_type_distribution(pred_counter: Counter, data_type='GEOM'): 38 | total_num_atoms = sum(pred_counter.values()) 39 | pred_atom_distribution = {} 40 | if data_type == 'GEOM': 41 | for k in ATOM_TYPE_DISTRIBUTION_GEOM: 42 | pred_atom_distribution[k] = pred_counter[k] / total_num_atoms 43 | js = sci_spatial.distance.jensenshannon(np.array(list(ATOM_TYPE_DISTRIBUTION_GEOM.values())), 44 | np.array(list(pred_atom_distribution.values()))) 45 | elif data_type == 'CrossDock': 46 | for k in ATOM_TYPE_DISTRIBUTION: 47 | pred_atom_distribution[k] = pred_counter[k] / total_num_atoms 48 | js = sci_spatial.distance.jensenshannon(np.array(list(ATOM_TYPE_DISTRIBUTION.values())), 49 | np.array(list(pred_atom_distribution.values()))) 50 | 51 | return js -------------------------------------------------------------------------------- /analysis/get_empirical_dists.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | from tqdm import tqdm 4 | import os 5 | from eval_bond_length import pair_distance_from_pos_v, bond_distance_from_mol 6 | from eval_bond_length import get_pair_length_profile, get_bond_length_profile 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | supplier = list(Chem.SDMolSupplier('/srv/ds/set-1/user/mahdi.ghorbani/FragDiff/datasets/geom_conformers.sdf')) 12 | 13 | all_pair_dists = [] 14 | all_bond_dists = [] 15 | for mol_id, mol in enumerate(supplier): 16 | try: 17 | pos = mol.GetConformer().GetPositions() 18 | 19 | atomicnums = [] 20 | for atom in mol.GetAtoms(): 21 | atomicnums.append(atom.GetAtomicNum()) 22 | 23 | all_pair_dists += pair_distance_from_pos_v(pos, atomicnums) 24 | all_bond_dists += bond_distance_from_mol(mol) 25 | except: 26 | print(f'could not process mol {mol_id}') 27 | 28 | empirical_pair_length_profiles = get_pair_length_profile(all_pair_dists) 29 | empirical_bond_length_profiles = get_bond_length_profile(all_bond_dists) 30 | 31 | print(empirical_bond_length_profiles) 32 | print(empirical_pair_length_profiles) -------------------------------------------------------------------------------- /analysis/get_volume.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import tempfile 3 | import numpy as np 4 | import os 5 | 6 | info_dict = {} 7 | root_dir = '/Users/mahdimac/Science/Keiser_lab/diffusion/AutoFragDiff/scaffolds' 8 | with tempfile.TemporaryDirectory() as tmp_dir: 9 | command = f"fpocket -f {root_dir}/1a2g.pdb" 10 | os.chdir(tmp_dir) 11 | 12 | #process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 13 | out = os.popen(command).read() 14 | #stdout, stderr = process.communicate() 15 | 16 | with open(os.path.join('1a2g_out', '1a2g_info.txt'), 'r') as fp: 17 | #file_content = fp.read() 18 | 19 | lines = fp.readlines() 20 | pocket_info_started = False 21 | 22 | for line in lines: 23 | line = line.strip() 24 | if line == "Pocket 1 :": 25 | pocket_info_started = True 26 | continue 27 | if pocket_info_started: 28 | if line == "": 29 | break 30 | key, value = line.split(":") 31 | info_dict[key.strip()] = float(value.strip()) 32 | 33 | print(info_dict) -------------------------------------------------------------------------------- /analysis/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import cdist 3 | from tqdm import tqdm 4 | from rdkit import Chem, DataStructs 5 | 6 | from collections import Counter 7 | from copy import deepcopy 8 | 9 | from rdkit.Chem import AllChem, Descriptors, Crippen, Lipinski 10 | from rdkit.Chem.FilterCatalog import * 11 | from rdkit.Chem.QED import qed 12 | 13 | from analysis.SA_Score.sascorer import compute_sa_score 14 | 15 | 16 | def is_connected(mol): 17 | try: 18 | mol_frags = Chem.GetMolFrags(mol, asMols=True) 19 | except Chem.rdchem.AtomValenceException: 20 | return False 21 | if len(mol_frags) != 1: 22 | return False 23 | return True 24 | 25 | def is_valid(mol): 26 | try: 27 | Chem.SanitizeMol(mol) 28 | except: 29 | return False 30 | return True 31 | 32 | def obey_lipinski(mol): 33 | mol = deepcopy(mol) 34 | Chem.SanitizeMol(mol) 35 | rule_1 = Descriptors.ExactMolWt(mol) < 500 36 | rule_2 = Lipinski.NumHDonors(mol) <= 5 37 | rule_3 = Lipinski.NumHAcceptors(mol) <= 10 38 | logp = get_logp(mol) 39 | rule_4 = (logp >= -2) & (logp <= 5) 40 | rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10 41 | return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) 42 | 43 | def get_basic(mol): 44 | n_atoms = len(mol.GetAtoms()) 45 | n_bonds = len(mol.GetBonds()) 46 | n_rings = len(Chem.GetSymmSSSR(mol)) 47 | weight = Descriptors.ExactMolWt(mol) 48 | return n_atoms, n_bonds, n_rings, weight 49 | 50 | def get_rdkit_rmsd(mol, n_conf=20, random_seed=42, mode='energy'): 51 | """ 52 | Calculate the alignment of generated mol and rdkit predicted mol 53 | Return the rmsd (max, min, median) of the n_conf rdkit conformers 54 | """ 55 | 56 | mol = deepcopy(mol) 57 | Chem.SanitizeMol(mol) 58 | 59 | mol_smiles = Chem.MolToSmiles(mol) 60 | mol_smiles = Chem.MolFromSmiles(mol_smiles) 61 | mol3d = Chem.AddHs(mol) 62 | 63 | rmsd_list = [] 64 | conf_energies = [] 65 | # predict 3d 66 | try: 67 | confIds = AllChem.EmbedMultipleConfs(mol3d, n_conf, randomSeed=random_seed) 68 | for confId in confIds: 69 | AllChem.UFFOptimizeMolecule(mol3d, confId=confId) 70 | rmsd = Chem.rdMolAlign.GetBestRMS(Chem.RemoveHs(mol), Chem.RemoveHs(mol3d), refId=confId) 71 | rmsd_list.append(rmsd) 72 | #conf_energies.append(get_conformer_energies(mol3d)) 73 | 74 | mol_energy = get_conformer_energies(Chem.AddHs(mol, addCoords=True)) 75 | conf_energies = get_conformer_energies(mol3d) 76 | rmsd_list = np.array(rmsd_list) 77 | conf_lowest_en = np.argmin(conf_energies) 78 | 79 | mol = Chem.AddHs(mol) 80 | new_mol = Chem.Mol(mol) 81 | new_mol.RemoveAllConformers() 82 | conf_ids = [conf.GetId() for conf in mol3d.GetConformers()] 83 | conf = mol3d.GetConformer(conf_ids[conf_lowest_en]) 84 | new_mol.AddConformer(conf, assignId=True) 85 | 86 | return rmsd_list[conf_lowest_en], new_mol, conf_energies, mol_energy 87 | except: 88 | return np.nan, np.nan, np.nan, np.nan 89 | 90 | def get_logp(mol): 91 | return Crippen.MolLogP(mol) 92 | 93 | def get_chem(mol): 94 | qed_score = qed(mol) 95 | sa_score = compute_sa_score(mol) 96 | logp_score = get_logp(mol) 97 | lipinski_score = obey_lipinski(mol) 98 | ring_info = mol.GetRingInfo() 99 | ring_size = Counter([len(r) for r in ring_info.AtomRings()]) 100 | 101 | return { 102 | 'qed': qed_score, 103 | 'sa': sa_score, 104 | 'logp': logp_score, 105 | 'lipinski': lipinski_score, 106 | 'ring_size': ring_size 107 | } 108 | 109 | def get_molecule_force_field(mol, conf_id=None, force_field='mmff', **kwargs): 110 | """ 111 | Get a force field for a molecule. 112 | Parameters 113 | ---------- 114 | mol : RDKit Mol 115 | Molecule. 116 | conf_id : int, optional 117 | ID of the conformer to associate with the force field. 118 | force_field : str, optional 119 | Force Field name. 120 | kwargs : dict, optional 121 | Keyword arguments for force field constructor. 122 | """ 123 | if force_field == 'uff': 124 | ff = AllChem.UFFGetMoleculeForceField( 125 | mol, confId=conf_id, **kwargs) 126 | elif force_field.startswith('mmff'): 127 | AllChem.MMFFSanitizeMolecule(mol) 128 | mmff_props = AllChem.MMFFGetMoleculeProperties( 129 | mol, mmffVariant=force_field) 130 | ff = AllChem.MMFFGetMoleculeForceField( 131 | mol, mmff_props, confId=conf_id, **kwargs) 132 | else: 133 | raise ValueError("Invalid force_field {}".format(force_field)) 134 | return ff 135 | 136 | def get_conformer_energies(mol, force_field='mmff'): 137 | """ 138 | Calculate conformer energies. 139 | Parameters 140 | ---------- 141 | mol : RDKit Mol 142 | Molecule. 143 | force_field : str, optional 144 | Force Field name. 145 | Returns 146 | ------- 147 | energies : array_like 148 | Minimized conformer energies. 149 | """ 150 | energies = [] 151 | for conf in mol.GetConformers(): 152 | ff = get_molecule_force_field(mol, conf_id=conf.GetId(), force_field=force_field) 153 | ff.Minimize() 154 | energy = ff.CalcEnergy() 155 | energies.append(energy) 156 | energies = np.asarray(energies, dtype=float) 157 | return energies 158 | -------------------------------------------------------------------------------- /analysis/qvina_docking.py: -------------------------------------------------------------------------------- 1 | from joblib import Parallel, delayed 2 | import os 3 | from pathlib import Path 4 | import random 5 | import shutil 6 | import re 7 | import glob 8 | 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | from rdkit import Chem 13 | from rdkit.Chem import AllChem 14 | from rdkit import RDLogger 15 | import pandas as pd 16 | import torch 17 | 18 | affinity_pattern = r"Affinity:\s+(-?\d+\.\d+)\s+\(kcal/mol\)" 19 | RDLogger.DisableLog('rdApp.*') 20 | 21 | def sdf_to_pdbqt(sdf_file, pdbqt_outfile, mol_id): 22 | os.popen(f'obabel {sdf_file} -O {pdbqt_outfile} -f {0} -l {mol_id} -m').read() 23 | return pdbqt_outfile 24 | 25 | def get_vina_dock_score(receptor_pdbqt_file, ligand_pdbqt_file, cx, cy, cz, size): 26 | # Vina docking and getting the vina score 27 | out = os.popen( 28 | f'qvina2.1 --receptor {receptor_pdbqt_file} ' 29 | f'--ligand {ligand_pdbqt_file} ' 30 | f'--center_x {cx:.4f} --center_y {cy:.4f} --center_z {cz:.4f} ' 31 | f'--size_x {size} --size_y {size} --size_z {size} --exhaustiveness 16' 32 | ).read() 33 | out_split = out.splitlines() 34 | best_idx = out_split.index('-----+------------+----------+----------') + 1 35 | best_line = out_split[best_idx].split() 36 | print('\n best Affinity:', float(best_line[1])) 37 | return float(best_line[1]) 38 | 39 | def get_vina_score(receptor_pdbqt_file, ligand_pdbqt_file, cx, cy, cz, size): 40 | # TODO: using QVina to get vina scores gives weird results. Use Vina 41 | # scores the generated poses without docking them 42 | out = os.popen( 43 | f'qvina2.1 --score_only --receptor {receptor_pdbqt_file} ' 44 | f'--ligand {ligand_pdbqt_file} ' 45 | f'--center_x {cx:.4f} --center_y {cy:.4f} --center_z {cz:.4f} ' 46 | f'--size_x {size} --size_y {size} --size_z {size}' 47 | ).read() 48 | match = re.search(affinity_pattern, out) 49 | affinity_value = float(match.group(1)) 50 | print('vina score is:', affinity_value) 51 | return affinity_value 52 | 53 | def process_vina_iteration(n, save_file, receptor_pdbqt_file, mol_pos, size, result_type='vina_score'): 54 | pdbqt_file = save_file + 'all_mols_' + str(n) + '.pdbqt' 55 | cx, cy, cz = mol_pos 56 | if result_type == 'vina_score': 57 | affinity_value = get_vina_score(receptor_pdbqt_file, pdbqt_file, cx, cy, cz, size) 58 | elif result_type == 'dock_score': 59 | affinity_value = get_vina_dock_score(receptor_pdbqt_file, pdbqt_file, cx, cy, cz, size) 60 | return affinity_value 61 | 62 | 63 | -------------------------------------------------------------------------------- /analysis/scoring_func.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import Counter 3 | from copy import deepcopy 4 | 5 | import numpy as np 6 | from rdkit.Chem import AllChem, Descriptors, Crippen, Lipinski 7 | from rdkit.Chem.FilterCatalog import * 8 | from rdkit.Chem.QED import qed 9 | 10 | def obey_lipinski(mol): 11 | # compute the lipinski score 12 | mol = deepcopy(mol) 13 | Chem.SanitizeMol(mol) 14 | rule_1 = Descriptors.ExactMolWt(mol) < 500 15 | rule_2 = Lipinski.NumHDonors(mol) <= 5 16 | rule_3 = Lipinski.NumHAcceptors(mol) <= 10 17 | logp = get_logp(mol) 18 | rule_4 = (logp >= -2) & (logp <= 5) 19 | rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10 20 | return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) 21 | 22 | def get_basic(mol): 23 | # return n_atoms, bonds, rings, MW 24 | n_atoms = len(mol.GetAtoms()) 25 | n_bonds = len(mol.GetBonds()) 26 | n_rings = len(Chem.GetSymmSSSR(mol)) 27 | weight = Descriptors.ExactMolWt(mol) 28 | return n_atoms, n_bonds, n_rings, weight 29 | 30 | def get_rdkit_rmsd(mol, n_conf=20, random_seed=42): 31 | # return [max_rmsd, min_rmsd, median_rmsd] 32 | """ 33 | calculate the alignment of generated mol and rdkit predicted mol 34 | Return the rmsd (max, min, median) of the `n_conf` rdkit conformers 35 | """ 36 | mol = deepcopy(mol) 37 | Chem.SanitizeMol(mol) 38 | mol3d = Chem.AddHs(mol) # TODO: may need to add hydrogens in a different way 39 | rmsd_list = [] 40 | # predict 3d 41 | try: 42 | confIds = AllChem.EmbedMultipleConfs(mol3d, n_conf, randomSeed=random_seed) 43 | for confId in confIds: 44 | AllChem.UFFOptimizeMolecule(mol3d, confId=confId) 45 | rmsd = Chem.rdMolAlign.GetBestRMS(mol, mol3d, refId=confId) 46 | rmsd_list.append(rmsd) 47 | rmsd_list = np.array(rmsd_list) 48 | return [np.max(rmsd_list), np.min(rmsd_list), np.median(rmsd_list)] 49 | except: 50 | return [np.nan, np.nan, np.nan] 51 | 52 | def get_logp(mol): 53 | return Crippen.MolLogP(mol) 54 | 55 | def get_molecule_force_field(mol, conf_id=None, force_field='mmff', **kwargs): 56 | """ 57 | Get a force field for a molecule. 58 | Parameters 59 | ---------- 60 | mol : RDKit Mol 61 | Molecule. 62 | conf_id : int, optional 63 | ID of the conformer to associate with the force field. 64 | force_field : str, optional 65 | Force Field name. 66 | kwargs : dict, optional 67 | Keyword arguments for force field constructor. 68 | """ 69 | if force_field == 'uff': 70 | ff = AllChem.UFFGetMoleculeForceField( 71 | mol, confId=conf_id, **kwargs) 72 | elif force_field.startswith('mmff'): 73 | AllChem.MMFFSanitizeMolecule(mol) 74 | mmff_props = AllChem.MMFFGetMoleculeProperties( 75 | mol, mmffVariant=force_field) 76 | ff = AllChem.MMFFGetMoleculeForceField( 77 | mol, mmff_props, confId=conf_id, **kwargs) 78 | else: 79 | raise ValueError("Invalid force_field {}".format(force_field)) 80 | return ff 81 | 82 | def get_conformer_energies(mol, force_field='mmff'): 83 | """ 84 | Calculate conformer energies. 85 | Parameters 86 | ---------- 87 | mol : RDKit Mol 88 | Molecule. 89 | force_field : str, optional 90 | Force Field name. 91 | Returns 92 | ------- 93 | energies : array_like 94 | Minimized conformer energies. 95 | """ 96 | energies = [] 97 | for conf in mol.GetConformers(): 98 | ff = get_molecule_force_field(mol, conf_id=conf.GetId(), force_field=force_field) 99 | energy = ff.CalcEnergy() 100 | energies.append(energy) 101 | energies = np.asarray(energies, dtype=float) 102 | return energies -------------------------------------------------------------------------------- /analysis/similarity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem, DataStructs 3 | 4 | def tanimoto_sim(mol, ref): 5 | fp1 = Chem.RDKFingerprint(ref) 6 | fp2 = Chem.RDKFingerprint(mol) 7 | return DataStructs.TanimotoSimilarity(fp1, fp2) 8 | 9 | def tanimoto_sim_N_to_1(mols, ref): 10 | sim = [tanimoto_sim(m, ref) for m in mols] 11 | return sim 12 | 13 | def batched_number_of_rings(mols): 14 | n = [] 15 | for m in mols: 16 | n.append(Chem.rdMolDescriptors.CalcNumRings(m)) 17 | return np.array(n) -------------------------------------------------------------------------------- /analysis/utils.py: -------------------------------------------------------------------------------- 1 | from rdkit.Chem.rdchem import BondType 2 | from rdkit.Chem import ChemicalFeatures 3 | from rdkit import RDConfig 4 | 5 | ATOM_FAMILIES = ['Acceptor', 'Donor', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe', 'NegIonizable', 'PosIonizable', 6 | 'ZnBinder'] 7 | ATOM_FAMILIES_ID = {s: i for i, s in enumerate(ATOM_FAMILIES)} 8 | BOND_TYPES = { 9 | BondType.UNSPECIFIED: 0, 10 | BondType.SINGLE: 1, 11 | BondType.DOUBLE: 2, 12 | BondType.TRIPLE: 3, 13 | BondType.AROMATIC: 4, 14 | } 15 | -------------------------------------------------------------------------------- /analyze_scaffolds_generated.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import pandas as pd 4 | import os 5 | from tqdm import tqdm 6 | import argparse 7 | import re 8 | 9 | from rdkit import Chem 10 | from rdkit.Chem import AllChem 11 | from rdkit import RDLogger 12 | from openbabel import openbabel 13 | 14 | from analysis import eval_bond_length 15 | from analysis.reconstruct_mol import reconstruct_from_generated, MolReconsError 16 | from analysis.metrics import is_connected, is_valid, get_chem 17 | from analysis.eval_bond_angles import get_distribution, eval_angle_dist_profile, find_angle_dist 18 | from analysis.vina_docking import VinaDockingTask 19 | from joblib import Parallel, delayed 20 | 21 | from src.utils import get_logger 22 | import collections 23 | import torch 24 | 25 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9} 26 | idx2atom = {0:'C', 1:'N', 2:'O', 3:'S', 4:'B', 5:'Br', 6:'Cl', 7:'P', 8:'I', 9:'F'} 27 | 28 | CROSSDOCK_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'B':5, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15} 29 | 30 | def print_dict(d, logger): 31 | for k, v in d.items(): 32 | if v is not None: 33 | logger.info(f'{k}:\t{v:4f}') 34 | else: 35 | logger.info(f'{k}\tNone') 36 | 37 | def print_ring_ratio(all_ring_sizes, logger): 38 | for ring_size in range(3, 10): 39 | n_mol = 0 40 | for counter in all_ring_sizes: 41 | if ring_size in counter: 42 | n_mol += 1 43 | logger.info(f'ring size: {ring_size} ratio: {n_mol / len(all_ring_sizes):.3f}') 44 | 45 | 46 | if __name__ == '__main__': 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--results-path', type=str, default='results_scaffold', 49 | help='path to save the scaffold based optimization') 50 | parser.add_argument('--scaffold-path', type=str, default='scaffolds/1a2g_scaff.sdf', 51 | help='path to sdf of scaffold') 52 | parser.add_argument('--original-path', type=str, default='scaffolds/1a2g_orig.sdf', 53 | help='path to original molecule') 54 | parser.add_argument('--receptor-path', type=str, default='scaffolds/1a2g.pdb', 55 | help='path to pdb file of receptor') 56 | parser.add_argument('--docking_mode', type=str, choices=['qvina', 'vina_score', 'vina_dock', 'None']) 57 | parser.add_argument('--exhaustiveness', type=int, default=16) 58 | parser.add_argument('--verbose', type=eval, default=False) 59 | parser.add_argument('--n-mols-per-file', type=int, default=20, help='number of molecules per each file') 60 | parser.add_argument('--crossdock-dir', type=str, default='/srv/home/mahdi.ghorbani/FragDiff/crossdock') 61 | 62 | args = parser.parse_args() 63 | results_path = args.results_path 64 | n_mols_per_file = args.n_mols_per_file 65 | eval_path = os.path.join(results_path, 'eval_results') 66 | root_dir = args.crossdock_dir 67 | 68 | scaffold_path = args.scaffold_path # sdf file of scaffold 69 | receptor_path = args.receptor_path # pdb file of receptor 70 | 71 | os.makedirs(eval_path, exist_ok=True) 72 | logger = get_logger('evaluate', log_dir=eval_path) 73 | 74 | if not args.verbose: 75 | RDLogger.DisableLog('rdApp.*') 76 | 77 | valid_mols = 0 78 | connected_mols = 0 79 | results = [] 80 | 81 | n_files = 0 82 | n_samples = 0 83 | 84 | scaff_mol = Chem.SDMolSupplier(scaffold_path)[0] 85 | orig_mol = Chem.SDMolSupplier(args.original_path)[0] 86 | 87 | # compute vina score for the scaffold 88 | vina_task = VinaDockingTask.from_generated_mol(orig_mol, protein_path=receptor_path) 89 | score_result = vina_task.run(mode='score_only', exhaustiveness=16) 90 | scaffold_score = score_result[0]['affinity'] 91 | print('------> Vina score for original molecule is : ', scaffold_score) 92 | 93 | # compute vina score for the scaffold 94 | vina_task = VinaDockingTask.from_generated_mol(scaff_mol, protein_path=receptor_path) 95 | score_result = vina_task.run(mode='score_only', exhaustiveness=16) 96 | scaffold_score = score_result[0]['affinity'] 97 | print('------> Vina score for scaffold is : ', scaffold_score) 98 | 99 | 100 | for n in tqdm(range(10), desc='Eval'): 101 | prot_path = receptor_path 102 | if os.path.exists(results_path + 'pocket_' + str(n) + '_coords.npy'): 103 | 104 | n_files += 1 105 | x = np.load(results_path + 'pocket_' + str(n) + '_coords.npy') 106 | h = np.load(results_path + 'pocket_' + str(n) + '_onehot.npy') 107 | mol_masks = np.load(results_path + 'pocket_' + str(n) + '_mol_masks.npy') 108 | 109 | all_mols = [] 110 | for k in range(len(x)): 111 | 112 | mask = mol_masks[k] 113 | h_mol = h[k] 114 | x_mol = x[k][mask.astype(np.bool_)] 115 | 116 | atom_inds = h_mol[mask.astype(np.bool_)].argmax(axis=1) 117 | atom_types = [idx2atom[x] for x in atom_inds] 118 | atomic_nums = [CROSSDOCK_CHARGES[i] for i in atom_types] 119 | 120 | #all_validity_results.append(validity_results) 121 | n_samples += 1 122 | try: 123 | mol_rec = reconstruct_from_generated(x_mol.tolist(), atomic_nums, aromatic=None, basic_mode=True) 124 | smiles = Chem.MolToSmiles(mol_rec) 125 | Chem.SanitizeMol(mol_rec) 126 | 127 | except Exception as e: 128 | print(e) 129 | continue 130 | valid_mols += 1 131 | 132 | if is_connected(mol_rec): 133 | connected_mols += 1 134 | else: 135 | # if the molecule is not connected, then take the largest fragment 136 | m_frags = Chem.GetMolFrags(mol_rec, asMols=True, sanitizeFrags=False) 137 | mol_rec = max(m_frags, default=mol_rec, key=lambda m: m.GetNumAtoms()) 138 | 139 | chem_results = get_chem(mol_rec) # a dictionary with qed, sa, logp, lipinski, ring_size 140 | 141 | # --------------------------- Getting Vina Docking results --------------------------- 142 | try: 143 | if args.docking_mode == 'qvina': 144 | pass # TODO: add the qvina like in TargetDiff 145 | elif args.docking_mode in ['vina_score', 'vina_dock']: 146 | vina_task = VinaDockingTask.from_generated_mol(mol_rec, protein_path=prot_path) 147 | score_only_results = vina_task.run(mode='score_only', exhaustiveness=args.exhaustiveness) 148 | minimize_results = vina_task.run(mode='minimize', exhaustiveness=args.exhaustiveness) 149 | print('score_only: ', score_only_results[0]['affinity']) 150 | print('minimized score: ', minimize_results[0]['affinity']) 151 | vina_results = { 152 | 'score_only': score_only_results, 153 | 'minimize': minimize_results 154 | } 155 | if args.docking_mode == 'vina_dock': 156 | docking_results = vina_task.run(mode='dock', exhaustiveness=args.exhaustiveness) 157 | vina_results['dock'] = docking_results 158 | print('vina dock: ', docking_results[0]['affinity']) 159 | else: 160 | vina_results = None 161 | 162 | except: 163 | if args.verbose: 164 | logger.warning(f'Docking failed for pocket {n} and molecule {k}') 165 | continue 166 | 167 | results.append({ 168 | 'mol': mol_rec, 169 | 'smiles': smiles, 170 | 'chem_results': chem_results, 171 | 'vina' :vina_results 172 | }) 173 | 174 | logger.info(f'Evaluation is done! {n_samples} samples in total') 175 | 176 | fraction_valid = valid_mols / n_samples 177 | fraction_connected = connected_mols / n_samples 178 | 179 | print('fraction_connected is: ', fraction_connected) 180 | print('fraction_valid is :' , fraction_valid) 181 | 182 | qed = [r['chem_results']['qed'] for r in results] 183 | sa = [r['chem_results']['sa'] for r in results] 184 | logger.info('QED: Mean: %.3f Median: %.3f std: %.3f' % (np.mean(qed), np.median(qed), np.std(qed))) 185 | logger.info('SA: Mean: %.3f Median: %.3f std: %.3f' % (np.mean(qed), np.median(sa), np.std(sa))) 186 | 187 | if args.docking_mode == 'qvina': 188 | vina = [r['vina'[0]]['affinity'] for r in results] 189 | logger.info('Vina: Mean: %.3f Median: %.3f Std: %.3f' %(np.mean(vina), np.median(vina), np.std(vina))) 190 | elif args.docking_mode in ['vina_dock', 'vina_score']: 191 | vina_score_only = [r['vina']['score_only'][0]['affinity'] for r in results] 192 | vina_min = [r['vina']['minimize'][0]['affinity'] for r in results] 193 | logger.info('Vina Score : Mean %.3f Median: %.3f Std: %.3f' % (np.mean(vina_score_only), np.median(vina_score_only), np.std(vina_score_only))) 194 | logger.info('Vina minimized : Mean %.3f Median: %.3f Std: %.3f' % (np.mean(vina_min), np.median(vina_min), np.std(vina_min))) 195 | if args.docking_mode == 'vina_dock': 196 | vina_dock = [r['vina']['dock'][0]['affinity'] for r in results] 197 | logger.info('Vina Dock : Mean: %.3f Median: %.3f Std: %.3f' % (np.mean(vina_dock), np.median(vina_dock), np.std(vina_dock))) 198 | 199 | print_ring_ratio([r['chem_results']['ring_size'] for r in results], logger) 200 | 201 | torch.save({ 202 | 'fraction_connected': fraction_connected, 203 | 'fraction_valid': fraction_valid, 204 | 'all_results': results, 205 | }, os.path.join(eval_path, 'metrics.pt')) -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/assets/.DS_Store -------------------------------------------------------------------------------- /assets/movie.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/assets/movie.gif -------------------------------------------------------------------------------- /assets/scaffold_optim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/assets/scaffold_optim.png -------------------------------------------------------------------------------- /data/CROSSDOCK/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/data/CROSSDOCK/__init__.py -------------------------------------------------------------------------------- /data/CROSSDOCK/fragment_hierarchy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | import networkx as nx 4 | from prepare_fragments import * 5 | 6 | def find_neigh_frags_to_neigh_atoms(bonds_broken_frags, bonds_broken_frag_ids): 7 | """ find a mapping between broken bonds (atom tuples) and their corresponding fragment ids (index tuple) 8 | bonds_broken_frags: -> tuple of bonds broken btween fragments 9 | bonds_broken_frag_ids -> ids of bonds broken between fragments 10 | """ 11 | # a dictionary mapping from fragments (tuple) to atom ids (tuple) in the original moleucle 12 | neigh_frags_to_neigh_atoms = {} 13 | for i, bond in enumerate(bonds_broken_frags): 14 | neigh_frags_to_neigh_atoms[bonds_broken_frag_ids[i]] = bond 15 | neigh_frags_to_neigh_atoms[bonds_broken_frag_ids[i][::-1]] = bond[::-1] 16 | 17 | return neigh_frags_to_neigh_atoms 18 | 19 | class FragmentGraph(): 20 | """ class for fragment graph 21 | """ 22 | def __init__(self, mol, fragments, adjacency, frag_atom_ids, frag_to_id_dict, neigh_frags_to_neigh_atoms): 23 | self.original_mol = mol 24 | self.graph = nx.Graph() 25 | self.fragments = fragments 26 | self.frag_atom_ids = frag_atom_ids 27 | self.frag_to_id_dict = frag_to_id_dict 28 | self.conformer = mol.GetConformer().GetPositions() # position of atoms in the original molecule 29 | self.all_mol_atom_symbols = self.get_all_mol_atom_symbols() # array of atom symbols in the molecule 30 | self.all_mol_atom_charges = [] # including charge of atoms may not be wise since we are doint autoregressive generation 31 | self.neigh_frags_to_neigh_atoms = neigh_frags_to_neigh_atoms 32 | 33 | fragment_bonds = np.argwhere(np.triu(adjacency)) 34 | 35 | for i, f in enumerate(fragments): 36 | self.graph.add_node(f, name='frag_' + str(i)) 37 | 38 | for bond in fragment_bonds: 39 | self.graph.add_edge(fragments[bond[0]], fragments[bond[1]]) 40 | 41 | def draw_graph(self, graph): 42 | labels = nx.get_node_attributes(graph, 'name') 43 | nx.draw_circular(graph, labels=labels, node_size=3000) 44 | 45 | def get_bfs_order(self, starting_point=None): 46 | """ returns a list of tuples of fragments that should be connected to traverse the graph 47 | in a BFS order 48 | """ 49 | if starting_point is None: 50 | starting_point = 0 51 | starting_frag = self.fragments[starting_point] 52 | bfs_edges = list(nx.bfs_edges(self.graph, starting_frag)) 53 | return bfs_edges 54 | 55 | def get_dfs_order(self, starting_point=None): 56 | """ return a list of tuples of fragments that should be connected to traverse 57 | in DFS order 58 | """ 59 | if starting_point is None: 60 | starting_point = 0 61 | starting_frag = self.fragments[starting_point] 62 | dfs_edges = list(nx.dfs_edges(self.graph, starting_frag)) 63 | return dfs_edges 64 | 65 | def hierarchical_reconstruct(self, edge_order='BFS', starting_point=None): 66 | """ 67 | Returns the reconstruction of the molecule in the order given in edge_order 68 | 69 | if edge_order is given the reconstruction is based on edge_order 70 | else reconstructoin is based on BFS order with starting point 71 | 72 | Returns: 73 | hierarchical_mol : hierarchical molecule built in BFS order 74 | atom_ids_hierarchical: ids of atoms added in the BFS order 75 | """ 76 | if starting_point is None: 77 | starting_point = 0 78 | 79 | if edge_order == 'BFS': 80 | edge_list = self.get_bfs_order(starting_point) 81 | elif edge_order == 'DFS': 82 | edge_list = self.get_dfs_order(starting_point) 83 | else: 84 | raise ValueError('edge order not found.') 85 | 86 | tmp = edge_list[0][0] # this is a mol 87 | tmp_id = self.frag_to_id_dict[tmp] # id of the fragment 88 | hierarchical_mol = [tmp] # the initial molecule 89 | 90 | # ------------- find the atom ids in hier ------------ 91 | atom_ids_hierarchical = [] # a set of atoms 92 | tmp_frag_atom_ids = self.frag_atom_ids[tmp_id] 93 | atom_ids_hierarchical.append(tmp_frag_atom_ids) 94 | 95 | 96 | # -------------- find the conformer in hier ----------- 97 | hierarchical_conformer = [] # hierarchical conformeration of the molecule 98 | first_frag_conformer = self.transfer_conformer(tmp_frag_atom_ids) 99 | hierarchical_conformer.append(first_frag_conformer) 100 | 101 | 102 | # --------------- find the atom symbols in hier ------------- 103 | hier_atom_symbol = [] 104 | first_frag_symbols = self.all_mol_atom_symbols[list(tmp_frag_atom_ids)] 105 | hier_atom_symbol.append(first_frag_symbols) 106 | 107 | all_anchor_ids = [] 108 | first_frag_id = tmp_id 109 | extensions_atom_ids = [self.frag_atom_ids[first_frag_id]] 110 | 111 | for edge in edge_list: 112 | 113 | tmp = Chem.CombineMols(tmp, edge[1]) 114 | hierarchical_mol.append(tmp) 115 | frag_id = self.frag_to_id_dict[edge[1]] # id of the next fragment to add 116 | index_of_two_frags = (self.frag_to_id_dict[edge[0]], self.frag_to_id_dict[edge[1]]) 117 | tmp_frag_atom_ids = tmp_frag_atom_ids.union(self.frag_atom_ids[frag_id]) 118 | extensions_atom_ids.append(self.frag_atom_ids[frag_id]) 119 | atom_ids_hierarchical.append(tmp_frag_atom_ids) 120 | 121 | anchor_idx = self.neigh_frags_to_neigh_atoms[index_of_two_frags][0] 122 | all_anchor_ids.append(anchor_idx) 123 | 124 | conformer_at_this_step = self.transfer_conformer(tmp_frag_atom_ids) 125 | hierarchical_conformer.append(conformer_at_this_step) 126 | 127 | hier_atom_symbol.append(self.all_mol_atom_symbols[list(tmp_frag_atom_ids)]) 128 | 129 | 130 | return hierarchical_mol, atom_ids_hierarchical, extensions_atom_ids, hierarchical_conformer, hier_atom_symbol, all_anchor_ids 131 | 132 | 133 | def transfer_conformer(self, atom_ids): 134 | 135 | conformer_at_this_step = self.conformer[list(atom_ids)] 136 | return conformer_at_this_step 137 | 138 | def get_all_mol_atom_symbols(self): 139 | mol_atom_symbols = [] 140 | for atom in self.original_mol.GetAtoms(): 141 | mol_atom_symbols.append(atom.GetSymbol()) 142 | mol_atom_symbols = np.array(mol_atom_symbols) 143 | return mol_atom_symbols 144 | 145 | def get_anchor_idx(self): 146 | pass 147 | 148 | 149 | @staticmethod 150 | def draw_fragment_graph(all_frags, fragment_bonds): 151 | 152 | G = nx.Graph() 153 | for i, f in enumerate(all_frags): 154 | img = Draw.MolToImage(f) 155 | img.save('frag_' + str(i)+ '.png') 156 | 157 | for i,f in enumerate(all_frags): 158 | G.add_node(f, name='frag_'+str(i), img=plt.imread('frag_' + str(i) + '.png')) 159 | 160 | for bond in fragment_bonds: 161 | G.add_edge(all_frags[bond[0]], all_frags[bond[1]]) 162 | 163 | pos = nx.circular_layout(G) 164 | fig = plt.figure(figsize=(12,10)) 165 | ax = plt.subplot(111) 166 | ax.set_aspect('equal') 167 | nx.draw_networkx_edges(G, pos, ax=ax, edge_color='black', width=2.) 168 | 169 | #plt.ylim(-4.5,4.5) 170 | trans=ax.transData.transform 171 | trans2=fig.transFigure.inverted().transform 172 | 173 | piesize=0.2 # this is the image size 174 | p2=piesize/2.0 175 | for n in G: 176 | xx,yy=trans(pos[n]) # figure coordinates 177 | xa,ya=trans2((xx,yy)) # axes coordinates 178 | a = plt.axes([xa-p2,ya-p2, piesize, piesize]) 179 | a.set_aspect('equal') 180 | a.imshow(G.nodes[n]['img']) 181 | a.axis('off') 182 | ax.axis('off') 183 | plt.show() 184 | 185 | @staticmethod 186 | def draw_hier_recons(hier): 187 | graph = nx.Graph() 188 | for i, f in enumerate(hier): 189 | img = Draw.MolToImage(f) 190 | img.save('frag_' + str(i)+ '.png') 191 | 192 | for i, f in enumerate(hier): 193 | graph.add_node(f, name='frag_'+str(i), img=plt.imread('frag_' + str(i) + '.png')) 194 | 195 | for i in range(len(hier)-1): 196 | graph.add_edge(hier[i], hier[i+1]) 197 | 198 | pos = nx.circular_layout(graph, dim=2,scale=2.5) 199 | fig = plt.figure(figsize=(12,10)) 200 | ax = plt.subplot(111) 201 | ax.set_aspect('equal') 202 | nx.draw_networkx_edges(graph, pos, ax=ax, node_size=40, edge_color='blue', width=1.5, arrows=True, arrowsize=30) 203 | 204 | 205 | #plt.ylim(-4.5,4.5) 206 | trans=ax.transData.transform 207 | trans2=fig.transFigure.inverted().transform 208 | 209 | piesize=0.22 # this is the image size 210 | p2=piesize/2. 211 | for n in graph: 212 | xx,yy=trans(pos[n]) # figure coordinates 213 | xa,ya=trans2((xx,yy)) # axes coordinates 214 | a = plt.axes([xa-p2,ya-p2, piesize, piesize]) 215 | a.set_aspect('equal') 216 | a.imshow(graph.nodes[n]['img']) 217 | a.axis('off') 218 | ax.axis('off') 219 | plt.show() 220 | 221 | @staticmethod 222 | def get_smiles(mol): 223 | return Chem.MolToSmiles(mol) 224 | 225 | @staticmethod 226 | def get_mol_from_smiles(smiles): 227 | return Chem.MolFromSmiles(smiles) 228 | 229 | -------------------------------------------------------------------------------- /data/CROSSDOCK/process_ligands.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import re 4 | import argparse 5 | 6 | from rdkit import Chem 7 | from tqdm import tqdm 8 | 9 | from fragment_hierarchy import * 10 | from prepare_fragments import * 11 | from sascorer import calculateScore 12 | from rdkit.Chem import rdchem 13 | 14 | import os 15 | from pathlib import Path 16 | 17 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9} 18 | atom_charges = {'C':6, 'N': 7, 'O': 8, 'S': 16, 'B':5, 'Br':35, 'Cl':17, 'P':15, 'I':53, 'F':9} 19 | hybrid_to_onehot = {'SP':0, 'SP2': 1, 'SP3': 2} 20 | 21 | def get_one_hot(atom, atoms_dict): 22 | one_hot = np.zeros(len(atoms_dict)) 23 | one_hot[atoms_dict[atom]] = 1 24 | return one_hot 25 | 26 | def process_ligand(sdffile, max_num_frags=12, num_atoms_cutoff=22, add_QED=True, add_SA=True): 27 | try: 28 | mol = Chem.SDMolSupplier(str(sdffile))[0] 29 | 30 | add_Hs= True 31 | 32 | # get positions and ont-hot 33 | mol_pos = mol.GetConformer().GetPositions() 34 | all_symbols = [] 35 | mol_onehot = [] 36 | mol_onehot = [] 37 | for atom in mol.GetAtoms(): 38 | #all_symbols.append(atom.GetSymbol()) 39 | atom_symb_onehot = get_one_hot(atom.GetSymbol(), atom_dict) 40 | hyb_onehot = np.eye(1,len(hybrid_to_onehot), hybrid_to_onehot[str(atom.GetHybridization())]).squeeze() 41 | aromatic_onehot = float(atom.GetIsAromatic()) 42 | mol_onehot.append(np.concatenate([atom_symb_onehot, hyb_onehot, (aromatic_onehot,)])) 43 | 44 | # NOTE: adding extra node features (aromaticity and hybridization) see if these help 45 | mol_onehot = np.array(mol_onehot) 46 | 47 | mol_charges = [] 48 | for atom in all_symbols: 49 | mol_charges.append(atom_charges[atom]) 50 | 51 | # get charges 52 | mol_charges = np.array(mol_charges) 53 | 54 | output = find_bonds_broken_with_frags(mol, find_single_ring_fragments(mol), max_num_frags=max_num_frags, max_num_atoms_single_frag=num_atoms_cutoff) 55 | if output is not None: 56 | all_frags, bonds_broken_frags, bonds_broken_indices, \ 57 | bonds_broken_frag_ids, all_frag_atom_ids, atom2frag = output 58 | 59 | # -------------- get the smiles of fragments for making a fragment library 60 | du = Chem.MolFromSmiles('*') 61 | frag_smiles_temp = [Chem.MolFromSmiles(Chem.MolToSmiles(all_frags[i])) for i in range(len(all_frags))] 62 | 63 | frag_smiles = [] 64 | frag_n_atoms = [] 65 | for i in range(len(all_frags)): 66 | frag=AllChem.ReplaceSubstructs(frag_smiles_temp[i],du,Chem.MolFromSmiles('[H]'),True)[0] 67 | frag = Chem.RemoveAllHs(frag) 68 | frag_n_atoms.append(frag.GetNumAtoms()) 69 | frag_smiles.append(Chem.MolToSmiles(frag)) 70 | # -------------------------------------------------------------------- 71 | 72 | if len(all_frags) > 1: # more than 1 fragment exists in the molecule 73 | adjacency = find_neighboring_frags(all_frags, atom2frag, bonds_broken_frags) 74 | neigh_frags_to_neigh_atoms = find_neigh_frags_to_neigh_atoms(bonds_broken_frags, bonds_broken_frag_ids) 75 | 76 | frag_to_id_dict = {} 77 | for i,frag in enumerate(all_frags): 78 | frag_to_id_dict[frag] = i 79 | 80 | g = FragmentGraph(mol, 81 | all_frags, 82 | adjacency, 83 | frag_atom_ids=all_frag_atom_ids, 84 | frag_to_id_dict=frag_to_id_dict, 85 | neigh_frags_to_neigh_atoms=neigh_frags_to_neigh_atoms) 86 | 87 | n_frags = len(all_frags) 88 | 89 | assert n_frags <= max_num_frags 90 | 91 | mol_atom_ids = [] 92 | mol_extension_ids = [] 93 | mol_anchor_ids = [] 94 | mol_QED_scores = [] 95 | mol_SA_scores = [] 96 | all_sub_mols = [] 97 | 98 | for order in ['BFS', 'DFS']: 99 | for j in range(n_frags): # 5 different ways to reconstruct the molecule in total 100 | hier, perm_atom_ids, perm_extensions_atom_ids, _, _, perm_anchor_ids = g.hierarchical_reconstruct(edge_order=order, starting_point=j) 101 | # save this hierarchy 102 | 103 | assert len(perm_atom_ids) != 0 104 | assert len(perm_anchor_ids) != 0 105 | assert len(perm_extensions_atom_ids) != 0 106 | 107 | assert len(perm_extensions_atom_ids) == len(perm_atom_ids) 108 | assert (len(perm_extensions_atom_ids) == len(perm_anchor_ids) + 1) 109 | 110 | num_atoms = [len(perm_extensions_atom_ids[i]) for i in range(len(perm_extensions_atom_ids))] 111 | max_num_atoms = max(num_atoms) 112 | if max_num_atoms > num_atoms_cutoff: 113 | break 114 | mol_atom_ids.append(perm_atom_ids) 115 | mol_extension_ids.append(perm_extensions_atom_ids) 116 | mol_anchor_ids.append(perm_anchor_ids) 117 | 118 | QED_scores = [] 119 | SA_scores = [] 120 | if add_QED: 121 | for i in range(len(all_frags)): 122 | atom_indices = list(perm_atom_ids[i]) 123 | sub_mol = rdchem.EditableMol(Chem.Mol()) 124 | atom_map = {} 125 | for atom_idx in atom_indices: 126 | atom = mol.GetAtomWithIdx(atom_idx) 127 | new_idx = sub_mol.AddAtom(atom) 128 | atom_map[atom_idx] = new_idx 129 | 130 | for bond in mol.GetBonds(): 131 | begin_idx = bond.GetBeginAtomIdx() 132 | end_idx = bond.GetEndAtomIdx() 133 | if begin_idx in atom_indices and end_idx in atom_indices: 134 | bond_type = bond.GetBondType() 135 | sub_mol.AddBond(atom_indices.index(begin_idx), atom_indices.index(end_idx), bond_type) 136 | 137 | 138 | sub_mol = sub_mol.GetMol() 139 | # Adding 3d Coordinates to the fragments 140 | try: 141 | Chem.SanitizeMol(sub_mol) 142 | conf = Chem.Conformer(sub_mol.GetNumAtoms()) 143 | for atom_idx, new_atom_idx in atom_map.items(): 144 | conf.SetAtomPosition(new_atom_idx, mol.GetConformer().GetAtomPosition(atom_idx)) 145 | sub_mol.AddConformer(conf) 146 | except: 147 | print('sanitization failed! using smarts instead!') 148 | sub_mol = Chem.MolFromSmarts(Chem.MolToSmarts(sub_mol)) 149 | Chem.SanitizeMol(sub_mol) 150 | conf = Chem.Conformer(sub_mol.GetNumAtoms()) 151 | for atom_idx, new_atom_idx in atom_map.items(): 152 | conf.SetAtomPosition(new_atom_idx, mol.GetConformer().GetAtomPosition(atom_idx)) 153 | sub_mol.AddConformer(conf) 154 | 155 | #sub_mol = Chem.MolFromSmarts(Chem.MolToSmarts(sub_mol)) 156 | #Chem.SanitizeMol(sub_mol) 157 | if add_Hs: 158 | sub_mol_h = Chem.AddHs(sub_mol, addCoords=True) 159 | 160 | all_sub_mols.append(sub_mol_h) 161 | 162 | QED_scores.append(Chem.QED.qed(sub_mol)) 163 | if add_SA: 164 | sa = calculateScore(sub_mol) 165 | sa_as_pocket2mol = round((10-sa)/9, 2) # from pocket2mol 166 | SA_scores.append(sa_as_pocket2mol) 167 | 168 | 169 | mol_QED_scores.append(QED_scores) 170 | mol_SA_scores.append(SA_scores) 171 | 172 | mol_atom_ids = np.array(mol_atom_ids, dtype=object) 173 | mol_extension_ids = np.array(mol_extension_ids, dtype=object) 174 | mol_anchor_ids = np.array(mol_anchor_ids, dtype=object) 175 | 176 | is_single_frag = False 177 | 178 | return mol_pos, mol_onehot, mol_charges, mol_atom_ids, mol_extension_ids, mol_anchor_ids, is_single_frag, frag_smiles, frag_n_atoms, mol_QED_scores, mol_SA_scores, all_sub_mols 179 | else: 180 | print('using single fragment') 181 | is_single_frag = True 182 | mol_H = Chem.AddHs(mol, addCoords=True) 183 | all_sub_mols = [mol_H] 184 | mol_QED_score = [Chem.QED.qed(mol)] 185 | sa = calculateScore(mol) 186 | mol_SA_score = [round((10-sa)/9, 2)] 187 | return mol_pos, mol_onehot, mol_charges, None, None, None, is_single_frag, frag_smiles, frag_n_atoms, mol_QED_score, mol_SA_score, all_sub_mols 188 | 189 | except Exception as e: 190 | print(f'Error {e} for sdffile {sdffile}') 191 | return -------------------------------------------------------------------------------- /data/CROSSDOCK/process_pockets.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from time import time 3 | import argparse 4 | import shutil 5 | import random 6 | import matplotlib.pyplot as plt 7 | 8 | from tqdm import tqdm 9 | import numpy as np 10 | 11 | from Bio.PDB import PDBParser 12 | from Bio.PDB.Polypeptide import is_aa, three_to_one 13 | from rdkit import Chem 14 | 15 | amino_acid_dict = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19} 16 | pocket_atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3} # only 4 atoms types for pocket 17 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9} 18 | 19 | 20 | def get_one_hot(atom, atoms_dict): 21 | one_hot = np.zeros(len(atoms_dict)) 22 | one_hot[atoms_dict[atom]] = 1 23 | return one_hot 24 | 25 | def process_pocket(pdbfile, sdffile, atom_dict, pocket_atom_dict, dist_cutoff, remove_H=True, ca_only=False): 26 | 27 | pdb_struct = PDBParser(QUIET=True).get_structure('', pdbfile) 28 | 29 | try: 30 | ligand = Chem.SDMolSupplier(str(sdffile))[0] 31 | except: 32 | raise Exception(f'cannot read sdf mol ({sdffile})') 33 | 34 | # remove H atom if not in atom_dict, other atom types taht aren't allowed 35 | # should stay so that the entire ligand can be removed from the dataset 36 | lig_atoms = [a.GetSymbol() for a in ligand.GetAtoms() 37 | if (a.GetSymbol().capitalize() in atom_dict or a.element !='H')] 38 | lig_coords = np.array([list(ligand.GetConformer(0).GetAtomPosition(idx)) 39 | for idx in range(ligand.GetNumAtoms())]) 40 | 41 | # find interacting pocket residues based on distance cutoff 42 | pocket_residues = [] 43 | for residue in pdb_struct[0].get_residues(): 44 | res_coords = np.array([a.get_coord() for a in residue.get_atoms()]) 45 | if is_aa(residue.get_resname(), standard=True) and \ 46 | (((res_coords[:, None, :] - lig_coords[None, :, :]) ** 2).sum(-1)**0.5).min() < dist_cutoff: 47 | pocket_residues.append(residue) 48 | 49 | 50 | pocket_ids = [f'{res.parent.id}:{res.id[1]}' for res in pocket_residues] 51 | 52 | if ca_only: 53 | try: 54 | pocket_one_hot = [] 55 | pocket_coords = [] 56 | for res in pocket_residues: 57 | for atom in res.get_atoms(): 58 | if atom.name == 'CA': 59 | pocket_one_hot.append(np.eye(1, len(amino_acid_dict), 60 | amino_acid_dict[three_to_one(res.get_resname())]).squeeze()) 61 | pocket_coords.append(atom.coord) 62 | pocket_one_hot = np.stack(pocket_one_hot) 63 | pocket_coords = np.stack(pocket_coords) 64 | except KeyError as e: 65 | raise KeyError(f'{e} not in amino acid dict ({pdbfile}, {sdffile})') 66 | else: 67 | full_atoms = np.concatenate([np.array([atom.element for atom in res.get_atoms()]) for res in pocket_residues], axis=0) 68 | full_coords = np.concatenate([np.array([atom.coord for atom in res.get_atoms()]) for res in pocket_residues], axis=0) 69 | full_atoms_names = np.concatenate([np.array([atom.get_id() for atom in res.get_atoms()]) for res in pocket_residues], axis=0) 70 | pocket_AA = np.concatenate([([three_to_one(atom.get_parent().get_resname()) for atom in res.get_atoms()]) for res in pocket_residues], axis=0) 71 | 72 | # removing Hs if present 73 | if remove_H: 74 | h_mask = full_atoms == 'H' 75 | full_atoms = full_atoms[~h_mask] 76 | pocket_coords = full_coords[~h_mask] 77 | full_atoms_names = full_atoms_names[~h_mask] 78 | pocket_AA = pocket_AA[~h_mask] 79 | 80 | try: 81 | pocket_one_hot = [] 82 | for i in range(len(full_atoms)): 83 | a = full_atoms[i] 84 | aa = pocket_AA[i] 85 | atom_onehot = np.eye(1, len(pocket_atom_dict), pocket_atom_dict[a.capitalize()]).squeeze() 86 | amino_onehot = np.eye(1, len(amino_acid_dict), amino_acid_dict[aa.capitalize()]).squeeze() 87 | is_backbone = 1 if full_atoms_names[i].capitalize() in ['N','CA','C','O'] else 0 88 | pocket_one_hot.append(np.concatenate([atom_onehot, amino_onehot, (is_backbone,)])) 89 | 90 | pocket_one_hot = np.stack(pocket_one_hot) 91 | except KeyError as e: 92 | raise KeyError( 93 | f'{e} not in atom dict ({pdbfile})') 94 | 95 | pocket_one_hot = np.array(pocket_one_hot) 96 | return pocket_one_hot, pocket_coords, lig_coords -------------------------------------------------------------------------------- /data/CROSSDOCK/sascorer.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | from rdkit.Chem import rdMolDescriptors 3 | import pickle 4 | 5 | import math 6 | from collections import defaultdict 7 | 8 | import os.path as op 9 | 10 | _fscores = None 11 | 12 | 13 | def readFragmentScores(name='fpscores'): 14 | import gzip 15 | global _fscores 16 | # generate the full path filename: 17 | #if name == "fpscores": 18 | #name = op.join(op.dirname(__file__), name) 19 | #data = pickle.load(gzip.open('%s.pkl.gz' % name)) 20 | data = pickle.load(gzip.open('/srv/home/mahdi.ghorbani/FragDiff/fpscores.pkl.gz')) 21 | outDict = {} 22 | for i in data: 23 | for j in range(1, len(i)): 24 | outDict[i[j]] = float(i[0]) 25 | _fscores = outDict 26 | 27 | 28 | def numBridgeheadsAndSpiro(mol, ri=None): 29 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 30 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 31 | return nBridgehead, nSpiro 32 | 33 | def calculateScore(m): 34 | if _fscores is None: 35 | readFragmentScores() 36 | 37 | # fragment score 38 | fp = rdMolDescriptors.GetMorganFingerprint(m, 39 | 2) # <- 2 is the *radius* of the circular fingerprint 40 | fps = fp.GetNonzeroElements() 41 | score1 = 0. 42 | nf = 0 43 | for bitId, v in fps.items(): 44 | nf += v 45 | sfp = bitId 46 | score1 += _fscores.get(sfp, -4) * v 47 | score1 /= nf 48 | 49 | # features score 50 | nAtoms = m.GetNumAtoms() 51 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 52 | ri = m.GetRingInfo() 53 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 54 | nMacrocycles = 0 55 | for x in ri.AtomRings(): 56 | if len(x) > 8: 57 | nMacrocycles += 1 58 | 59 | sizePenalty = nAtoms**1.005 - nAtoms 60 | stereoPenalty = math.log10(nChiralCenters + 1) 61 | spiroPenalty = math.log10(nSpiro + 1) 62 | bridgePenalty = math.log10(nBridgeheads + 1) 63 | macrocyclePenalty = 0. 64 | # --------------------------------------- 65 | # This differs from the paper, which defines: 66 | # macrocyclePenalty = math.log10(nMacrocycles+1) 67 | # This form generates better results when 2 or more macrocycles are present 68 | if nMacrocycles > 0: 69 | macrocyclePenalty = math.log10(2) 70 | 71 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 72 | 73 | # correction for the fingerprint density 74 | # not in the original publication, added in version 1.1 75 | # to make highly symmetrical molecules easier to synthetise 76 | score3 = 0. 77 | if nAtoms > len(fps): 78 | score3 = math.log(float(nAtoms) / len(fps)) * .5 79 | 80 | sascore = score1 + score2 + score3 81 | 82 | # need to transform "raw" value into scale between 1 and 10 83 | min = -4.0 84 | max = 2.5 85 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 86 | # smooth the 10-end 87 | if sascore > 8.: 88 | sascore = 8. + math.log(sascore + 1. - 9.) 89 | if sascore > 10.: 90 | sascore = 10.0 91 | elif sascore < 1.: 92 | sascore = 1.0 93 | 94 | return sascore 95 | 96 | 97 | def processMols(mols): 98 | print('smiles\tName\tsa_score') 99 | for i, m in enumerate(mols): 100 | if m is None: 101 | continue 102 | 103 | s = calculateScore(m) 104 | 105 | smiles = Chem.MolToSmiles(m) 106 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 107 | 108 | 109 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/data/__init__.py -------------------------------------------------------------------------------- /data/sascorer.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | from rdkit.Chem import rdMolDescriptors 3 | import pickle 4 | 5 | import math 6 | 7 | _fscores = None 8 | 9 | def readFragmentScores(name='fpscores'): 10 | import gzip 11 | global _fscores 12 | # generate the full path filename: 13 | #if name == "fpscores": 14 | #name = op.join(op.dirname(__file__), name) 15 | #data = pickle.load(gzip.open('%s.pkl.gz' % name)) 16 | data = pickle.load(gzip.open('/srv/home/mahdi.ghorbani/FragDiff/fpscores.pkl.gz')) 17 | outDict = {} 18 | for i in data: 19 | for j in range(1, len(i)): 20 | outDict[i[j]] = float(i[0]) 21 | _fscores = outDict 22 | 23 | 24 | def numBridgeheadsAndSpiro(mol, ri=None): 25 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 26 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 27 | return nBridgehead, nSpiro 28 | 29 | def calculateScore(m): 30 | if _fscores is None: 31 | readFragmentScores() 32 | 33 | # fragment score 34 | fp = rdMolDescriptors.GetMorganFingerprint(m, 35 | 2) # <- 2 is the *radius* of the circular fingerprint 36 | fps = fp.GetNonzeroElements() 37 | score1 = 0. 38 | nf = 0 39 | for bitId, v in fps.items(): 40 | nf += v 41 | sfp = bitId 42 | score1 += _fscores.get(sfp, -4) * v 43 | score1 /= nf 44 | 45 | # features score 46 | nAtoms = m.GetNumAtoms() 47 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 48 | ri = m.GetRingInfo() 49 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 50 | nMacrocycles = 0 51 | for x in ri.AtomRings(): 52 | if len(x) > 8: 53 | nMacrocycles += 1 54 | 55 | sizePenalty = nAtoms**1.005 - nAtoms 56 | stereoPenalty = math.log10(nChiralCenters + 1) 57 | spiroPenalty = math.log10(nSpiro + 1) 58 | bridgePenalty = math.log10(nBridgeheads + 1) 59 | macrocyclePenalty = 0. 60 | # --------------------------------------- 61 | # This differs from the paper, which defines: 62 | # macrocyclePenalty = math.log10(nMacrocycles+1) 63 | # This form generates better results when 2 or more macrocycles are present 64 | if nMacrocycles > 0: 65 | macrocyclePenalty = math.log10(2) 66 | 67 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 68 | 69 | # correction for the fingerprint density 70 | # not in the original publication, added in version 1.1 71 | # to make highly symmetrical molecules easier to synthetise 72 | score3 = 0. 73 | if nAtoms > len(fps): 74 | score3 = math.log(float(nAtoms) / len(fps)) * .5 75 | 76 | sascore = score1 + score2 + score3 77 | 78 | # need to transform "raw" value into scale between 1 and 10 79 | min = -4.0 80 | max = 2.5 81 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 82 | # smooth the 10-end 83 | if sascore > 8.: 84 | sascore = 8. + math.log(sascore + 1. - 9.) 85 | if sascore > 10.: 86 | sascore = 10.0 87 | elif sascore < 1.: 88 | sascore = 1.0 89 | 90 | return sascore 91 | 92 | 93 | def processMols(mols): 94 | print('smiles\tName\tsa_score') 95 | for i, m in enumerate(mols): 96 | if m is None: 97 | continue 98 | 99 | s = calculateScore(m) 100 | 101 | smiles = Chem.MolToSmiles(m) 102 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 103 | 104 | 105 | -------------------------------------------------------------------------------- /fpscores.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/fpscores.pkl.gz -------------------------------------------------------------------------------- /generate_pocket_molecules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | import argparse 5 | 6 | import torch 7 | import time 8 | import shutil 9 | 10 | from utils.volume_sampling import sample_discrete_number, bin_edges, prob_dist_df 11 | from utils.templates import get_one_hot, get_pocket 12 | 13 | from src.lightning_anchor_gnn import AnchorGNN_pl 14 | from src.lightning import AR_DDPM 15 | from scipy.spatial import distance 16 | 17 | from analysis.reconstruct_mol import reconstruct_from_generated 18 | 19 | from rdkit.Chem import rdmolfiles 20 | from sampling.sample_mols import generate_mols_for_pocket 21 | 22 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9} 23 | idx2atom = {0:'C', 1:'N', 2:'O', 3:'S', 4:'B', 5:'Br', 6:'Cl', 7:'P', 8:'I', 9:'F'} 24 | CROSSDOCK_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'B':5, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15} 25 | pocket_atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3} # only 4 atoms types for pocket 26 | 27 | vdws = {'C': 1.7, 'N': 1.55, 'O': 1.52, 'S': 1.8, 'B': 1.92, 'Br': 1.85, 'Cl': 1.75, 'P': 1.8, 'I': 1.98, 'F': 1.47} 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--results-path', type=str, default='results', 31 | help='path to save the results ') 32 | parser.add_argument('--data-path', action='store', type=str, default='/srv/home/mahdi.ghorbani/FragDiff/crossdock', 33 | help='path to the test data for generating molecules') 34 | parser.add_argument('--use-anchor-model', action='store_true', default=False, 35 | help='Whether to use an anchor prediction model') 36 | parser.add_argument('--anchor-model', type=str, default='anchor_model.ckpt', 37 | help='path to the anchor model. Note that for guidance, the anchor model should incorporate the conditionals') 38 | parser.add_argument('--n-samples', type=int, default=20, 39 | help='total number of ligands to generate per pocket') 40 | parser.add_argument('--exp-name', type=str, default='exp-1', 41 | help='name of the generation experiment') 42 | parser.add_argument('--diff-model', type=str, default='diff-model.ckpt', 43 | help='path to the diffusion model checkpoint') 44 | parser.add_argument('--device', type=str, default='cuda:0') 45 | parser.add_argument('--rejection-sampling', action='store_true', default=False, help='enable rejection sampling') 46 | 47 | if __name__ == '__main__': 48 | args = parser.parse_args() 49 | torch_device = args.device 50 | anchor_checkpoint = args.anchor_model 51 | data_path = args.data_path 52 | diff_model_checkpoint = args.diff_model 53 | 54 | model = AR_DDPM.load_from_checkpoint(diff_model_checkpoint, device=torch_device) # load diffusion model 55 | model = model.to(torch_device) 56 | 57 | if args.use_anchor_model is not None: 58 | anchor_model = AnchorGNN_pl.load_from_checkpoint(anchor_checkpoint, device=torch_device) 59 | anchor_model = anchor_model.to(torch_device) 60 | else: 61 | anchor_model = None # TODO: implement random anchor selection 62 | 63 | split = torch.load(data_path + '/' + 'split_by_name.pt') 64 | prefix = data_path + '/crossdocked_pocket10/' 65 | 66 | if not os.path.exists(args.results_path): 67 | print('creating results directory') 68 | 69 | save_dir = args.results_path + '/' + args.exp_name 70 | if not os.path.exists(save_dir): 71 | os.makedirs(save_dir, exist_ok=True) 72 | 73 | for n in range(100): 74 | prot_name = prefix + split['test'][n][0] 75 | lig_name = prefix + split['test'][n][1] 76 | 77 | pocket_onehot, pocket_coords, lig_coords, _ = get_pocket(prot_name, lig_name, atom_dict, pocket_atom_dict=pocket_atom_dict, dist_cutoff=7) 78 | 79 | # --------------- make a grid box around the pocket ---------------- 80 | min_coords = pocket_coords.min(axis=0) - 2.5 # 81 | max_coords = pocket_coords.max(axis=0) + 2.5 82 | 83 | x_range = slice(min_coords[0], max_coords[0] + 1, 1.5) # spheres of radius 1.2 (vdw radius of H) 84 | y_range = slice(min_coords[1], max_coords[1] + 1, 1.5) 85 | z_range = slice(min_coords[2], max_coords[2] + 1, 1.5) 86 | 87 | grid = np.mgrid[x_range, y_range, z_range] 88 | grid_points = grid.reshape(3, -1).T # This transposes the grid to a list of coordinates 89 | 90 | # remove grids points not in 3.5A neighborhood of original ligand 91 | distances_mol = distance.cdist(grid_points, lig_coords) 92 | mask_mol = (distances_mol < 3.5).any(axis=1) 93 | filtered_mol_points = grid_points[mask_mol] 94 | 95 | # remove grid points that are close to the pocket 96 | pocket_distances = distance.cdist(filtered_mol_points, pocket_coords) 97 | mask_pocket = (pocket_distances < 2).any(axis=1) 98 | grids = filtered_mol_points[~mask_pocket] 99 | 100 | n_samples = args.n_samples 101 | max_mol_sizes = [] 102 | 103 | fpocket_out = prot_name[:-4] + '_out' 104 | shutil.rmtree(fpocket_out, ignore_errors=True) 105 | 106 | #print('running fpocket!') 107 | #try: 108 | # run_fpocket(prot_name) 109 | #except: 110 | # print('Error in running fpocket! using random sizes') 111 | 112 | # NOTE: using original molecule coordinates for making the grid 113 | 114 | grids = torch.tensor(grids) 115 | all_grids = [] # list of grids 116 | for i in range(n_samples): 117 | all_grids.append(grids) 118 | 119 | pocket_vol = len(grids) 120 | #if os.path.exists(fpocket_out): 121 | # filename = prot_name[:-4] + '_out/pockets/pocket1_atm.pdb' 122 | # score, drug_score, pocket_volume = extract_values(filename) 123 | #else: 124 | # print('running fpocket!') 125 | # run_fpocket(prot_name) 126 | # filename = prot_name[:-4] + '_out/pockets/pocket1_atm.pdb' 127 | # score, drug_score, pocket_volume = extract_values(filename) 128 | 129 | #print('pocket_volume', pocket_volume) 130 | 131 | for i in range(n_samples): 132 | max_mol_sizes.append(sample_discrete_number(pocket_vol)) 133 | 134 | pocket_onehot = torch.tensor(pocket_onehot).float() 135 | pocket_coords = torch.tensor(pocket_coords).float() 136 | lig_coords = torch.tensor(lig_coords).float() 137 | pocket_size = len(pocket_coords) 138 | 139 | t1 = time.time() 140 | 141 | max_mol_sizes = np.array(max_mol_sizes) 142 | print('maximum sizes for molecules', max_mol_sizes) 143 | x, h, mol_masks = generate_mols_for_pocket(n_samples=n_samples, 144 | num_frags=8, 145 | pocket_size=pocket_size, 146 | pocket_coords=pocket_coords, 147 | pocket_onehot=pocket_onehot, 148 | lig_coords=lig_coords, 149 | anchor_model=anchor_model, 150 | diff_model=model, 151 | device=torch_device, 152 | return_all=False, 153 | prot_path=prot_name, 154 | max_mol_sizes=max_mol_sizes, 155 | all_grids=all_grids, 156 | rejection_sampling=args.rejection_sampling, 157 | rejection_criteria='clash') 158 | 159 | x = x.cpu().numpy() 160 | h = h.cpu().numpy() 161 | mol_masks = mol_masks.cpu().cpu().numpy() 162 | 163 | # convert to SDF 164 | all_mols = [] 165 | for k in range(len(x)): 166 | mask = mol_masks[k] 167 | h_mol = h[k] 168 | x_mol = x[k][mask.astype(np.bool_)] 169 | 170 | atom_inds = h_mol[mask.astype(np.bool_)].argmax(axis=1) 171 | atom_types = [idx2atom[x] for x in atom_inds] 172 | atomic_nums = [CROSSDOCK_CHARGES[i] for i in atom_types] 173 | 174 | try: 175 | mol_rec = reconstruct_from_generated(x_mol.tolist(), atomic_nums) 176 | all_mols.append(mol_rec) 177 | except: 178 | continue 179 | 180 | t2 = time.time() 181 | print('time to generate one is: ', (t2-t1)/n_samples) 182 | save_path = save_dir + '/' + 'pocket_' + str(n) 183 | 184 | # write sdf file of molecules 185 | with rdmolfiles.SDWriter(save_path + '_mols.sdf') as writer: 186 | for mol in all_mols: 187 | if mol: 188 | writer.write(mol) 189 | 190 | np.save(save_path + '_coords.npy', x) 191 | np.save(save_path + '_onehot.npy', h) 192 | np.save(save_path + '_mol_masks.npy', mol_masks) 193 | -------------------------------------------------------------------------------- /notebooks/2z3h_out/2z3h.pml: -------------------------------------------------------------------------------- 1 | from pymol import cmd,stored 2 | load 2z3h_out.pdb 3 | #select pockets, resn STP 4 | stored.list=[] 5 | cmd.iterate("(resn STP)","stored.list.append(resi)") #read info about residues STP 6 | #print stored.list 7 | lastSTP=stored.list[-1] #get the index of the last residu 8 | hide lines, resn STP 9 | 10 | #show spheres, resn STP 11 | for my_index in range(1,int(lastSTP)+1): cmd.select("pocket"+str(my_index), "resn STP and resi "+str(my_index)) 12 | for my_index in range(2,int(lastSTP)+2): cmd.color(my_index,"pocket"+str(my_index)) 13 | for my_index in range(1,int(lastSTP)+1): cmd.show("spheres","pocket"+str(my_index)) 14 | for my_index in range(1,int(lastSTP)+1): cmd.set("sphere_scale","0.3","pocket"+str(my_index)) 15 | for my_index in range(1,int(lastSTP)+1): cmd.set("sphere_transparency","0.1","pocket"+str(my_index)) 16 | -------------------------------------------------------------------------------- /notebooks/2z3h_out/2z3h.tcl: -------------------------------------------------------------------------------- 1 | proc highlighting { colorId representation id selection } { 2 | puts "highlighting $id" 3 | mol representation $representation 4 | mol material "Diffuse" 5 | mol color $colorId 6 | mol selection $selection 7 | mol addrep $id 8 | } 9 | 10 | set id [mol new 2z3h_out.pdb type pdb] 11 | mol delrep top $id 12 | highlighting Name "Lines" $id "protein" 13 | highlighting Name "Licorice" $id "not protein and not resname STP" 14 | highlighting Element "NewCartoon" $id "protein" 15 | highlighting "ColorID 7" "VdW 0.4" $id "protein and occupancy>0.95" 16 | set id [mol new 2z3h_pockets.pqr type pqr] 17 | mol selection "all" 18 | mol material "Glass3" 19 | mol delrep top $id 20 | mol representation "QuickSurf 0.3" 21 | mol color ResId $id 22 | mol addrep $id 23 | highlighting Index "Points 1" $id "resname STP" 24 | display rendermode GLSL 25 | -------------------------------------------------------------------------------- /notebooks/2z3h_out/2z3h_PYMOL.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | pymol 2z3h.pml 3 | -------------------------------------------------------------------------------- /notebooks/2z3h_out/2z3h_VMD.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | vmd 2z3h_out.pdb -e 2z3h.tcl 3 | -------------------------------------------------------------------------------- /notebooks/2z3h_out/2z3h_info.txt: -------------------------------------------------------------------------------- 1 | Pocket 1 : 2 | Score : 0.350 3 | Druggability Score : 0.871 4 | Number of Alpha Spheres : 70 5 | Total SASA : 176.938 6 | Polar SASA : 67.856 7 | Apolar SASA : 109.082 8 | Volume : 685.980 9 | Mean local hydrophobic density : 28.632 10 | Mean alpha sphere radius : 3.828 11 | Mean alp. sph. solvent access : 0.460 12 | Apolar alpha sphere proportion : 0.543 13 | Hydrophobicity score: 24.800 14 | Volume score: 4.000 15 | Polarity score: 11 16 | Charge score : 0 17 | Proportion of polar atoms: 37.736 18 | Alpha sphere density : 5.683 19 | Cent. of mass - Alpha Sphere max dist: 13.914 20 | Flexibility : 0.380 21 | 22 | Pocket 2 : 23 | Score : 0.025 24 | Druggability Score : 0.002 25 | Number of Alpha Spheres : 19 26 | Total SASA : 61.984 27 | Polar SASA : 16.094 28 | Apolar SASA : 45.890 29 | Volume : 221.120 30 | Mean local hydrophobic density : 8.000 31 | Mean alpha sphere radius : 4.123 32 | Mean alp. sph. solvent access : 0.669 33 | Apolar alpha sphere proportion : 0.474 34 | Hydrophobicity score: 40.250 35 | Volume score: 3.500 36 | Polarity score: 3 37 | Charge score : 0 38 | Proportion of polar atoms: 43.750 39 | Alpha sphere density : 2.046 40 | Cent. of mass - Alpha Sphere max dist: 5.129 41 | Flexibility : 0.306 42 | 43 | -------------------------------------------------------------------------------- /notebooks/2z3h_out/2z3h_pockets.pqr: -------------------------------------------------------------------------------- 1 | HEADER 2 | HEADER This is a pqr format file writen by the programm fpocket. 3 | HEADER It contains all the pocket vertices found by fpocket. 4 | ATOM 1 C STP 1 27.832 35.872 104.021 0.00 3.69 5 | ATOM 2 C STP 1 27.954 35.449 103.706 0.00 4.07 6 | ATOM 3 O STP 1 25.370 34.420 103.783 0.00 4.01 7 | ATOM 4 O STP 1 27.724 35.111 103.740 0.00 4.18 8 | ATOM 5 C STP 1 24.801 33.618 103.765 0.00 3.82 9 | ATOM 6 C STP 1 21.596 27.720 100.725 0.00 3.87 10 | ATOM 7 O STP 1 23.353 34.262 111.342 0.00 3.64 11 | ATOM 8 O STP 1 23.638 33.858 112.345 0.00 3.55 12 | ATOM 9 O STP 1 26.104 36.231 100.922 0.00 3.67 13 | ATOM 10 O STP 1 26.179 36.489 100.894 0.00 3.56 14 | ATOM 11 O STP 1 26.176 36.334 100.914 0.00 3.63 15 | ATOM 12 O STP 1 24.838 36.294 100.826 0.00 3.60 16 | ATOM 13 O STP 1 24.946 36.155 100.851 0.00 3.66 17 | ATOM 14 C STP 1 28.730 35.819 103.728 0.00 3.92 18 | ATOM 15 C STP 1 26.814 35.345 102.503 0.00 3.69 19 | ATOM 16 O STP 1 26.373 34.656 103.564 0.00 3.96 20 | ATOM 17 O STP 1 27.492 34.995 103.579 0.00 4.14 21 | ATOM 18 C STP 1 22.814 30.645 104.838 0.00 4.62 22 | ATOM 19 C STP 1 22.340 30.784 103.285 0.00 3.67 23 | ATOM 20 O STP 1 21.530 29.239 101.689 0.00 3.78 24 | ATOM 21 O STP 1 21.569 29.345 101.655 0.00 3.71 25 | ATOM 22 C STP 1 21.659 27.794 101.499 0.00 4.33 26 | ATOM 23 C STP 1 21.686 27.719 101.468 0.00 4.35 27 | ATOM 24 C STP 1 28.707 33.958 103.385 0.00 3.72 28 | ATOM 25 C STP 1 21.363 27.197 100.342 0.00 3.70 29 | ATOM 26 O STP 1 22.966 34.609 109.925 0.00 3.95 30 | ATOM 27 O STP 1 22.195 33.840 106.546 0.00 4.43 31 | ATOM 28 O STP 1 22.184 33.812 106.570 0.00 4.45 32 | ATOM 29 O STP 1 24.182 34.117 108.716 0.00 3.47 33 | ATOM 30 O STP 1 23.499 33.586 105.919 0.00 4.32 34 | ATOM 31 O STP 1 23.414 33.518 106.010 0.00 4.37 35 | ATOM 32 O STP 1 24.878 34.569 104.793 0.00 4.02 36 | ATOM 33 O STP 1 26.129 35.033 105.102 0.00 3.79 37 | ATOM 34 O STP 1 25.995 34.250 107.350 0.00 3.45 38 | ATOM 35 O STP 1 24.386 35.825 100.260 0.00 3.46 39 | ATOM 36 C STP 1 29.851 32.841 102.748 0.00 3.60 40 | ATOM 37 C STP 1 24.701 33.358 103.737 0.00 3.76 41 | ATOM 38 C STP 1 23.960 32.096 104.019 0.00 3.78 42 | ATOM 39 C STP 1 29.026 34.688 103.966 0.00 4.07 43 | ATOM 40 C STP 1 29.002 34.303 103.714 0.00 3.89 44 | ATOM 41 C STP 1 29.868 32.882 102.753 0.00 3.63 45 | ATOM 42 C STP 1 29.863 32.869 102.753 0.00 3.62 46 | ATOM 43 C STP 1 29.847 32.902 102.767 0.00 3.62 47 | ATOM 44 C STP 1 29.437 35.468 103.873 0.00 3.94 48 | ATOM 45 C STP 1 29.764 34.817 104.028 0.00 4.11 49 | ATOM 46 C STP 1 29.899 34.795 104.009 0.00 4.11 50 | ATOM 47 C STP 1 31.693 34.249 105.339 0.00 3.73 51 | ATOM 48 C STP 1 31.333 34.241 105.797 0.00 3.49 52 | ATOM 49 C STP 1 28.775 36.216 103.232 0.00 3.53 53 | ATOM 50 C STP 1 28.934 36.340 103.276 0.00 3.49 54 | ATOM 51 C STP 1 29.044 36.027 103.659 0.00 3.84 55 | ATOM 52 C STP 1 29.068 36.054 103.658 0.00 3.83 56 | ATOM 53 C STP 1 29.483 35.777 103.760 0.00 3.87 57 | ATOM 54 C STP 1 29.517 35.981 103.735 0.00 3.80 58 | ATOM 55 C STP 1 29.637 35.694 103.759 0.00 3.86 59 | ATOM 56 O STP 1 29.648 35.983 103.725 0.00 3.76 60 | ATOM 57 C STP 1 30.753 33.263 102.941 0.00 3.77 61 | ATOM 58 C STP 1 30.509 33.161 102.917 0.00 3.72 62 | ATOM 59 O STP 1 31.028 33.632 103.177 0.00 3.90 63 | ATOM 60 C STP 1 30.931 33.710 103.264 0.00 3.90 64 | ATOM 61 O STP 1 24.895 33.716 107.469 0.00 3.69 65 | ATOM 62 C STP 1 31.689 34.254 105.307 0.00 3.74 66 | ATOM 63 O STP 1 31.145 34.459 104.665 0.00 3.82 67 | ATOM 64 O STP 1 31.094 33.808 103.577 0.00 3.82 68 | ATOM 65 C STP 1 31.722 34.224 105.318 0.00 3.73 69 | ATOM 66 C STP 1 31.828 34.074 105.272 0.00 3.68 70 | ATOM 67 O STP 1 32.238 34.042 105.383 0.00 3.47 71 | ATOM 68 O STP 1 31.885 34.027 105.266 0.00 3.65 72 | ATOM 69 O STP 1 31.107 33.629 103.227 0.00 3.87 73 | ATOM 70 O STP 1 31.614 33.384 103.513 0.00 3.58 74 | ATOM 71 C STP 2 32.182 41.416 112.136 0.00 3.53 75 | ATOM 72 O STP 2 27.717 42.628 109.921 0.00 3.86 76 | ATOM 73 O STP 2 30.860 41.923 113.152 0.00 4.55 77 | ATOM 74 O STP 2 27.790 42.644 110.150 0.00 3.99 78 | ATOM 75 C STP 2 28.097 42.799 111.011 0.00 4.54 79 | ATOM 76 C STP 2 27.958 42.614 110.279 0.00 4.05 80 | ATOM 77 C STP 2 28.056 42.649 110.510 0.00 4.19 81 | ATOM 78 C STP 2 27.863 42.546 109.704 0.00 3.69 82 | ATOM 79 O STP 2 28.973 41.271 111.405 0.00 3.78 83 | ATOM 80 C STP 2 28.081 42.825 111.069 0.00 4.58 84 | ATOM 81 C STP 2 28.002 42.857 111.062 0.00 4.58 85 | ATOM 82 C STP 2 28.121 42.766 111.074 0.00 4.55 86 | ATOM 83 O STP 2 28.418 41.993 111.115 0.00 4.06 87 | ATOM 84 O STP 2 30.452 41.910 112.989 0.00 4.53 88 | ATOM 85 O STP 2 30.112 40.927 112.555 0.00 3.67 89 | ATOM 86 C STP 2 29.290 41.849 111.865 0.00 4.19 90 | ATOM 87 O STP 2 29.160 41.443 111.627 0.00 3.92 91 | ATOM 88 O STP 2 29.599 41.639 111.963 0.00 4.11 92 | ATOM 89 O STP 2 29.489 41.459 111.826 0.00 3.98 93 | TER 94 | END 95 | -------------------------------------------------------------------------------- /notebooks/2z3h_out/pockets/pocket1_atm.pdb: -------------------------------------------------------------------------------- 1 | HEADER 2 | HEADER This is a pdb format file writen by the programm fpocket. 3 | HEADER It represents the atoms contacted by the voronoi vertices of the pocket. 4 | HEADER 5 | HEADER Information about the pocket 1: 6 | HEADER 0 - Pocket Score : 0.3504 7 | HEADER 1 - Drug Score : 0.8715 8 | HEADER 2 - Number of alpha spheres : 70 9 | HEADER 3 - Mean alpha-sphere radius : 3.8276 10 | HEADER 4 - Mean alpha-sphere Solvent Acc. : 0.4599 11 | HEADER 5 - Mean B-factor of pocket residues : 0.3795 12 | HEADER 6 - Hydrophobicity Score : 24.8000 13 | HEADER 7 - Polarity Score : 11 14 | HEADER 8 - Amino Acid based volume Score : 4.0000 15 | HEADER 9 - Pocket volume (Monte Carlo) : 685.9797 16 | HEADER 10 - Pocket volume (convex hull) : 245.6618 17 | HEADER 11 - Charge Score : 0 18 | HEADER 12 - Local hydrophobic density Score : 28.6316 19 | HEADER 13 - Number of apolar alpha sphere : 38 20 | HEADER 14 - Proportion of apolar alpha sphere : 0.5429 21 | ATOM 314 CG ASN A 45 26.203 39.184 104.145 0.00 0.00 C 0 22 | ATOM 207 CG2 VAL A 29 28.584 36.516 107.579 0.00 0.00 C 0 23 | ATOM 316 ND2 ASN A 45 25.186 38.429 103.708 0.65 9.84 N 0 24 | ATOM 313 CB ASN A 45 27.324 39.433 103.184 0.00 0.00 C 0 25 | ATOM 384 SG CYS A 54 29.462 35.137 99.942 0.00 0.00 S 0 26 | ATOM 1772 OH TYR B 126 27.265 31.475 105.744 0.59 2.14 O 0 27 | ATOM 2146 CD2 PHE C 49 25.778 32.652 100.203 0.00 0.00 C 0 28 | ATOM 331 CD2 TYR A 47 21.731 35.473 102.458 0.00 0.00 C 0 29 | ATOM 333 CE2 TYR A 47 21.355 34.151 102.211 0.00 0.00 C 0 30 | ATOM 2130 CA HIS C 48 21.690 28.736 96.995 0.00 0.00 C 0 31 | ATOM 2123 CD1 TYR C 47 18.441 28.861 98.802 0.00 0.00 C 0 32 | ATOM 2136 CD2 HIS C 48 24.268 26.784 98.091 0.00 0.00 C 0 33 | ATOM 2139 N PHE C 49 22.569 30.572 98.302 0.00 0.00 N 0 34 | ATOM 541 OD1 ASN A 79 25.473 36.874 112.744 0.46 1.07 O 0 35 | ATOM 568 CZ ARG A 82 26.992 34.126 111.204 0.00 0.00 C 0 36 | ATOM 1793 CH2 TRP B 128 23.122 30.712 110.552 0.00 0.00 C 0 37 | ATOM 569 NH1 ARG A 82 26.778 35.316 110.679 0.00 0.00 N 0 38 | ATOM 567 NE ARG A 82 27.183 34.021 112.514 0.62 1.09 N 0 39 | ATOM 1792 CZ3 TRP B 128 23.775 30.358 111.747 0.00 0.00 C 0 40 | ATOM 371 O GLY A 52 25.612 37.637 97.573 0.39 2.14 O 0 41 | ATOM 380 CA CYS A 54 29.373 37.889 100.182 0.00 0.00 C 0 42 | ATOM 375 O PRO A 53 27.273 39.635 99.639 0.00 0.00 O 0 43 | ATOM 325 CA TYR A 47 22.035 38.536 101.149 0.00 0.00 C 0 44 | ATOM 205 CB VAL A 29 29.300 37.755 107.084 0.00 0.00 C 0 45 | ATOM 2148 CE2 PHE C 49 27.097 32.261 100.497 0.00 0.00 C 0 46 | ATOM 2145 CD1 PHE C 49 24.992 30.463 100.771 0.00 0.00 C 0 47 | ATOM 1791 CZ2 TRP B 128 23.785 30.730 109.352 0.00 0.00 C 0 48 | ATOM 2143 CB PHE C 49 23.308 32.193 100.039 0.00 0.00 C 0 49 | ATOM 335 OH TYR A 47 19.781 32.580 101.365 0.56 7.50 O 0 50 | ATOM 2125 CE1 TYR C 47 18.544 30.055 99.511 0.00 0.00 C 0 51 | ATOM 1771 CZ TYR B 126 28.422 30.835 105.379 0.00 0.00 C 0 52 | ATOM 2120 O TYR C 47 20.227 26.413 96.911 0.21 6.43 O 0 53 | ATOM 200 OG SER A 28 24.600 37.390 107.646 0.00 0.00 O 0 54 | ATOM 199 CB SER A 28 24.374 38.136 108.841 0.00 0.00 C 0 55 | ATOM 169 O GLU A 25 20.412 37.883 106.280 0.36 7.50 O 0 56 | ATOM 570 NH2 ARG A 82 27.008 33.049 110.420 0.46 2.19 N 0 57 | ATOM 2142 O PHE C 49 23.855 33.848 97.464 0.56 1.07 O 0 58 | ATOM 2149 CZ PHE C 49 27.363 30.972 100.930 0.00 0.00 C 0 59 | ATOM 608 CB CYS A 88 31.343 30.542 100.407 0.00 0.00 C 0 60 | ATOM 2144 CG PHE C 49 24.717 31.760 100.332 0.00 0.00 C 0 61 | ATOM 1769 CE1 TYR B 126 29.648 31.187 105.952 0.00 0.00 C 0 62 | ATOM 206 CG1 VAL A 29 30.811 37.488 106.970 0.00 0.00 C 0 63 | ATOM 397 OE1 GLU A 56 33.052 36.868 102.393 0.86 2.14 O 0 64 | ATOM 587 CB LEU A 85 32.872 32.683 108.516 0.00 0.00 C 0 65 | ATOM 522 CB ALA A 76 34.491 36.497 106.363 0.00 0.00 C 0 66 | ATOM 589 CD1 LEU A 85 30.637 33.611 109.160 0.00 0.00 C 0 67 | ATOM 383 CB CYS A 54 30.319 36.707 100.096 0.00 0.00 C 0 68 | ATOM 385 N ALA A 55 30.414 39.195 101.925 0.88 1.09 N 0 69 | ATOM 604 N CYS A 88 33.497 31.222 101.366 0.61 3.28 N 0 70 | ATOM 1767 CD1 TYR B 126 30.818 30.545 105.546 0.00 0.00 C 0 71 | ATOM 1788 NE1 TRP B 128 26.049 30.318 108.329 0.50 2.19 N 0 72 | ATOM 398 OE2 GLU A 56 34.522 35.447 103.175 0.50 2.14 O 0 73 | ATOM 594 O SER A 86 34.168 31.217 104.781 0.51 1.07 O 0 74 | TER 75 | END 76 | -------------------------------------------------------------------------------- /notebooks/2z3h_out/pockets/pocket1_vert.pqr: -------------------------------------------------------------------------------- 1 | HEADER 2 | HEADER This is a pqr format file writen by the programm fpocket. 3 | HEADER It represent the voronoi vertices of a single pocket found by the 4 | HEADER algorithm. 5 | HEADER 6 | HEADER Information about the pocket 1: 7 | HEADER 0 - Pocket Score : 0.3504 8 | HEADER 1 - Drug Score : 0.8715 9 | HEADER 2 - Number of V. Vertices : 70 10 | HEADER 3 - Mean alpha-sphere radius : 3.8276 11 | HEADER 4 - Mean alpha-sphere SA : 0.4599 12 | HEADER 5 - Mean B-factor : 0.3795 13 | HEADER 6 - Hydrophobicity Score : 24.8000 14 | HEADER 7 - Polarity Score : 11 15 | HEADER 8 - Volume Score : 4.0000 16 | HEADER 9 - Real volume (approximation) : 685.9797 17 | HEADER 10 - Charge Score : 0 18 | HEADER 11 - Local hydrophobic density Score : 28.6316 19 | HEADER 12 - Number of apolar alpha sphere : 38 20 | HEADER 13 - Proportion of apolar alpha sphere : 0.5429 21 | ATOM 1 C STP 1 27.832 35.872 104.021 0.00 3.69 22 | ATOM 2 C STP 1 27.954 35.449 103.706 0.00 4.07 23 | ATOM 3 O STP 1 25.370 34.420 103.783 0.00 4.01 24 | ATOM 4 O STP 1 27.724 35.111 103.740 0.00 4.18 25 | ATOM 5 C STP 1 24.801 33.618 103.765 0.00 3.82 26 | ATOM 6 C STP 1 21.596 27.720 100.725 0.00 3.87 27 | ATOM 7 O STP 1 23.353 34.262 111.342 0.00 3.64 28 | ATOM 8 O STP 1 23.638 33.858 112.345 0.00 3.55 29 | ATOM 9 O STP 1 26.104 36.231 100.922 0.00 3.67 30 | ATOM 10 O STP 1 26.179 36.489 100.894 0.00 3.56 31 | ATOM 11 O STP 1 26.176 36.334 100.914 0.00 3.63 32 | ATOM 12 O STP 1 24.838 36.294 100.826 0.00 3.60 33 | ATOM 13 O STP 1 24.946 36.155 100.851 0.00 3.66 34 | ATOM 14 C STP 1 28.730 35.819 103.728 0.00 3.92 35 | ATOM 15 C STP 1 26.814 35.345 102.503 0.00 3.69 36 | ATOM 16 O STP 1 26.373 34.656 103.564 0.00 3.96 37 | ATOM 17 O STP 1 27.492 34.995 103.579 0.00 4.14 38 | ATOM 18 C STP 1 22.814 30.645 104.838 0.00 4.62 39 | ATOM 19 C STP 1 22.340 30.784 103.285 0.00 3.67 40 | ATOM 20 O STP 1 21.530 29.239 101.689 0.00 3.78 41 | ATOM 21 O STP 1 21.569 29.345 101.655 0.00 3.71 42 | ATOM 22 C STP 1 21.659 27.794 101.499 0.00 4.33 43 | ATOM 23 C STP 1 21.686 27.719 101.468 0.00 4.35 44 | ATOM 24 C STP 1 28.707 33.958 103.385 0.00 3.72 45 | ATOM 25 C STP 1 21.363 27.197 100.342 0.00 3.70 46 | ATOM 26 O STP 1 22.966 34.609 109.925 0.00 3.95 47 | ATOM 27 O STP 1 22.195 33.840 106.546 0.00 4.43 48 | ATOM 28 O STP 1 22.184 33.812 106.570 0.00 4.45 49 | ATOM 29 O STP 1 24.182 34.117 108.716 0.00 3.47 50 | ATOM 30 O STP 1 23.499 33.586 105.919 0.00 4.32 51 | ATOM 31 O STP 1 23.414 33.518 106.010 0.00 4.37 52 | ATOM 32 O STP 1 24.878 34.569 104.793 0.00 4.02 53 | ATOM 33 O STP 1 26.129 35.033 105.102 0.00 3.79 54 | ATOM 34 O STP 1 25.995 34.250 107.350 0.00 3.45 55 | ATOM 35 O STP 1 24.386 35.825 100.260 0.00 3.46 56 | ATOM 36 C STP 1 29.851 32.841 102.748 0.00 3.60 57 | ATOM 37 C STP 1 24.701 33.358 103.737 0.00 3.76 58 | ATOM 38 C STP 1 23.960 32.096 104.019 0.00 3.78 59 | ATOM 39 C STP 1 29.026 34.688 103.966 0.00 4.07 60 | ATOM 40 C STP 1 29.002 34.303 103.714 0.00 3.89 61 | ATOM 41 C STP 1 29.868 32.882 102.753 0.00 3.63 62 | ATOM 42 C STP 1 29.863 32.869 102.753 0.00 3.62 63 | ATOM 43 C STP 1 29.847 32.902 102.767 0.00 3.62 64 | ATOM 44 C STP 1 29.437 35.468 103.873 0.00 3.94 65 | ATOM 45 C STP 1 29.764 34.817 104.028 0.00 4.11 66 | ATOM 46 C STP 1 29.899 34.795 104.009 0.00 4.11 67 | ATOM 47 C STP 1 31.693 34.249 105.339 0.00 3.73 68 | ATOM 48 C STP 1 31.333 34.241 105.797 0.00 3.49 69 | ATOM 49 C STP 1 28.775 36.216 103.232 0.00 3.53 70 | ATOM 50 C STP 1 28.934 36.340 103.276 0.00 3.49 71 | ATOM 51 C STP 1 29.044 36.027 103.659 0.00 3.84 72 | ATOM 52 C STP 1 29.068 36.054 103.658 0.00 3.83 73 | ATOM 53 C STP 1 29.483 35.777 103.760 0.00 3.87 74 | ATOM 54 C STP 1 29.517 35.981 103.735 0.00 3.80 75 | ATOM 55 C STP 1 29.637 35.694 103.759 0.00 3.86 76 | ATOM 56 O STP 1 29.648 35.983 103.725 0.00 3.76 77 | ATOM 57 C STP 1 30.753 33.263 102.941 0.00 3.77 78 | ATOM 58 C STP 1 30.509 33.161 102.917 0.00 3.72 79 | ATOM 59 O STP 1 31.028 33.632 103.177 0.00 3.90 80 | ATOM 60 C STP 1 30.931 33.710 103.264 0.00 3.90 81 | ATOM 61 O STP 1 24.895 33.716 107.469 0.00 3.69 82 | ATOM 62 C STP 1 31.689 34.254 105.307 0.00 3.74 83 | ATOM 63 O STP 1 31.145 34.459 104.665 0.00 3.82 84 | ATOM 64 O STP 1 31.094 33.808 103.577 0.00 3.82 85 | ATOM 65 C STP 1 31.722 34.224 105.318 0.00 3.73 86 | ATOM 66 C STP 1 31.828 34.074 105.272 0.00 3.68 87 | ATOM 67 O STP 1 32.238 34.042 105.383 0.00 3.47 88 | ATOM 68 O STP 1 31.885 34.027 105.266 0.00 3.65 89 | ATOM 69 O STP 1 31.107 33.629 103.227 0.00 3.87 90 | ATOM 70 O STP 1 31.614 33.384 103.513 0.00 3.58 91 | TER 92 | END 93 | -------------------------------------------------------------------------------- /notebooks/2z3h_out/pockets/pocket2_atm.pdb: -------------------------------------------------------------------------------- 1 | HEADER 2 | HEADER This is a pdb format file writen by the programm fpocket. 3 | HEADER It represents the atoms contacted by the voronoi vertices of the pocket. 4 | HEADER 5 | HEADER Information about the pocket 2: 6 | HEADER 0 - Pocket Score : 0.0248 7 | HEADER 1 - Drug Score : 0.0018 8 | HEADER 2 - Number of alpha spheres : 19 9 | HEADER 3 - Mean alpha-sphere radius : 4.1232 10 | HEADER 4 - Mean alpha-sphere Solvent Acc. : 0.6695 11 | HEADER 5 - Mean B-factor of pocket residues : 0.3063 12 | HEADER 6 - Hydrophobicity Score : 40.2500 13 | HEADER 7 - Polarity Score : 3 14 | HEADER 8 - Amino Acid based volume Score : 3.5000 15 | HEADER 9 - Pocket volume (Monte Carlo) : 221.1203 16 | HEADER 10 - Pocket volume (convex hull) : 3.2576 17 | HEADER 11 - Charge Score : 0 18 | HEADER 12 - Local hydrophobic density Score : 8.0000 19 | HEADER 13 - Number of apolar alpha sphere : 9 20 | HEADER 14 - Proportion of apolar alpha sphere : 0.4737 21 | ATOM 529 CG2 ILE A 77 33.796 38.445 113.141 0.00 0.00 C 0 22 | ATOM 527 CB ILE A 77 34.293 38.624 111.699 0.00 0.00 C 0 23 | ATOM 526 O ILE A 77 31.607 38.471 110.281 0.67 3.21 O 0 24 | ATOM 212 CB ALA A 30 32.061 42.598 108.814 0.00 0.00 C 0 25 | ATOM 307 CG1AVAL A 44 27.236 43.527 106.200 0.00 0.00 C 0 26 | ATOM 186 O TYR A 27 24.056 41.907 108.938 0.00 0.00 O 0 27 | ATOM 201 N VAL A 29 27.537 39.283 108.006 0.00 0.00 N 0 28 | ATOM 308 CG2AVAL A 44 29.222 44.754 107.075 0.00 0.00 C 0 29 | ATOM 534 O GLY A 78 30.299 37.434 113.651 0.17 4.29 O 0 30 | ATOM 197 C SER A 28 26.711 38.972 109.006 0.00 0.00 C 0 31 | ATOM 208 N ALA A 30 30.839 40.517 108.208 0.00 0.00 N 0 32 | ATOM 203 C VAL A 29 29.730 40.211 107.536 0.00 0.00 C 0 33 | ATOM 539 CB ASN A 79 26.403 39.070 113.085 0.00 0.00 C 0 34 | ATOM 198 O SER A 28 27.059 38.333 109.999 0.46 1.07 O 0 35 | ATOM 535 N ASN A 79 28.376 37.699 112.480 0.84 1.09 N 0 36 | ATOM 536 CA ASN A 79 27.682 38.376 113.568 0.00 0.00 C 0 37 | TER 38 | END 39 | -------------------------------------------------------------------------------- /notebooks/2z3h_out/pockets/pocket2_vert.pqr: -------------------------------------------------------------------------------- 1 | HEADER 2 | HEADER This is a pqr format file writen by the programm fpocket. 3 | HEADER It represent the voronoi vertices of a single pocket found by the 4 | HEADER algorithm. 5 | HEADER 6 | HEADER Information about the pocket 2: 7 | HEADER 0 - Pocket Score : 0.0248 8 | HEADER 1 - Drug Score : 0.0018 9 | HEADER 2 - Number of V. Vertices : 19 10 | HEADER 3 - Mean alpha-sphere radius : 4.1232 11 | HEADER 4 - Mean alpha-sphere SA : 0.6695 12 | HEADER 5 - Mean B-factor : 0.3063 13 | HEADER 6 - Hydrophobicity Score : 40.2500 14 | HEADER 7 - Polarity Score : 3 15 | HEADER 8 - Volume Score : 3.5000 16 | HEADER 9 - Real volume (approximation) : 221.1203 17 | HEADER 10 - Charge Score : 0 18 | HEADER 11 - Local hydrophobic density Score : 8.0000 19 | HEADER 12 - Number of apolar alpha sphere : 9 20 | HEADER 13 - Proportion of apolar alpha sphere : 0.4737 21 | ATOM 1 C STP 2 32.182 41.416 112.136 0.00 3.53 22 | ATOM 2 O STP 2 27.717 42.628 109.921 0.00 3.86 23 | ATOM 3 O STP 2 30.860 41.923 113.152 0.00 4.55 24 | ATOM 4 O STP 2 27.790 42.644 110.150 0.00 3.99 25 | ATOM 5 C STP 2 28.097 42.799 111.011 0.00 4.54 26 | ATOM 6 C STP 2 27.958 42.614 110.279 0.00 4.05 27 | ATOM 7 C STP 2 28.056 42.649 110.510 0.00 4.19 28 | ATOM 8 C STP 2 27.863 42.546 109.704 0.00 3.69 29 | ATOM 9 O STP 2 28.973 41.271 111.405 0.00 3.78 30 | ATOM 10 C STP 2 28.081 42.825 111.069 0.00 4.58 31 | ATOM 11 C STP 2 28.002 42.857 111.062 0.00 4.58 32 | ATOM 12 C STP 2 28.121 42.766 111.074 0.00 4.55 33 | ATOM 13 O STP 2 28.418 41.993 111.115 0.00 4.06 34 | ATOM 14 O STP 2 30.452 41.910 112.989 0.00 4.53 35 | ATOM 15 O STP 2 30.112 40.927 112.555 0.00 3.67 36 | ATOM 16 C STP 2 29.290 41.849 111.865 0.00 4.19 37 | ATOM 17 O STP 2 29.160 41.443 111.627 0.00 3.92 38 | ATOM 18 O STP 2 29.599 41.639 111.963 0.00 4.11 39 | ATOM 19 O STP 2 29.489 41.459 111.826 0.00 3.98 40 | TER 41 | END 42 | -------------------------------------------------------------------------------- /notebooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/notebooks/__init__.py -------------------------------------------------------------------------------- /sample_crossdock_mols.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | import argparse 5 | 6 | from rdkit import Chem 7 | import torch 8 | import time 9 | import shutil 10 | from scipy.spatial import distance 11 | from rdkit.Chem import rdmolfiles 12 | 13 | from utils.volume_sampling import sample_discrete_number, bin_edges, prob_dist_df 14 | from utils.templates import get_one_hot, get_pocket 15 | from utils.templates import add_hydrogens, extract_hydrogen_coordinates, run_fpocket, extract_values 16 | 17 | from src.lightning_anchor_gnn import AnchorGNN_pl 18 | from src.lightning import AR_DDPM 19 | from src.const import prot_mol_lj_rm, CROSSDOCK_LJ_RM 20 | from src.noise import cosine_beta_schedule 21 | 22 | from analysis.reconstruct_mol import reconstruct_from_generated 23 | #from analysis.vina_docking import VinaDockingTask 24 | from sampling.sample_mols import generate_mols_for_pocket 25 | 26 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9} 27 | idx2atom = {0:'C', 1:'N', 2:'O', 3:'S', 4:'B', 5:'Br', 6:'Cl', 7:'P', 8:'I', 9:'F'} 28 | CROSSDOCK_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'B':5, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15} 29 | pocket_atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3} # only 4 atoms types for pocket 30 | vdws = {'C': 1.7, 'N': 1.55, 'O': 1.52, 'S': 1.8, 'B': 1.92, 'Br': 1.85, 'Cl': 1.75, 'P': 1.8, 'I': 1.98, 'F': 1.47} 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--results-path', type=str, default='results', 34 | help='path to save the results') 35 | parser.add_argument('--data-path', action='store', type=str, default='/srv/home/mahdi.ghorbani/FragDiff/crossdock', 36 | help='path to the test data for generating molecules') 37 | parser.add_argument('--anchor-model', type=str, default='anchor_model.ckpt', 38 | help='path to the anchor model. Note that for guidance, the anchor model should incorporate the conditionals') 39 | parser.add_argument('--n-samples', type=int, default=20, 40 | help='total number of ligands to generate per pocket') 41 | parser.add_argument('--exp-name', type=str, default='exp-1', 42 | help='name of the generation experiment') 43 | parser.add_argument('--diff-model', type=str, default='diff-model.ckpt', 44 | help='path to the diffusion model checkpoint') 45 | parser.add_argument('--device', type=str, default='cuda:0') 46 | parser.add_argument('--rejection-sampling', action='store_true', default=False, help='enable rejection sampling') 47 | 48 | if __name__ == '__main__': 49 | args = parser.parse_args() 50 | torch_device = args.device 51 | anchor_checkpoint = args.anchor_model 52 | data_path = args.data_path 53 | diff_model_checkpoint = args.diff_model 54 | 55 | add_H = True # adding hydrogens to protein for LJ computation 56 | model = AR_DDPM.load_from_checkpoint(diff_model_checkpoint, device=torch_device) 57 | model = model.to(torch_device) 58 | 59 | anchor_model = AnchorGNN_pl.load_from_checkpoint(anchor_checkpoint, device=torch_device) 60 | anchor_model = anchor_model.to(torch_device) 61 | 62 | split = torch.load(data_path + '/' + 'split_by_name.pt') 63 | prefix = data_path + '/crossdocked_pocket10/' 64 | 65 | if not os.path.exists(args.results_path): 66 | print('creating results directory') 67 | 68 | save_dir = args.results_path + '/' + args.exp_name 69 | if not os.path.exists(save_dir): 70 | os.makedirs(save_dir, exist_ok=True) 71 | 72 | for n in range(100): 73 | prot_name = prefix + split['test'][n][0] 74 | lig_name = prefix + split['test'][n][1] 75 | 76 | pocket_onehot, pocket_coords, lig_coords, _ = get_pocket(prot_name, lig_name, atom_dict, pocket_atom_dict=pocket_atom_dict, dist_cutoff=7) 77 | 78 | # --------------- make a grid box around the pocket ---------------- 79 | min_coords = pocket_coords.min(axis=0) - 2.5 # 80 | max_coords = pocket_coords.max(axis=0) + 2.5 81 | 82 | x_range = slice(min_coords[0], max_coords[0] + 1, 1.5) # spheres of radius 1.5 83 | y_range = slice(min_coords[1], max_coords[1] + 1, 1.5) 84 | z_range = slice(min_coords[2], max_coords[2] + 1, 1.5) 85 | 86 | grid = np.mgrid[x_range, y_range, z_range] 87 | grid_points = grid.reshape(3, -1).T # This transposes the grid to a list of coordinates 88 | 89 | # remove grids points not in 3.5A neighborhood of original ligand 90 | distances_mol = distance.cdist(grid_points, lig_coords) 91 | mask_mol = (distances_mol < 3.5).any(axis=1) 92 | filtered_mol_points = grid_points[mask_mol] 93 | 94 | # remove grid points that are close to the pocket 95 | pocket_distances = distance.cdist(filtered_mol_points, pocket_coords) 96 | mask_pocket = (pocket_distances < 2).any(axis=1) 97 | grids = filtered_mol_points[~mask_pocket] 98 | 99 | n_samples = args.n_samples 100 | max_mol_sizes = [] 101 | 102 | fpocket_out = prot_name[:-4] + '_out' 103 | 104 | shutil.rmtree(fpocket_out, ignore_errors=True) 105 | 106 | if add_H: 107 | add_hydrogens(prot_name) 108 | prot_name_with_H = prot_name[:-4] + '_H.pdb' 109 | 110 | H_coords = extract_hydrogen_coordinates(prot_name_with_H) 111 | H_coords = torch.tensor(H_coords).float().to(torch_device) 112 | #print('running fpocket!') 113 | #try: 114 | # run_fpocket(prot_name) 115 | #except: 116 | # print('Error in running fpocket! using random sizes') 117 | # NOTE: using original molecule coordinates for making the grid 118 | 119 | grids = torch.tensor(grids) 120 | all_grids = [] # list of grids 121 | all_H_coords = [] 122 | for i in range(n_samples): 123 | all_grids.append(grids) 124 | all_H_coords.append(H_coords) 125 | 126 | pocket_vol = len(grids) 127 | #if os.path.exists(fpocket_out): 128 | # filename = prot_name[:-4] + '_out/pockets/pocket1_atm.pdb' 129 | # score, drug_score, pocket_volume = extract_values(filename) 130 | #else: 131 | # print('running fpocket!') 132 | # run_fpocket(prot_name) 133 | # filename = prot_name[:-4] + '_out/pockets/pocket1_atm.pdb' 134 | # score, drug_score, pocket_volume = extract_values(filename) 135 | 136 | #print('pocket_volume', pocket_volume) 137 | 138 | for i in range(n_samples): 139 | max_mol_sizes.append(sample_discrete_number(pocket_vol)) 140 | 141 | pocket_onehot = torch.tensor(pocket_onehot).float() 142 | pocket_coords = torch.tensor(pocket_coords).float() 143 | lig_coords = torch.tensor(lig_coords).float() 144 | pocket_size = len(pocket_coords) 145 | 146 | t1 = time.time() 147 | 148 | max_mol_sizes = np.array(max_mol_sizes) 149 | 150 | print('maximum sizes for molecules', max_mol_sizes) 151 | prot_mol_lj_rm = torch.tensor(prot_mol_lj_rm).to(torch_device) 152 | mol_mol_lj_rm = torch.tensor(CROSSDOCK_LJ_RM).to(torch_device) / 100 153 | 154 | lj_weight_scheduler = cosine_beta_schedule(500, s=0.01, raise_to_power=2) 155 | weights = 1 - lj_weight_scheduler 156 | weights = np.clip(weights, a_min=0.1, a_max=1.) 157 | x, h, mol_masks = generate_mols_for_pocket(n_samples=n_samples, 158 | num_frags=8, 159 | pocket_size=pocket_size, 160 | pocket_coords=pocket_coords, 161 | pocket_onehot=pocket_onehot, 162 | lig_coords=lig_coords, 163 | anchor_model=anchor_model, 164 | diff_model=model, 165 | device=torch_device, 166 | return_all=False, 167 | max_mol_sizes=max_mol_sizes, 168 | all_grids=all_grids, 169 | rejection_sampling=args.rejection_sampling, 170 | lj_guidance=True, 171 | prot_mol_lj_rm=prot_mol_lj_rm, 172 | mol_mol_lj_rm=mol_mol_lj_rm, 173 | all_H_coords=all_H_coords, 174 | guidance_weights=weights) 175 | 176 | x = x.cpu().numpy() 177 | h = h.cpu().numpy() 178 | mol_masks = mol_masks.cpu().cpu().numpy() 179 | 180 | # convert to SDF 181 | all_mols = [] 182 | for k in range(len(x)): 183 | mask = mol_masks[k] 184 | h_mol = h[k] 185 | x_mol = x[k][mask.astype(np.bool_)] 186 | 187 | atom_inds = h_mol[mask.astype(np.bool_)].argmax(axis=1) 188 | atom_types = [idx2atom[x] for x in atom_inds] 189 | atomic_nums = [CROSSDOCK_CHARGES[i] for i in atom_types] 190 | 191 | try: 192 | mol_rec = reconstruct_from_generated(x_mol.tolist(), atomic_nums) 193 | Chem.Kekulize(mol_rec) 194 | all_mols.append(mol_rec) 195 | except: 196 | continue 197 | 198 | t2 = time.time() 199 | print('time to generate one is: ', (t2-t1)/n_samples) 200 | save_path = save_dir + '/' + 'pocket_' + str(n) 201 | 202 | # write sdf file of molecules 203 | with rdmolfiles.SDWriter(save_path + '_mols.sdf') as writer: 204 | for mol in all_mols: 205 | if mol: 206 | writer.write(mol) 207 | 208 | np.save(save_path + '_coords.npy', x) 209 | np.save(save_path + '_onehot.npy', h) 210 | np.save(save_path + '_mol_masks.npy', mol_masks) 211 | -------------------------------------------------------------------------------- /sampling/rejection_sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from rdkit import Chem 4 | import torch 5 | from pathlib import Path 6 | from analysis.docking import calculate_qvina2_score 7 | 8 | from analysis.reconstruct_mol import reconstruct_from_generated 9 | from analysis.metrics import is_connected 10 | 11 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9} 12 | idx2atom = {0:'C', 1:'N', 2:'O', 3:'S', 4:'B', 5:'Br', 6:'Cl', 7:'P', 8:'I', 9:'F'} 13 | CROSSDOCK_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'B':5, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15} 14 | pocket_atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3} # only 4 atoms types for pocket 15 | vdws = {'C': 1.7, 'N': 1.55, 'O': 1.52, 'S': 1.8, 'B': 1.92, 'Br': 1.85, 'Cl': 1.75, 'P': 1.8, 'I': 1.98, 'F': 1.47} 16 | 17 | def compute_number_of_clashes(lig_x, lig_h, pocket_x, pocket_h, pocket_H_coords=None, tolerace=0.5, prot_mol_lj_rm=None): 18 | """ 19 | lig_x and lig_h [n_atoms, 3] and [n_atoms] coordinates and atom types of the ligand (only extension atoms) 20 | pocket_x, pocket_h => [N_pocket, 3 or hp] 21 | pocket_H_coords -> [N_pocket, 3] coordinates of the pocket H atoms 22 | """ 23 | 24 | dists = torch.cdist(lig_x, pocket_x, p=2) # [n_lig_atoms, n_pocket_atoms] 25 | dists = torch.where(dists==0, 1e-5, dists) 26 | inds_lig = torch.argmax(lig_h, dim=1) # [n_lig_atoms] 27 | 28 | inds_pocket = torch.argmax(pocket_h, dim=1).long() # [n_pocket_atoms] 29 | rm = prot_mol_lj_rm[inds_lig][:, inds_pocket] # [n_lig_atoms, n_pocket_atoms] 30 | clashes = ((dists + tolerace ) < rm).sum().item() 31 | 32 | dists_h = torch.cdist(lig_x, pocket_H_coords, p=2) 33 | inds_h = torch.ones(len(pocket_H_coords), device=lig_x.device).long() * 10 34 | rm_h = prot_mol_lj_rm[inds_lig][:, inds_h] # [n_lig_atoms, n_pocket_atoms] 35 | clashes_h = ((dists_h + tolerace ) < rm_h).sum().item() 36 | 37 | total_clashes = clashes + clashes_h 38 | return total_clashes 39 | 40 | 41 | def reject_sample(x, h, pocket_x, pocket_h, prot_path=None, rejection_criteria='clashes'): 42 | # NOTE: x and pocket_x must already be translated to COM 43 | # x :torch.Tensor -> [n_atoms, 3] coordiantes of a single molecule 44 | # h :list-> [n_atoms] atom types (eg. 'C', 'N') of a single molecule 45 | atomic_nums = [CROSSDOCK_CHARGES[a] for a in h] 46 | if rejection_criteria == 'qvina': 47 | try: 48 | mol_rec = reconstruct_from_generated(x.tolist(), atomic_nums) 49 | Chem.SanitizeMol(mol_rec) 50 | 51 | if not is_connected(mol_rec): 52 | m_frags = Chem.GetMolFrags(mol_rec, asMols=True, sanitizeFrags=False) 53 | mol_rec = max(m_frags, key=lambda x: x.GetNumAtoms()) 54 | 55 | prot_pdbqt_file = prot_path[:-4] + '.pdbqt' 56 | out_sdf_file = 'mol.sdf' 57 | with Chem.SDWriter(out_sdf_file) as writer: 58 | writer.write(mol_rec) 59 | sdf_file = Path(out_sdf_file) 60 | if not os.path.exists('qvina-path'): 61 | os.mkdir('qvina-path') 62 | score_result = calculate_qvina2_score(prot_pdbqt_file, sdf_file, out_dir='qvina-path', return_rdmol=False, score_only=True) 63 | print('qvina score: ', score_result) 64 | files = os.listdir('qvina-path') 65 | for file in files: 66 | if file.endswith('.sdf') or file.endswith('.pdbqt'): 67 | os.remove(os.path.join('qvina-path', file)) 68 | 69 | except: 70 | score_result = 100 71 | 72 | return score_result 73 | 74 | elif rejection_criteria == 'clashes': 75 | # pocket_x -> [n_atoms, 3] coordiantes of a single pocket 76 | # pocket_h -> [n_atoms] atom types (eg. 'C', 'N') of a single pocket 77 | # x -> [n_atoms, 3] coordiantes of a single molecule 78 | clashes, clashed_ids, clashed_pocket_ids, n_clashes = compute_number_of_clashes(pocket_x, x, pocket_h, h) 79 | return n_clashes 80 | 81 | def compute_lj(lig_x, lig_h, extension_mask, scaffold_mask, pocket_x, pocket_h, pocket_mask, prot_mol_lj_rm, all_H_coords, mol_mol_lj_rm=None): 82 | """ compute the LJ between protein and ligand 83 | lig_x: [B, N, 3] 84 | lig_h: [B, N, hf] 85 | """ 86 | 87 | num_atoms = extension_mask.sum() 88 | 89 | # ------------- ligand - ligand LJ ---------- 90 | mol_mask = (scaffold_mask.bool() | extension_mask.bool()) 91 | N = mol_mask.sum() 92 | 93 | x_mol = lig_x[mol_mask] # [N_mol, 3] 94 | h_mol = lig_h[mol_mask] # [N_mol, hf] 95 | 96 | x = lig_x[extension_mask.bool()] 97 | h = lig_x[extension_mask.bool()] 98 | 99 | dists_mol = torch.cdist(x, x_mol, p=2) # [N_ext, N_mol] 100 | 101 | inds_mol = torch.argmax(h_mol, dim=1) # [N_mol] 102 | inds_ext = torch.argmax(h, dim=1) # [N_ext] 103 | rm_mol = mol_mol_lj_rm[inds_ext][:, inds_mol] # [N_ext, N_mol] 104 | 105 | 106 | dists_mol = torch.where(dists_mol==0.0, 1, dists_mol) 107 | rm_mol = torch.where(rm_mol==0.0, 1, rm_mol) 108 | 109 | 110 | dists_mol = torch.where(dists_mol < 0.5, 0.5, dists_mol) # clamp the distance to 0.1 111 | lj_mol = ((rm_mol / dists_mol) ** 12 - (rm_mol / dists_mol) ** 6) # [N_mol, N_mol] 112 | 113 | lj_lig_lig = lj_mol.sum() / num_atoms 114 | 115 | # --------------- compute the LJ between protein and ligand -------------- 116 | 117 | 118 | pocket_x = pocket_x[pocket_mask.bool()] # [N_p, 3] 119 | pocket_h = pocket_h[pocket_mask.bool()][:, :4] # [N_p, hf] 120 | h_coords = all_H_coords # [N_p, 3] 121 | 122 | # --------------- compute the LJ between protein and ligand -------------- 123 | dists = torch.cdist(x, pocket_x, p=2) 124 | inds_lig = torch.argmax(h, dim=1) # [N_l] 125 | inds_pocket = torch.argmax(pocket_h, dim=1).long() # [N_p] 126 | 127 | rm = prot_mol_lj_rm[inds_lig][:, inds_pocket] # [N_l, N_p] 128 | lj = ((rm / dists) ** 12 - (rm / dists) ** 6) # [N_l, N_p] 129 | lj[torch.isnan(lj)] = 0 130 | 131 | # ------------- compute the loss for h atoms ---------------- 132 | dists_h = torch.cdist(x, h_coords, p=2) 133 | #dists_h = torch.where(dists_h<0.5, 0.5, dists_h) 134 | inds_H = torch.ones(len(h_coords), device=x.device).long() * 10 # index of H is 10 in the table 135 | rm_h = prot_mol_lj_rm[inds_lig][:, inds_H] 136 | lj_h = ((rm_h / dists_h) ** 12 - (rm_h / dists_h) ** 6) # [N_l, N_p] 137 | 138 | lj_h[torch.isnan(lj_h)] = 0 # remove nan values 139 | 140 | lj = lj.sum() 141 | lj_h = lj_h.sum() 142 | 143 | lj_prot_lig = (lj + lj_h) / num_atoms 144 | return lj_prot_lig, lj_lig_lig -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/src/__init__.py -------------------------------------------------------------------------------- /src/anchor_gnn.py: -------------------------------------------------------------------------------- 1 | from src.egnn import GCL, GaussianSmearing 2 | import torch.nn as nn 3 | import torch 4 | from src.egnn import coord2diff 5 | 6 | class MaskedBCEWithLogitsLoss2(torch.nn.Module): 7 | def __init__(self): 8 | super(MaskedBCEWithLogitsLoss2, self).__init__() 9 | self.loss = torch.nn.BCEWithLogitsLoss(reduction='none') 10 | 11 | def forward(self, input, target, scaffold_mask, pocket_mask, is_first_frag_mask): 12 | # TODO: 13 | """ 14 | if_first_frag_mask -> mask for the first fragment (if the fragment is the first the mask is 1) 15 | """ 16 | masked_loss = self.loss(input, target) 17 | masked_loss_1 = masked_loss * (~is_first_frag_mask.bool()) # only for parts that are not the first fragment 18 | masked_loss_1 = masked_loss_1 * scaffold_mask.float() # only for the scaffold atoms 19 | 20 | masked_loss_2 = masked_loss * is_first_frag_mask.bool() # only for parts that are the first fragment 21 | masked_loss_2 = masked_loss_2 * pocket_mask.float() # only for the pocket atoms 22 | 23 | total_masked_loss = (masked_loss_1.sum() / scaffold_mask.sum().float()) + (masked_loss_2.sum() / pocket_mask.sum().float()) 24 | return total_masked_loss 25 | 26 | class MaskedBCEWithLogitsLoss(torch.nn.Module): 27 | def __init__(self): 28 | super(MaskedBCEWithLogitsLoss, self).__init__() 29 | self.loss = torch.nn.BCEWithLogitsLoss(reduction='none') 30 | 31 | def forward(self, input, target, mask): 32 | masked_loss = self.loss(input, target) 33 | masked_loss = masked_loss * mask.float() 34 | return masked_loss.sum() / mask.sum().float() 35 | 36 | class AnchorGNNPocket(nn.Module): 37 | def __init__(self, 38 | lig_nf, # ligand node features 39 | pocket_nf, # pocket node features 40 | joint_nf, # joint number of features 41 | hidden_nf, 42 | out_node_nf, 43 | n_layers, 44 | normalization, 45 | attention=True, 46 | normalization_factor=100, 47 | aggregation_method='sum', 48 | dist_cutoff=7, 49 | gaussian_expansion=False, 50 | num_gaussians=16, 51 | edge_cutoff_ligand=None, 52 | edge_cutoff_pocket=4.5, 53 | edge_cutoff_interaction=4.5 54 | ): 55 | 56 | super(AnchorGNNPocket, self).__init__() 57 | 58 | #in_node_nf = in_node_nf + context_node_nf # adding the context pocket 59 | if gaussian_expansion: 60 | self.gauss_exp = GaussianSmearing(start=0., stop=7., num_gaussians=16) 61 | in_edge_nf = num_gaussians 62 | else: 63 | in_edge_nf = 1 64 | 65 | self.hidden_nf = hidden_nf 66 | self.out_node_nf = out_node_nf 67 | self.n_layers = n_layers 68 | self.normalization = normalization 69 | self.attention = attention 70 | self.dist_cutoff = dist_cutoff 71 | self.normalization_factor = normalization_factor 72 | self.gaussian_expansion = gaussian_expansion 73 | self.num_gaussians = num_gaussians 74 | self.joint_nf = joint_nf 75 | self.edge_cutoff_l = edge_cutoff_ligand 76 | self.edge_cutoff_p = edge_cutoff_pocket 77 | self.edge_cutoff_i = edge_cutoff_interaction 78 | 79 | self.mol_encoder = nn.Sequential( 80 | nn.Linear(lig_nf, joint_nf), 81 | nn.SiLU() 82 | ) 83 | 84 | self.pocket_encoder = nn.Sequential( 85 | nn.Linear(pocket_nf, joint_nf), 86 | nn.SiLU() 87 | ) 88 | 89 | self.embed_both = nn.Linear(joint_nf, self.hidden_nf) 90 | 91 | self.gcl1 = GCL( 92 | input_nf=self.hidden_nf, 93 | output_nf=self.hidden_nf, 94 | hidden_nf=self.hidden_nf, 95 | normalization_factor=normalization_factor, 96 | aggregation_method=aggregation_method, 97 | edges_in_d=in_edge_nf, 98 | activation=nn.ReLU(), 99 | attention=attention, 100 | normalization=normalization 101 | ) 102 | 103 | layers = [] 104 | layers.append(self.gcl1) 105 | for i in range(n_layers - 1): 106 | layer = GCL( 107 | input_nf=self.hidden_nf, 108 | output_nf=self.hidden_nf, 109 | hidden_nf=self.hidden_nf, 110 | normalization_factor=normalization_factor, 111 | aggregation_method='sum', 112 | edges_in_d=in_edge_nf, 113 | activation=nn.ReLU(), 114 | attention=attention, 115 | normalization=normalization 116 | ) 117 | layers.append(layer) 118 | 119 | self.gcl_layers = nn.ModuleList(layers) 120 | self.embedding_out = nn.Linear(self.hidden_nf, self.out_node_nf) 121 | self.lin_out = nn.Linear(self.out_node_nf, 1) 122 | self.act = nn.ReLU() 123 | #self.bce_loss = MaskedBCEWithLogitsLoss() 124 | 125 | def forward(self, mol_x, mol_h, node_mask, pocket_x, pocket_h, pocket_mask): 126 | """ 127 | input: 128 | mol_x: [B, Ns, 3] coordinates of scaffold 129 | mol_h: [B, Ns, nf] onehot of scaffold 130 | node_mask: [B, Ns] masking on the scaffold 131 | pocket_x: [B, Np] coordinates of pocket 132 | pocket_h: [B, NP, nf_p] onehot of pocket 133 | pocket_mask: [B, Np] masking on pocket atoms 134 | output: 135 | h_out: [B, Ns, 1] logits for the scaffold atoms 136 | """ 137 | bs, n_lig_nodes = mol_x.shape[0], mol_x.shape[1] 138 | n_pocket_nodes = pocket_x.shape[1] 139 | node_mask = node_mask.squeeze() 140 | 141 | N = n_lig_nodes + n_pocket_nodes 142 | mol_x = mol_x[node_mask.bool()] # [N_l, 3] 143 | mol_h = mol_h[node_mask.bool()] # [N_l, nf] 144 | 145 | pocket_x = pocket_x[pocket_mask.bool()] # [N_p, 3] 146 | pocket_h = pocket_h[pocket_mask.bool()] # [N_p, nf] 147 | 148 | mol_h = self.mol_encoder(mol_h) # [N_l, joint_nf] 149 | pocket_h = self.pocket_encoder(pocket_h) # [N_p, joint_nf] 150 | 151 | h = torch.cat([mol_h, pocket_h], dim=0) # [N_l+N_p, joint_nf] 152 | x = torch.cat([mol_x, pocket_x], dim=0) # [N_l+N_p, 3] 153 | 154 | batch_mask_ligand = self.get_batch_mask(node_mask.bool(), device=x.device) # [N_l] 155 | batch_mask_pocket = self.get_batch_mask(pocket_mask.bool(), device=x.device) # [N_p] 156 | 157 | edges = self.get_edges_cutoff(batch_mask_ligand, batch_mask_pocket, mol_x, pocket_x) # [2, num_edges] 158 | 159 | h = self.embed_both(h) # [N_l+N_p, hidden_nf] 160 | 161 | distances, _ = coord2diff(x, edges) 162 | if self.gaussian_expansion: 163 | distances = self.gauss_exp(distances) 164 | 165 | for gcl in self.gcl_layers: 166 | h, _ = gcl(h, edges, edge_attr=distances, node_mask=None, edge_mask=None) # [N_l+N_p, hidden_nf] 167 | 168 | h_atoms = h[:len(batch_mask_ligand)] # [N_l, hidden_nf] 169 | h_atoms = self.act(self.embedding_out(h_atoms)) # [N_l, out_node_nf] 170 | h_out = self.lin_out(h_atoms) # [N_l, 1] 171 | 172 | # convert to batch 173 | num_atoms = node_mask.sum(dim=1).int() # [B] 174 | reshaped_h_out = torch.zeros(bs, n_lig_nodes, 1, dtype=h_out.dtype).to(h_out.device) 175 | positions = torch.zeros_like(batch_mask_ligand).to(h_out.device) 176 | for idx in range(bs): 177 | positions[batch_mask_ligand == idx] = torch.arange(num_atoms[idx]).to(x.device) 178 | reshaped_h_out[batch_mask_ligand, positions] = h_out # [B, n_lig_nodes, 1] 179 | 180 | return reshaped_h_out 181 | 182 | def get_edges_cutoff(self, batch_mask_ligand, batch_mask_pocket, x_ligand, x_pocket): 183 | 184 | adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :] 185 | adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :] 186 | adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :] 187 | 188 | if self.edge_cutoff_l is not None: 189 | adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l) 190 | 191 | if self.edge_cutoff_p is not None: 192 | adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p) 193 | 194 | if self.edge_cutoff_i is not None: 195 | adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i) 196 | 197 | adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1), 198 | torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0) 199 | edges = torch.stack(torch.where(adj), dim=0) 200 | 201 | return edges 202 | 203 | @staticmethod 204 | def get_batch_mask(mask, device): 205 | n_nodes = mask.float().sum(dim=1).int() 206 | batch_size = mask.shape[0] 207 | batch_mask = torch.cat([torch.ones(n_nodes[i]) * i for i in range(batch_size)]).long().to(device) 208 | return batch_mask -------------------------------------------------------------------------------- /src/conv_layer.py: -------------------------------------------------------------------------------- 1 | 2 | # Following DiffHopp implementation of GVP https://github.com/jostorge/diffusion-hopping/tree/main 3 | 4 | from abc import ABC 5 | from functools import partial 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | from torch import nn as nn 10 | from torch.nn import functional as F 11 | from torch_geometric.nn import MessagePassing 12 | 13 | from src.dropout import GVPDropout 14 | from src.gvp import GVP, s_V 15 | from src.layer_norm import GVPLayerNorm 16 | 17 | from abc import ABC 18 | from functools import partial 19 | from typing import Optional, Tuple, Union 20 | 21 | import torch 22 | from torch import nn as nn 23 | from torch.nn import functional as F 24 | from torch_geometric.nn import MessagePassing 25 | 26 | class GVPMessagePassing(MessagePassing, ABC): 27 | def __init__( 28 | self, 29 | in_dims: Tuple[int, int], 30 | out_dims: Tuple[int, int], 31 | edge_dims: Tuple[int, int], 32 | hidden_dims: Optional[Tuple[int, int]] = None, 33 | activations=(F.relu, torch.sigmoid), 34 | vector_gate: bool = False, 35 | attention: bool = True, 36 | aggr: str = "add", 37 | normalization_factor: float = 1.0, 38 | ): 39 | super().__init__(aggr) 40 | if hidden_dims is None: 41 | hidden_dims = out_dims 42 | 43 | in_scalar, in_vector = in_dims 44 | hidden_scalar, hidden_vector = hidden_dims 45 | 46 | edge_scalar, edge_vector = edge_dims 47 | 48 | self.out_scalar, self.out_vector = out_dims 49 | self.in_vector = in_vector 50 | self.hidden_scalar = hidden_scalar 51 | self.hidden_vector = hidden_vector 52 | self.normalization_factor = normalization_factor 53 | 54 | GVP_ = partial(GVP, activations=activations, vector_gate=vector_gate) 55 | self.edge_gvps = nn.Sequential( 56 | GVP_( 57 | (2 * in_scalar + edge_scalar, 2 * in_vector + edge_vector), 58 | hidden_dims, 59 | ), 60 | GVP_(hidden_dims, hidden_dims), 61 | GVP_(hidden_dims, out_dims, activations=(None, None)), 62 | ) 63 | 64 | self.attention = attention 65 | if attention: 66 | self.attention_gvp = GVP_( 67 | out_dims, 68 | (1, 0), 69 | activations=(torch.sigmoid, None), 70 | ) 71 | 72 | def forward(self, x: s_V, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> s_V: 73 | s, V = x 74 | v_dim = V.shape[-1] 75 | V = torch.flatten(V, start_dim=-2, end_dim=-1) 76 | return self.propagate(edge_index, s=s, V=V, edge_attr=edge_attr, v_dim=v_dim) 77 | 78 | def message(self, s_i, s_j, V_i, V_j, edge_attr, v_dim): 79 | V_i = V_i.view(*V_i.shape[:-1], self.in_vector, v_dim) 80 | V_j = V_j.view(*V_j.shape[:-1], self.in_vector, v_dim) 81 | edge_scalar, edge_vector = edge_attr 82 | 83 | s = torch.cat([s_i, s_j, edge_scalar], dim=-1) 84 | V = torch.cat([V_i, V_j, edge_vector], dim=-2) 85 | s, V = self.edge_gvps((s, V)) 86 | 87 | if self.attention: 88 | att = self.attention_gvp((s, V)) 89 | s, V = att * s, att[..., None] * V 90 | return self._combine(s, V) 91 | 92 | def update(self, aggr_out: torch.Tensor) -> s_V: 93 | s_aggr, V_aggr = self._split(aggr_out, self.out_scalar, self.out_vector) 94 | if self.aggr == "add" or self.aggr == "sum": 95 | s_aggr = s_aggr / self.normalization_factor 96 | V_aggr = V_aggr / self.normalization_factor 97 | return s_aggr, V_aggr 98 | 99 | @staticmethod 100 | def _combine(s, V) -> torch.Tensor: 101 | V = torch.flatten(V, start_dim=-2, end_dim=-1) 102 | return torch.cat([s, V], dim=-1) 103 | 104 | @staticmethod 105 | def _split(s_V: torch.Tensor, scalar: int, vector: int) -> s_V: 106 | s = s_V[..., :scalar] 107 | V = s_V[..., scalar:] 108 | V = V.view(*V.shape[:-1], vector, -1) 109 | return s, V 110 | 111 | def reset_parameters(self): 112 | for gvp in self.edge_gvps: 113 | gvp.reset_parameters() 114 | if self.attention: 115 | self.attention_gvp.reset_parameters() 116 | 117 | class GVPConvLayer(GVPMessagePassing, ABC): 118 | def __init__( 119 | self, 120 | node_dims: Tuple[int, int], 121 | edge_dims: Tuple[int, int], 122 | drop_rate: float = 0.0, 123 | activations=(F.relu, torch.sigmoid), 124 | vector_gate: bool = False, 125 | residual: bool = True, 126 | attention: bool = True, 127 | aggr: str = "add", 128 | normalization_factor: float = 1.0, 129 | ): 130 | super().__init__( 131 | node_dims, 132 | node_dims, 133 | edge_dims, 134 | hidden_dims=node_dims, 135 | activations=activations, 136 | vector_gate=vector_gate, 137 | attention=attention, 138 | aggr=aggr, 139 | normalization_factor=normalization_factor, 140 | ) 141 | self.residual = residual 142 | self.drop_rate = drop_rate 143 | GVP_ = partial(GVP, activations=activations, vector_gate=vector_gate) 144 | self.norm = nn.ModuleList([GVPLayerNorm(node_dims) for _ in range(2)]) 145 | self.dropout = nn.ModuleList([GVPDropout(drop_rate) for _ in range(2)]) 146 | 147 | self.ff_func = nn.Sequential( 148 | GVP_(node_dims, node_dims), 149 | GVP_(node_dims, node_dims, activations=(None, None)), 150 | ) 151 | self.residual = residual 152 | 153 | def forward( 154 | self, 155 | x: Union[s_V, torch.Tensor], 156 | edge_index: torch.Tensor, 157 | edge_attr: torch.Tensor, 158 | ) -> s_V: 159 | 160 | s, V = super().forward(x, edge_index, edge_attr) 161 | if self.residual: 162 | s, V = self.dropout[0]((s, V)) 163 | s, V = x[0] + s, x[1] + V 164 | s, V = self.norm[0]((s, V)) 165 | 166 | x = (s, V) 167 | s, V = self.ff_func(x) 168 | 169 | if self.residual: 170 | s, V = self.dropout[1]((s, V)) 171 | s, V = s + x[0], V + x[1] 172 | s, V = self.norm[1]((s, V)) 173 | 174 | return s, V -------------------------------------------------------------------------------- /src/dropout.py: -------------------------------------------------------------------------------- 1 | # Following diffhopp implementation of GVP https://github.com/jostorge/diffusion-hopping/tree/main 2 | 3 | from typing import Union, Tuple 4 | 5 | import torch 6 | from torch import nn as nn 7 | 8 | s_V = Tuple[torch.Tensor, torch.Tensor] 9 | 10 | class GVPDropout(nn.Module): 11 | def __init__(self, p: float=0.5) -> None: 12 | super().__init__() 13 | self.dropout_features = nn.Dropout(p) 14 | self.dropout_vector = nn.Dropout1d(p) 15 | 16 | def forward(self, x: Union[torch.Tensor, s_V]) -> Union[torch.Tensor, s_V]: 17 | if isinstance(x, torch.Tensor): 18 | return self.dropout_features(x) 19 | 20 | s, V = x 21 | s = self.dropout_features(s) 22 | V = self.dropout_vector(V) 23 | return s, V -------------------------------------------------------------------------------- /src/dynamics_gvp.py: -------------------------------------------------------------------------------- 1 | # Following DiffHopp implementation of GVP: https://github.com/jostorge/diffusion-hopping/tree/main 2 | 3 | import torch.nn as nn 4 | import torch 5 | import numpy as np 6 | from src.gvp_model import GVPNetwork 7 | 8 | class DynamicsWithPockets(nn.Module): 9 | def __init__( 10 | self, n_dims, lig_nf, pocket_nf, context_node_nf=3, joint_nf=32, hidden_nf=128, activation=nn.SiLU(), 11 | n_layers=4, attention=False, condition_time=True, tanh=False, normalization_factor=100, model='gvp', 12 | centering=False, edge_cutoff=7, edge_cutoff_interaction=4.5, edge_cutoff_pocket=4.5, edge_cutoff_ligand=None 13 | ): 14 | super().__init__() 15 | 16 | self.edge_cutoff_l = edge_cutoff_ligand 17 | self.edge_cutoff_p = edge_cutoff_pocket 18 | self.edge_cutoff_i = edge_cutoff_interaction 19 | 20 | self.atom_encoder = nn.Sequential( 21 | nn.Linear(lig_nf, joint_nf), 22 | ) 23 | 24 | self.pocket_encoder = nn.Sequential( 25 | nn.Linear(pocket_nf, joint_nf), 26 | ) 27 | 28 | self.atom_decoder = nn.Sequential( 29 | nn.Linear(joint_nf, lig_nf), 30 | ) 31 | 32 | if condition_time: 33 | dynamics_node_nf = joint_nf + 1 34 | else: 35 | print('Warning: dynamics moddel is _not_ conditioned on time') 36 | dynamics_node_nf = joint_nf 37 | 38 | self.dynamics = GVPNetwork( 39 | in_dims=(dynamics_node_nf + context_node_nf, 0), # (scalar_features, vector_features) 40 | out_dims=(joint_nf, 1), 41 | hidden_dims=(hidden_nf, hidden_nf//2), 42 | vector_gate=True, 43 | num_layers=n_layers, 44 | attention=attention, 45 | normalization_factor=normalization_factor, 46 | ) # other parameters are default 47 | 48 | self.n_dims = n_dims 49 | self.condition_time = condition_time 50 | self.centering = centering 51 | self.context_node_nf = context_node_nf 52 | self.edge_cutoff = edge_cutoff 53 | self.model = model 54 | 55 | def forward(self, t, xh, pocket_xh, extension_mask, scaffold_mask, anchors, pocket_anchors, pocket_mask): 56 | """ 57 | input: 58 | t: timestep: [B] 59 | xh: ligand atoms (noised) [B, N_l, h_l+3] 60 | pocket_xh: pocket atoms (no noised added) [B, N_p, h_p + 3] 61 | extension_masks: mask on fragment extension atoms [B, N] 62 | scaffold_masks: mask on scaffold atoms [B, N] 63 | anchor_masks: mask on anchor atoms [B, N] 64 | pocket_masks: masking on all the pocket atoms [B, N_p] 65 | output: 66 | (x_out,h_out) for ligand 67 | """ 68 | bs, n_lig_nodes = xh.shape[0], xh.shape[1] 69 | n_pocket_nodes = pocket_xh.shape[1] 70 | 71 | N = n_lig_nodes + n_pocket_nodes 72 | 73 | node_mask = (scaffold_mask.bool() | extension_mask.bool()) # [B, N_l] 74 | xh = xh[node_mask] # [N_l, h_l+3] 75 | pocket_xh = pocket_xh[pocket_mask.bool()] # [N_p, h_p+3] 76 | 77 | x_atoms = xh[:, :self.n_dims].clone() # [N_l,3] 78 | h_atoms = xh[:, self.n_dims:].clone() # [N_l,nf] 79 | 80 | x_pocket = pocket_xh[:, :self.n_dims].clone() # [N_p, 3] 81 | h_pocket = pocket_xh[:, self.n_dims:].clone() # [N_p, hp] 82 | 83 | h_atoms = self.atom_encoder(h_atoms) # [N_l, joint_nf] 84 | h_pocket = self.pocket_encoder(h_pocket) # [N_p, joint_nf] 85 | 86 | x = torch.cat((x_atoms, x_pocket), dim=0) # [N_l+N_p, 3] 87 | h = torch.cat((h_atoms, h_pocket), dim=0) # [N_l+N_p, joint_nf] 88 | 89 | batch_mask_ligand = self.get_batch_mask(node_mask, device=x.device) # [N_l] 90 | batch_mask_pocket = self.get_batch_mask(pocket_mask, device=x.device) # [N_p] 91 | mask = torch.cat([batch_mask_ligand, batch_mask_pocket], dim=0) # [N_l+N_p] 92 | 93 | new_anchor_mask = torch.cat([anchors[node_mask], pocket_anchors[pocket_mask.bool()]], dim=0).unsqueeze(-1) 94 | new_scaffold_msak = torch.cat([scaffold_mask[node_mask], torch.zeros_like(batch_mask_pocket, device=xh.device)], dim=0).unsqueeze(-1) 95 | new_pocket_mask = torch.cat([torch.zeros_like(batch_mask_ligand, device=xh.device), torch.ones_like(batch_mask_pocket)], dim=0).unsqueeze(-1) 96 | 97 | h = torch.cat([h, new_anchor_mask, new_scaffold_msak, new_pocket_mask], dim=1) # [N_l+N_p, joint_nf+3] 98 | 99 | if self.condition_time: 100 | if np.prod(t.size()) == 1: 101 | # t is the same for all elements in batch. 102 | h_time = torch.empty_like(h[:, 0:1]).fill_(t.item()) 103 | else: 104 | # t is different over the batch dimension. 105 | h_time = t[mask] 106 | h = torch.cat([h, h_time], dim=1) 107 | 108 | edges = self.get_edges_cutoff(batch_mask_ligand, batch_mask_pocket, x_atoms, x_pocket) # [2, num_edges] 109 | assert torch.all(mask[edges[0]] == mask[edges[1]]) 110 | 111 | # --------------- apply the GVP dynamics ---------- 112 | h_final, pos_out = self.dynamics(h, x, edges) # [N_l+N_p, joint_nf], [N_l+N_p, 3] 113 | pos_out = pos_out.reshape(-1,3) # [N_l+N_p, 3] 114 | 115 | # decode atoms 116 | h_final_atoms = self.atom_decoder(h_final[:len(batch_mask_ligand)]) # [N_l, h_l] 117 | 118 | vel_ligand = pos_out[:len(batch_mask_ligand)] # [N_l, 3] 119 | vel_h_ligand = torch.cat([vel_ligand, h_final_atoms], dim=1) # [N_l, h_l+3] 120 | 121 | # convert to batch 122 | num_atoms = node_mask.sum(dim=1).int() # [B] 123 | reshaped_vel_h = torch.zeros(bs, n_lig_nodes, vel_h_ligand.shape[-1]).to(xh.device) 124 | positions = torch.zeros_like(batch_mask_ligand).to(xh.device) 125 | for idx in range(bs): 126 | positions[batch_mask_ligand == idx] = torch.arange(num_atoms[idx]).to(xh.device) 127 | reshaped_vel_h[batch_mask_ligand, positions] = vel_h_ligand 128 | 129 | return reshaped_vel_h # [B, N_l, h_l+3] 130 | 131 | @staticmethod 132 | def get_dist_edges(x, node_mask, batch_mask): 133 | node_mask = node_mask.squeeze().bool() 134 | batch_adj = (batch_mask[:, None] == batch_mask[None, :]) 135 | nodes_adj = (node_mask[:, None] & node_mask[None, :]) 136 | dists_adj = (torch.cdist(x, x) <= 7) 137 | rm_self_loops = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device) 138 | adj = batch_adj & nodes_adj & dists_adj & rm_self_loops 139 | edges = torch.stack(torch.where(adj)) 140 | return edges 141 | 142 | def get_edges_cutoff(self, batch_mask_ligand, batch_mask_pocket, x_ligand, x_pocket): 143 | 144 | adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :] 145 | adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :] 146 | adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :] 147 | 148 | if self.edge_cutoff_l is not None: 149 | adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l) 150 | 151 | if self.edge_cutoff_p is not None: 152 | adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p) 153 | 154 | if self.edge_cutoff_i is not None: 155 | adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i) 156 | 157 | adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1), 158 | torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0) 159 | edges = torch.stack(torch.where(adj), dim=0) 160 | return edges 161 | 162 | @staticmethod 163 | def get_batch_mask(mask, device): 164 | n_nodes = mask.float().sum(dim=1).int() 165 | batch_size = mask.shape[0] 166 | batch_mask = torch.cat([torch.ones(n_nodes[i]) * i for i in range(batch_size)]).long().to(device) 167 | return batch_mask -------------------------------------------------------------------------------- /src/extension_size.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.distributions.categorical import Categorical 6 | 7 | class DistributionNodes: 8 | def __init__(self, histogram): 9 | 10 | self.n_nodes = [] 11 | prob = [] 12 | self.keys = {} 13 | for i, nodes in enumerate(histogram): 14 | self.n_nodes.append(nodes) 15 | self.keys[nodes] = i 16 | prob.append(histogram[nodes]) 17 | self.n_nodes = torch.tensor(self.n_nodes) 18 | prob = np.array(prob) 19 | prob = prob / np.sum(prob) 20 | 21 | self.prob = torch.from_numpy(prob).float() 22 | self.m = Categorical(torch.tensor(prob)) 23 | 24 | def sample(self, n_samples=1): 25 | idx = self.m.sample((n_samples,)) 26 | return self.n_nodes[idx] 27 | 28 | def log_prob(self, batch_n_nodes): 29 | assert len(batch_n_nodes.size()) == 1 30 | 31 | idcs = [self.keys[i.item()] for i in batch_n_nodes] 32 | idcs = torch.tensor(idcs).to(batch_n_nodes.device) 33 | 34 | log_p = torch.log(self.prob + 1e-30) 35 | log_p = log_p.to(batch_n_nodes.device) 36 | 37 | log_probs = log_p[idcs] 38 | 39 | return log_probs 40 | 41 | 42 | -------------------------------------------------------------------------------- /src/fragment_size_gnn.py: -------------------------------------------------------------------------------- 1 | from src.egnn import GCL, GaussianSmearing 2 | import torch.nn as nn 3 | import torch 4 | from src.egnn import coord2diff 5 | from torch_scatter import scatter_mean 6 | 7 | class FragSizeGNN(nn.Module): 8 | def __init__(self, 9 | lig_nf, 10 | pocket_nf, 11 | joint_nf, 12 | hidden_nf, 13 | out_node_nf, # number of classes (fragment sizes) 14 | n_layers, 15 | normalization=True, 16 | attention=True, 17 | normalization_factor=100, 18 | aggregation_method='sum', 19 | edge_cutoff_ligand=None, 20 | edge_cutoff_pocket=5, 21 | edge_cutoff_interaction=5, 22 | dataset_type='CrossDock', 23 | gaussian_expansion=True, 24 | num_gaussians=16): 25 | super(FragSizeGNN, self).__init__() 26 | 27 | self.dataset_type = dataset_type 28 | if self.dataset_type == 'CrossDock': 29 | context_node_nf = 3 # mask on the pocket atoms and anchor points 30 | 31 | if gaussian_expansion: 32 | self.gauss_exp = GaussianSmearing(start=0., stop=5., num_gaussians=num_gaussians) 33 | in_edge_nf = num_gaussians 34 | 35 | self.hidden_nf = hidden_nf 36 | self.out_node_nf = out_node_nf 37 | self.n_layers = n_layers 38 | self.normalization = normalization 39 | self.attention = attention 40 | self.normalization_factor = normalization_factor 41 | self.gaussian_expansion = gaussian_expansion 42 | self.edge_cutoff_l = edge_cutoff_ligand 43 | self.edge_cutoff_p = edge_cutoff_pocket 44 | self.edge_cutoff_i = edge_cutoff_interaction 45 | 46 | self.mol_encoder = nn.Sequential( 47 | nn.Linear(lig_nf, joint_nf), 48 | ) 49 | 50 | self.pocket_encoder = nn.Sequential( 51 | nn.Linear(pocket_nf, joint_nf), 52 | ) 53 | 54 | self.embed_both = nn.Linear(joint_nf+context_node_nf, hidden_nf) # concatenate the context features to joint space 55 | 56 | self.gcl1 = GCL( 57 | input_nf=self.hidden_nf, 58 | output_nf=self.hidden_nf, 59 | hidden_nf=self.hidden_nf, 60 | normalization_factor=normalization_factor, 61 | aggregation_method=aggregation_method, 62 | edges_in_d=in_edge_nf, 63 | activation=nn.ReLU(), 64 | attention=attention, 65 | normalization=normalization 66 | ) 67 | 68 | layers = [] 69 | for i in range(n_layers - 1): 70 | layer = GCL( 71 | input_nf=self.hidden_nf, 72 | output_nf=self.hidden_nf, 73 | hidden_nf=self.hidden_nf, 74 | normalization_factor=normalization_factor, 75 | aggregation_method=aggregation_method, 76 | edges_in_d=in_edge_nf, 77 | activation=nn.ReLU(), 78 | attention=attention, 79 | normalization=normalization 80 | ) 81 | layers.append(layer) 82 | 83 | self.gcl_layers = nn.ModuleList(layers) 84 | self.embedding_out = nn.Linear(self.hidden_nf, self.out_node_nf) 85 | self.act = nn.ReLU() 86 | 87 | self.edge_cache = {} 88 | #self.lin_out = nn.Linear(self.out_node_nf, 1) 89 | 90 | def forward(self, mol_x, mol_h, node_mask, pocket_x, pocket_h, pocket_mask, anchors, pocket_anchors): 91 | """ 92 | mol_x: [B, N, 3] positions of scaffold atoms 93 | mol_h: [B, N, nf] onehot of scaffold atoms 94 | node_mask: [B, N] only for scaffold-based 95 | pocket_x: [B, N, 3] positions of pocket atoms 96 | pocket_h: [B, N, nf] onehot of pocket atoms 97 | anchors: [B, N, 3] positions of anchor points 98 | pocket_anchors: [B, N, 3] positions of anchor points 99 | """ 100 | bs, n_nodes_lig = mol_x.shape[0], mol_x.shape[1] 101 | n_nodes_pocket = pocket_x.shape[1] 102 | node_mask = node_mask.squeeze() 103 | 104 | N = n_nodes_lig + n_nodes_pocket 105 | mol_x = mol_x[node_mask.bool()] # [N_l, 3] 106 | mol_h = mol_h[node_mask.bool()] # [N_l, nf] 107 | 108 | pocket_x = pocket_x[pocket_mask.bool()] # [N_p, 3] 109 | pocket_h = pocket_h[pocket_mask.bool()] # [N_p, nf] 110 | 111 | mol_h = self.mol_encoder(mol_h) # [N_l, joint_nf] 112 | pocket_h = self.pocket_encoder(pocket_h) 113 | 114 | h = torch.cat([mol_h, pocket_h], dim=0) # [N, joint_nf] 115 | 116 | batch_mask_ligand = self.get_batch_mask(node_mask, device=mol_x.device) # [N_l] 117 | batch_mask_pocket = self.get_batch_mask(pocket_mask, device=mol_x.device) # [N_p] 118 | new_anchor_mask = torch.cat([anchors[node_mask.bool()], pocket_anchors[pocket_mask.bool()]], dim=0).unsqueeze(-1) 119 | new_scaffold_mask = torch.cat([torch.ones_like(batch_mask_ligand, device=mol_x.device), torch.zeros_like(batch_mask_pocket)], dim=0).unsqueeze(-1) 120 | new_pocket_mask = torch.cat([torch.zeros_like(batch_mask_ligand), torch.ones_like(batch_mask_pocket)], dim=0).unsqueeze(-1) 121 | 122 | h = torch.cat([h, new_anchor_mask, new_scaffold_mask, new_pocket_mask], dim=1) # [N, joint_nf+2] 123 | x = torch.cat([mol_x, pocket_x], dim=0) # [N, 3] 124 | 125 | mask = torch.cat([batch_mask_ligand, batch_mask_pocket], dim=0) # [N] 126 | device = mol_x.device 127 | 128 | h = self.embed_both(h) 129 | edges = self.get_edges_cutoff(batch_mask_ligand, batch_mask_pocket, mol_x, pocket_x) # [2, E] 130 | 131 | # selected only edges based on a 7A distance (all protein and scaffold atoms considered) 132 | distances, _ = coord2diff(x, edges) # TODO: consider adding more edge info such as the type of bond 133 | if self.gaussian_expansion: 134 | distances = self.gauss_exp(distances) 135 | 136 | for gcl in self.gcl_layers: 137 | h, _ = gcl(h, edges, edge_attr=distances, node_mask=None, edge_mask=None) 138 | 139 | h_final = self.act(self.embedding_out(h)) # [N, out_node_nf] 140 | 141 | # convert to batch 142 | #out = scatter_mean(h_final, mask, dim=0, dim_size=bs) # [B, out_node_nf] 143 | num_atoms = node_mask.sum(dim=1).int() + pocket_mask.sum(dim=1).int() 144 | reshaped_out = torch.zeros(bs, N, h_final.shape[-1], dtype=h.dtype, device=h.device) 145 | positions = torch.zeros_like(mask).to(h.device) 146 | for idx in range(bs): 147 | positions[mask == idx] = torch.arange(num_atoms[idx], device=h.device) 148 | reshaped_out[mask, positions] = h_final 149 | return reshaped_out # [B, N, out_node_nf] 150 | 151 | def get_edges_cutoff(self, batch_mask_ligand, batch_mask_pocket, x_ligand, x_pocket): 152 | 153 | adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :] 154 | adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :] 155 | adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :] 156 | 157 | if self.edge_cutoff_l is not None: 158 | adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l) 159 | 160 | if self.edge_cutoff_p is not None: 161 | adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p) 162 | 163 | if self.edge_cutoff_i is not None: 164 | adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i) 165 | 166 | adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1), 167 | torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0) 168 | edges = torch.stack(torch.where(adj), dim=0) 169 | return edges 170 | 171 | @staticmethod 172 | def get_batch_mask(mask, device): 173 | n_nodes = mask.float().sum(dim=1).int() 174 | batch_size = mask.shape[0] 175 | batch_mask = torch.cat([torch.ones(n_nodes[i]) * i for i in range(batch_size)]).long().to(device) 176 | return batch_mask -------------------------------------------------------------------------------- /src/gvp.py: -------------------------------------------------------------------------------- 1 | # GVP implementation from DiffHopp https://github.com/jostorge/diffusion-hopping/tree/main 2 | 3 | import math 4 | from typing import Tuple, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | s_V = Tuple[torch.Tensor, torch.Tensor] 11 | 12 | 13 | # Relevant papers: 14 | # Learning from Protein Structure with Geometric Vector Perceptrons, 15 | # Equivariant Graph Neural Networks for 3D Macromolecular Structure, 16 | class GVP(nn.Module): 17 | def __init__( 18 | self, 19 | in_dims: Tuple[int, int], 20 | out_dims: Tuple[int, int], 21 | activations=(F.relu, torch.sigmoid), 22 | vector_gate: bool = False, 23 | eps: float = 1e-4, 24 | ) -> None: 25 | super().__init__() 26 | in_scalar, in_vector = in_dims 27 | out_scalar, out_vector = out_dims 28 | self.sigma, self.sigma_plus = activations 29 | 30 | if self.sigma is None: 31 | self.sigma = nn.Identity() 32 | if self.sigma_plus is None: 33 | self.sigma_plus = nn.Identity() 34 | 35 | self.h = max(in_vector, out_vector) 36 | self.W_h = nn.Parameter(torch.empty((self.h, in_vector))) 37 | self.W_mu = nn.Parameter(torch.empty((out_vector, self.h))) 38 | 39 | self.W_m = nn.Linear(self.h + in_scalar, out_scalar) 40 | self.v = in_vector 41 | self.mu = out_vector 42 | self.n = in_scalar 43 | self.m = out_scalar 44 | self.vector_gate = vector_gate 45 | 46 | if vector_gate: 47 | self.sigma_g = nn.Sigmoid() 48 | self.W_g = nn.Linear(out_scalar, out_vector) 49 | 50 | self.eps = eps 51 | self.reset_parameters() 52 | 53 | def reset_parameters(self): 54 | torch.nn.init.kaiming_uniform_(self.W_h, a=math.sqrt(5)) 55 | torch.nn.init.kaiming_uniform_(self.W_mu, a=math.sqrt(5)) 56 | self.W_m.reset_parameters() 57 | if self.vector_gate: 58 | self.W_g.reset_parameters() 59 | 60 | def forward(self, x: Union[torch.Tensor, s_V]) -> Union[torch.Tensor, s_V]: 61 | """Geometric vector perceptron""" 62 | s, V = ( 63 | x if self.v > 0 else (x, torch.empty((x.shape[0], 0, 3), device=x.device)) 64 | ) 65 | 66 | assert ( 67 | s.shape[-1] == self.n 68 | ), f"{s.shape[-1]} != {self.n} Scalar dimension mismatch" 69 | assert ( 70 | V.shape[-2] == self.v 71 | ), f" {V.shape[-2]} != {self.v} Vector dimension mismatch" 72 | assert V.shape[0] == s.shape[0], "Batch size mismatch" 73 | 74 | V_h = self.W_h @ V 75 | V_mu = self.W_mu @ V_h 76 | s_h = torch.clip(torch.norm(V_h, dim=-1), min=self.eps) 77 | s_hn = torch.cat([s, s_h], dim=-1) 78 | s_m = self.W_m(s_hn) 79 | s_dash = self.sigma(s_m) 80 | if self.vector_gate: 81 | V_dash = self.sigma_g(self.W_g(self.sigma_plus(s_m)))[..., None] * V_mu 82 | else: 83 | v_mu = torch.clip(torch.norm(V_mu, dim=-1, keepdim=True), min=self.eps) 84 | V_dash = self.sigma_plus(v_mu) * V_mu 85 | return (s_dash, V_dash) if self.mu > 0 else s_dash -------------------------------------------------------------------------------- /src/gvp_model.py: -------------------------------------------------------------------------------- 1 | # GVP implementation from DiffHopp https://github.com/jostorge/diffusion-hopping/tree/main 2 | from typing import Tuple, Union, Optional 3 | 4 | import torch 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | from src.conv_layer import GVPConvLayer 9 | from src.gvp import GVP, s_V 10 | from src.layer_norm import GVPLayerNorm 11 | 12 | class GVPNetwork(nn.Module): 13 | def __init__( 14 | self, 15 | in_dims: Tuple[int, int], 16 | out_dims: Tuple[int, int], 17 | hidden_dims: Tuple[int, int], 18 | num_layers: int, 19 | attention: bool = False, 20 | normalization_factor: float=100.0, 21 | aggr: str = "add", 22 | activations=(F.silu, None), 23 | vector_gate: bool = True, 24 | eps=1e-4 25 | ) -> None: 26 | super().__init__() 27 | edge_dims = (1,1) 28 | 29 | self.eps = eps 30 | self.embedding_in = nn.Sequential( 31 | GVPLayerNorm(in_dims), 32 | GVP( 33 | in_dims, 34 | hidden_dims, 35 | activations=(None,None), 36 | vector_gate=vector_gate 37 | ), 38 | ) 39 | self.embedding_out = nn.Sequential( 40 | GVPLayerNorm(hidden_dims), 41 | GVP( 42 | hidden_dims, 43 | out_dims, 44 | activations=activations, 45 | vector_gate=vector_gate 46 | ), 47 | ) 48 | self.edge_embedding = nn.Sequential( 49 | GVPLayerNorm(edge_dims), 50 | GVP( 51 | edge_dims, 52 | (hidden_dims[0],1), 53 | activations=(None, None), 54 | vector_gate=vector_gate 55 | ) 56 | ) 57 | 58 | self.layers = nn.ModuleList( 59 | [ 60 | GVPConvLayer( 61 | hidden_dims, 62 | (hidden_dims[0], 1), 63 | activations=activations, 64 | vector_gate=vector_gate, 65 | residual=True, 66 | attention=attention, 67 | aggr=aggr, 68 | normalization_factor=normalization_factor, 69 | ) 70 | for _ in range(num_layers) 71 | ] 72 | ) 73 | 74 | def get_edge_attr(self, edge_index, pos) -> s_V: 75 | V = pos[edge_index[0]] - pos[edge_index[1]] # [n_edges, 3] 76 | s = torch.linalg.norm(V, dim=-1, keepdim=True) # [n_edges, 1] 77 | V = (V / torch.clip(s, min=self.eps))[..., None, :] # [n_edges, 1, 3] 78 | return s, V 79 | 80 | def forward(self, h, pos, edge_index) -> s_V: 81 | edge_attr = self.get_edge_attr(edge_index, pos) 82 | edge_attr = self.edge_embedding(edge_attr) 83 | 84 | h = self.embedding_in(h) 85 | for layer in self.layers: 86 | h = layer(h, edge_index, edge_attr) 87 | 88 | return self.embedding_out(h) -------------------------------------------------------------------------------- /src/layer_norm.py: -------------------------------------------------------------------------------- 1 | # GVP implementation from DiffHopp https://github.com/jostorge/diffusion-hopping/tree/main 2 | 3 | import math 4 | from typing import Tuple, Optional, Union 5 | 6 | import torch 7 | from torch import nn as nn 8 | s_V = Tuple[torch.Tensor, torch.Tensor] 9 | 10 | class GVPLayerNorm(nn.Module): 11 | def __init__(self, dims: Tuple[int, int], eps: float=0.00001) ->None: 12 | super().__init__() 13 | self.eps = math.sqrt(eps) 14 | self.scalar_size, self.vector_size = dims 15 | self.feature_layer_norm = nn.LayerNorm(self.scalar_size, eps=eps) 16 | 17 | def forward(self, x:Union[torch.Tensor, s_V]) -> Union[torch.Tensor, s_V]: 18 | if self.vector_size == 0: 19 | return self.feature_layer_norm(x) 20 | 21 | s, V = x 22 | s = self.feature_layer_norm(s) 23 | norm = torch.clip( 24 | torch.linalg.vector_norm(V, dim=(-1,-2), keepdim=True) 25 | / math.sqrt(self.vector_size), 26 | min=self.eps 27 | ) 28 | 29 | V = V / norm 30 | return s, V -------------------------------------------------------------------------------- /src/lightning_anchor_gnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytorch_lightning as pl 4 | 5 | from torch.nn.functional import sigmoid 6 | from src.datasets import HierCrossDockDataset, get_dataloader, collate_pocket_aux 7 | from src.anchor_gnn import AnchorGNNPocket 8 | 9 | from typing import Dict, List, Optional 10 | from tqdm import tqdm 11 | import os 12 | import torch.nn as nn 13 | 14 | def get_activation(activation): 15 | if activation == 'silu': 16 | return torch.nn.SiLU() 17 | else: 18 | raise Exception('activation fn not found. add it here') 19 | 20 | class MaskedBCEWithLogitsLoss(torch.nn.Module): 21 | """ masks the pocket atoms for anchor prediction loss calculation """ 22 | def __init__(self): 23 | super(MaskedBCEWithLogitsLoss, self).__init__() 24 | self.loss = torch.nn.BCEWithLogitsLoss(reduction='none') 25 | 26 | def forward(self, input, target, mask=None, return_mean=False): 27 | masked_loss = self.loss(input, target) 28 | 29 | if mask is not None: 30 | masked_loss = masked_loss * mask.float() 31 | if return_mean: 32 | if mask is not None: 33 | return masked_loss.sum() / mask.sum().float() 34 | else: 35 | return masked_loss.mean() 36 | else: 37 | return masked_loss 38 | 39 | class AnchorGNN_pl(pl.LightningModule): 40 | train_dataset = None 41 | val_dataset = None 42 | starting_epoch = None 43 | metrics: Dict[str, List[float]] = {} 44 | 45 | def __init__( 46 | self, 47 | lig_node_nf, 48 | pocket_node_nf, 49 | joint_nf, 50 | n_dims, 51 | hidden_nf, 52 | activation, 53 | tanh, 54 | n_layers, 55 | attention, 56 | norm_constant, 57 | data_path, 58 | train_data_prefix, 59 | val_data_prefix, 60 | batch_size, 61 | lr, 62 | test_epochs, 63 | dataset_type, 64 | normalization_factor, 65 | gaussian_expansion=False, 66 | normalization=None, 67 | include_charges=False, 68 | samples_dir=None, 69 | train_dataframe_path='paths_train.csv', 70 | val_dataframe_path='paths_val.csv', 71 | num_workers=0, 72 | ): 73 | 74 | super(AnchorGNN_pl, self).__init__() 75 | self.save_hyperparameters() 76 | self.data_path = data_path 77 | self.train_data_prefix = train_data_prefix 78 | self.val_data_prefix = val_data_prefix 79 | self.batch_size = batch_size 80 | self.lr = lr 81 | self.test_epochs = test_epochs 82 | self.samples_dir = samples_dir 83 | self.n_dims = n_dims 84 | self.num_classes = lig_node_nf - include_charges 85 | self.include_charges = include_charges 86 | self.train_dataframe_path = train_dataframe_path 87 | self.val_dataframe_path = val_dataframe_path 88 | self.num_workers = num_workers 89 | self.n_layers = n_layers 90 | self.attention = attention 91 | self.normalization_factor = normalization_factor 92 | 93 | self.joint_nf = joint_nf 94 | self.lig_node_nf = lig_node_nf 95 | self.pocket_node_nf = pocket_node_nf 96 | 97 | self.norm_constant = norm_constant 98 | self.tanh = tanh 99 | self.dataset_type = dataset_type 100 | self.gaussian_expansion = gaussian_expansion 101 | #self.bce_loss = MaskedBCEWithLogitsLoss() 102 | 103 | if self.dataset_type == 'GEOM': 104 | self.bce_loss = nn.BCEWithLogitsLoss(reduction='none') 105 | elif self.dataset_type == 'CrossDock': 106 | self.bce_loss = MaskedBCEWithLogitsLoss() 107 | 108 | if type(activation) is str: 109 | activation = get_activation(activation) 110 | 111 | self.anchor_predictor = AnchorGNNPocket( 112 | lig_nf=lig_node_nf, 113 | pocket_nf=pocket_node_nf, 114 | joint_nf=joint_nf, 115 | hidden_nf=hidden_nf, 116 | out_node_nf=hidden_nf, 117 | n_layers=4, 118 | normalization_factor=normalization_factor, 119 | normalization=normalization, 120 | attention=True, 121 | aggregation_method='sum', 122 | dist_cutoff=7, 123 | gaussian_expansion=gaussian_expansion, 124 | edge_cutoff_ligand=None, 125 | edge_cutoff_pocket=4.5, 126 | edge_cutoff_interaction=4.5 127 | ) 128 | 129 | def setup(self, stage: Optional[str]=None): 130 | if stage == 'fit': 131 | self.train_dataset = HierCrossDockDataset( 132 | data_path=self.data_path, 133 | prefix=self.train_data_prefix, 134 | device=self.device, 135 | dataframe_path=self.train_dataframe_path 136 | ) 137 | print('loaded train data') 138 | self.val_dataset = HierCrossDockDataset( 139 | data_path=self.data_path, 140 | prefix=self.val_data_prefix, 141 | device=self.device, 142 | dataframe_path=self.val_dataframe_path 143 | ) 144 | print('loaded validation data') 145 | 146 | elif stage == 'val': 147 | self.val_dataset = HierCrossDockDataset( 148 | data_path=self.data_path, 149 | prefix=self.val_data_prefix, 150 | device=self.device, 151 | dataframe_path=self.val_dataframe_path 152 | ) 153 | else: 154 | raise NotImplementedError 155 | 156 | def train_dataloader(self): 157 | return get_dataloader(self.train_dataset, self.batch_size, num_workers=self.num_workers, collate_fn=collate_pocket_aux, shuffle=True) 158 | 159 | def val_dataloader(self): 160 | return get_dataloader(self.val_dataset, self.batch_size, num_workers=self.num_workers, collate_fn=collate_pocket_aux) 161 | 162 | def test_dataloader(self): 163 | return get_dataloader(self.test_dataset, self.batch_size, num_workers=self.num_workers, collate_fn=collate_pocket_aux) 164 | 165 | def forward(self, data, training): 166 | 167 | scaff_x = data['position_aux'].to(self.device) # [B, Ns, 3] 168 | scaff_h = data['onehot_aux'].to(self.device) # [B, Ns, nf] 169 | scaffold_masks = data['scaffold_masks_aux'].to(self.device) # [B, Ns] 170 | pocket_masks = data['pocket_mask_aux'].to(self.device) # [B, Np] 171 | scaffold_anchors = data['anchors_aux'].to(self.device) # [B,Ns] 172 | pocket_x = data['pocket_coords_aux'].to(self.device) 173 | pocket_h = data['pocket_onehot_aux'].to(self.device) 174 | 175 | B, N = scaff_x.shape[0], scaff_x.shape[1] 176 | 177 | B = scaff_x.shape[0] 178 | N_s = scaff_x.shape[1] 179 | N_p = pocket_x.shape[1] 180 | N = N_s+N_p 181 | 182 | anchor_out = self.anchor_predictor.forward(mol_x=scaff_x, # [B, Ns, 3] 183 | mol_h=scaff_h, # [B, Ns, nf] 184 | pocket_x=pocket_x, # [B, Np, 3] 185 | pocket_h=pocket_h, # [B, Np, hp] 186 | node_mask=scaffold_masks, # [B, Np] # mask on both pocket and scaffold 187 | pocket_mask=pocket_masks, 188 | ) # [B, Np] masks only on pocket atoms) 189 | 190 | anchor_loss = self.bce_loss(anchor_out.view(B*N_s, 1), scaffold_anchors.view(B*N_s, 1), scaffold_masks.view(B*N_s, 1), return_mean=True) 191 | #anchor_loss = anchor_loss[not_first_frag_mask].mean() 192 | return anchor_out, anchor_loss 193 | 194 | def training_step(self, data, *args): 195 | _, loss = self.forward(data, training=True) 196 | training_metrics = { 197 | 'loss': loss 198 | } 199 | for metric_name, metric in training_metrics.items(): 200 | self.metrics.setdefault(f'{metric_name}/train', []).append(metric) 201 | self.log(f'{metric_name}/train', metric, on_step=True, on_epoch=True, batch_size=self.batch_size, prog_bar=True) 202 | self.metrics.clear() 203 | return training_metrics 204 | 205 | def validation_step(self, data, *args): 206 | _, loss = self.forward(data, training=False) 207 | validation_metrics = { 208 | 'loss': loss 209 | } 210 | return validation_metrics 211 | 212 | def training_epoch_end(self, training_step_outputs): 213 | for metric in training_step_outputs[0].keys(): 214 | avg_metric = self.aggregate_metric(training_step_outputs, metric) 215 | self.metrics.setdefault(f'{metric}/train', []).append(avg_metric) 216 | self.log(f'{metric}/train', avg_metric, prog_bar=True) 217 | 218 | self.metrics.clear() # free up memory 219 | 220 | def validation_epoch_end(self, validation_step_outputs): 221 | for metric in validation_step_outputs[0].keys(): 222 | avg_metric = self.aggregate_metric(validation_step_outputs, metric) 223 | self.metrics.setdefault(f'{metric}/val', []).append(avg_metric) 224 | self.log(f'{metric}/val', avg_metric, prog_bar=True) 225 | 226 | self.metrics.clear() 227 | 228 | def configure_optimizers(self): 229 | return torch.optim.AdamW(self.anchor_predictor.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12) 230 | 231 | @staticmethod 232 | def aggregate_metric(step_outputs, metric): 233 | return torch.tensor([out[metric] for out in step_outputs]).mean() 234 | -------------------------------------------------------------------------------- /src/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | import numpy as np 5 | 6 | def clip_noise_schedule(alphas2, clip_value=0.001): 7 | """ 8 | For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during 9 | sampling. 10 | """ 11 | alphas2 = np.concatenate([np.ones(1), alphas2], axis=0) 12 | 13 | alphas_step = (alphas2[1:] / alphas2[:-1]) 14 | 15 | alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.) 16 | alphas2 = np.cumprod(alphas_step, axis=0) 17 | return alphas2 18 | 19 | def polynomial_schedule(timesteps: int, s=1e-4, power=3.): 20 | """ 21 | A noise schedule based on a simple polynomial equation: 1 - x^power. 22 | """ 23 | steps = timesteps + 1 24 | x = np.linspace(0, steps, steps) 25 | alphas2 = (1 - np.power(x / steps, power)) ** 2 26 | 27 | alphas2 = clip_noise_schedule(alphas2, clip_value=0.001) 28 | 29 | precision = 1 - 2 * s 30 | 31 | alphas2 = precision * alphas2 + s 32 | 33 | return alphas2 34 | 35 | 36 | def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1): 37 | """ 38 | cosine schedule 39 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 40 | """ 41 | steps = timesteps + 2 42 | x = np.linspace(0, steps, steps) 43 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 44 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 45 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 46 | betas = np.clip(betas, a_min=0, a_max=0.999) 47 | alphas = 1. - betas 48 | alphas_cumprod = np.cumprod(alphas, axis=0) 49 | 50 | if raise_to_power != 1: 51 | alphas_cumprod = np.power(alphas_cumprod, raise_to_power) 52 | 53 | return alphas_cumprod 54 | 55 | 56 | class PositiveLinear(torch.nn.Module): 57 | """Linear layer with weights forced to be positive.""" 58 | 59 | def __init__(self, in_features: int, out_features: int, bias: bool = True, 60 | weight_init_offset: int = -2): 61 | super(PositiveLinear, self).__init__() 62 | self.in_features = in_features 63 | self.out_features = out_features 64 | self.weight = torch.nn.Parameter( 65 | torch.empty((out_features, in_features))) 66 | if bias: 67 | self.bias = torch.nn.Parameter(torch.empty(out_features)) 68 | else: 69 | self.register_parameter('bias', None) 70 | self.weight_init_offset = weight_init_offset 71 | self.reset_parameters() 72 | 73 | def reset_parameters(self) -> None: 74 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 75 | 76 | with torch.no_grad(): 77 | self.weight.add_(self.weight_init_offset) 78 | 79 | if self.bias is not None: 80 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) 81 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 82 | torch.nn.init.uniform_(self.bias, -bound, bound) 83 | 84 | def forward(self, x): 85 | positive_weight = F.softplus(self.weight) 86 | return F.linear(x, positive_weight, self.bias) 87 | 88 | 89 | class PredefinedNoiseSchedule(torch.nn.Module): 90 | """ 91 | Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules. 92 | """ 93 | 94 | def __init__(self, noise_schedule, timesteps, precision): 95 | super(PredefinedNoiseSchedule, self).__init__() 96 | self.timesteps = timesteps 97 | 98 | if noise_schedule == 'cosine': 99 | alphas2 = cosine_beta_schedule(timesteps) 100 | elif 'polynomial' in noise_schedule: 101 | splits = noise_schedule.split('_') 102 | assert len(splits) == 2 103 | power = float(splits[1]) 104 | alphas2 = polynomial_schedule(timesteps, s=precision, power=power) 105 | else: 106 | raise ValueError(noise_schedule) 107 | 108 | # print('alphas2', alphas2) 109 | 110 | sigmas2 = 1 - alphas2 111 | 112 | log_alphas2 = np.log(alphas2) 113 | log_sigmas2 = np.log(sigmas2) 114 | 115 | log_alphas2_to_sigmas2 = log_alphas2 - log_sigmas2 116 | 117 | # print('gamma', -log_alphas2_to_sigmas2) 118 | 119 | self.gamma = torch.nn.Parameter( 120 | torch.from_numpy(-log_alphas2_to_sigmas2).float(), 121 | requires_grad=False) 122 | 123 | def forward(self, t): 124 | t_int = torch.round(t * self.timesteps).long() 125 | return self.gamma[t_int] 126 | 127 | class GammaNetwork(torch.nn.Module): 128 | """The gamma network models a monotonic increasing function. Construction as in the VDM paper.""" 129 | 130 | def __init__(self): 131 | super().__init__() 132 | 133 | self.l1 = PositiveLinear(1, 1) 134 | self.l2 = PositiveLinear(1, 1024) 135 | self.l3 = PositiveLinear(1024, 1) 136 | 137 | self.gamma_0 = torch.nn.Parameter(torch.tensor([-5.])) 138 | self.gamma_1 = torch.nn.Parameter(torch.tensor([10.])) 139 | self.show_schedule() 140 | 141 | def show_schedule(self, num_steps=50): 142 | t = torch.linspace(0, 1, num_steps).view(num_steps, 1) 143 | gamma = self.forward(t) 144 | print('Gamma schedule:') 145 | print(gamma.detach().cpu().numpy().reshape(num_steps)) 146 | 147 | def gamma_tilde(self, t): 148 | l1_t = self.l1(t) 149 | return l1_t + self.l3(torch.sigmoid(self.l2(l1_t))) 150 | 151 | def forward(self, t): 152 | zeros, ones = torch.zeros_like(t), torch.ones_like(t) 153 | # Not super efficient. 154 | gamma_tilde_0 = self.gamma_tilde(zeros) 155 | gamma_tilde_1 = self.gamma_tilde(ones) 156 | gamma_tilde_t = self.gamma_tilde(t) 157 | 158 | # Normalize to [0, 1] 159 | normalized_gamma = (gamma_tilde_t - gamma_tilde_0) / ( 160 | gamma_tilde_1 - gamma_tilde_0) 161 | 162 | # Rescale to [gamma_0, gamma_1] 163 | gamma = self.gamma_0 + (self.gamma_1 - self.gamma_0) * normalized_gamma 164 | 165 | return gamma -------------------------------------------------------------------------------- /train_anchor_predictor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from src.lightning_anchor_gnn import AnchorGNN_pl 5 | from src.utils import disable_rdkit_logging, Logger 6 | from pytorch_lightning import Trainer, callbacks, loggers 7 | from pytorch_lightning.loggers import TensorBoardLogger 8 | 9 | def find_last_checkpoint(checkpoints_dir): 10 | epoch2fname = [ 11 | (int(fname.split('=')[1].split('.')[0]), fname) 12 | for fname in os.listdir(checkpoints_dir) 13 | if fname.endswith('.ckpt') 14 | ] 15 | latest_fname = max(epoch2fname, key=lambda t: t[0])[1] 16 | return os.path.join(checkpoints_dir, latest_fname) 17 | 18 | def main(args): 19 | run_name = args.exp_name 20 | experiment = run_name if args.resume is None else args.resume 21 | checkpoints_dir = os.path.join(args.checkpoints, experiment) 22 | os.makedirs(os.path.join(args.logs, 'general_logs', experiment), exist_ok=True) 23 | sys.stdout = Logger(logpath=os.path.join(args.logs, "general_logs", experiment, f'log.log'), syspart=sys.stdout) 24 | sys.stderr = Logger(logpath=os.path.join(args.logs, "general_logs", experiment, f'log.log'), syspart=sys.stderr) 25 | 26 | os.makedirs(checkpoints_dir, exist_ok=True) 27 | os.makedirs(args.logs, exist_ok=True) 28 | samples_dir = os.path.join(args.logs, 'samples', experiment) 29 | 30 | TB_Logger = TensorBoardLogger('tb_logs', name=experiment) 31 | wandb_logger = loggers.WandbLogger( 32 | save_dir=args.logs, 33 | project='diffusion-anchor-pred', 34 | name=experiment, 35 | id=experiment, 36 | resume='must' if args.resume is not None else 'allow' 37 | ) 38 | 39 | if args.gaussian_expansion is not None: 40 | gaussian_expansion = True 41 | else: 42 | gaussian_expansion = False 43 | 44 | if args.use_guidance: 45 | use_guidance = True 46 | else: 47 | use_guidance = False 48 | 49 | if args.guidance_feature == 'QED' or args.guidance_feature == 'SA': 50 | guidance_classes = 6 51 | elif args.guidance_feature == 'Vina': 52 | guidance_classes = 6 53 | else: 54 | raise ValueError 55 | 56 | # --------------------------------------------------------- 57 | lig_nf = 10 # atom types 58 | pocket_nf = 25 # node features (4) + AA type (20) + BB (1) 59 | #context_node_nf = 3 # context is (anchors + scaffold_masks + pocket_masks ) 60 | joint_nf = 32 61 | 62 | anchor_predictor = AnchorGNN_pl( 63 | lig_node_nf=lig_nf, 64 | pocket_node_nf=pocket_nf, 65 | joint_nf=joint_nf, # TODO: change this? 66 | n_dims=3, 67 | hidden_nf=args.nf, 68 | activation=args.activation, 69 | tanh=args.tanh, 70 | n_layers=args.n_layers, 71 | attention=args.attention, 72 | norm_constant=args.norm_constant, 73 | data_path=args.data, 74 | train_data_prefix=args.train_data_prefix, 75 | val_data_prefix=args.val_data_prefix, 76 | batch_size=args.batch_size, 77 | lr=args.lr, 78 | test_epochs=args.test_epochs, 79 | normalization_factor=args.normalization_factor, 80 | normalization=args.normalization, 81 | include_charges=False, 82 | samples_dir=None, 83 | train_dataframe_path='paths_train.csv', 84 | val_dataframe_path='paths_val.csv', 85 | num_workers=0, 86 | dataset_type=args.dataset_type, 87 | use_guidance=use_guidance, 88 | guidance_classes=guidance_classes, 89 | guidance_feature=args.guidance_feature, 90 | gaussian_expansion=gaussian_expansion) 91 | 92 | checkpoint_callback = callbacks.ModelCheckpoint( 93 | dirpath=checkpoints_dir, 94 | filename=experiment+'_{epoch:02}', 95 | monitor='loss/val', 96 | save_top_k=10 97 | ) 98 | 99 | trainer = Trainer( 100 | max_epochs=args.n_epochs, 101 | logger=wandb_logger, 102 | callbacks=checkpoint_callback, 103 | accelerator='gpu', 104 | devices=[0,1], 105 | num_sanity_val_steps=0, 106 | enable_progress_bar=True, 107 | strategy='ddp', 108 | precision=16 109 | ) 110 | 111 | if args.resume is None: 112 | last_checkpoint = None 113 | else: 114 | last_checkpoint = find_last_checkpoint(checkpoints_dir) 115 | print(f'Training will be resumed from the last checkpoint {last_checkpoint}') 116 | print('Start training') 117 | trainer.fit(model=anchor_predictor, ckpt_path=last_checkpoint) 118 | 119 | if __name__ == '__main__': 120 | p = argparse.ArgumentParser(description='anchor_predictor') 121 | p.add_argument('--data', action='store', type=str, default="") 122 | p.add_argument('--train-dataframe-path', action='store', type=str, default='paths_train.csv') 123 | p.add_argument('--valid-dataframe-path', action='store', type=str, default='paths_val.csv') 124 | p.add_argument('--train_data_prefix', action='store', type=str, default='train_data') 125 | p.add_argument('--val_data_prefix', action='store', type=str, default='val_data') 126 | p.add_argument('--checkpoints', action='store', type=str, default='checkpoints') 127 | p.add_argument('--logs', action='store', type=str, default='logs') 128 | p.add_argument('--device', action='store', type=str, default='cuda:1') 129 | p.add_argument('--trainer_params', type=dict, help='parameters with keywords of the lightning trainer') 130 | p.add_argument('--log_iterations', action='store', type=str, default=20) 131 | p.add_argument('--exp_name', type=str, default='test_1') 132 | 133 | p.add_argument('--n_epochs', type=int, default=400) 134 | p.add_argument('--batch_size', type=int, default=16) 135 | p.add_argument('--lr', type=float, default=5e-4) 136 | 137 | p.add_argument('--activation', type=str, default='silu', help='activation function') 138 | p.add_argument('--n_layers', type=int, default=4, help='number of layers') 139 | p.add_argument('--inv_sublayers', type=int, default=2, help='number of layers') 140 | p.add_argument('--nf', type=int, default=128, help='number of layers') 141 | p.add_argument('--tanh', type=eval, default=False, help='use tanh in the coord_mlp') 142 | p.add_argument('--attention', type=eval, default=False, help='use attention in the EGNN') 143 | p.add_argument('--norm_constant', type=float, default=100, help='diff/(|diff| + norm_constant)') 144 | 145 | p.add_argument('--resume', type=str, default=None, help='') 146 | p.add_argument('--start_epoch', type=int, default=0, help='') 147 | p.add_argument('--ema_decay', type=float, default=0.999, help='Amount of EMA decay, 0 means off. A reasonable value is 0.999.') 148 | p.add_argument('--test_epochs', type=int, default=100) 149 | p.add_argument('--aggregation_method', type=str, default='sum',help='"sum" or "mean"') 150 | p.add_argument('--normalization', type=str, default='batch_norm', help='batch_norm') 151 | p.add_argument('--normalization_factor', type=float, default=100, help="Normalize the sum aggregation of EGNN") 152 | p.add_argument('--dataset-type', type=str, default='GEOM', help='dataset-type can be GEOM or CrossDock for now') 153 | 154 | p.add_argument('--gaussian-expansion', action='store_true', default=False, help='whether to use gaussian expansion of distances') 155 | p.add_argument('--use-guidance', action='store_true', default=False, help='whether to train anchor-predictor for a specific guidance feature') 156 | p.add_argument('--guidance-feature', type=str, default='QED', help='guidance feature for adding to anchor predictor') 157 | args = p.parse_args() 158 | main(args=args) -------------------------------------------------------------------------------- /utils/sample_frag_size.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | def sample_fragment_size(new_score, bin_edges, distributions): 5 | # Find which bin the new score belongs to 6 | bin_idx = np.digitize(new_score, bin_edges) 7 | # Get the probability distribution for the bin 8 | probabilities = distributions.loc[bin_idx].values 9 | discrete_values = distributions.columns.values 10 | # Sample a discrete number from the distribution 11 | return np.random.choice(discrete_values, p=probabilities) 12 | 13 | bounds = [4.1, 8.1, 12.1, 16.1] 14 | fragsize_prob = np.array([[7.11770964e-01, 1.53752812e-01, 1.06553153e-01, 1.95619411e-02, 15 | 2.82028835e-03, 5.54084154e-03, 0.00000000e+00, 0.00000000e+00, 16 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 17 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 18 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], 19 | [1.20153410e-01, 4.57824749e-02, 3.31473162e-02, 2.52469754e-02, 20 | 1.16709580e-01, 4.48883000e-01, 1.22096176e-01, 4.48883000e-02, 21 | 2.69329800e-02, 1.61597880e-02, 0.00000000e+00, 0.00000000e+00, 22 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 23 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], 24 | [7.74508034e-02, 3.24180141e-02, 1.07572679e-02, 7.19736364e-03, 25 | 1.01711551e-01, 2.79706766e-01, 1.27139439e-01, 7.62836633e-02, 26 | 7.62836633e-02, 1.14425495e-01, 5.08557756e-02, 2.54278878e-02, 27 | 1.27139439e-02, 3.81418317e-03, 2.54278878e-03, 1.27139439e-03, 28 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], 29 | [4.83709778e-02, 1.47227584e-02, 3.28787372e-03, 4.42784727e-03, 30 | 6.65155568e-02, 2.08995151e-01, 1.13997355e-01, 1.13997355e-01, 31 | 1.04497576e-01, 1.42496694e-01, 6.64984572e-02, 4.74988980e-02, 32 | 2.84993388e-02, 1.89995592e-02, 9.49977961e-03, 4.74988980e-03, 33 | 2.84993388e-03, 9.49977961e-05, 0.00000000e+00, 0.00000000e+00], 34 | [1.77003281e-02, 2.15758183e-02, 8.85185643e-03, 3.94403048e-03, 35 | 5.92476134e-02, 1.69235378e-01, 1.26926533e-01, 1.10002995e-01, 36 | 8.46176888e-02, 1.52311840e-01, 1.18464764e-01, 5.92323822e-02, 37 | 3.38470755e-02, 1.69235378e-02, 8.46176888e-03, 5.07706133e-03, 38 | 3.38470755e-03, 1.69235378e-04, 8.46176888e-06, 1.69235378e-05]]) 39 | 40 | fragsize_prob_df = pd.DataFrame(fragsize_prob, index=[0,1,2,3,4], columns=np.arange(1,21)) -------------------------------------------------------------------------------- /utils/visuals.py: -------------------------------------------------------------------------------- 1 | import py3Dmol 2 | from rdkit.Chem import AllChem 3 | from rdkit import Chem 4 | from rdkit.Geometry import Point3D 5 | from openbabel import openbabel 6 | import numpy as np 7 | from openbabel import openbabel 8 | import tempfile 9 | 10 | 11 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9} 12 | idx2atom = {0:'C', 1:'N', 2:'O', 3:'S', 4:'B', 5:'Br', 6:'Cl', 7:'P', 8:'I', 9:'F'} 13 | 14 | def write_xyz_file(coords, atom_types, filename): 15 | out = f"{len(coords)}\n\n" 16 | assert len(coords) == len(atom_types) 17 | for i in range(len(coords)): 18 | out += f"{atom_types[i]} {coords[i, 0]:.3f} {coords[i, 1]:.3f} {coords[i, 2]:.3f}\n" 19 | with open(filename, 'w') as f: 20 | f.write(out) 21 | 22 | def visualize_molecules_grid(mols, grid_size=(3, 3), spacing=5.0, spin=True): 23 | viewer = py3Dmol.view(width=900, height=900) 24 | 25 | for i, mol in enumerate(mols): 26 | try: 27 | Chem.SanitizeMol(mol) 28 | except: 29 | print('couldnt sanitize') 30 | #AllChem.EmbedMolecule(mol) # Generate 3D coordinates 31 | #AllChem.MMFFOptimizeMolecule(mol, maxIters=500) # Optimize the geometry using MMFF94 force field 32 | 33 | # Calculate the grid position 34 | grid_x = i % grid_size[0] 35 | grid_y = i // grid_size[0] 36 | 37 | # Translate the molecule according to its position in the grid 38 | conf = mol.GetConformer() 39 | translation_vector = Point3D((grid_x * spacing) + (spacing / 2), (grid_y * spacing) + (spacing / 2), 0.0) 40 | for atom_idx in range(mol.GetNumAtoms()): 41 | atom_position = conf.GetAtomPosition(atom_idx) 42 | atom_position += translation_vector 43 | conf.SetAtomPosition(atom_idx, atom_position) 44 | 45 | mb = Chem.MolToMolBlock(mol) 46 | viewer.addModel(mb, 'sdf') 47 | 48 | #if spin: 49 | # viewer.spin({'x': 0, 'y': 1, 'z': 0}) 50 | 51 | # Draw separating lines 52 | for i in range(grid_size[0] - 1): 53 | x = (i + 1) * spacing 54 | viewer.addLine({'start': {'x': x, 'y': 0, 'z': 0}, 55 | 'end': {'x': x, 'y': grid_size[1] * spacing, 'z': 0}, 56 | 'color': 'gray'}) 57 | for i in range(grid_size[1] - 1): 58 | y = (i + 1) * spacing 59 | viewer.addLine({'start': {'x': 0, 'y': y, 'z': 0}, 60 | 'end': {'x': grid_size[0] * spacing, 'y': y, 'z': 0}, 61 | 'color': 'gray'}) 62 | 63 | #viewer.spin({'x': 0, 'y': 1, 'z': 0}, origin=(grid_size[0] * spacing / 2, grid_size[1] * spacing / 2, 0)) 64 | viewer.setStyle({}, {'stick': {'colorscheme': ['silverCarbon', 'redOxygen', 'blueNitrogen'], 'radius': 0.15, 'opacity': 1}, 65 | 'sphere': {'colorscheme': ['silverCarbon', 'redOxygen', 'blueNitrogen'], 'radius': 0.35, 'opacity': 1}}) 66 | viewer.zoomTo() 67 | viewer.show() 68 | 69 | def get_pocket_mol(pocket_coords, pocket_onehot): 70 | with tempfile.NamedTemporaryFile() as tmp: 71 | tmp_file = tmp.name 72 | 73 | atom_inds= pocket_onehot.argmax(1) 74 | atom_types = [idx2atom[x] for x in atom_inds] 75 | # write xyz file 76 | write_xyz_file(pocket_coords, atom_types, tmp_file) 77 | 78 | obConversion = openbabel.OBConversion() 79 | obConversion.SetInAndOutFormats('xyz', 'sdf') 80 | ob_mol = openbabel.OBMol() 81 | obConversion.ReadFile(ob_mol, tmp_file) 82 | 83 | obConversion.WriteFile(ob_mol, tmp_file) 84 | pocket_mol = Chem.SDMolSupplier(tmp_file, sanitize=False)[0] 85 | 86 | return pocket_mol 87 | 88 | def visualize_3d_pocket_molecule(pocket_mol, mol=None, spin=False, optimize_coords=False, sphere_positions1=None, sphere_positions2=None, rotate=None): 89 | viewer = py3Dmol.view() 90 | 91 | pocket_mol = Chem.RemoveHs(pocket_mol) 92 | pocket_mb = Chem.MolToMolBlock(pocket_mol) 93 | viewer.addModel(pocket_mb, 'sdf') 94 | viewer.setStyle({'model': -1}, {"sphere": {'color': 'grey', 'opacity': 0.8, 'radius':0.9}}) 95 | #viewer.setStyle({'model': 0}, {'stick': {'colorscheme': ['whiteCarbon', 'redOxygen', 'blueNitrogen'], 'radius': 0.2, 'opacity': 1}, 96 | # 'sphere': {'colorscheme': ['whiteCarbon', 'redOxygen', 'blueNitrogen'], 'radius': 0.3, 'opacity': 1}}) 97 | 98 | viewer.zoomTo() 99 | #viewer.setStyle({'model': 0}, {'cartoon': {'color': 'spectrum'}}) # Updated style for cartoon representation 100 | #viewer.addSurface(py3Dmol.SAS, {'opacity': 0.9, 'radius': 0.5}) 101 | 102 | if mol is not None: 103 | try: 104 | Chem.SanitizeMol(mol) 105 | except: 106 | print('Problem with the molecule') 107 | return 108 | 109 | mol = Chem.RemoveHs(mol) 110 | mol_mb = Chem.MolToMolBlock(mol) 111 | viewer.addModel(mol_mb, 'sdf') 112 | viewer.setStyle({'model': 1}, {'stick': {'colorscheme': 'cyanCarbon', 'radius': 0.15, 'opacity': 1}, 113 | 'sphere': {'colorscheme': 'cyanCarbon', 'radius': 0.35, 'opacity': 1}}) 114 | 115 | if sphere_positions1 is not None: 116 | for pos in sphere_positions1: 117 | sphere_spec = {'center': {'x': float(pos[0]), 'y': float(pos[1]), 'z': float(pos[2])}, 'radius': 1, 'color': 'green', 'opacity': 0.75} 118 | viewer.addSphere(sphere_spec) 119 | 120 | if sphere_positions2 is not None: 121 | for pos in sphere_positions2: 122 | sphere_spec = {'center': {'x': float(pos[0]), 'y': float(pos[1]), 'z': float(pos[2])}, 'radius': 0.3, 'color': 'yellow', 'opacity': 0.75} 123 | viewer.addSphere(sphere_spec) 124 | 125 | 126 | if spin: 127 | viewer.spin({'x': 0, 'y': 1, 'z': 0}) 128 | 129 | if rotate: 130 | viewer.rotate(rotate,'y',1); 131 | return viewer 132 | --------------------------------------------------------------------------------