├── .DS_Store
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── analysis
├── SA_Score
│ ├── README.md
│ ├── __pycache__
│ │ ├── sascorer.cpython-310.pyc
│ │ └── sascorer.cpython-38.pyc
│ └── sascorer.py
├── all_pdbs_to_pdbqts.py
├── bond_angle_config.py
├── bond_length_config.py
├── docking.py
├── docking_py27.py
├── eval_bond_angles.py
├── eval_bond_length.py
├── get_atom_types_dist.py
├── get_empirical_angles.py
├── get_empirical_dists.py
├── get_volume.py
├── metrics.py
├── qvina_docking.py
├── reconstruct_mol.py
├── scoring_func.py
├── similarity.py
├── utils.py
└── vina_docking.py
├── analyze_generated_pocket_mols.py
├── analyze_scaffolds_generated.py
├── assets
├── .DS_Store
├── movie.gif
└── scaffold_optim.png
├── data
├── CROSSDOCK
│ ├── __init__.py
│ ├── fragment_hierarchy.py
│ ├── prepare_fragments.py
│ ├── process_crossdock.py
│ ├── process_ligands.py
│ ├── process_pockets.py
│ └── sascorer.py
├── __init__.py
└── sascorer.py
├── extend_scaffold_crossdock.py
├── fpscores.pkl.gz
├── generate_pocket_molecules.py
├── notebooks
├── 2z3h.pdb
├── 2z3h_H.pdb
├── 2z3h_out
│ ├── 2z3h.pml
│ ├── 2z3h.tcl
│ ├── 2z3h_PYMOL.sh
│ ├── 2z3h_VMD.sh
│ ├── 2z3h_info.txt
│ ├── 2z3h_out.pdb
│ ├── 2z3h_pockets.pqr
│ └── pockets
│ │ ├── pocket1_atm.pdb
│ │ ├── pocket1_vert.pqr
│ │ ├── pocket2_atm.pdb
│ │ └── pocket2_vert.pqr
├── __init__.py
└── sample_for_pocket.ipynb
├── sample_crossdock_mols.py
├── sample_from_pocket.py
├── sampling
├── rejection_sampling.py
├── sample_mols.py
└── scaffold_extension.py
├── src
├── __init__.py
├── anchor_gnn.py
├── const.py
├── conv_layer.py
├── datasets.py
├── dropout.py
├── dynamics_gvp.py
├── edm.py
├── egnn.py
├── extension_size.py
├── fragment_size_gnn.py
├── gvp.py
├── gvp_model.py
├── layer_norm.py
├── lightning.py
├── lightning_anchor_gnn.py
├── noise.py
└── utils.py
├── train_anchor_predictor.py
├── train_frag_diffuser.py
└── utils
├── sample_frag_size.py
├── templates.py
├── visuals.py
└── volume_sampling.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 keiserlab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AutoFragDiff
2 |
3 | This repository is the official implementation of Autoregressive fragment based diffusion model for target-aware ligand design
4 |
5 |
6 |
7 |
8 | # Dependencies
9 | - RDKit
10 | - openbabel
11 | - PyTorch
12 | - biopython
13 | - biopandas
14 | - networkx
15 | - py3dmol
16 | - scikit-learn
17 | - tensorboard
18 | - wandb
19 | - pytorch-lightning
20 |
21 | ## Create conda environment
22 | ```
23 | conda create -n autofragdiff
24 | pip install rdkit
25 | conda install -c conda-forge openbabel
26 | pip3 install torch torchvision torchaudio
27 | pip install biopython
28 | pip install biopandas
29 | pip install networkx
30 | pip install py3dmol
31 | pip install scikit-learn
32 | pip install tensorboard
33 | pip install wandb
34 | pip install tqdm
35 | pip install pytorch-lightning==1.6.0
36 | ```
37 |
38 | The model has been tested with the following software versions:
39 |
40 | | Software | Version |
41 | | --------------- | ----------- |
42 | | rdkit | 2023.3.1 |
43 | | openbabel | 3.1.1 |
44 | | pytorch | 2.0.1 |
45 | | biopython | 1.81 |
46 | | biopandas | 0.4.1 |
47 | | networkx | 3.1 |
48 | | py3dmol | 2.0.1. |
49 | | scikit-learn | 1.2.2 |
50 | | tensorboard | 2.13.0 |
51 | | wandb | 0.15.2 |
52 | | pytorch-lightning | 1.6.0 |
53 |
54 |
55 | ## QucikVina2
56 | For Docking with qvina install QuickVina2:
57 | ```
58 | wget https://github.com/QVina/qvina/raw/master/bin/qvina2.1
59 | chmod +x qvina2.1
60 | ```
61 | We also need MGLTools for preparing the receptor for docking (pdb->pdbqt) but it can mess up the conda environment, so make a new one.
62 | ```
63 | conda create -n mgltools -c bioconda mgltools
64 | ```
65 |
66 | # Data Preparation
67 |
68 | ## CrossDock
69 | Download and extract the dataset as described by the authors of Pocket2Mol: https://github.com/pengxingang/Pocket2Mol/tree/main/data
70 |
71 | process the molecule fragments using a custom fragmentation.
72 | ```
73 | python process_crossdock.py --rootdir $CROSSDOCK_PATH --outdir $OUT_DIR \
74 | --dist_cutoff 7. --max-num-frags 8 --split test --max-atoms-single-fragment 22 \
75 | --add-Vina-score --add-QED-score --add-SA-score --n-cores 16
76 | ```
77 | - For adding Vina you also need to generate pdbqt files for each receptor and crystallographic ligand.
78 |
79 | # Training
80 |
81 | ## Training AutoFragdiff.
82 | ```
83 | python train_frag_diffuser.py --data $CROSSDOCK_DIR --exp_name CROSSDOCK_model_1 \
84 | --lr 0.0001 --n_layers 6 --nf 128 --diffusoin_steps 500 \
85 | --diffusion_loss_type l2 --n_epochs 1000 --batch_size 4
86 | ```
87 |
88 | ## Training anchor predictor
89 | ```
90 | python train_anchor_predictor --data $CROSSDOCK_DIR --exp_name CROSDOCK_anchor_model_1 \
91 | --n_layers 4 --inv_sublayers 2 --nf 128 --dataset-type CrossDock
92 | ```
93 |
94 |
95 | # Sampling:
96 |
97 | Firt download the trained models from the google drive in the following link
98 |
99 | https://drive.google.com/drive/folders/1DQwIfibHIoFPGJP6aHBGiYRp87bCZFA0?usp=share_link
100 |
101 | ## CrossDock pocket-based molecule generation:
102 |
103 | To generate molecules from trained pocket-based model, also use anchor-predictor model. fragment sizes are sampled from the data distribution.
104 |
105 | ## CrossDock pocket-based molecule generation (with guidance):
106 |
107 | To generate molecules for crossdock test set:
108 | ```
109 | python sample_crossdock_mols.py --results-path results/ --data-path $(path-to-crossdock-dataset) --use-anchor-model --anchor-model anchor-model.ckpt --n-samples 20 --exp-name test-crossdock --diff-model pocket-gvp.ckpt --device cuda:0
110 | ```
111 |
112 | To sample molecules from a pdb file:
113 | first run fpocket and identify the correct pocket using:
114 | ```
115 | fpocket -f $pdb.pdb
116 | ```
117 | fpocket gives multiple pockets, you can visualize the identify the right pocket and run sampling
118 |
119 | ```
120 | python sample_from_pocket.py --result-path results --pdb $pdbname --anchor-model anchor-model.ckpt --n-samples 10 --device cuda:0 --pocket-number 1
121 | ```
122 |
123 | ## Scaffold-based molecule property optimization
124 |
125 | For scaffold-based optimization you need the pdb file of the pocket and the sdf file of the scaffold molecule (and the original molecule).
126 |
127 | Scaffold-extension for crossdock test set
128 | ```
129 | python extend_scaffold_crossdock.py --data-path $(path-to-crossdock) --results-path scaffold-gen --anchor-model anchor-model.ckpt --n-samples 20 --exp-name scaffold-gen --diff-model pocket-gvp.ckpt --device cuda:0
130 | ```
131 |
132 | - In order to select the anchor you can add the `--custom-anchors` argument and provide the ids of custom anchors (starts from 0 and based on atomic ids in the scaffold molecule).
133 |
134 |

135 |
136 |
137 |
138 |
139 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/__init__.py
--------------------------------------------------------------------------------
/analysis/SA_Score/README.md:
--------------------------------------------------------------------------------
1 | # README
2 |
3 | Files taken from [rdkit/rdkit](https://github.com/rdkit/rdkit/tree/master/Contrib/SA_Score) repository on GitHub.
4 |
--------------------------------------------------------------------------------
/analysis/SA_Score/__pycache__/sascorer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/analysis/SA_Score/__pycache__/sascorer.cpython-310.pyc
--------------------------------------------------------------------------------
/analysis/SA_Score/__pycache__/sascorer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/analysis/SA_Score/__pycache__/sascorer.cpython-38.pyc
--------------------------------------------------------------------------------
/analysis/SA_Score/sascorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of synthetic accessibility score as described in:
3 | #
4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5 | # Peter Ertl and Ansgar Schuffenhauer
6 | # Journal of Cheminformatics 1:8 (2009)
7 | # http://www.jcheminf.com/content/1/1/8
8 | #
9 | # several small modifications to the original paper are included
10 | # particularly slightly different formula for marocyclic penalty
11 | # and taking into account also molecule symmetry (fingerprint density)
12 | #
13 | # for a set of 10k diverse molecules the agreement between the original method
14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97
15 | #
16 | # peter ertl & greg landrum, september 2013
17 | #
18 |
19 |
20 | from rdkit import Chem
21 | from rdkit.Chem import rdMolDescriptors
22 | import pickle
23 |
24 | import math
25 | from collections import defaultdict
26 |
27 | import os.path as op
28 |
29 | _fscores = None
30 |
31 |
32 | def readFragmentScores(name='fpscores'):
33 | import gzip
34 | global _fscores
35 | # generate the full path filename:
36 | #if name == "fpscores":
37 | # name = op.join(op.dirname(__file__), name)
38 | data = pickle.load(gzip.open('fpscores.pkl.gz'))
39 | outDict = {}
40 | for i in data:
41 | for j in range(1, len(i)):
42 | outDict[i[j]] = float(i[0])
43 | _fscores = outDict
44 |
45 |
46 | def numBridgeheadsAndSpiro(mol, ri=None):
47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
49 | return nBridgehead, nSpiro
50 |
51 |
52 | def calculateScore(m):
53 | if _fscores is None:
54 | readFragmentScores()
55 |
56 | # fragment score
57 | fp = rdMolDescriptors.GetMorganFingerprint(m,
58 | 2) # <- 2 is the *radius* of the circular fingerprint
59 | fps = fp.GetNonzeroElements()
60 | score1 = 0.
61 | nf = 0
62 | for bitId, v in fps.items():
63 | nf += v
64 | sfp = bitId
65 | score1 += _fscores.get(sfp, -4) * v
66 | score1 /= nf
67 |
68 | # features score
69 | nAtoms = m.GetNumAtoms()
70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
71 | ri = m.GetRingInfo()
72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
73 | nMacrocycles = 0
74 | for x in ri.AtomRings():
75 | if len(x) > 8:
76 | nMacrocycles += 1
77 |
78 | sizePenalty = nAtoms**1.005 - nAtoms
79 | stereoPenalty = math.log10(nChiralCenters + 1)
80 | spiroPenalty = math.log10(nSpiro + 1)
81 | bridgePenalty = math.log10(nBridgeheads + 1)
82 | macrocyclePenalty = 0.
83 | # ---------------------------------------
84 | # This differs from the paper, which defines:
85 | # macrocyclePenalty = math.log10(nMacrocycles+1)
86 | # This form generates better results when 2 or more macrocycles are present
87 | if nMacrocycles > 0:
88 | macrocyclePenalty = math.log10(2)
89 |
90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
91 |
92 | # correction for the fingerprint density
93 | # not in the original publication, added in version 1.1
94 | # to make highly symmetrical molecules easier to synthetise
95 | score3 = 0.
96 | if nAtoms > len(fps):
97 | score3 = math.log(float(nAtoms) / len(fps)) * .5
98 |
99 | sascore = score1 + score2 + score3
100 |
101 | # need to transform "raw" value into scale between 1 and 10
102 | min = -4.0
103 | max = 2.5
104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
105 | # smooth the 10-end
106 | if sascore > 8.:
107 | sascore = 8. + math.log(sascore + 1. - 9.)
108 | if sascore > 10.:
109 | sascore = 10.0
110 | elif sascore < 1.:
111 | sascore = 1.0
112 |
113 | return sascore
114 |
115 |
116 | def processMols(mols):
117 | print('smiles\tName\tsa_score')
118 | for i, m in enumerate(mols):
119 | if m is None:
120 | continue
121 |
122 | s = calculateScore(m)
123 |
124 | smiles = Chem.MolToSmiles(m)
125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
126 |
127 |
128 | if __name__ == '__main__':
129 | import sys
130 | import time
131 |
132 | t1 = time.time()
133 | readFragmentScores("fpscores")
134 | t2 = time.time()
135 |
136 | suppl = Chem.SmilesMolSupplier(sys.argv[1])
137 | t3 = time.time()
138 | processMols(suppl)
139 | t4 = time.time()
140 |
141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
142 | file=sys.stderr)
143 |
144 | #
145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
146 | # All rights reserved.
147 | #
148 | # Redistribution and use in source and binary forms, with or without
149 | # modification, are permitted provided that the following conditions are
150 | # met:
151 | #
152 | # * Redistributions of source code must retain the above copyright
153 | # notice, this list of conditions and the following disclaimer.
154 | # * Redistributions in binary form must reproduce the above
155 | # copyright notice, this list of conditions and the following
156 | # disclaimer in the documentation and/or other materials provided
157 | # with the distribution.
158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
159 | # nor the names of its contributors may be used to endorse or promote
160 | # products derived from this software without specific prior written permission.
161 | #
162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
173 | #
174 |
175 | def compute_sa_score(rdmol):
176 | rdmol = Chem.MolFromSmiles(Chem.MolToSmiles(rdmol))
177 | sa = calculateScore(rdmol)
178 | sa_norm = round((10 - sa) / 9, 2)
179 | return sa_norm
--------------------------------------------------------------------------------
/analysis/all_pdbs_to_pdbqts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 |
5 | if __name__ == '__main__':
6 | with open('/srv/home/mahdi.ghorbani/FragDiff/pdb_paths.txt', 'r') as f:
7 | all_files = [line.strip() for line in f.readlines()]
8 | root_dir = '/srv/home/mahdi.ghorbani/FragDiff/crossdock/crossdocked_pocket10/'
9 | for i, file in enumerate(all_files):
10 | if i % 100 == 0:
11 | print(i)
12 | prot_name = root_dir + file
13 | pdbqt_name = prot_name[:-3] + 'pdbqt'
14 | if not os.path.exists(pdbqt_name):
15 | os.system('prepare_receptor4.py -r {} -o {}'.format(prot_name, pdbqt_name))
16 |
--------------------------------------------------------------------------------
/analysis/docking.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import torch
4 | from pathlib import Path
5 | import argparse
6 |
7 | import pandas as pd
8 | from rdkit import Chem
9 | from tqdm import tqdm
10 |
11 | affinity_pattern = r"Affinity:\s+(-?\d+\.\d+)\s+\(kcal/mol\)"
12 | def calculate_smina_score(pdb_file, sdf_file):
13 | # add '-o _smina.sdf' if you want to see the output
14 | out = os.popen(f'smina.static -l {sdf_file} -r {pdb_file} '
15 | f'--score_only').read()
16 | matches = re.findall(
17 | r"Affinity:[ ]+([+-]?[0-9]*[.]?[0-9]+)[ ]+\(kcal/mol\)", out)
18 | return [float(x) for x in matches]
19 |
20 | def sdf_to_pdbqt(sdf_file, pdbqt_outfile, mol_id):
21 | os.popen(f'obabel {sdf_file} -O {pdbqt_outfile} -f {mol_id + 1} -l {mol_id + 1}').read()
22 | return pdbqt_outfile
23 |
24 | def calculate_qvina2_score(receptor_file, sdf_file, out_dir, size=20,
25 | exhaustiveness=16, return_rdmol=False, score_only=False):
26 | """
27 | receptor_file: pdbqt file for receptor
28 | sdf_file: sdf file for ligand
29 | out_dir: output directory
30 |
31 | returns:
32 | scores: list of scores for each ligand
33 | rdmols: list of qvina docked ligands
34 | """
35 |
36 | receptor_pdbqt_file = Path(receptor_file)
37 | sdf_file = Path(sdf_file)
38 |
39 | scores = []
40 | rdmols = [] # for if return rdmols
41 | suppl = Chem.SDMolSupplier(str(sdf_file), sanitize=False)
42 |
43 | for i, mol in enumerate(suppl): # sdf file may contain several ligands
44 | ligand_name = f'{sdf_file.stem}_{i}'
45 | # prepare ligand
46 | ligand_pdbqt_file = Path(out_dir, ligand_name + '.pdbqt')
47 | out_sdf_file = Path(out_dir, ligand_name + '_out.sdf')
48 |
49 | if out_sdf_file.exists():
50 | with open(out_sdf_file, 'r') as f:
51 | scores.append(min([float(x.split()[2]) for x in f.readlines()
52 | if x.startswith(' VINA RESULT:')]))
53 | else:
54 | sdf_to_pdbqt(sdf_file, ligand_pdbqt_file, i)
55 |
56 | # center box at ligand's center of mass
57 | cx, cy, cz = mol.GetConformer().GetPositions().mean(0)
58 |
59 | # run QuckVina2
60 | # run QuickVina 2
61 | if not score_only:
62 |
63 | out = os.popen(
64 | f'qvina2.1 --receptor {receptor_pdbqt_file} '
65 | f'--ligand {ligand_pdbqt_file} '
66 | f'--center_x {cx:.4f} --center_y {cy:.4f} --center_z {cz:.4f} '
67 | f'--size_x {size} --size_y {size} --size_z {size} '
68 | f'--exhaustiveness {exhaustiveness}'
69 | ).read()
70 | out_split = out.splitlines()
71 | best_ids = out_split.index('-----+------------+----------+----------') + 1
72 | best_line = out_split[best_ids].split()
73 | assert best_line[0] == '1'
74 | scores.append(float(best_line[1]))
75 |
76 | out_pdbqt_file = Path(out_dir, ligand_name + '_out.pdbqt')
77 | if out_pdbqt_file.exists():
78 | os.popen(f'obabel {out_pdbqt_file} -O {out_sdf_file}').read()
79 |
80 | if return_rdmol:
81 | rdmol = Chem.SDMolSupplier(str(out_sdf_file))[0]
82 | rdmols.append(rdmol)
83 |
84 | else:
85 | out = os.popen(
86 | f'qvina2.1 --score_only --receptor {receptor_pdbqt_file} '
87 | f'--ligand {ligand_pdbqt_file} '
88 | f'--center_x {cx:.4f} --center_y {cy:.4f} --center_z {cz:.4f} '
89 | f'--size_x {size} --size_y {size} --size_z {size} '
90 | ).read()
91 | match = re.search(affinity_pattern, out)
92 | scores = float(match.group(1))
93 |
94 | if return_rdmol:
95 | return scores, rdmols
96 | else:
97 | return scores
98 |
99 | if __name__ == '__main__':
100 | parser = argparse.ArgumentParser('QuickVina evaulation')
101 | parser.add_argument('--pdbqt_dir', type=Path,
102 | help='Receptor files in pdbqt format')
103 | parser.add_argument('--sdf_dir', type=Path, default=None,
104 | help='Ligand files in sdf format')
105 | parser.add_argument('--out_dir', type=Path)
106 | parser.add_argument('--write_csv', action='store_true')
107 | parser.add_argument('--write_dict', action='store_true')
108 | parser.add_argument('--dataset', type=str, default='CROSSDOCK')
109 | args = parser.parse_args()
110 |
111 | assert (args.sdf_dir is not None)
112 |
113 | results = {'receptor': [], 'ligand': [], 'scores':[]}
114 | results_dict = {}
115 |
116 | sdf_files = list(os.listdir(args.sdf_dir))
117 | pbar = tqdm(sdf_files)
118 |
119 | for sdf_file in pbar:
120 | pbar.set_description(f'Processing {sdf_file}')
121 |
122 | if args.dataset == 'CROSSDOCK':
123 | receptor_name = sdf_file.split('_')[0] + '_pocket'
124 | receptor_file = Path(args.pdbqt_dir, receptor_name + '.pdbqt')
125 |
126 | sdf_path = Path(str(args.sdf_dir) + '/' + sdf_file)
127 | try:
128 | scores, rdmols = calculate_qvina2_score(receptor_file, sdf_path, args.out_dir, return_rdmol=True)
129 | except (ValueError, AttributeError) as e:
130 | print(e)
131 | continue
132 | results['receptor'].append(str(receptor_file))
133 | results['ligand'].append(str(sdf_file))
134 | results['scores'].append(scores)
135 |
136 | if args.write_dict:
137 | results_dict[receptor_name] = [scores, rdmols]
138 |
139 | if args.write_csv:
140 | df = pd.DataFrame.from_dict(results)
141 | df.to_csv(Path(args.out_dir, 'qvina2_scores.csv'))
142 |
143 | if args.write_dict:
144 | torch.save(results_dict, Path(args.out_dir, 'qvina2_scores.pt'))
--------------------------------------------------------------------------------
/analysis/docking_py27.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 |
5 | def pdbs_to_pdbqts(pdb_dir, pdbqt_dir, dataset):
6 | for file in glob.glob(os.path.join(pdb_dir, '*.pdb')):
7 | name = os.path.splitext(os.path.basename(file))[0]
8 | outfile = os.path.join(pdbqt_dir, name + '.pdbqt')
9 | pdb_to_pdbqt(file, outfile, dataset)
10 | print('Wrote converted file to {}'.format(outfile))
11 |
12 | def pdb_to_pdbqt(pdb_file, pdbqt_file, dataset):
13 | if dataset == 'CROSSDOCK':
14 | os.system('prepare_receptor4.py -r {} -o {}'.format(pdb_file, pdbqt_file))
15 |
16 | else:
17 | raise NotImplementedError
18 |
19 | return pdbqt_file
20 |
21 | if __name__ == '__main__':
22 | pdbs_to_pdbqts(sys.argv[1], sys.argv[2], sys.argv[3])
23 |
--------------------------------------------------------------------------------
/analysis/eval_bond_angles.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import collections
4 | from typing import Tuple, Sequence, Dict, Optional
5 |
6 | from scipy import spatial as sci_spatial
7 | from .bond_angle_config import ANGLE_DIST_CROSSDOCK, DIHED_DIST_CROSSDOCK
8 | from rdkit.Chem.rdMolTransforms import GetAngleRad, GetDihedralRad, GetAngleDeg, GetDihedralDeg
9 |
10 | def get_distribution(angles, bins):
11 |
12 | bin_counts = collections.Counter(np.searchsorted(bins, angles))
13 | bin_counts = [bin_counts[i] if i in bin_counts else 0 for i in range(len(bins))]
14 | bin_counts = np.array(bin_counts) / np.sum(bin_counts)
15 | return bin_counts
16 |
17 | def eval_angle_dist_profile(bond_angle_profile, dihedral_angle_profile, frag):
18 |
19 | # frag is the smiles of fragment
20 | # bond_angle_profile -> a dictionary with keys the smiles of fragmenst and values the distribution of angles/dihedrals
21 | metrics = {}
22 | gt_distribution = ANGLE_DIST_CROSSDOCK[frag]
23 | metrics[f'Angle-JSD_{frag}'] = sci_spatial.distance.jensenshannon(gt_distribution,
24 | bond_angle_profile)
25 |
26 | gt_distribution = DIHED_DIST_CROSSDOCK[frag]
27 | metrics[f'Dihedral-JSD_{frag}'] = sci_spatial.distance.jensenshannon(gt_distribution,
28 | dihedral_angle_profile)
29 | return metrics
30 |
31 |
32 | def find_angle_dist(mol, frag):
33 | all_frag_angles = []
34 | all_frag_dihedrals = []
35 |
36 | conf = mol.GetConformer()
37 |
38 | matches_frag = mol.GetSubstructMatches(frag)
39 | for match in matches_frag:
40 | match_angles = []
41 | match_dih = []
42 | match_set = set(match)
43 |
44 | for atom_index in match:
45 | atom = mol.GetAtomWithIdx(atom_index)
46 | neighbors = [neighbor.GetIdx() for neighbor in atom.GetNeighbors() if neighbor.GetIdx() in match_set]
47 | for i in range(len(neighbors)-1):
48 | for j in range(i+1, len(neighbors)):
49 | angle_deg = GetAngleDeg(conf, neighbors[i], atom_index, neighbors[j])
50 |
51 | if angle_deg < 0:
52 | angle_deg += 360
53 | match_angles.append(angle_deg)
54 |
55 | for neighbor in neighbors:
56 | next_neighbors = [next_neighbor.GetIdx() for next_neighbor in mol.GetAtomWithIdx(neighbor).GetNeighbors() if next_neighbor.GetIdx() in match_set]
57 | for next_neighbor in next_neighbors:
58 | if next_neighbor != atom_index: # don't want to go to original atom
59 | # calculate and print dihedral angle
60 | dihedral_deg = GetDihedralDeg(conf, neighbor, atom_index, next_neighbor, neighbors[(neighbors.index(neighbor)+1) % len(neighbors)])
61 | if dihedral_deg < 0:
62 | dihedral_deg += 360
63 | match_dih.append(dihedral_deg)
64 |
65 | all_frag_angles += match_angles
66 | all_frag_dihedrals += match_dih
67 |
68 | return all_frag_angles, all_frag_dihedrals
--------------------------------------------------------------------------------
/analysis/eval_bond_length.py:
--------------------------------------------------------------------------------
1 |
2 | # taken from https://github.com/guanjq/targetdiff/blob/main/utils/evaluation/eval_bond_length.py
3 |
4 | import collections
5 | from typing import Tuple, Sequence, Dict, Optional
6 |
7 | import numpy as np
8 | from scipy import spatial as sci_spatial
9 | import matplotlib.pyplot as plt
10 |
11 | from analysis import bond_length_config
12 | from analysis import utils
13 |
14 | BondType = Tuple[int, int, int] # (atomic_num, atomic_num, bond_type)
15 | BondLengthData = Tuple[BondType, float] # (bond_type, bond_length)
16 | BondLengthProfile = Dict[BondType, np.ndarray] # bond_type -> empirical distribution
17 |
18 | def get_distribution(distances: Sequence[float], bins=bond_length_config.DISTANCE_BINS) -> np.ndarray:
19 | """ Get teh distribution of distances.
20 |
21 | Args:
22 | distances: (list) List of distances
23 | bins (list): bins of distances
24 | Returns:
25 | np.array: empirical distribution of distances with length equal to DISTANCE_BINS
26 | """
27 | bin_counts = collections.Counter(np.searchsorted(bins, distances))
28 | bin_counts = [bin_counts[i] if i in bin_counts else 0 for i in range(len(bins) + 1)]
29 | bin_counts = np.array(bin_counts) / np.sum(bin_counts)
30 | return bin_counts
31 |
32 | def _format_bond_type(bond_type: BondType) -> BondType:
33 | atom1, atom2, bond_category = bond_type
34 | if atom1 > atom2:
35 | atom1, atom2 = atom2, atom1
36 | return atom1, atom2, bond_category
37 |
38 | def get_bond_length_profile(bond_lengths: Sequence[BondLengthData]) -> BondLengthProfile:
39 | bond_length_profile = collections.defaultdict(list)
40 | for bond_type, bond_length in bond_lengths:
41 | bond_type = _format_bond_type(bond_type)
42 | bond_length_profile[bond_type].append(bond_length)
43 | bond_length_profile = {k: get_distribution(v) for k, v in bond_length_profile.items()}
44 | return bond_length_profile
45 |
46 | def _bond_type_str(bond_type: BondType) -> str:
47 | atom1, atom2, bond_category = bond_type
48 | return f'{atom1}-{atom2}|{bond_category}'
49 |
50 | def eval_bond_length_profile(bond_length_profile: BondLengthProfile) -> Dict[str, Optional[float]]:
51 | # gives the JS divergence of bond distances (different C-(C,O,N) bonds)
52 | metrics = {}
53 | for bond_type, gt_distribution in bond_length_config.EMPIRICAL_DISTRIBUTIONS.items():
54 | if bond_type not in bond_length_profile:
55 | metrics[f'JSD_{_bond_type_str(bond_type)}'] = None
56 | else:
57 | metrics[f'JSD_{_bond_type_str(bond_type)}'] = sci_spatial.distance.jensenshannon(gt_distribution,
58 | bond_length_profile[bond_type])
59 | return metrics
60 |
61 | def get_pair_length_profile(pair_lengths):
62 | cc_dist = [d[1] for d in pair_lengths if d[0] == (6,6) and d[1] < 2]
63 | all_dist = [d[1] for d in pair_lengths if d[1] < 12]
64 | pair_length_profile = {
65 | 'CC_2A': get_distribution(cc_dist, bins=np.linspace(0, 2, 100)), # distances of C-C bonds less than 2 A
66 | 'All_12A': get_distribution(all_dist, bins=np.linspace(0, 12, 100)) # all distances less than 12 A
67 | }
68 | return pair_length_profile
69 |
70 | def eval_pair_length_profile(pair_length_profile):
71 | metrics = {}
72 | for k, gt_distribution in bond_length_config.PAIR_EMPIRICAL_DISTRIBUTIONS.items():
73 | if k not in pair_length_profile:
74 | metrics[f'JSD_{k}'] = None
75 | else:
76 | metrics[f'JSD_{k}'] = sci_spatial.distance.jensenshannon(gt_distribution, pair_length_profile[k])
77 | return metrics
78 |
79 | def plot_distance_hist(pair_length_profile, metrics=None, save_path=None):
80 |
81 | gt_profile = bond_length_config.PAIR_EMPIRICAL_DISTRIBUTIONS
82 | plt.figure(figsize=(6*len(gt_profile), 4))
83 | for idx, (k, gt_distribution) in enumerate(bond_length_config.PAIR_EMPIRICAL_DISTRIBUTIONS.items()):
84 | plt.subplot(1, len(gt_profile), idx+1)
85 | x = bond_length_config.PAIR_EMPIRICAL_BINS[k]
86 | plt.step(x, gt_profile[k][1:])
87 | plt.step(x, pair_length_profile[k][1:])
88 | plt.legend(['True', 'Learned'])
89 | if metrics is not None:
90 | plt.title(f'{k} JS div: {metrics["JSD_" + k]:.4f}')
91 | else:
92 | plt.title(k)
93 |
94 | if save_path is not None:
95 | plt.savefig(save_path)
96 | else:
97 | plt.show()
98 | plt.close()
99 |
100 | def pair_distance_from_pos_v(pos, elements):
101 | pdist = pos[None, :] - pos[:,None]
102 | pdist = np.sqrt(np.sum(pdist ** 2, axis=-1))
103 | dist_list = []
104 | for s in range(len(pos)):
105 | for e in range(s+1, len(pos)):
106 | s_sym = elements[s]
107 | e_sym = elements[e]
108 | d = pdist[s, e]
109 | dist_list.append(((s_sym, e_sym), d))
110 | return dist_list
111 |
112 | def bond_distance_from_mol(mol):
113 | pos = mol.GetConformer().GetPositions()
114 | pdist = pos[None, :] - pos[:, None]
115 | pdist = np.sqrt(np.sum(pdist ** 2, axis=-1))
116 | all_distances = []
117 | for bond in mol.GetBonds():
118 | s_sym = bond.GetBeginAtom().GetAtomicNum()
119 | e_sym = bond.GetEndAtom().GetAtomicNum()
120 | s_idx, e_idx = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
121 | bond_type = utils.BOND_TYPES[bond.GetBondType()]
122 | distance = pdist[s_idx, e_idx]
123 | all_distances.append(((s_sym, e_sym, bond_type), distance))
124 | return all_distances
--------------------------------------------------------------------------------
/analysis/get_atom_types_dist.py:
--------------------------------------------------------------------------------
1 | from scipy import spatial as sci_spatial
2 | import numpy as np
3 | from collections import Counter
4 |
5 |
6 | CROSSDOCK_atom_charges = {'C':6, 'N': 7, 'O': 8, 'S': 16, 'B': 5, 'Br': 35, 'Cl': 17, 'P': 15, 'I':53 ,'F':9}
7 |
8 | def get_atom_charges(mol, charge_dict):
9 | atomic_nums = []
10 | for atom in mol.GetAtoms():
11 | atomic_nums.append(charge_dict[atom.GetSymbol()])
12 |
13 | atomic_nums = np.array(atomic_nums)
14 | return atomic_nums
15 |
16 | ATOM_TYPE_DISTRIBUTION = { # atom type distributions in CrossDock
17 | 6: 0.6715020339893559,
18 | 7: 0.11703509510732567,
19 | 8: 0.16956379168491933,
20 | 9: 0.01307879304486639,
21 | 15: 0.01113716146426898,
22 | 16: 0.01123926340861198,
23 | 17: 0.006443861300651673,
24 | }
25 |
26 | ATOM_TYPE_DISTRIBUTION_GEOM = { # atom type distributions in CrossDock
27 | 6: 0.7266496963585743,
28 | 7: 0.11690156566351215,
29 | 8: 0.11619156632264795,
30 | 9: 0.008849559988534103,
31 | 15: 0.0001854777473386173,
32 | 16: 0.022003011957949646,
33 | 17: 0.007286864677748788,
34 | 35: 0.001897001182960629,
35 | }
36 |
37 | def eval_atom_type_distribution(pred_counter: Counter, data_type='GEOM'):
38 | total_num_atoms = sum(pred_counter.values())
39 | pred_atom_distribution = {}
40 | if data_type == 'GEOM':
41 | for k in ATOM_TYPE_DISTRIBUTION_GEOM:
42 | pred_atom_distribution[k] = pred_counter[k] / total_num_atoms
43 | js = sci_spatial.distance.jensenshannon(np.array(list(ATOM_TYPE_DISTRIBUTION_GEOM.values())),
44 | np.array(list(pred_atom_distribution.values())))
45 | elif data_type == 'CrossDock':
46 | for k in ATOM_TYPE_DISTRIBUTION:
47 | pred_atom_distribution[k] = pred_counter[k] / total_num_atoms
48 | js = sci_spatial.distance.jensenshannon(np.array(list(ATOM_TYPE_DISTRIBUTION.values())),
49 | np.array(list(pred_atom_distribution.values())))
50 |
51 | return js
--------------------------------------------------------------------------------
/analysis/get_empirical_dists.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from rdkit import Chem
3 | from tqdm import tqdm
4 | import os
5 | from eval_bond_length import pair_distance_from_pos_v, bond_distance_from_mol
6 | from eval_bond_length import get_pair_length_profile, get_bond_length_profile
7 |
8 |
9 | if __name__ == '__main__':
10 |
11 | supplier = list(Chem.SDMolSupplier('/srv/ds/set-1/user/mahdi.ghorbani/FragDiff/datasets/geom_conformers.sdf'))
12 |
13 | all_pair_dists = []
14 | all_bond_dists = []
15 | for mol_id, mol in enumerate(supplier):
16 | try:
17 | pos = mol.GetConformer().GetPositions()
18 |
19 | atomicnums = []
20 | for atom in mol.GetAtoms():
21 | atomicnums.append(atom.GetAtomicNum())
22 |
23 | all_pair_dists += pair_distance_from_pos_v(pos, atomicnums)
24 | all_bond_dists += bond_distance_from_mol(mol)
25 | except:
26 | print(f'could not process mol {mol_id}')
27 |
28 | empirical_pair_length_profiles = get_pair_length_profile(all_pair_dists)
29 | empirical_bond_length_profiles = get_bond_length_profile(all_bond_dists)
30 |
31 | print(empirical_bond_length_profiles)
32 | print(empirical_pair_length_profiles)
--------------------------------------------------------------------------------
/analysis/get_volume.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import tempfile
3 | import numpy as np
4 | import os
5 |
6 | info_dict = {}
7 | root_dir = '/Users/mahdimac/Science/Keiser_lab/diffusion/AutoFragDiff/scaffolds'
8 | with tempfile.TemporaryDirectory() as tmp_dir:
9 | command = f"fpocket -f {root_dir}/1a2g.pdb"
10 | os.chdir(tmp_dir)
11 |
12 | #process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
13 | out = os.popen(command).read()
14 | #stdout, stderr = process.communicate()
15 |
16 | with open(os.path.join('1a2g_out', '1a2g_info.txt'), 'r') as fp:
17 | #file_content = fp.read()
18 |
19 | lines = fp.readlines()
20 | pocket_info_started = False
21 |
22 | for line in lines:
23 | line = line.strip()
24 | if line == "Pocket 1 :":
25 | pocket_info_started = True
26 | continue
27 | if pocket_info_started:
28 | if line == "":
29 | break
30 | key, value = line.split(":")
31 | info_dict[key.strip()] = float(value.strip())
32 |
33 | print(info_dict)
--------------------------------------------------------------------------------
/analysis/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.distance import cdist
3 | from tqdm import tqdm
4 | from rdkit import Chem, DataStructs
5 |
6 | from collections import Counter
7 | from copy import deepcopy
8 |
9 | from rdkit.Chem import AllChem, Descriptors, Crippen, Lipinski
10 | from rdkit.Chem.FilterCatalog import *
11 | from rdkit.Chem.QED import qed
12 |
13 | from analysis.SA_Score.sascorer import compute_sa_score
14 |
15 |
16 | def is_connected(mol):
17 | try:
18 | mol_frags = Chem.GetMolFrags(mol, asMols=True)
19 | except Chem.rdchem.AtomValenceException:
20 | return False
21 | if len(mol_frags) != 1:
22 | return False
23 | return True
24 |
25 | def is_valid(mol):
26 | try:
27 | Chem.SanitizeMol(mol)
28 | except:
29 | return False
30 | return True
31 |
32 | def obey_lipinski(mol):
33 | mol = deepcopy(mol)
34 | Chem.SanitizeMol(mol)
35 | rule_1 = Descriptors.ExactMolWt(mol) < 500
36 | rule_2 = Lipinski.NumHDonors(mol) <= 5
37 | rule_3 = Lipinski.NumHAcceptors(mol) <= 10
38 | logp = get_logp(mol)
39 | rule_4 = (logp >= -2) & (logp <= 5)
40 | rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
41 | return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
42 |
43 | def get_basic(mol):
44 | n_atoms = len(mol.GetAtoms())
45 | n_bonds = len(mol.GetBonds())
46 | n_rings = len(Chem.GetSymmSSSR(mol))
47 | weight = Descriptors.ExactMolWt(mol)
48 | return n_atoms, n_bonds, n_rings, weight
49 |
50 | def get_rdkit_rmsd(mol, n_conf=20, random_seed=42, mode='energy'):
51 | """
52 | Calculate the alignment of generated mol and rdkit predicted mol
53 | Return the rmsd (max, min, median) of the n_conf rdkit conformers
54 | """
55 |
56 | mol = deepcopy(mol)
57 | Chem.SanitizeMol(mol)
58 |
59 | mol_smiles = Chem.MolToSmiles(mol)
60 | mol_smiles = Chem.MolFromSmiles(mol_smiles)
61 | mol3d = Chem.AddHs(mol)
62 |
63 | rmsd_list = []
64 | conf_energies = []
65 | # predict 3d
66 | try:
67 | confIds = AllChem.EmbedMultipleConfs(mol3d, n_conf, randomSeed=random_seed)
68 | for confId in confIds:
69 | AllChem.UFFOptimizeMolecule(mol3d, confId=confId)
70 | rmsd = Chem.rdMolAlign.GetBestRMS(Chem.RemoveHs(mol), Chem.RemoveHs(mol3d), refId=confId)
71 | rmsd_list.append(rmsd)
72 | #conf_energies.append(get_conformer_energies(mol3d))
73 |
74 | mol_energy = get_conformer_energies(Chem.AddHs(mol, addCoords=True))
75 | conf_energies = get_conformer_energies(mol3d)
76 | rmsd_list = np.array(rmsd_list)
77 | conf_lowest_en = np.argmin(conf_energies)
78 |
79 | mol = Chem.AddHs(mol)
80 | new_mol = Chem.Mol(mol)
81 | new_mol.RemoveAllConformers()
82 | conf_ids = [conf.GetId() for conf in mol3d.GetConformers()]
83 | conf = mol3d.GetConformer(conf_ids[conf_lowest_en])
84 | new_mol.AddConformer(conf, assignId=True)
85 |
86 | return rmsd_list[conf_lowest_en], new_mol, conf_energies, mol_energy
87 | except:
88 | return np.nan, np.nan, np.nan, np.nan
89 |
90 | def get_logp(mol):
91 | return Crippen.MolLogP(mol)
92 |
93 | def get_chem(mol):
94 | qed_score = qed(mol)
95 | sa_score = compute_sa_score(mol)
96 | logp_score = get_logp(mol)
97 | lipinski_score = obey_lipinski(mol)
98 | ring_info = mol.GetRingInfo()
99 | ring_size = Counter([len(r) for r in ring_info.AtomRings()])
100 |
101 | return {
102 | 'qed': qed_score,
103 | 'sa': sa_score,
104 | 'logp': logp_score,
105 | 'lipinski': lipinski_score,
106 | 'ring_size': ring_size
107 | }
108 |
109 | def get_molecule_force_field(mol, conf_id=None, force_field='mmff', **kwargs):
110 | """
111 | Get a force field for a molecule.
112 | Parameters
113 | ----------
114 | mol : RDKit Mol
115 | Molecule.
116 | conf_id : int, optional
117 | ID of the conformer to associate with the force field.
118 | force_field : str, optional
119 | Force Field name.
120 | kwargs : dict, optional
121 | Keyword arguments for force field constructor.
122 | """
123 | if force_field == 'uff':
124 | ff = AllChem.UFFGetMoleculeForceField(
125 | mol, confId=conf_id, **kwargs)
126 | elif force_field.startswith('mmff'):
127 | AllChem.MMFFSanitizeMolecule(mol)
128 | mmff_props = AllChem.MMFFGetMoleculeProperties(
129 | mol, mmffVariant=force_field)
130 | ff = AllChem.MMFFGetMoleculeForceField(
131 | mol, mmff_props, confId=conf_id, **kwargs)
132 | else:
133 | raise ValueError("Invalid force_field {}".format(force_field))
134 | return ff
135 |
136 | def get_conformer_energies(mol, force_field='mmff'):
137 | """
138 | Calculate conformer energies.
139 | Parameters
140 | ----------
141 | mol : RDKit Mol
142 | Molecule.
143 | force_field : str, optional
144 | Force Field name.
145 | Returns
146 | -------
147 | energies : array_like
148 | Minimized conformer energies.
149 | """
150 | energies = []
151 | for conf in mol.GetConformers():
152 | ff = get_molecule_force_field(mol, conf_id=conf.GetId(), force_field=force_field)
153 | ff.Minimize()
154 | energy = ff.CalcEnergy()
155 | energies.append(energy)
156 | energies = np.asarray(energies, dtype=float)
157 | return energies
158 |
--------------------------------------------------------------------------------
/analysis/qvina_docking.py:
--------------------------------------------------------------------------------
1 | from joblib import Parallel, delayed
2 | import os
3 | from pathlib import Path
4 | import random
5 | import shutil
6 | import re
7 | import glob
8 |
9 | from tqdm import tqdm
10 | import numpy as np
11 |
12 | from rdkit import Chem
13 | from rdkit.Chem import AllChem
14 | from rdkit import RDLogger
15 | import pandas as pd
16 | import torch
17 |
18 | affinity_pattern = r"Affinity:\s+(-?\d+\.\d+)\s+\(kcal/mol\)"
19 | RDLogger.DisableLog('rdApp.*')
20 |
21 | def sdf_to_pdbqt(sdf_file, pdbqt_outfile, mol_id):
22 | os.popen(f'obabel {sdf_file} -O {pdbqt_outfile} -f {0} -l {mol_id} -m').read()
23 | return pdbqt_outfile
24 |
25 | def get_vina_dock_score(receptor_pdbqt_file, ligand_pdbqt_file, cx, cy, cz, size):
26 | # Vina docking and getting the vina score
27 | out = os.popen(
28 | f'qvina2.1 --receptor {receptor_pdbqt_file} '
29 | f'--ligand {ligand_pdbqt_file} '
30 | f'--center_x {cx:.4f} --center_y {cy:.4f} --center_z {cz:.4f} '
31 | f'--size_x {size} --size_y {size} --size_z {size} --exhaustiveness 16'
32 | ).read()
33 | out_split = out.splitlines()
34 | best_idx = out_split.index('-----+------------+----------+----------') + 1
35 | best_line = out_split[best_idx].split()
36 | print('\n best Affinity:', float(best_line[1]))
37 | return float(best_line[1])
38 |
39 | def get_vina_score(receptor_pdbqt_file, ligand_pdbqt_file, cx, cy, cz, size):
40 | # TODO: using QVina to get vina scores gives weird results. Use Vina
41 | # scores the generated poses without docking them
42 | out = os.popen(
43 | f'qvina2.1 --score_only --receptor {receptor_pdbqt_file} '
44 | f'--ligand {ligand_pdbqt_file} '
45 | f'--center_x {cx:.4f} --center_y {cy:.4f} --center_z {cz:.4f} '
46 | f'--size_x {size} --size_y {size} --size_z {size}'
47 | ).read()
48 | match = re.search(affinity_pattern, out)
49 | affinity_value = float(match.group(1))
50 | print('vina score is:', affinity_value)
51 | return affinity_value
52 |
53 | def process_vina_iteration(n, save_file, receptor_pdbqt_file, mol_pos, size, result_type='vina_score'):
54 | pdbqt_file = save_file + 'all_mols_' + str(n) + '.pdbqt'
55 | cx, cy, cz = mol_pos
56 | if result_type == 'vina_score':
57 | affinity_value = get_vina_score(receptor_pdbqt_file, pdbqt_file, cx, cy, cz, size)
58 | elif result_type == 'dock_score':
59 | affinity_value = get_vina_dock_score(receptor_pdbqt_file, pdbqt_file, cx, cy, cz, size)
60 | return affinity_value
61 |
62 |
63 |
--------------------------------------------------------------------------------
/analysis/scoring_func.py:
--------------------------------------------------------------------------------
1 |
2 | from collections import Counter
3 | from copy import deepcopy
4 |
5 | import numpy as np
6 | from rdkit.Chem import AllChem, Descriptors, Crippen, Lipinski
7 | from rdkit.Chem.FilterCatalog import *
8 | from rdkit.Chem.QED import qed
9 |
10 | def obey_lipinski(mol):
11 | # compute the lipinski score
12 | mol = deepcopy(mol)
13 | Chem.SanitizeMol(mol)
14 | rule_1 = Descriptors.ExactMolWt(mol) < 500
15 | rule_2 = Lipinski.NumHDonors(mol) <= 5
16 | rule_3 = Lipinski.NumHAcceptors(mol) <= 10
17 | logp = get_logp(mol)
18 | rule_4 = (logp >= -2) & (logp <= 5)
19 | rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
20 | return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
21 |
22 | def get_basic(mol):
23 | # return n_atoms, bonds, rings, MW
24 | n_atoms = len(mol.GetAtoms())
25 | n_bonds = len(mol.GetBonds())
26 | n_rings = len(Chem.GetSymmSSSR(mol))
27 | weight = Descriptors.ExactMolWt(mol)
28 | return n_atoms, n_bonds, n_rings, weight
29 |
30 | def get_rdkit_rmsd(mol, n_conf=20, random_seed=42):
31 | # return [max_rmsd, min_rmsd, median_rmsd]
32 | """
33 | calculate the alignment of generated mol and rdkit predicted mol
34 | Return the rmsd (max, min, median) of the `n_conf` rdkit conformers
35 | """
36 | mol = deepcopy(mol)
37 | Chem.SanitizeMol(mol)
38 | mol3d = Chem.AddHs(mol) # TODO: may need to add hydrogens in a different way
39 | rmsd_list = []
40 | # predict 3d
41 | try:
42 | confIds = AllChem.EmbedMultipleConfs(mol3d, n_conf, randomSeed=random_seed)
43 | for confId in confIds:
44 | AllChem.UFFOptimizeMolecule(mol3d, confId=confId)
45 | rmsd = Chem.rdMolAlign.GetBestRMS(mol, mol3d, refId=confId)
46 | rmsd_list.append(rmsd)
47 | rmsd_list = np.array(rmsd_list)
48 | return [np.max(rmsd_list), np.min(rmsd_list), np.median(rmsd_list)]
49 | except:
50 | return [np.nan, np.nan, np.nan]
51 |
52 | def get_logp(mol):
53 | return Crippen.MolLogP(mol)
54 |
55 | def get_molecule_force_field(mol, conf_id=None, force_field='mmff', **kwargs):
56 | """
57 | Get a force field for a molecule.
58 | Parameters
59 | ----------
60 | mol : RDKit Mol
61 | Molecule.
62 | conf_id : int, optional
63 | ID of the conformer to associate with the force field.
64 | force_field : str, optional
65 | Force Field name.
66 | kwargs : dict, optional
67 | Keyword arguments for force field constructor.
68 | """
69 | if force_field == 'uff':
70 | ff = AllChem.UFFGetMoleculeForceField(
71 | mol, confId=conf_id, **kwargs)
72 | elif force_field.startswith('mmff'):
73 | AllChem.MMFFSanitizeMolecule(mol)
74 | mmff_props = AllChem.MMFFGetMoleculeProperties(
75 | mol, mmffVariant=force_field)
76 | ff = AllChem.MMFFGetMoleculeForceField(
77 | mol, mmff_props, confId=conf_id, **kwargs)
78 | else:
79 | raise ValueError("Invalid force_field {}".format(force_field))
80 | return ff
81 |
82 | def get_conformer_energies(mol, force_field='mmff'):
83 | """
84 | Calculate conformer energies.
85 | Parameters
86 | ----------
87 | mol : RDKit Mol
88 | Molecule.
89 | force_field : str, optional
90 | Force Field name.
91 | Returns
92 | -------
93 | energies : array_like
94 | Minimized conformer energies.
95 | """
96 | energies = []
97 | for conf in mol.GetConformers():
98 | ff = get_molecule_force_field(mol, conf_id=conf.GetId(), force_field=force_field)
99 | energy = ff.CalcEnergy()
100 | energies.append(energy)
101 | energies = np.asarray(energies, dtype=float)
102 | return energies
--------------------------------------------------------------------------------
/analysis/similarity.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from rdkit import Chem, DataStructs
3 |
4 | def tanimoto_sim(mol, ref):
5 | fp1 = Chem.RDKFingerprint(ref)
6 | fp2 = Chem.RDKFingerprint(mol)
7 | return DataStructs.TanimotoSimilarity(fp1, fp2)
8 |
9 | def tanimoto_sim_N_to_1(mols, ref):
10 | sim = [tanimoto_sim(m, ref) for m in mols]
11 | return sim
12 |
13 | def batched_number_of_rings(mols):
14 | n = []
15 | for m in mols:
16 | n.append(Chem.rdMolDescriptors.CalcNumRings(m))
17 | return np.array(n)
--------------------------------------------------------------------------------
/analysis/utils.py:
--------------------------------------------------------------------------------
1 | from rdkit.Chem.rdchem import BondType
2 | from rdkit.Chem import ChemicalFeatures
3 | from rdkit import RDConfig
4 |
5 | ATOM_FAMILIES = ['Acceptor', 'Donor', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe', 'NegIonizable', 'PosIonizable',
6 | 'ZnBinder']
7 | ATOM_FAMILIES_ID = {s: i for i, s in enumerate(ATOM_FAMILIES)}
8 | BOND_TYPES = {
9 | BondType.UNSPECIFIED: 0,
10 | BondType.SINGLE: 1,
11 | BondType.DOUBLE: 2,
12 | BondType.TRIPLE: 3,
13 | BondType.AROMATIC: 4,
14 | }
15 |
--------------------------------------------------------------------------------
/analyze_scaffolds_generated.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import pandas as pd
4 | import os
5 | from tqdm import tqdm
6 | import argparse
7 | import re
8 |
9 | from rdkit import Chem
10 | from rdkit.Chem import AllChem
11 | from rdkit import RDLogger
12 | from openbabel import openbabel
13 |
14 | from analysis import eval_bond_length
15 | from analysis.reconstruct_mol import reconstruct_from_generated, MolReconsError
16 | from analysis.metrics import is_connected, is_valid, get_chem
17 | from analysis.eval_bond_angles import get_distribution, eval_angle_dist_profile, find_angle_dist
18 | from analysis.vina_docking import VinaDockingTask
19 | from joblib import Parallel, delayed
20 |
21 | from src.utils import get_logger
22 | import collections
23 | import torch
24 |
25 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9}
26 | idx2atom = {0:'C', 1:'N', 2:'O', 3:'S', 4:'B', 5:'Br', 6:'Cl', 7:'P', 8:'I', 9:'F'}
27 |
28 | CROSSDOCK_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'B':5, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15}
29 |
30 | def print_dict(d, logger):
31 | for k, v in d.items():
32 | if v is not None:
33 | logger.info(f'{k}:\t{v:4f}')
34 | else:
35 | logger.info(f'{k}\tNone')
36 |
37 | def print_ring_ratio(all_ring_sizes, logger):
38 | for ring_size in range(3, 10):
39 | n_mol = 0
40 | for counter in all_ring_sizes:
41 | if ring_size in counter:
42 | n_mol += 1
43 | logger.info(f'ring size: {ring_size} ratio: {n_mol / len(all_ring_sizes):.3f}')
44 |
45 |
46 | if __name__ == '__main__':
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument('--results-path', type=str, default='results_scaffold',
49 | help='path to save the scaffold based optimization')
50 | parser.add_argument('--scaffold-path', type=str, default='scaffolds/1a2g_scaff.sdf',
51 | help='path to sdf of scaffold')
52 | parser.add_argument('--original-path', type=str, default='scaffolds/1a2g_orig.sdf',
53 | help='path to original molecule')
54 | parser.add_argument('--receptor-path', type=str, default='scaffolds/1a2g.pdb',
55 | help='path to pdb file of receptor')
56 | parser.add_argument('--docking_mode', type=str, choices=['qvina', 'vina_score', 'vina_dock', 'None'])
57 | parser.add_argument('--exhaustiveness', type=int, default=16)
58 | parser.add_argument('--verbose', type=eval, default=False)
59 | parser.add_argument('--n-mols-per-file', type=int, default=20, help='number of molecules per each file')
60 | parser.add_argument('--crossdock-dir', type=str, default='/srv/home/mahdi.ghorbani/FragDiff/crossdock')
61 |
62 | args = parser.parse_args()
63 | results_path = args.results_path
64 | n_mols_per_file = args.n_mols_per_file
65 | eval_path = os.path.join(results_path, 'eval_results')
66 | root_dir = args.crossdock_dir
67 |
68 | scaffold_path = args.scaffold_path # sdf file of scaffold
69 | receptor_path = args.receptor_path # pdb file of receptor
70 |
71 | os.makedirs(eval_path, exist_ok=True)
72 | logger = get_logger('evaluate', log_dir=eval_path)
73 |
74 | if not args.verbose:
75 | RDLogger.DisableLog('rdApp.*')
76 |
77 | valid_mols = 0
78 | connected_mols = 0
79 | results = []
80 |
81 | n_files = 0
82 | n_samples = 0
83 |
84 | scaff_mol = Chem.SDMolSupplier(scaffold_path)[0]
85 | orig_mol = Chem.SDMolSupplier(args.original_path)[0]
86 |
87 | # compute vina score for the scaffold
88 | vina_task = VinaDockingTask.from_generated_mol(orig_mol, protein_path=receptor_path)
89 | score_result = vina_task.run(mode='score_only', exhaustiveness=16)
90 | scaffold_score = score_result[0]['affinity']
91 | print('------> Vina score for original molecule is : ', scaffold_score)
92 |
93 | # compute vina score for the scaffold
94 | vina_task = VinaDockingTask.from_generated_mol(scaff_mol, protein_path=receptor_path)
95 | score_result = vina_task.run(mode='score_only', exhaustiveness=16)
96 | scaffold_score = score_result[0]['affinity']
97 | print('------> Vina score for scaffold is : ', scaffold_score)
98 |
99 |
100 | for n in tqdm(range(10), desc='Eval'):
101 | prot_path = receptor_path
102 | if os.path.exists(results_path + 'pocket_' + str(n) + '_coords.npy'):
103 |
104 | n_files += 1
105 | x = np.load(results_path + 'pocket_' + str(n) + '_coords.npy')
106 | h = np.load(results_path + 'pocket_' + str(n) + '_onehot.npy')
107 | mol_masks = np.load(results_path + 'pocket_' + str(n) + '_mol_masks.npy')
108 |
109 | all_mols = []
110 | for k in range(len(x)):
111 |
112 | mask = mol_masks[k]
113 | h_mol = h[k]
114 | x_mol = x[k][mask.astype(np.bool_)]
115 |
116 | atom_inds = h_mol[mask.astype(np.bool_)].argmax(axis=1)
117 | atom_types = [idx2atom[x] for x in atom_inds]
118 | atomic_nums = [CROSSDOCK_CHARGES[i] for i in atom_types]
119 |
120 | #all_validity_results.append(validity_results)
121 | n_samples += 1
122 | try:
123 | mol_rec = reconstruct_from_generated(x_mol.tolist(), atomic_nums, aromatic=None, basic_mode=True)
124 | smiles = Chem.MolToSmiles(mol_rec)
125 | Chem.SanitizeMol(mol_rec)
126 |
127 | except Exception as e:
128 | print(e)
129 | continue
130 | valid_mols += 1
131 |
132 | if is_connected(mol_rec):
133 | connected_mols += 1
134 | else:
135 | # if the molecule is not connected, then take the largest fragment
136 | m_frags = Chem.GetMolFrags(mol_rec, asMols=True, sanitizeFrags=False)
137 | mol_rec = max(m_frags, default=mol_rec, key=lambda m: m.GetNumAtoms())
138 |
139 | chem_results = get_chem(mol_rec) # a dictionary with qed, sa, logp, lipinski, ring_size
140 |
141 | # --------------------------- Getting Vina Docking results ---------------------------
142 | try:
143 | if args.docking_mode == 'qvina':
144 | pass # TODO: add the qvina like in TargetDiff
145 | elif args.docking_mode in ['vina_score', 'vina_dock']:
146 | vina_task = VinaDockingTask.from_generated_mol(mol_rec, protein_path=prot_path)
147 | score_only_results = vina_task.run(mode='score_only', exhaustiveness=args.exhaustiveness)
148 | minimize_results = vina_task.run(mode='minimize', exhaustiveness=args.exhaustiveness)
149 | print('score_only: ', score_only_results[0]['affinity'])
150 | print('minimized score: ', minimize_results[0]['affinity'])
151 | vina_results = {
152 | 'score_only': score_only_results,
153 | 'minimize': minimize_results
154 | }
155 | if args.docking_mode == 'vina_dock':
156 | docking_results = vina_task.run(mode='dock', exhaustiveness=args.exhaustiveness)
157 | vina_results['dock'] = docking_results
158 | print('vina dock: ', docking_results[0]['affinity'])
159 | else:
160 | vina_results = None
161 |
162 | except:
163 | if args.verbose:
164 | logger.warning(f'Docking failed for pocket {n} and molecule {k}')
165 | continue
166 |
167 | results.append({
168 | 'mol': mol_rec,
169 | 'smiles': smiles,
170 | 'chem_results': chem_results,
171 | 'vina' :vina_results
172 | })
173 |
174 | logger.info(f'Evaluation is done! {n_samples} samples in total')
175 |
176 | fraction_valid = valid_mols / n_samples
177 | fraction_connected = connected_mols / n_samples
178 |
179 | print('fraction_connected is: ', fraction_connected)
180 | print('fraction_valid is :' , fraction_valid)
181 |
182 | qed = [r['chem_results']['qed'] for r in results]
183 | sa = [r['chem_results']['sa'] for r in results]
184 | logger.info('QED: Mean: %.3f Median: %.3f std: %.3f' % (np.mean(qed), np.median(qed), np.std(qed)))
185 | logger.info('SA: Mean: %.3f Median: %.3f std: %.3f' % (np.mean(qed), np.median(sa), np.std(sa)))
186 |
187 | if args.docking_mode == 'qvina':
188 | vina = [r['vina'[0]]['affinity'] for r in results]
189 | logger.info('Vina: Mean: %.3f Median: %.3f Std: %.3f' %(np.mean(vina), np.median(vina), np.std(vina)))
190 | elif args.docking_mode in ['vina_dock', 'vina_score']:
191 | vina_score_only = [r['vina']['score_only'][0]['affinity'] for r in results]
192 | vina_min = [r['vina']['minimize'][0]['affinity'] for r in results]
193 | logger.info('Vina Score : Mean %.3f Median: %.3f Std: %.3f' % (np.mean(vina_score_only), np.median(vina_score_only), np.std(vina_score_only)))
194 | logger.info('Vina minimized : Mean %.3f Median: %.3f Std: %.3f' % (np.mean(vina_min), np.median(vina_min), np.std(vina_min)))
195 | if args.docking_mode == 'vina_dock':
196 | vina_dock = [r['vina']['dock'][0]['affinity'] for r in results]
197 | logger.info('Vina Dock : Mean: %.3f Median: %.3f Std: %.3f' % (np.mean(vina_dock), np.median(vina_dock), np.std(vina_dock)))
198 |
199 | print_ring_ratio([r['chem_results']['ring_size'] for r in results], logger)
200 |
201 | torch.save({
202 | 'fraction_connected': fraction_connected,
203 | 'fraction_valid': fraction_valid,
204 | 'all_results': results,
205 | }, os.path.join(eval_path, 'metrics.pt'))
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/movie.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/assets/movie.gif
--------------------------------------------------------------------------------
/assets/scaffold_optim.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/assets/scaffold_optim.png
--------------------------------------------------------------------------------
/data/CROSSDOCK/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/data/CROSSDOCK/__init__.py
--------------------------------------------------------------------------------
/data/CROSSDOCK/fragment_hierarchy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from rdkit import Chem
3 | import networkx as nx
4 | from prepare_fragments import *
5 |
6 | def find_neigh_frags_to_neigh_atoms(bonds_broken_frags, bonds_broken_frag_ids):
7 | """ find a mapping between broken bonds (atom tuples) and their corresponding fragment ids (index tuple)
8 | bonds_broken_frags: -> tuple of bonds broken btween fragments
9 | bonds_broken_frag_ids -> ids of bonds broken between fragments
10 | """
11 | # a dictionary mapping from fragments (tuple) to atom ids (tuple) in the original moleucle
12 | neigh_frags_to_neigh_atoms = {}
13 | for i, bond in enumerate(bonds_broken_frags):
14 | neigh_frags_to_neigh_atoms[bonds_broken_frag_ids[i]] = bond
15 | neigh_frags_to_neigh_atoms[bonds_broken_frag_ids[i][::-1]] = bond[::-1]
16 |
17 | return neigh_frags_to_neigh_atoms
18 |
19 | class FragmentGraph():
20 | """ class for fragment graph
21 | """
22 | def __init__(self, mol, fragments, adjacency, frag_atom_ids, frag_to_id_dict, neigh_frags_to_neigh_atoms):
23 | self.original_mol = mol
24 | self.graph = nx.Graph()
25 | self.fragments = fragments
26 | self.frag_atom_ids = frag_atom_ids
27 | self.frag_to_id_dict = frag_to_id_dict
28 | self.conformer = mol.GetConformer().GetPositions() # position of atoms in the original molecule
29 | self.all_mol_atom_symbols = self.get_all_mol_atom_symbols() # array of atom symbols in the molecule
30 | self.all_mol_atom_charges = [] # including charge of atoms may not be wise since we are doint autoregressive generation
31 | self.neigh_frags_to_neigh_atoms = neigh_frags_to_neigh_atoms
32 |
33 | fragment_bonds = np.argwhere(np.triu(adjacency))
34 |
35 | for i, f in enumerate(fragments):
36 | self.graph.add_node(f, name='frag_' + str(i))
37 |
38 | for bond in fragment_bonds:
39 | self.graph.add_edge(fragments[bond[0]], fragments[bond[1]])
40 |
41 | def draw_graph(self, graph):
42 | labels = nx.get_node_attributes(graph, 'name')
43 | nx.draw_circular(graph, labels=labels, node_size=3000)
44 |
45 | def get_bfs_order(self, starting_point=None):
46 | """ returns a list of tuples of fragments that should be connected to traverse the graph
47 | in a BFS order
48 | """
49 | if starting_point is None:
50 | starting_point = 0
51 | starting_frag = self.fragments[starting_point]
52 | bfs_edges = list(nx.bfs_edges(self.graph, starting_frag))
53 | return bfs_edges
54 |
55 | def get_dfs_order(self, starting_point=None):
56 | """ return a list of tuples of fragments that should be connected to traverse
57 | in DFS order
58 | """
59 | if starting_point is None:
60 | starting_point = 0
61 | starting_frag = self.fragments[starting_point]
62 | dfs_edges = list(nx.dfs_edges(self.graph, starting_frag))
63 | return dfs_edges
64 |
65 | def hierarchical_reconstruct(self, edge_order='BFS', starting_point=None):
66 | """
67 | Returns the reconstruction of the molecule in the order given in edge_order
68 |
69 | if edge_order is given the reconstruction is based on edge_order
70 | else reconstructoin is based on BFS order with starting point
71 |
72 | Returns:
73 | hierarchical_mol : hierarchical molecule built in BFS order
74 | atom_ids_hierarchical: ids of atoms added in the BFS order
75 | """
76 | if starting_point is None:
77 | starting_point = 0
78 |
79 | if edge_order == 'BFS':
80 | edge_list = self.get_bfs_order(starting_point)
81 | elif edge_order == 'DFS':
82 | edge_list = self.get_dfs_order(starting_point)
83 | else:
84 | raise ValueError('edge order not found.')
85 |
86 | tmp = edge_list[0][0] # this is a mol
87 | tmp_id = self.frag_to_id_dict[tmp] # id of the fragment
88 | hierarchical_mol = [tmp] # the initial molecule
89 |
90 | # ------------- find the atom ids in hier ------------
91 | atom_ids_hierarchical = [] # a set of atoms
92 | tmp_frag_atom_ids = self.frag_atom_ids[tmp_id]
93 | atom_ids_hierarchical.append(tmp_frag_atom_ids)
94 |
95 |
96 | # -------------- find the conformer in hier -----------
97 | hierarchical_conformer = [] # hierarchical conformeration of the molecule
98 | first_frag_conformer = self.transfer_conformer(tmp_frag_atom_ids)
99 | hierarchical_conformer.append(first_frag_conformer)
100 |
101 |
102 | # --------------- find the atom symbols in hier -------------
103 | hier_atom_symbol = []
104 | first_frag_symbols = self.all_mol_atom_symbols[list(tmp_frag_atom_ids)]
105 | hier_atom_symbol.append(first_frag_symbols)
106 |
107 | all_anchor_ids = []
108 | first_frag_id = tmp_id
109 | extensions_atom_ids = [self.frag_atom_ids[first_frag_id]]
110 |
111 | for edge in edge_list:
112 |
113 | tmp = Chem.CombineMols(tmp, edge[1])
114 | hierarchical_mol.append(tmp)
115 | frag_id = self.frag_to_id_dict[edge[1]] # id of the next fragment to add
116 | index_of_two_frags = (self.frag_to_id_dict[edge[0]], self.frag_to_id_dict[edge[1]])
117 | tmp_frag_atom_ids = tmp_frag_atom_ids.union(self.frag_atom_ids[frag_id])
118 | extensions_atom_ids.append(self.frag_atom_ids[frag_id])
119 | atom_ids_hierarchical.append(tmp_frag_atom_ids)
120 |
121 | anchor_idx = self.neigh_frags_to_neigh_atoms[index_of_two_frags][0]
122 | all_anchor_ids.append(anchor_idx)
123 |
124 | conformer_at_this_step = self.transfer_conformer(tmp_frag_atom_ids)
125 | hierarchical_conformer.append(conformer_at_this_step)
126 |
127 | hier_atom_symbol.append(self.all_mol_atom_symbols[list(tmp_frag_atom_ids)])
128 |
129 |
130 | return hierarchical_mol, atom_ids_hierarchical, extensions_atom_ids, hierarchical_conformer, hier_atom_symbol, all_anchor_ids
131 |
132 |
133 | def transfer_conformer(self, atom_ids):
134 |
135 | conformer_at_this_step = self.conformer[list(atom_ids)]
136 | return conformer_at_this_step
137 |
138 | def get_all_mol_atom_symbols(self):
139 | mol_atom_symbols = []
140 | for atom in self.original_mol.GetAtoms():
141 | mol_atom_symbols.append(atom.GetSymbol())
142 | mol_atom_symbols = np.array(mol_atom_symbols)
143 | return mol_atom_symbols
144 |
145 | def get_anchor_idx(self):
146 | pass
147 |
148 |
149 | @staticmethod
150 | def draw_fragment_graph(all_frags, fragment_bonds):
151 |
152 | G = nx.Graph()
153 | for i, f in enumerate(all_frags):
154 | img = Draw.MolToImage(f)
155 | img.save('frag_' + str(i)+ '.png')
156 |
157 | for i,f in enumerate(all_frags):
158 | G.add_node(f, name='frag_'+str(i), img=plt.imread('frag_' + str(i) + '.png'))
159 |
160 | for bond in fragment_bonds:
161 | G.add_edge(all_frags[bond[0]], all_frags[bond[1]])
162 |
163 | pos = nx.circular_layout(G)
164 | fig = plt.figure(figsize=(12,10))
165 | ax = plt.subplot(111)
166 | ax.set_aspect('equal')
167 | nx.draw_networkx_edges(G, pos, ax=ax, edge_color='black', width=2.)
168 |
169 | #plt.ylim(-4.5,4.5)
170 | trans=ax.transData.transform
171 | trans2=fig.transFigure.inverted().transform
172 |
173 | piesize=0.2 # this is the image size
174 | p2=piesize/2.0
175 | for n in G:
176 | xx,yy=trans(pos[n]) # figure coordinates
177 | xa,ya=trans2((xx,yy)) # axes coordinates
178 | a = plt.axes([xa-p2,ya-p2, piesize, piesize])
179 | a.set_aspect('equal')
180 | a.imshow(G.nodes[n]['img'])
181 | a.axis('off')
182 | ax.axis('off')
183 | plt.show()
184 |
185 | @staticmethod
186 | def draw_hier_recons(hier):
187 | graph = nx.Graph()
188 | for i, f in enumerate(hier):
189 | img = Draw.MolToImage(f)
190 | img.save('frag_' + str(i)+ '.png')
191 |
192 | for i, f in enumerate(hier):
193 | graph.add_node(f, name='frag_'+str(i), img=plt.imread('frag_' + str(i) + '.png'))
194 |
195 | for i in range(len(hier)-1):
196 | graph.add_edge(hier[i], hier[i+1])
197 |
198 | pos = nx.circular_layout(graph, dim=2,scale=2.5)
199 | fig = plt.figure(figsize=(12,10))
200 | ax = plt.subplot(111)
201 | ax.set_aspect('equal')
202 | nx.draw_networkx_edges(graph, pos, ax=ax, node_size=40, edge_color='blue', width=1.5, arrows=True, arrowsize=30)
203 |
204 |
205 | #plt.ylim(-4.5,4.5)
206 | trans=ax.transData.transform
207 | trans2=fig.transFigure.inverted().transform
208 |
209 | piesize=0.22 # this is the image size
210 | p2=piesize/2.
211 | for n in graph:
212 | xx,yy=trans(pos[n]) # figure coordinates
213 | xa,ya=trans2((xx,yy)) # axes coordinates
214 | a = plt.axes([xa-p2,ya-p2, piesize, piesize])
215 | a.set_aspect('equal')
216 | a.imshow(graph.nodes[n]['img'])
217 | a.axis('off')
218 | ax.axis('off')
219 | plt.show()
220 |
221 | @staticmethod
222 | def get_smiles(mol):
223 | return Chem.MolToSmiles(mol)
224 |
225 | @staticmethod
226 | def get_mol_from_smiles(smiles):
227 | return Chem.MolFromSmiles(smiles)
228 |
229 |
--------------------------------------------------------------------------------
/data/CROSSDOCK/process_ligands.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import re
4 | import argparse
5 |
6 | from rdkit import Chem
7 | from tqdm import tqdm
8 |
9 | from fragment_hierarchy import *
10 | from prepare_fragments import *
11 | from sascorer import calculateScore
12 | from rdkit.Chem import rdchem
13 |
14 | import os
15 | from pathlib import Path
16 |
17 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9}
18 | atom_charges = {'C':6, 'N': 7, 'O': 8, 'S': 16, 'B':5, 'Br':35, 'Cl':17, 'P':15, 'I':53, 'F':9}
19 | hybrid_to_onehot = {'SP':0, 'SP2': 1, 'SP3': 2}
20 |
21 | def get_one_hot(atom, atoms_dict):
22 | one_hot = np.zeros(len(atoms_dict))
23 | one_hot[atoms_dict[atom]] = 1
24 | return one_hot
25 |
26 | def process_ligand(sdffile, max_num_frags=12, num_atoms_cutoff=22, add_QED=True, add_SA=True):
27 | try:
28 | mol = Chem.SDMolSupplier(str(sdffile))[0]
29 |
30 | add_Hs= True
31 |
32 | # get positions and ont-hot
33 | mol_pos = mol.GetConformer().GetPositions()
34 | all_symbols = []
35 | mol_onehot = []
36 | mol_onehot = []
37 | for atom in mol.GetAtoms():
38 | #all_symbols.append(atom.GetSymbol())
39 | atom_symb_onehot = get_one_hot(atom.GetSymbol(), atom_dict)
40 | hyb_onehot = np.eye(1,len(hybrid_to_onehot), hybrid_to_onehot[str(atom.GetHybridization())]).squeeze()
41 | aromatic_onehot = float(atom.GetIsAromatic())
42 | mol_onehot.append(np.concatenate([atom_symb_onehot, hyb_onehot, (aromatic_onehot,)]))
43 |
44 | # NOTE: adding extra node features (aromaticity and hybridization) see if these help
45 | mol_onehot = np.array(mol_onehot)
46 |
47 | mol_charges = []
48 | for atom in all_symbols:
49 | mol_charges.append(atom_charges[atom])
50 |
51 | # get charges
52 | mol_charges = np.array(mol_charges)
53 |
54 | output = find_bonds_broken_with_frags(mol, find_single_ring_fragments(mol), max_num_frags=max_num_frags, max_num_atoms_single_frag=num_atoms_cutoff)
55 | if output is not None:
56 | all_frags, bonds_broken_frags, bonds_broken_indices, \
57 | bonds_broken_frag_ids, all_frag_atom_ids, atom2frag = output
58 |
59 | # -------------- get the smiles of fragments for making a fragment library
60 | du = Chem.MolFromSmiles('*')
61 | frag_smiles_temp = [Chem.MolFromSmiles(Chem.MolToSmiles(all_frags[i])) for i in range(len(all_frags))]
62 |
63 | frag_smiles = []
64 | frag_n_atoms = []
65 | for i in range(len(all_frags)):
66 | frag=AllChem.ReplaceSubstructs(frag_smiles_temp[i],du,Chem.MolFromSmiles('[H]'),True)[0]
67 | frag = Chem.RemoveAllHs(frag)
68 | frag_n_atoms.append(frag.GetNumAtoms())
69 | frag_smiles.append(Chem.MolToSmiles(frag))
70 | # --------------------------------------------------------------------
71 |
72 | if len(all_frags) > 1: # more than 1 fragment exists in the molecule
73 | adjacency = find_neighboring_frags(all_frags, atom2frag, bonds_broken_frags)
74 | neigh_frags_to_neigh_atoms = find_neigh_frags_to_neigh_atoms(bonds_broken_frags, bonds_broken_frag_ids)
75 |
76 | frag_to_id_dict = {}
77 | for i,frag in enumerate(all_frags):
78 | frag_to_id_dict[frag] = i
79 |
80 | g = FragmentGraph(mol,
81 | all_frags,
82 | adjacency,
83 | frag_atom_ids=all_frag_atom_ids,
84 | frag_to_id_dict=frag_to_id_dict,
85 | neigh_frags_to_neigh_atoms=neigh_frags_to_neigh_atoms)
86 |
87 | n_frags = len(all_frags)
88 |
89 | assert n_frags <= max_num_frags
90 |
91 | mol_atom_ids = []
92 | mol_extension_ids = []
93 | mol_anchor_ids = []
94 | mol_QED_scores = []
95 | mol_SA_scores = []
96 | all_sub_mols = []
97 |
98 | for order in ['BFS', 'DFS']:
99 | for j in range(n_frags): # 5 different ways to reconstruct the molecule in total
100 | hier, perm_atom_ids, perm_extensions_atom_ids, _, _, perm_anchor_ids = g.hierarchical_reconstruct(edge_order=order, starting_point=j)
101 | # save this hierarchy
102 |
103 | assert len(perm_atom_ids) != 0
104 | assert len(perm_anchor_ids) != 0
105 | assert len(perm_extensions_atom_ids) != 0
106 |
107 | assert len(perm_extensions_atom_ids) == len(perm_atom_ids)
108 | assert (len(perm_extensions_atom_ids) == len(perm_anchor_ids) + 1)
109 |
110 | num_atoms = [len(perm_extensions_atom_ids[i]) for i in range(len(perm_extensions_atom_ids))]
111 | max_num_atoms = max(num_atoms)
112 | if max_num_atoms > num_atoms_cutoff:
113 | break
114 | mol_atom_ids.append(perm_atom_ids)
115 | mol_extension_ids.append(perm_extensions_atom_ids)
116 | mol_anchor_ids.append(perm_anchor_ids)
117 |
118 | QED_scores = []
119 | SA_scores = []
120 | if add_QED:
121 | for i in range(len(all_frags)):
122 | atom_indices = list(perm_atom_ids[i])
123 | sub_mol = rdchem.EditableMol(Chem.Mol())
124 | atom_map = {}
125 | for atom_idx in atom_indices:
126 | atom = mol.GetAtomWithIdx(atom_idx)
127 | new_idx = sub_mol.AddAtom(atom)
128 | atom_map[atom_idx] = new_idx
129 |
130 | for bond in mol.GetBonds():
131 | begin_idx = bond.GetBeginAtomIdx()
132 | end_idx = bond.GetEndAtomIdx()
133 | if begin_idx in atom_indices and end_idx in atom_indices:
134 | bond_type = bond.GetBondType()
135 | sub_mol.AddBond(atom_indices.index(begin_idx), atom_indices.index(end_idx), bond_type)
136 |
137 |
138 | sub_mol = sub_mol.GetMol()
139 | # Adding 3d Coordinates to the fragments
140 | try:
141 | Chem.SanitizeMol(sub_mol)
142 | conf = Chem.Conformer(sub_mol.GetNumAtoms())
143 | for atom_idx, new_atom_idx in atom_map.items():
144 | conf.SetAtomPosition(new_atom_idx, mol.GetConformer().GetAtomPosition(atom_idx))
145 | sub_mol.AddConformer(conf)
146 | except:
147 | print('sanitization failed! using smarts instead!')
148 | sub_mol = Chem.MolFromSmarts(Chem.MolToSmarts(sub_mol))
149 | Chem.SanitizeMol(sub_mol)
150 | conf = Chem.Conformer(sub_mol.GetNumAtoms())
151 | for atom_idx, new_atom_idx in atom_map.items():
152 | conf.SetAtomPosition(new_atom_idx, mol.GetConformer().GetAtomPosition(atom_idx))
153 | sub_mol.AddConformer(conf)
154 |
155 | #sub_mol = Chem.MolFromSmarts(Chem.MolToSmarts(sub_mol))
156 | #Chem.SanitizeMol(sub_mol)
157 | if add_Hs:
158 | sub_mol_h = Chem.AddHs(sub_mol, addCoords=True)
159 |
160 | all_sub_mols.append(sub_mol_h)
161 |
162 | QED_scores.append(Chem.QED.qed(sub_mol))
163 | if add_SA:
164 | sa = calculateScore(sub_mol)
165 | sa_as_pocket2mol = round((10-sa)/9, 2) # from pocket2mol
166 | SA_scores.append(sa_as_pocket2mol)
167 |
168 |
169 | mol_QED_scores.append(QED_scores)
170 | mol_SA_scores.append(SA_scores)
171 |
172 | mol_atom_ids = np.array(mol_atom_ids, dtype=object)
173 | mol_extension_ids = np.array(mol_extension_ids, dtype=object)
174 | mol_anchor_ids = np.array(mol_anchor_ids, dtype=object)
175 |
176 | is_single_frag = False
177 |
178 | return mol_pos, mol_onehot, mol_charges, mol_atom_ids, mol_extension_ids, mol_anchor_ids, is_single_frag, frag_smiles, frag_n_atoms, mol_QED_scores, mol_SA_scores, all_sub_mols
179 | else:
180 | print('using single fragment')
181 | is_single_frag = True
182 | mol_H = Chem.AddHs(mol, addCoords=True)
183 | all_sub_mols = [mol_H]
184 | mol_QED_score = [Chem.QED.qed(mol)]
185 | sa = calculateScore(mol)
186 | mol_SA_score = [round((10-sa)/9, 2)]
187 | return mol_pos, mol_onehot, mol_charges, None, None, None, is_single_frag, frag_smiles, frag_n_atoms, mol_QED_score, mol_SA_score, all_sub_mols
188 |
189 | except Exception as e:
190 | print(f'Error {e} for sdffile {sdffile}')
191 | return
--------------------------------------------------------------------------------
/data/CROSSDOCK/process_pockets.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from time import time
3 | import argparse
4 | import shutil
5 | import random
6 | import matplotlib.pyplot as plt
7 |
8 | from tqdm import tqdm
9 | import numpy as np
10 |
11 | from Bio.PDB import PDBParser
12 | from Bio.PDB.Polypeptide import is_aa, three_to_one
13 | from rdkit import Chem
14 |
15 | amino_acid_dict = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19}
16 | pocket_atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3} # only 4 atoms types for pocket
17 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9}
18 |
19 |
20 | def get_one_hot(atom, atoms_dict):
21 | one_hot = np.zeros(len(atoms_dict))
22 | one_hot[atoms_dict[atom]] = 1
23 | return one_hot
24 |
25 | def process_pocket(pdbfile, sdffile, atom_dict, pocket_atom_dict, dist_cutoff, remove_H=True, ca_only=False):
26 |
27 | pdb_struct = PDBParser(QUIET=True).get_structure('', pdbfile)
28 |
29 | try:
30 | ligand = Chem.SDMolSupplier(str(sdffile))[0]
31 | except:
32 | raise Exception(f'cannot read sdf mol ({sdffile})')
33 |
34 | # remove H atom if not in atom_dict, other atom types taht aren't allowed
35 | # should stay so that the entire ligand can be removed from the dataset
36 | lig_atoms = [a.GetSymbol() for a in ligand.GetAtoms()
37 | if (a.GetSymbol().capitalize() in atom_dict or a.element !='H')]
38 | lig_coords = np.array([list(ligand.GetConformer(0).GetAtomPosition(idx))
39 | for idx in range(ligand.GetNumAtoms())])
40 |
41 | # find interacting pocket residues based on distance cutoff
42 | pocket_residues = []
43 | for residue in pdb_struct[0].get_residues():
44 | res_coords = np.array([a.get_coord() for a in residue.get_atoms()])
45 | if is_aa(residue.get_resname(), standard=True) and \
46 | (((res_coords[:, None, :] - lig_coords[None, :, :]) ** 2).sum(-1)**0.5).min() < dist_cutoff:
47 | pocket_residues.append(residue)
48 |
49 |
50 | pocket_ids = [f'{res.parent.id}:{res.id[1]}' for res in pocket_residues]
51 |
52 | if ca_only:
53 | try:
54 | pocket_one_hot = []
55 | pocket_coords = []
56 | for res in pocket_residues:
57 | for atom in res.get_atoms():
58 | if atom.name == 'CA':
59 | pocket_one_hot.append(np.eye(1, len(amino_acid_dict),
60 | amino_acid_dict[three_to_one(res.get_resname())]).squeeze())
61 | pocket_coords.append(atom.coord)
62 | pocket_one_hot = np.stack(pocket_one_hot)
63 | pocket_coords = np.stack(pocket_coords)
64 | except KeyError as e:
65 | raise KeyError(f'{e} not in amino acid dict ({pdbfile}, {sdffile})')
66 | else:
67 | full_atoms = np.concatenate([np.array([atom.element for atom in res.get_atoms()]) for res in pocket_residues], axis=0)
68 | full_coords = np.concatenate([np.array([atom.coord for atom in res.get_atoms()]) for res in pocket_residues], axis=0)
69 | full_atoms_names = np.concatenate([np.array([atom.get_id() for atom in res.get_atoms()]) for res in pocket_residues], axis=0)
70 | pocket_AA = np.concatenate([([three_to_one(atom.get_parent().get_resname()) for atom in res.get_atoms()]) for res in pocket_residues], axis=0)
71 |
72 | # removing Hs if present
73 | if remove_H:
74 | h_mask = full_atoms == 'H'
75 | full_atoms = full_atoms[~h_mask]
76 | pocket_coords = full_coords[~h_mask]
77 | full_atoms_names = full_atoms_names[~h_mask]
78 | pocket_AA = pocket_AA[~h_mask]
79 |
80 | try:
81 | pocket_one_hot = []
82 | for i in range(len(full_atoms)):
83 | a = full_atoms[i]
84 | aa = pocket_AA[i]
85 | atom_onehot = np.eye(1, len(pocket_atom_dict), pocket_atom_dict[a.capitalize()]).squeeze()
86 | amino_onehot = np.eye(1, len(amino_acid_dict), amino_acid_dict[aa.capitalize()]).squeeze()
87 | is_backbone = 1 if full_atoms_names[i].capitalize() in ['N','CA','C','O'] else 0
88 | pocket_one_hot.append(np.concatenate([atom_onehot, amino_onehot, (is_backbone,)]))
89 |
90 | pocket_one_hot = np.stack(pocket_one_hot)
91 | except KeyError as e:
92 | raise KeyError(
93 | f'{e} not in atom dict ({pdbfile})')
94 |
95 | pocket_one_hot = np.array(pocket_one_hot)
96 | return pocket_one_hot, pocket_coords, lig_coords
--------------------------------------------------------------------------------
/data/CROSSDOCK/sascorer.py:
--------------------------------------------------------------------------------
1 | from rdkit import Chem
2 | from rdkit.Chem import rdMolDescriptors
3 | import pickle
4 |
5 | import math
6 | from collections import defaultdict
7 |
8 | import os.path as op
9 |
10 | _fscores = None
11 |
12 |
13 | def readFragmentScores(name='fpscores'):
14 | import gzip
15 | global _fscores
16 | # generate the full path filename:
17 | #if name == "fpscores":
18 | #name = op.join(op.dirname(__file__), name)
19 | #data = pickle.load(gzip.open('%s.pkl.gz' % name))
20 | data = pickle.load(gzip.open('/srv/home/mahdi.ghorbani/FragDiff/fpscores.pkl.gz'))
21 | outDict = {}
22 | for i in data:
23 | for j in range(1, len(i)):
24 | outDict[i[j]] = float(i[0])
25 | _fscores = outDict
26 |
27 |
28 | def numBridgeheadsAndSpiro(mol, ri=None):
29 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
30 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
31 | return nBridgehead, nSpiro
32 |
33 | def calculateScore(m):
34 | if _fscores is None:
35 | readFragmentScores()
36 |
37 | # fragment score
38 | fp = rdMolDescriptors.GetMorganFingerprint(m,
39 | 2) # <- 2 is the *radius* of the circular fingerprint
40 | fps = fp.GetNonzeroElements()
41 | score1 = 0.
42 | nf = 0
43 | for bitId, v in fps.items():
44 | nf += v
45 | sfp = bitId
46 | score1 += _fscores.get(sfp, -4) * v
47 | score1 /= nf
48 |
49 | # features score
50 | nAtoms = m.GetNumAtoms()
51 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
52 | ri = m.GetRingInfo()
53 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
54 | nMacrocycles = 0
55 | for x in ri.AtomRings():
56 | if len(x) > 8:
57 | nMacrocycles += 1
58 |
59 | sizePenalty = nAtoms**1.005 - nAtoms
60 | stereoPenalty = math.log10(nChiralCenters + 1)
61 | spiroPenalty = math.log10(nSpiro + 1)
62 | bridgePenalty = math.log10(nBridgeheads + 1)
63 | macrocyclePenalty = 0.
64 | # ---------------------------------------
65 | # This differs from the paper, which defines:
66 | # macrocyclePenalty = math.log10(nMacrocycles+1)
67 | # This form generates better results when 2 or more macrocycles are present
68 | if nMacrocycles > 0:
69 | macrocyclePenalty = math.log10(2)
70 |
71 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
72 |
73 | # correction for the fingerprint density
74 | # not in the original publication, added in version 1.1
75 | # to make highly symmetrical molecules easier to synthetise
76 | score3 = 0.
77 | if nAtoms > len(fps):
78 | score3 = math.log(float(nAtoms) / len(fps)) * .5
79 |
80 | sascore = score1 + score2 + score3
81 |
82 | # need to transform "raw" value into scale between 1 and 10
83 | min = -4.0
84 | max = 2.5
85 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
86 | # smooth the 10-end
87 | if sascore > 8.:
88 | sascore = 8. + math.log(sascore + 1. - 9.)
89 | if sascore > 10.:
90 | sascore = 10.0
91 | elif sascore < 1.:
92 | sascore = 1.0
93 |
94 | return sascore
95 |
96 |
97 | def processMols(mols):
98 | print('smiles\tName\tsa_score')
99 | for i, m in enumerate(mols):
100 | if m is None:
101 | continue
102 |
103 | s = calculateScore(m)
104 |
105 | smiles = Chem.MolToSmiles(m)
106 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
107 |
108 |
109 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/data/__init__.py
--------------------------------------------------------------------------------
/data/sascorer.py:
--------------------------------------------------------------------------------
1 | from rdkit import Chem
2 | from rdkit.Chem import rdMolDescriptors
3 | import pickle
4 |
5 | import math
6 |
7 | _fscores = None
8 |
9 | def readFragmentScores(name='fpscores'):
10 | import gzip
11 | global _fscores
12 | # generate the full path filename:
13 | #if name == "fpscores":
14 | #name = op.join(op.dirname(__file__), name)
15 | #data = pickle.load(gzip.open('%s.pkl.gz' % name))
16 | data = pickle.load(gzip.open('/srv/home/mahdi.ghorbani/FragDiff/fpscores.pkl.gz'))
17 | outDict = {}
18 | for i in data:
19 | for j in range(1, len(i)):
20 | outDict[i[j]] = float(i[0])
21 | _fscores = outDict
22 |
23 |
24 | def numBridgeheadsAndSpiro(mol, ri=None):
25 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
26 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
27 | return nBridgehead, nSpiro
28 |
29 | def calculateScore(m):
30 | if _fscores is None:
31 | readFragmentScores()
32 |
33 | # fragment score
34 | fp = rdMolDescriptors.GetMorganFingerprint(m,
35 | 2) # <- 2 is the *radius* of the circular fingerprint
36 | fps = fp.GetNonzeroElements()
37 | score1 = 0.
38 | nf = 0
39 | for bitId, v in fps.items():
40 | nf += v
41 | sfp = bitId
42 | score1 += _fscores.get(sfp, -4) * v
43 | score1 /= nf
44 |
45 | # features score
46 | nAtoms = m.GetNumAtoms()
47 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
48 | ri = m.GetRingInfo()
49 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
50 | nMacrocycles = 0
51 | for x in ri.AtomRings():
52 | if len(x) > 8:
53 | nMacrocycles += 1
54 |
55 | sizePenalty = nAtoms**1.005 - nAtoms
56 | stereoPenalty = math.log10(nChiralCenters + 1)
57 | spiroPenalty = math.log10(nSpiro + 1)
58 | bridgePenalty = math.log10(nBridgeheads + 1)
59 | macrocyclePenalty = 0.
60 | # ---------------------------------------
61 | # This differs from the paper, which defines:
62 | # macrocyclePenalty = math.log10(nMacrocycles+1)
63 | # This form generates better results when 2 or more macrocycles are present
64 | if nMacrocycles > 0:
65 | macrocyclePenalty = math.log10(2)
66 |
67 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
68 |
69 | # correction for the fingerprint density
70 | # not in the original publication, added in version 1.1
71 | # to make highly symmetrical molecules easier to synthetise
72 | score3 = 0.
73 | if nAtoms > len(fps):
74 | score3 = math.log(float(nAtoms) / len(fps)) * .5
75 |
76 | sascore = score1 + score2 + score3
77 |
78 | # need to transform "raw" value into scale between 1 and 10
79 | min = -4.0
80 | max = 2.5
81 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
82 | # smooth the 10-end
83 | if sascore > 8.:
84 | sascore = 8. + math.log(sascore + 1. - 9.)
85 | if sascore > 10.:
86 | sascore = 10.0
87 | elif sascore < 1.:
88 | sascore = 1.0
89 |
90 | return sascore
91 |
92 |
93 | def processMols(mols):
94 | print('smiles\tName\tsa_score')
95 | for i, m in enumerate(mols):
96 | if m is None:
97 | continue
98 |
99 | s = calculateScore(m)
100 |
101 | smiles = Chem.MolToSmiles(m)
102 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
103 |
104 |
105 |
--------------------------------------------------------------------------------
/fpscores.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/fpscores.pkl.gz
--------------------------------------------------------------------------------
/generate_pocket_molecules.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import os
4 | import argparse
5 |
6 | import torch
7 | import time
8 | import shutil
9 |
10 | from utils.volume_sampling import sample_discrete_number, bin_edges, prob_dist_df
11 | from utils.templates import get_one_hot, get_pocket
12 |
13 | from src.lightning_anchor_gnn import AnchorGNN_pl
14 | from src.lightning import AR_DDPM
15 | from scipy.spatial import distance
16 |
17 | from analysis.reconstruct_mol import reconstruct_from_generated
18 |
19 | from rdkit.Chem import rdmolfiles
20 | from sampling.sample_mols import generate_mols_for_pocket
21 |
22 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9}
23 | idx2atom = {0:'C', 1:'N', 2:'O', 3:'S', 4:'B', 5:'Br', 6:'Cl', 7:'P', 8:'I', 9:'F'}
24 | CROSSDOCK_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'B':5, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15}
25 | pocket_atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3} # only 4 atoms types for pocket
26 |
27 | vdws = {'C': 1.7, 'N': 1.55, 'O': 1.52, 'S': 1.8, 'B': 1.92, 'Br': 1.85, 'Cl': 1.75, 'P': 1.8, 'I': 1.98, 'F': 1.47}
28 |
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--results-path', type=str, default='results',
31 | help='path to save the results ')
32 | parser.add_argument('--data-path', action='store', type=str, default='/srv/home/mahdi.ghorbani/FragDiff/crossdock',
33 | help='path to the test data for generating molecules')
34 | parser.add_argument('--use-anchor-model', action='store_true', default=False,
35 | help='Whether to use an anchor prediction model')
36 | parser.add_argument('--anchor-model', type=str, default='anchor_model.ckpt',
37 | help='path to the anchor model. Note that for guidance, the anchor model should incorporate the conditionals')
38 | parser.add_argument('--n-samples', type=int, default=20,
39 | help='total number of ligands to generate per pocket')
40 | parser.add_argument('--exp-name', type=str, default='exp-1',
41 | help='name of the generation experiment')
42 | parser.add_argument('--diff-model', type=str, default='diff-model.ckpt',
43 | help='path to the diffusion model checkpoint')
44 | parser.add_argument('--device', type=str, default='cuda:0')
45 | parser.add_argument('--rejection-sampling', action='store_true', default=False, help='enable rejection sampling')
46 |
47 | if __name__ == '__main__':
48 | args = parser.parse_args()
49 | torch_device = args.device
50 | anchor_checkpoint = args.anchor_model
51 | data_path = args.data_path
52 | diff_model_checkpoint = args.diff_model
53 |
54 | model = AR_DDPM.load_from_checkpoint(diff_model_checkpoint, device=torch_device) # load diffusion model
55 | model = model.to(torch_device)
56 |
57 | if args.use_anchor_model is not None:
58 | anchor_model = AnchorGNN_pl.load_from_checkpoint(anchor_checkpoint, device=torch_device)
59 | anchor_model = anchor_model.to(torch_device)
60 | else:
61 | anchor_model = None # TODO: implement random anchor selection
62 |
63 | split = torch.load(data_path + '/' + 'split_by_name.pt')
64 | prefix = data_path + '/crossdocked_pocket10/'
65 |
66 | if not os.path.exists(args.results_path):
67 | print('creating results directory')
68 |
69 | save_dir = args.results_path + '/' + args.exp_name
70 | if not os.path.exists(save_dir):
71 | os.makedirs(save_dir, exist_ok=True)
72 |
73 | for n in range(100):
74 | prot_name = prefix + split['test'][n][0]
75 | lig_name = prefix + split['test'][n][1]
76 |
77 | pocket_onehot, pocket_coords, lig_coords, _ = get_pocket(prot_name, lig_name, atom_dict, pocket_atom_dict=pocket_atom_dict, dist_cutoff=7)
78 |
79 | # --------------- make a grid box around the pocket ----------------
80 | min_coords = pocket_coords.min(axis=0) - 2.5 #
81 | max_coords = pocket_coords.max(axis=0) + 2.5
82 |
83 | x_range = slice(min_coords[0], max_coords[0] + 1, 1.5) # spheres of radius 1.2 (vdw radius of H)
84 | y_range = slice(min_coords[1], max_coords[1] + 1, 1.5)
85 | z_range = slice(min_coords[2], max_coords[2] + 1, 1.5)
86 |
87 | grid = np.mgrid[x_range, y_range, z_range]
88 | grid_points = grid.reshape(3, -1).T # This transposes the grid to a list of coordinates
89 |
90 | # remove grids points not in 3.5A neighborhood of original ligand
91 | distances_mol = distance.cdist(grid_points, lig_coords)
92 | mask_mol = (distances_mol < 3.5).any(axis=1)
93 | filtered_mol_points = grid_points[mask_mol]
94 |
95 | # remove grid points that are close to the pocket
96 | pocket_distances = distance.cdist(filtered_mol_points, pocket_coords)
97 | mask_pocket = (pocket_distances < 2).any(axis=1)
98 | grids = filtered_mol_points[~mask_pocket]
99 |
100 | n_samples = args.n_samples
101 | max_mol_sizes = []
102 |
103 | fpocket_out = prot_name[:-4] + '_out'
104 | shutil.rmtree(fpocket_out, ignore_errors=True)
105 |
106 | #print('running fpocket!')
107 | #try:
108 | # run_fpocket(prot_name)
109 | #except:
110 | # print('Error in running fpocket! using random sizes')
111 |
112 | # NOTE: using original molecule coordinates for making the grid
113 |
114 | grids = torch.tensor(grids)
115 | all_grids = [] # list of grids
116 | for i in range(n_samples):
117 | all_grids.append(grids)
118 |
119 | pocket_vol = len(grids)
120 | #if os.path.exists(fpocket_out):
121 | # filename = prot_name[:-4] + '_out/pockets/pocket1_atm.pdb'
122 | # score, drug_score, pocket_volume = extract_values(filename)
123 | #else:
124 | # print('running fpocket!')
125 | # run_fpocket(prot_name)
126 | # filename = prot_name[:-4] + '_out/pockets/pocket1_atm.pdb'
127 | # score, drug_score, pocket_volume = extract_values(filename)
128 |
129 | #print('pocket_volume', pocket_volume)
130 |
131 | for i in range(n_samples):
132 | max_mol_sizes.append(sample_discrete_number(pocket_vol))
133 |
134 | pocket_onehot = torch.tensor(pocket_onehot).float()
135 | pocket_coords = torch.tensor(pocket_coords).float()
136 | lig_coords = torch.tensor(lig_coords).float()
137 | pocket_size = len(pocket_coords)
138 |
139 | t1 = time.time()
140 |
141 | max_mol_sizes = np.array(max_mol_sizes)
142 | print('maximum sizes for molecules', max_mol_sizes)
143 | x, h, mol_masks = generate_mols_for_pocket(n_samples=n_samples,
144 | num_frags=8,
145 | pocket_size=pocket_size,
146 | pocket_coords=pocket_coords,
147 | pocket_onehot=pocket_onehot,
148 | lig_coords=lig_coords,
149 | anchor_model=anchor_model,
150 | diff_model=model,
151 | device=torch_device,
152 | return_all=False,
153 | prot_path=prot_name,
154 | max_mol_sizes=max_mol_sizes,
155 | all_grids=all_grids,
156 | rejection_sampling=args.rejection_sampling,
157 | rejection_criteria='clash')
158 |
159 | x = x.cpu().numpy()
160 | h = h.cpu().numpy()
161 | mol_masks = mol_masks.cpu().cpu().numpy()
162 |
163 | # convert to SDF
164 | all_mols = []
165 | for k in range(len(x)):
166 | mask = mol_masks[k]
167 | h_mol = h[k]
168 | x_mol = x[k][mask.astype(np.bool_)]
169 |
170 | atom_inds = h_mol[mask.astype(np.bool_)].argmax(axis=1)
171 | atom_types = [idx2atom[x] for x in atom_inds]
172 | atomic_nums = [CROSSDOCK_CHARGES[i] for i in atom_types]
173 |
174 | try:
175 | mol_rec = reconstruct_from_generated(x_mol.tolist(), atomic_nums)
176 | all_mols.append(mol_rec)
177 | except:
178 | continue
179 |
180 | t2 = time.time()
181 | print('time to generate one is: ', (t2-t1)/n_samples)
182 | save_path = save_dir + '/' + 'pocket_' + str(n)
183 |
184 | # write sdf file of molecules
185 | with rdmolfiles.SDWriter(save_path + '_mols.sdf') as writer:
186 | for mol in all_mols:
187 | if mol:
188 | writer.write(mol)
189 |
190 | np.save(save_path + '_coords.npy', x)
191 | np.save(save_path + '_onehot.npy', h)
192 | np.save(save_path + '_mol_masks.npy', mol_masks)
193 |
--------------------------------------------------------------------------------
/notebooks/2z3h_out/2z3h.pml:
--------------------------------------------------------------------------------
1 | from pymol import cmd,stored
2 | load 2z3h_out.pdb
3 | #select pockets, resn STP
4 | stored.list=[]
5 | cmd.iterate("(resn STP)","stored.list.append(resi)") #read info about residues STP
6 | #print stored.list
7 | lastSTP=stored.list[-1] #get the index of the last residu
8 | hide lines, resn STP
9 |
10 | #show spheres, resn STP
11 | for my_index in range(1,int(lastSTP)+1): cmd.select("pocket"+str(my_index), "resn STP and resi "+str(my_index))
12 | for my_index in range(2,int(lastSTP)+2): cmd.color(my_index,"pocket"+str(my_index))
13 | for my_index in range(1,int(lastSTP)+1): cmd.show("spheres","pocket"+str(my_index))
14 | for my_index in range(1,int(lastSTP)+1): cmd.set("sphere_scale","0.3","pocket"+str(my_index))
15 | for my_index in range(1,int(lastSTP)+1): cmd.set("sphere_transparency","0.1","pocket"+str(my_index))
16 |
--------------------------------------------------------------------------------
/notebooks/2z3h_out/2z3h.tcl:
--------------------------------------------------------------------------------
1 | proc highlighting { colorId representation id selection } {
2 | puts "highlighting $id"
3 | mol representation $representation
4 | mol material "Diffuse"
5 | mol color $colorId
6 | mol selection $selection
7 | mol addrep $id
8 | }
9 |
10 | set id [mol new 2z3h_out.pdb type pdb]
11 | mol delrep top $id
12 | highlighting Name "Lines" $id "protein"
13 | highlighting Name "Licorice" $id "not protein and not resname STP"
14 | highlighting Element "NewCartoon" $id "protein"
15 | highlighting "ColorID 7" "VdW 0.4" $id "protein and occupancy>0.95"
16 | set id [mol new 2z3h_pockets.pqr type pqr]
17 | mol selection "all"
18 | mol material "Glass3"
19 | mol delrep top $id
20 | mol representation "QuickSurf 0.3"
21 | mol color ResId $id
22 | mol addrep $id
23 | highlighting Index "Points 1" $id "resname STP"
24 | display rendermode GLSL
25 |
--------------------------------------------------------------------------------
/notebooks/2z3h_out/2z3h_PYMOL.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | pymol 2z3h.pml
3 |
--------------------------------------------------------------------------------
/notebooks/2z3h_out/2z3h_VMD.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | vmd 2z3h_out.pdb -e 2z3h.tcl
3 |
--------------------------------------------------------------------------------
/notebooks/2z3h_out/2z3h_info.txt:
--------------------------------------------------------------------------------
1 | Pocket 1 :
2 | Score : 0.350
3 | Druggability Score : 0.871
4 | Number of Alpha Spheres : 70
5 | Total SASA : 176.938
6 | Polar SASA : 67.856
7 | Apolar SASA : 109.082
8 | Volume : 685.980
9 | Mean local hydrophobic density : 28.632
10 | Mean alpha sphere radius : 3.828
11 | Mean alp. sph. solvent access : 0.460
12 | Apolar alpha sphere proportion : 0.543
13 | Hydrophobicity score: 24.800
14 | Volume score: 4.000
15 | Polarity score: 11
16 | Charge score : 0
17 | Proportion of polar atoms: 37.736
18 | Alpha sphere density : 5.683
19 | Cent. of mass - Alpha Sphere max dist: 13.914
20 | Flexibility : 0.380
21 |
22 | Pocket 2 :
23 | Score : 0.025
24 | Druggability Score : 0.002
25 | Number of Alpha Spheres : 19
26 | Total SASA : 61.984
27 | Polar SASA : 16.094
28 | Apolar SASA : 45.890
29 | Volume : 221.120
30 | Mean local hydrophobic density : 8.000
31 | Mean alpha sphere radius : 4.123
32 | Mean alp. sph. solvent access : 0.669
33 | Apolar alpha sphere proportion : 0.474
34 | Hydrophobicity score: 40.250
35 | Volume score: 3.500
36 | Polarity score: 3
37 | Charge score : 0
38 | Proportion of polar atoms: 43.750
39 | Alpha sphere density : 2.046
40 | Cent. of mass - Alpha Sphere max dist: 5.129
41 | Flexibility : 0.306
42 |
43 |
--------------------------------------------------------------------------------
/notebooks/2z3h_out/2z3h_pockets.pqr:
--------------------------------------------------------------------------------
1 | HEADER
2 | HEADER This is a pqr format file writen by the programm fpocket.
3 | HEADER It contains all the pocket vertices found by fpocket.
4 | ATOM 1 C STP 1 27.832 35.872 104.021 0.00 3.69
5 | ATOM 2 C STP 1 27.954 35.449 103.706 0.00 4.07
6 | ATOM 3 O STP 1 25.370 34.420 103.783 0.00 4.01
7 | ATOM 4 O STP 1 27.724 35.111 103.740 0.00 4.18
8 | ATOM 5 C STP 1 24.801 33.618 103.765 0.00 3.82
9 | ATOM 6 C STP 1 21.596 27.720 100.725 0.00 3.87
10 | ATOM 7 O STP 1 23.353 34.262 111.342 0.00 3.64
11 | ATOM 8 O STP 1 23.638 33.858 112.345 0.00 3.55
12 | ATOM 9 O STP 1 26.104 36.231 100.922 0.00 3.67
13 | ATOM 10 O STP 1 26.179 36.489 100.894 0.00 3.56
14 | ATOM 11 O STP 1 26.176 36.334 100.914 0.00 3.63
15 | ATOM 12 O STP 1 24.838 36.294 100.826 0.00 3.60
16 | ATOM 13 O STP 1 24.946 36.155 100.851 0.00 3.66
17 | ATOM 14 C STP 1 28.730 35.819 103.728 0.00 3.92
18 | ATOM 15 C STP 1 26.814 35.345 102.503 0.00 3.69
19 | ATOM 16 O STP 1 26.373 34.656 103.564 0.00 3.96
20 | ATOM 17 O STP 1 27.492 34.995 103.579 0.00 4.14
21 | ATOM 18 C STP 1 22.814 30.645 104.838 0.00 4.62
22 | ATOM 19 C STP 1 22.340 30.784 103.285 0.00 3.67
23 | ATOM 20 O STP 1 21.530 29.239 101.689 0.00 3.78
24 | ATOM 21 O STP 1 21.569 29.345 101.655 0.00 3.71
25 | ATOM 22 C STP 1 21.659 27.794 101.499 0.00 4.33
26 | ATOM 23 C STP 1 21.686 27.719 101.468 0.00 4.35
27 | ATOM 24 C STP 1 28.707 33.958 103.385 0.00 3.72
28 | ATOM 25 C STP 1 21.363 27.197 100.342 0.00 3.70
29 | ATOM 26 O STP 1 22.966 34.609 109.925 0.00 3.95
30 | ATOM 27 O STP 1 22.195 33.840 106.546 0.00 4.43
31 | ATOM 28 O STP 1 22.184 33.812 106.570 0.00 4.45
32 | ATOM 29 O STP 1 24.182 34.117 108.716 0.00 3.47
33 | ATOM 30 O STP 1 23.499 33.586 105.919 0.00 4.32
34 | ATOM 31 O STP 1 23.414 33.518 106.010 0.00 4.37
35 | ATOM 32 O STP 1 24.878 34.569 104.793 0.00 4.02
36 | ATOM 33 O STP 1 26.129 35.033 105.102 0.00 3.79
37 | ATOM 34 O STP 1 25.995 34.250 107.350 0.00 3.45
38 | ATOM 35 O STP 1 24.386 35.825 100.260 0.00 3.46
39 | ATOM 36 C STP 1 29.851 32.841 102.748 0.00 3.60
40 | ATOM 37 C STP 1 24.701 33.358 103.737 0.00 3.76
41 | ATOM 38 C STP 1 23.960 32.096 104.019 0.00 3.78
42 | ATOM 39 C STP 1 29.026 34.688 103.966 0.00 4.07
43 | ATOM 40 C STP 1 29.002 34.303 103.714 0.00 3.89
44 | ATOM 41 C STP 1 29.868 32.882 102.753 0.00 3.63
45 | ATOM 42 C STP 1 29.863 32.869 102.753 0.00 3.62
46 | ATOM 43 C STP 1 29.847 32.902 102.767 0.00 3.62
47 | ATOM 44 C STP 1 29.437 35.468 103.873 0.00 3.94
48 | ATOM 45 C STP 1 29.764 34.817 104.028 0.00 4.11
49 | ATOM 46 C STP 1 29.899 34.795 104.009 0.00 4.11
50 | ATOM 47 C STP 1 31.693 34.249 105.339 0.00 3.73
51 | ATOM 48 C STP 1 31.333 34.241 105.797 0.00 3.49
52 | ATOM 49 C STP 1 28.775 36.216 103.232 0.00 3.53
53 | ATOM 50 C STP 1 28.934 36.340 103.276 0.00 3.49
54 | ATOM 51 C STP 1 29.044 36.027 103.659 0.00 3.84
55 | ATOM 52 C STP 1 29.068 36.054 103.658 0.00 3.83
56 | ATOM 53 C STP 1 29.483 35.777 103.760 0.00 3.87
57 | ATOM 54 C STP 1 29.517 35.981 103.735 0.00 3.80
58 | ATOM 55 C STP 1 29.637 35.694 103.759 0.00 3.86
59 | ATOM 56 O STP 1 29.648 35.983 103.725 0.00 3.76
60 | ATOM 57 C STP 1 30.753 33.263 102.941 0.00 3.77
61 | ATOM 58 C STP 1 30.509 33.161 102.917 0.00 3.72
62 | ATOM 59 O STP 1 31.028 33.632 103.177 0.00 3.90
63 | ATOM 60 C STP 1 30.931 33.710 103.264 0.00 3.90
64 | ATOM 61 O STP 1 24.895 33.716 107.469 0.00 3.69
65 | ATOM 62 C STP 1 31.689 34.254 105.307 0.00 3.74
66 | ATOM 63 O STP 1 31.145 34.459 104.665 0.00 3.82
67 | ATOM 64 O STP 1 31.094 33.808 103.577 0.00 3.82
68 | ATOM 65 C STP 1 31.722 34.224 105.318 0.00 3.73
69 | ATOM 66 C STP 1 31.828 34.074 105.272 0.00 3.68
70 | ATOM 67 O STP 1 32.238 34.042 105.383 0.00 3.47
71 | ATOM 68 O STP 1 31.885 34.027 105.266 0.00 3.65
72 | ATOM 69 O STP 1 31.107 33.629 103.227 0.00 3.87
73 | ATOM 70 O STP 1 31.614 33.384 103.513 0.00 3.58
74 | ATOM 71 C STP 2 32.182 41.416 112.136 0.00 3.53
75 | ATOM 72 O STP 2 27.717 42.628 109.921 0.00 3.86
76 | ATOM 73 O STP 2 30.860 41.923 113.152 0.00 4.55
77 | ATOM 74 O STP 2 27.790 42.644 110.150 0.00 3.99
78 | ATOM 75 C STP 2 28.097 42.799 111.011 0.00 4.54
79 | ATOM 76 C STP 2 27.958 42.614 110.279 0.00 4.05
80 | ATOM 77 C STP 2 28.056 42.649 110.510 0.00 4.19
81 | ATOM 78 C STP 2 27.863 42.546 109.704 0.00 3.69
82 | ATOM 79 O STP 2 28.973 41.271 111.405 0.00 3.78
83 | ATOM 80 C STP 2 28.081 42.825 111.069 0.00 4.58
84 | ATOM 81 C STP 2 28.002 42.857 111.062 0.00 4.58
85 | ATOM 82 C STP 2 28.121 42.766 111.074 0.00 4.55
86 | ATOM 83 O STP 2 28.418 41.993 111.115 0.00 4.06
87 | ATOM 84 O STP 2 30.452 41.910 112.989 0.00 4.53
88 | ATOM 85 O STP 2 30.112 40.927 112.555 0.00 3.67
89 | ATOM 86 C STP 2 29.290 41.849 111.865 0.00 4.19
90 | ATOM 87 O STP 2 29.160 41.443 111.627 0.00 3.92
91 | ATOM 88 O STP 2 29.599 41.639 111.963 0.00 4.11
92 | ATOM 89 O STP 2 29.489 41.459 111.826 0.00 3.98
93 | TER
94 | END
95 |
--------------------------------------------------------------------------------
/notebooks/2z3h_out/pockets/pocket1_atm.pdb:
--------------------------------------------------------------------------------
1 | HEADER
2 | HEADER This is a pdb format file writen by the programm fpocket.
3 | HEADER It represents the atoms contacted by the voronoi vertices of the pocket.
4 | HEADER
5 | HEADER Information about the pocket 1:
6 | HEADER 0 - Pocket Score : 0.3504
7 | HEADER 1 - Drug Score : 0.8715
8 | HEADER 2 - Number of alpha spheres : 70
9 | HEADER 3 - Mean alpha-sphere radius : 3.8276
10 | HEADER 4 - Mean alpha-sphere Solvent Acc. : 0.4599
11 | HEADER 5 - Mean B-factor of pocket residues : 0.3795
12 | HEADER 6 - Hydrophobicity Score : 24.8000
13 | HEADER 7 - Polarity Score : 11
14 | HEADER 8 - Amino Acid based volume Score : 4.0000
15 | HEADER 9 - Pocket volume (Monte Carlo) : 685.9797
16 | HEADER 10 - Pocket volume (convex hull) : 245.6618
17 | HEADER 11 - Charge Score : 0
18 | HEADER 12 - Local hydrophobic density Score : 28.6316
19 | HEADER 13 - Number of apolar alpha sphere : 38
20 | HEADER 14 - Proportion of apolar alpha sphere : 0.5429
21 | ATOM 314 CG ASN A 45 26.203 39.184 104.145 0.00 0.00 C 0
22 | ATOM 207 CG2 VAL A 29 28.584 36.516 107.579 0.00 0.00 C 0
23 | ATOM 316 ND2 ASN A 45 25.186 38.429 103.708 0.65 9.84 N 0
24 | ATOM 313 CB ASN A 45 27.324 39.433 103.184 0.00 0.00 C 0
25 | ATOM 384 SG CYS A 54 29.462 35.137 99.942 0.00 0.00 S 0
26 | ATOM 1772 OH TYR B 126 27.265 31.475 105.744 0.59 2.14 O 0
27 | ATOM 2146 CD2 PHE C 49 25.778 32.652 100.203 0.00 0.00 C 0
28 | ATOM 331 CD2 TYR A 47 21.731 35.473 102.458 0.00 0.00 C 0
29 | ATOM 333 CE2 TYR A 47 21.355 34.151 102.211 0.00 0.00 C 0
30 | ATOM 2130 CA HIS C 48 21.690 28.736 96.995 0.00 0.00 C 0
31 | ATOM 2123 CD1 TYR C 47 18.441 28.861 98.802 0.00 0.00 C 0
32 | ATOM 2136 CD2 HIS C 48 24.268 26.784 98.091 0.00 0.00 C 0
33 | ATOM 2139 N PHE C 49 22.569 30.572 98.302 0.00 0.00 N 0
34 | ATOM 541 OD1 ASN A 79 25.473 36.874 112.744 0.46 1.07 O 0
35 | ATOM 568 CZ ARG A 82 26.992 34.126 111.204 0.00 0.00 C 0
36 | ATOM 1793 CH2 TRP B 128 23.122 30.712 110.552 0.00 0.00 C 0
37 | ATOM 569 NH1 ARG A 82 26.778 35.316 110.679 0.00 0.00 N 0
38 | ATOM 567 NE ARG A 82 27.183 34.021 112.514 0.62 1.09 N 0
39 | ATOM 1792 CZ3 TRP B 128 23.775 30.358 111.747 0.00 0.00 C 0
40 | ATOM 371 O GLY A 52 25.612 37.637 97.573 0.39 2.14 O 0
41 | ATOM 380 CA CYS A 54 29.373 37.889 100.182 0.00 0.00 C 0
42 | ATOM 375 O PRO A 53 27.273 39.635 99.639 0.00 0.00 O 0
43 | ATOM 325 CA TYR A 47 22.035 38.536 101.149 0.00 0.00 C 0
44 | ATOM 205 CB VAL A 29 29.300 37.755 107.084 0.00 0.00 C 0
45 | ATOM 2148 CE2 PHE C 49 27.097 32.261 100.497 0.00 0.00 C 0
46 | ATOM 2145 CD1 PHE C 49 24.992 30.463 100.771 0.00 0.00 C 0
47 | ATOM 1791 CZ2 TRP B 128 23.785 30.730 109.352 0.00 0.00 C 0
48 | ATOM 2143 CB PHE C 49 23.308 32.193 100.039 0.00 0.00 C 0
49 | ATOM 335 OH TYR A 47 19.781 32.580 101.365 0.56 7.50 O 0
50 | ATOM 2125 CE1 TYR C 47 18.544 30.055 99.511 0.00 0.00 C 0
51 | ATOM 1771 CZ TYR B 126 28.422 30.835 105.379 0.00 0.00 C 0
52 | ATOM 2120 O TYR C 47 20.227 26.413 96.911 0.21 6.43 O 0
53 | ATOM 200 OG SER A 28 24.600 37.390 107.646 0.00 0.00 O 0
54 | ATOM 199 CB SER A 28 24.374 38.136 108.841 0.00 0.00 C 0
55 | ATOM 169 O GLU A 25 20.412 37.883 106.280 0.36 7.50 O 0
56 | ATOM 570 NH2 ARG A 82 27.008 33.049 110.420 0.46 2.19 N 0
57 | ATOM 2142 O PHE C 49 23.855 33.848 97.464 0.56 1.07 O 0
58 | ATOM 2149 CZ PHE C 49 27.363 30.972 100.930 0.00 0.00 C 0
59 | ATOM 608 CB CYS A 88 31.343 30.542 100.407 0.00 0.00 C 0
60 | ATOM 2144 CG PHE C 49 24.717 31.760 100.332 0.00 0.00 C 0
61 | ATOM 1769 CE1 TYR B 126 29.648 31.187 105.952 0.00 0.00 C 0
62 | ATOM 206 CG1 VAL A 29 30.811 37.488 106.970 0.00 0.00 C 0
63 | ATOM 397 OE1 GLU A 56 33.052 36.868 102.393 0.86 2.14 O 0
64 | ATOM 587 CB LEU A 85 32.872 32.683 108.516 0.00 0.00 C 0
65 | ATOM 522 CB ALA A 76 34.491 36.497 106.363 0.00 0.00 C 0
66 | ATOM 589 CD1 LEU A 85 30.637 33.611 109.160 0.00 0.00 C 0
67 | ATOM 383 CB CYS A 54 30.319 36.707 100.096 0.00 0.00 C 0
68 | ATOM 385 N ALA A 55 30.414 39.195 101.925 0.88 1.09 N 0
69 | ATOM 604 N CYS A 88 33.497 31.222 101.366 0.61 3.28 N 0
70 | ATOM 1767 CD1 TYR B 126 30.818 30.545 105.546 0.00 0.00 C 0
71 | ATOM 1788 NE1 TRP B 128 26.049 30.318 108.329 0.50 2.19 N 0
72 | ATOM 398 OE2 GLU A 56 34.522 35.447 103.175 0.50 2.14 O 0
73 | ATOM 594 O SER A 86 34.168 31.217 104.781 0.51 1.07 O 0
74 | TER
75 | END
76 |
--------------------------------------------------------------------------------
/notebooks/2z3h_out/pockets/pocket1_vert.pqr:
--------------------------------------------------------------------------------
1 | HEADER
2 | HEADER This is a pqr format file writen by the programm fpocket.
3 | HEADER It represent the voronoi vertices of a single pocket found by the
4 | HEADER algorithm.
5 | HEADER
6 | HEADER Information about the pocket 1:
7 | HEADER 0 - Pocket Score : 0.3504
8 | HEADER 1 - Drug Score : 0.8715
9 | HEADER 2 - Number of V. Vertices : 70
10 | HEADER 3 - Mean alpha-sphere radius : 3.8276
11 | HEADER 4 - Mean alpha-sphere SA : 0.4599
12 | HEADER 5 - Mean B-factor : 0.3795
13 | HEADER 6 - Hydrophobicity Score : 24.8000
14 | HEADER 7 - Polarity Score : 11
15 | HEADER 8 - Volume Score : 4.0000
16 | HEADER 9 - Real volume (approximation) : 685.9797
17 | HEADER 10 - Charge Score : 0
18 | HEADER 11 - Local hydrophobic density Score : 28.6316
19 | HEADER 12 - Number of apolar alpha sphere : 38
20 | HEADER 13 - Proportion of apolar alpha sphere : 0.5429
21 | ATOM 1 C STP 1 27.832 35.872 104.021 0.00 3.69
22 | ATOM 2 C STP 1 27.954 35.449 103.706 0.00 4.07
23 | ATOM 3 O STP 1 25.370 34.420 103.783 0.00 4.01
24 | ATOM 4 O STP 1 27.724 35.111 103.740 0.00 4.18
25 | ATOM 5 C STP 1 24.801 33.618 103.765 0.00 3.82
26 | ATOM 6 C STP 1 21.596 27.720 100.725 0.00 3.87
27 | ATOM 7 O STP 1 23.353 34.262 111.342 0.00 3.64
28 | ATOM 8 O STP 1 23.638 33.858 112.345 0.00 3.55
29 | ATOM 9 O STP 1 26.104 36.231 100.922 0.00 3.67
30 | ATOM 10 O STP 1 26.179 36.489 100.894 0.00 3.56
31 | ATOM 11 O STP 1 26.176 36.334 100.914 0.00 3.63
32 | ATOM 12 O STP 1 24.838 36.294 100.826 0.00 3.60
33 | ATOM 13 O STP 1 24.946 36.155 100.851 0.00 3.66
34 | ATOM 14 C STP 1 28.730 35.819 103.728 0.00 3.92
35 | ATOM 15 C STP 1 26.814 35.345 102.503 0.00 3.69
36 | ATOM 16 O STP 1 26.373 34.656 103.564 0.00 3.96
37 | ATOM 17 O STP 1 27.492 34.995 103.579 0.00 4.14
38 | ATOM 18 C STP 1 22.814 30.645 104.838 0.00 4.62
39 | ATOM 19 C STP 1 22.340 30.784 103.285 0.00 3.67
40 | ATOM 20 O STP 1 21.530 29.239 101.689 0.00 3.78
41 | ATOM 21 O STP 1 21.569 29.345 101.655 0.00 3.71
42 | ATOM 22 C STP 1 21.659 27.794 101.499 0.00 4.33
43 | ATOM 23 C STP 1 21.686 27.719 101.468 0.00 4.35
44 | ATOM 24 C STP 1 28.707 33.958 103.385 0.00 3.72
45 | ATOM 25 C STP 1 21.363 27.197 100.342 0.00 3.70
46 | ATOM 26 O STP 1 22.966 34.609 109.925 0.00 3.95
47 | ATOM 27 O STP 1 22.195 33.840 106.546 0.00 4.43
48 | ATOM 28 O STP 1 22.184 33.812 106.570 0.00 4.45
49 | ATOM 29 O STP 1 24.182 34.117 108.716 0.00 3.47
50 | ATOM 30 O STP 1 23.499 33.586 105.919 0.00 4.32
51 | ATOM 31 O STP 1 23.414 33.518 106.010 0.00 4.37
52 | ATOM 32 O STP 1 24.878 34.569 104.793 0.00 4.02
53 | ATOM 33 O STP 1 26.129 35.033 105.102 0.00 3.79
54 | ATOM 34 O STP 1 25.995 34.250 107.350 0.00 3.45
55 | ATOM 35 O STP 1 24.386 35.825 100.260 0.00 3.46
56 | ATOM 36 C STP 1 29.851 32.841 102.748 0.00 3.60
57 | ATOM 37 C STP 1 24.701 33.358 103.737 0.00 3.76
58 | ATOM 38 C STP 1 23.960 32.096 104.019 0.00 3.78
59 | ATOM 39 C STP 1 29.026 34.688 103.966 0.00 4.07
60 | ATOM 40 C STP 1 29.002 34.303 103.714 0.00 3.89
61 | ATOM 41 C STP 1 29.868 32.882 102.753 0.00 3.63
62 | ATOM 42 C STP 1 29.863 32.869 102.753 0.00 3.62
63 | ATOM 43 C STP 1 29.847 32.902 102.767 0.00 3.62
64 | ATOM 44 C STP 1 29.437 35.468 103.873 0.00 3.94
65 | ATOM 45 C STP 1 29.764 34.817 104.028 0.00 4.11
66 | ATOM 46 C STP 1 29.899 34.795 104.009 0.00 4.11
67 | ATOM 47 C STP 1 31.693 34.249 105.339 0.00 3.73
68 | ATOM 48 C STP 1 31.333 34.241 105.797 0.00 3.49
69 | ATOM 49 C STP 1 28.775 36.216 103.232 0.00 3.53
70 | ATOM 50 C STP 1 28.934 36.340 103.276 0.00 3.49
71 | ATOM 51 C STP 1 29.044 36.027 103.659 0.00 3.84
72 | ATOM 52 C STP 1 29.068 36.054 103.658 0.00 3.83
73 | ATOM 53 C STP 1 29.483 35.777 103.760 0.00 3.87
74 | ATOM 54 C STP 1 29.517 35.981 103.735 0.00 3.80
75 | ATOM 55 C STP 1 29.637 35.694 103.759 0.00 3.86
76 | ATOM 56 O STP 1 29.648 35.983 103.725 0.00 3.76
77 | ATOM 57 C STP 1 30.753 33.263 102.941 0.00 3.77
78 | ATOM 58 C STP 1 30.509 33.161 102.917 0.00 3.72
79 | ATOM 59 O STP 1 31.028 33.632 103.177 0.00 3.90
80 | ATOM 60 C STP 1 30.931 33.710 103.264 0.00 3.90
81 | ATOM 61 O STP 1 24.895 33.716 107.469 0.00 3.69
82 | ATOM 62 C STP 1 31.689 34.254 105.307 0.00 3.74
83 | ATOM 63 O STP 1 31.145 34.459 104.665 0.00 3.82
84 | ATOM 64 O STP 1 31.094 33.808 103.577 0.00 3.82
85 | ATOM 65 C STP 1 31.722 34.224 105.318 0.00 3.73
86 | ATOM 66 C STP 1 31.828 34.074 105.272 0.00 3.68
87 | ATOM 67 O STP 1 32.238 34.042 105.383 0.00 3.47
88 | ATOM 68 O STP 1 31.885 34.027 105.266 0.00 3.65
89 | ATOM 69 O STP 1 31.107 33.629 103.227 0.00 3.87
90 | ATOM 70 O STP 1 31.614 33.384 103.513 0.00 3.58
91 | TER
92 | END
93 |
--------------------------------------------------------------------------------
/notebooks/2z3h_out/pockets/pocket2_atm.pdb:
--------------------------------------------------------------------------------
1 | HEADER
2 | HEADER This is a pdb format file writen by the programm fpocket.
3 | HEADER It represents the atoms contacted by the voronoi vertices of the pocket.
4 | HEADER
5 | HEADER Information about the pocket 2:
6 | HEADER 0 - Pocket Score : 0.0248
7 | HEADER 1 - Drug Score : 0.0018
8 | HEADER 2 - Number of alpha spheres : 19
9 | HEADER 3 - Mean alpha-sphere radius : 4.1232
10 | HEADER 4 - Mean alpha-sphere Solvent Acc. : 0.6695
11 | HEADER 5 - Mean B-factor of pocket residues : 0.3063
12 | HEADER 6 - Hydrophobicity Score : 40.2500
13 | HEADER 7 - Polarity Score : 3
14 | HEADER 8 - Amino Acid based volume Score : 3.5000
15 | HEADER 9 - Pocket volume (Monte Carlo) : 221.1203
16 | HEADER 10 - Pocket volume (convex hull) : 3.2576
17 | HEADER 11 - Charge Score : 0
18 | HEADER 12 - Local hydrophobic density Score : 8.0000
19 | HEADER 13 - Number of apolar alpha sphere : 9
20 | HEADER 14 - Proportion of apolar alpha sphere : 0.4737
21 | ATOM 529 CG2 ILE A 77 33.796 38.445 113.141 0.00 0.00 C 0
22 | ATOM 527 CB ILE A 77 34.293 38.624 111.699 0.00 0.00 C 0
23 | ATOM 526 O ILE A 77 31.607 38.471 110.281 0.67 3.21 O 0
24 | ATOM 212 CB ALA A 30 32.061 42.598 108.814 0.00 0.00 C 0
25 | ATOM 307 CG1AVAL A 44 27.236 43.527 106.200 0.00 0.00 C 0
26 | ATOM 186 O TYR A 27 24.056 41.907 108.938 0.00 0.00 O 0
27 | ATOM 201 N VAL A 29 27.537 39.283 108.006 0.00 0.00 N 0
28 | ATOM 308 CG2AVAL A 44 29.222 44.754 107.075 0.00 0.00 C 0
29 | ATOM 534 O GLY A 78 30.299 37.434 113.651 0.17 4.29 O 0
30 | ATOM 197 C SER A 28 26.711 38.972 109.006 0.00 0.00 C 0
31 | ATOM 208 N ALA A 30 30.839 40.517 108.208 0.00 0.00 N 0
32 | ATOM 203 C VAL A 29 29.730 40.211 107.536 0.00 0.00 C 0
33 | ATOM 539 CB ASN A 79 26.403 39.070 113.085 0.00 0.00 C 0
34 | ATOM 198 O SER A 28 27.059 38.333 109.999 0.46 1.07 O 0
35 | ATOM 535 N ASN A 79 28.376 37.699 112.480 0.84 1.09 N 0
36 | ATOM 536 CA ASN A 79 27.682 38.376 113.568 0.00 0.00 C 0
37 | TER
38 | END
39 |
--------------------------------------------------------------------------------
/notebooks/2z3h_out/pockets/pocket2_vert.pqr:
--------------------------------------------------------------------------------
1 | HEADER
2 | HEADER This is a pqr format file writen by the programm fpocket.
3 | HEADER It represent the voronoi vertices of a single pocket found by the
4 | HEADER algorithm.
5 | HEADER
6 | HEADER Information about the pocket 2:
7 | HEADER 0 - Pocket Score : 0.0248
8 | HEADER 1 - Drug Score : 0.0018
9 | HEADER 2 - Number of V. Vertices : 19
10 | HEADER 3 - Mean alpha-sphere radius : 4.1232
11 | HEADER 4 - Mean alpha-sphere SA : 0.6695
12 | HEADER 5 - Mean B-factor : 0.3063
13 | HEADER 6 - Hydrophobicity Score : 40.2500
14 | HEADER 7 - Polarity Score : 3
15 | HEADER 8 - Volume Score : 3.5000
16 | HEADER 9 - Real volume (approximation) : 221.1203
17 | HEADER 10 - Charge Score : 0
18 | HEADER 11 - Local hydrophobic density Score : 8.0000
19 | HEADER 12 - Number of apolar alpha sphere : 9
20 | HEADER 13 - Proportion of apolar alpha sphere : 0.4737
21 | ATOM 1 C STP 2 32.182 41.416 112.136 0.00 3.53
22 | ATOM 2 O STP 2 27.717 42.628 109.921 0.00 3.86
23 | ATOM 3 O STP 2 30.860 41.923 113.152 0.00 4.55
24 | ATOM 4 O STP 2 27.790 42.644 110.150 0.00 3.99
25 | ATOM 5 C STP 2 28.097 42.799 111.011 0.00 4.54
26 | ATOM 6 C STP 2 27.958 42.614 110.279 0.00 4.05
27 | ATOM 7 C STP 2 28.056 42.649 110.510 0.00 4.19
28 | ATOM 8 C STP 2 27.863 42.546 109.704 0.00 3.69
29 | ATOM 9 O STP 2 28.973 41.271 111.405 0.00 3.78
30 | ATOM 10 C STP 2 28.081 42.825 111.069 0.00 4.58
31 | ATOM 11 C STP 2 28.002 42.857 111.062 0.00 4.58
32 | ATOM 12 C STP 2 28.121 42.766 111.074 0.00 4.55
33 | ATOM 13 O STP 2 28.418 41.993 111.115 0.00 4.06
34 | ATOM 14 O STP 2 30.452 41.910 112.989 0.00 4.53
35 | ATOM 15 O STP 2 30.112 40.927 112.555 0.00 3.67
36 | ATOM 16 C STP 2 29.290 41.849 111.865 0.00 4.19
37 | ATOM 17 O STP 2 29.160 41.443 111.627 0.00 3.92
38 | ATOM 18 O STP 2 29.599 41.639 111.963 0.00 4.11
39 | ATOM 19 O STP 2 29.489 41.459 111.826 0.00 3.98
40 | TER
41 | END
42 |
--------------------------------------------------------------------------------
/notebooks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/notebooks/__init__.py
--------------------------------------------------------------------------------
/sample_crossdock_mols.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import os
4 | import argparse
5 |
6 | from rdkit import Chem
7 | import torch
8 | import time
9 | import shutil
10 | from scipy.spatial import distance
11 | from rdkit.Chem import rdmolfiles
12 |
13 | from utils.volume_sampling import sample_discrete_number, bin_edges, prob_dist_df
14 | from utils.templates import get_one_hot, get_pocket
15 | from utils.templates import add_hydrogens, extract_hydrogen_coordinates, run_fpocket, extract_values
16 |
17 | from src.lightning_anchor_gnn import AnchorGNN_pl
18 | from src.lightning import AR_DDPM
19 | from src.const import prot_mol_lj_rm, CROSSDOCK_LJ_RM
20 | from src.noise import cosine_beta_schedule
21 |
22 | from analysis.reconstruct_mol import reconstruct_from_generated
23 | #from analysis.vina_docking import VinaDockingTask
24 | from sampling.sample_mols import generate_mols_for_pocket
25 |
26 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9}
27 | idx2atom = {0:'C', 1:'N', 2:'O', 3:'S', 4:'B', 5:'Br', 6:'Cl', 7:'P', 8:'I', 9:'F'}
28 | CROSSDOCK_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'B':5, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15}
29 | pocket_atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3} # only 4 atoms types for pocket
30 | vdws = {'C': 1.7, 'N': 1.55, 'O': 1.52, 'S': 1.8, 'B': 1.92, 'Br': 1.85, 'Cl': 1.75, 'P': 1.8, 'I': 1.98, 'F': 1.47}
31 |
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--results-path', type=str, default='results',
34 | help='path to save the results')
35 | parser.add_argument('--data-path', action='store', type=str, default='/srv/home/mahdi.ghorbani/FragDiff/crossdock',
36 | help='path to the test data for generating molecules')
37 | parser.add_argument('--anchor-model', type=str, default='anchor_model.ckpt',
38 | help='path to the anchor model. Note that for guidance, the anchor model should incorporate the conditionals')
39 | parser.add_argument('--n-samples', type=int, default=20,
40 | help='total number of ligands to generate per pocket')
41 | parser.add_argument('--exp-name', type=str, default='exp-1',
42 | help='name of the generation experiment')
43 | parser.add_argument('--diff-model', type=str, default='diff-model.ckpt',
44 | help='path to the diffusion model checkpoint')
45 | parser.add_argument('--device', type=str, default='cuda:0')
46 | parser.add_argument('--rejection-sampling', action='store_true', default=False, help='enable rejection sampling')
47 |
48 | if __name__ == '__main__':
49 | args = parser.parse_args()
50 | torch_device = args.device
51 | anchor_checkpoint = args.anchor_model
52 | data_path = args.data_path
53 | diff_model_checkpoint = args.diff_model
54 |
55 | add_H = True # adding hydrogens to protein for LJ computation
56 | model = AR_DDPM.load_from_checkpoint(diff_model_checkpoint, device=torch_device)
57 | model = model.to(torch_device)
58 |
59 | anchor_model = AnchorGNN_pl.load_from_checkpoint(anchor_checkpoint, device=torch_device)
60 | anchor_model = anchor_model.to(torch_device)
61 |
62 | split = torch.load(data_path + '/' + 'split_by_name.pt')
63 | prefix = data_path + '/crossdocked_pocket10/'
64 |
65 | if not os.path.exists(args.results_path):
66 | print('creating results directory')
67 |
68 | save_dir = args.results_path + '/' + args.exp_name
69 | if not os.path.exists(save_dir):
70 | os.makedirs(save_dir, exist_ok=True)
71 |
72 | for n in range(100):
73 | prot_name = prefix + split['test'][n][0]
74 | lig_name = prefix + split['test'][n][1]
75 |
76 | pocket_onehot, pocket_coords, lig_coords, _ = get_pocket(prot_name, lig_name, atom_dict, pocket_atom_dict=pocket_atom_dict, dist_cutoff=7)
77 |
78 | # --------------- make a grid box around the pocket ----------------
79 | min_coords = pocket_coords.min(axis=0) - 2.5 #
80 | max_coords = pocket_coords.max(axis=0) + 2.5
81 |
82 | x_range = slice(min_coords[0], max_coords[0] + 1, 1.5) # spheres of radius 1.5
83 | y_range = slice(min_coords[1], max_coords[1] + 1, 1.5)
84 | z_range = slice(min_coords[2], max_coords[2] + 1, 1.5)
85 |
86 | grid = np.mgrid[x_range, y_range, z_range]
87 | grid_points = grid.reshape(3, -1).T # This transposes the grid to a list of coordinates
88 |
89 | # remove grids points not in 3.5A neighborhood of original ligand
90 | distances_mol = distance.cdist(grid_points, lig_coords)
91 | mask_mol = (distances_mol < 3.5).any(axis=1)
92 | filtered_mol_points = grid_points[mask_mol]
93 |
94 | # remove grid points that are close to the pocket
95 | pocket_distances = distance.cdist(filtered_mol_points, pocket_coords)
96 | mask_pocket = (pocket_distances < 2).any(axis=1)
97 | grids = filtered_mol_points[~mask_pocket]
98 |
99 | n_samples = args.n_samples
100 | max_mol_sizes = []
101 |
102 | fpocket_out = prot_name[:-4] + '_out'
103 |
104 | shutil.rmtree(fpocket_out, ignore_errors=True)
105 |
106 | if add_H:
107 | add_hydrogens(prot_name)
108 | prot_name_with_H = prot_name[:-4] + '_H.pdb'
109 |
110 | H_coords = extract_hydrogen_coordinates(prot_name_with_H)
111 | H_coords = torch.tensor(H_coords).float().to(torch_device)
112 | #print('running fpocket!')
113 | #try:
114 | # run_fpocket(prot_name)
115 | #except:
116 | # print('Error in running fpocket! using random sizes')
117 | # NOTE: using original molecule coordinates for making the grid
118 |
119 | grids = torch.tensor(grids)
120 | all_grids = [] # list of grids
121 | all_H_coords = []
122 | for i in range(n_samples):
123 | all_grids.append(grids)
124 | all_H_coords.append(H_coords)
125 |
126 | pocket_vol = len(grids)
127 | #if os.path.exists(fpocket_out):
128 | # filename = prot_name[:-4] + '_out/pockets/pocket1_atm.pdb'
129 | # score, drug_score, pocket_volume = extract_values(filename)
130 | #else:
131 | # print('running fpocket!')
132 | # run_fpocket(prot_name)
133 | # filename = prot_name[:-4] + '_out/pockets/pocket1_atm.pdb'
134 | # score, drug_score, pocket_volume = extract_values(filename)
135 |
136 | #print('pocket_volume', pocket_volume)
137 |
138 | for i in range(n_samples):
139 | max_mol_sizes.append(sample_discrete_number(pocket_vol))
140 |
141 | pocket_onehot = torch.tensor(pocket_onehot).float()
142 | pocket_coords = torch.tensor(pocket_coords).float()
143 | lig_coords = torch.tensor(lig_coords).float()
144 | pocket_size = len(pocket_coords)
145 |
146 | t1 = time.time()
147 |
148 | max_mol_sizes = np.array(max_mol_sizes)
149 |
150 | print('maximum sizes for molecules', max_mol_sizes)
151 | prot_mol_lj_rm = torch.tensor(prot_mol_lj_rm).to(torch_device)
152 | mol_mol_lj_rm = torch.tensor(CROSSDOCK_LJ_RM).to(torch_device) / 100
153 |
154 | lj_weight_scheduler = cosine_beta_schedule(500, s=0.01, raise_to_power=2)
155 | weights = 1 - lj_weight_scheduler
156 | weights = np.clip(weights, a_min=0.1, a_max=1.)
157 | x, h, mol_masks = generate_mols_for_pocket(n_samples=n_samples,
158 | num_frags=8,
159 | pocket_size=pocket_size,
160 | pocket_coords=pocket_coords,
161 | pocket_onehot=pocket_onehot,
162 | lig_coords=lig_coords,
163 | anchor_model=anchor_model,
164 | diff_model=model,
165 | device=torch_device,
166 | return_all=False,
167 | max_mol_sizes=max_mol_sizes,
168 | all_grids=all_grids,
169 | rejection_sampling=args.rejection_sampling,
170 | lj_guidance=True,
171 | prot_mol_lj_rm=prot_mol_lj_rm,
172 | mol_mol_lj_rm=mol_mol_lj_rm,
173 | all_H_coords=all_H_coords,
174 | guidance_weights=weights)
175 |
176 | x = x.cpu().numpy()
177 | h = h.cpu().numpy()
178 | mol_masks = mol_masks.cpu().cpu().numpy()
179 |
180 | # convert to SDF
181 | all_mols = []
182 | for k in range(len(x)):
183 | mask = mol_masks[k]
184 | h_mol = h[k]
185 | x_mol = x[k][mask.astype(np.bool_)]
186 |
187 | atom_inds = h_mol[mask.astype(np.bool_)].argmax(axis=1)
188 | atom_types = [idx2atom[x] for x in atom_inds]
189 | atomic_nums = [CROSSDOCK_CHARGES[i] for i in atom_types]
190 |
191 | try:
192 | mol_rec = reconstruct_from_generated(x_mol.tolist(), atomic_nums)
193 | Chem.Kekulize(mol_rec)
194 | all_mols.append(mol_rec)
195 | except:
196 | continue
197 |
198 | t2 = time.time()
199 | print('time to generate one is: ', (t2-t1)/n_samples)
200 | save_path = save_dir + '/' + 'pocket_' + str(n)
201 |
202 | # write sdf file of molecules
203 | with rdmolfiles.SDWriter(save_path + '_mols.sdf') as writer:
204 | for mol in all_mols:
205 | if mol:
206 | writer.write(mol)
207 |
208 | np.save(save_path + '_coords.npy', x)
209 | np.save(save_path + '_onehot.npy', h)
210 | np.save(save_path + '_mol_masks.npy', mol_masks)
211 |
--------------------------------------------------------------------------------
/sampling/rejection_sampling.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from rdkit import Chem
4 | import torch
5 | from pathlib import Path
6 | from analysis.docking import calculate_qvina2_score
7 |
8 | from analysis.reconstruct_mol import reconstruct_from_generated
9 | from analysis.metrics import is_connected
10 |
11 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9}
12 | idx2atom = {0:'C', 1:'N', 2:'O', 3:'S', 4:'B', 5:'Br', 6:'Cl', 7:'P', 8:'I', 9:'F'}
13 | CROSSDOCK_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'B':5, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15}
14 | pocket_atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3} # only 4 atoms types for pocket
15 | vdws = {'C': 1.7, 'N': 1.55, 'O': 1.52, 'S': 1.8, 'B': 1.92, 'Br': 1.85, 'Cl': 1.75, 'P': 1.8, 'I': 1.98, 'F': 1.47}
16 |
17 | def compute_number_of_clashes(lig_x, lig_h, pocket_x, pocket_h, pocket_H_coords=None, tolerace=0.5, prot_mol_lj_rm=None):
18 | """
19 | lig_x and lig_h [n_atoms, 3] and [n_atoms] coordinates and atom types of the ligand (only extension atoms)
20 | pocket_x, pocket_h => [N_pocket, 3 or hp]
21 | pocket_H_coords -> [N_pocket, 3] coordinates of the pocket H atoms
22 | """
23 |
24 | dists = torch.cdist(lig_x, pocket_x, p=2) # [n_lig_atoms, n_pocket_atoms]
25 | dists = torch.where(dists==0, 1e-5, dists)
26 | inds_lig = torch.argmax(lig_h, dim=1) # [n_lig_atoms]
27 |
28 | inds_pocket = torch.argmax(pocket_h, dim=1).long() # [n_pocket_atoms]
29 | rm = prot_mol_lj_rm[inds_lig][:, inds_pocket] # [n_lig_atoms, n_pocket_atoms]
30 | clashes = ((dists + tolerace ) < rm).sum().item()
31 |
32 | dists_h = torch.cdist(lig_x, pocket_H_coords, p=2)
33 | inds_h = torch.ones(len(pocket_H_coords), device=lig_x.device).long() * 10
34 | rm_h = prot_mol_lj_rm[inds_lig][:, inds_h] # [n_lig_atoms, n_pocket_atoms]
35 | clashes_h = ((dists_h + tolerace ) < rm_h).sum().item()
36 |
37 | total_clashes = clashes + clashes_h
38 | return total_clashes
39 |
40 |
41 | def reject_sample(x, h, pocket_x, pocket_h, prot_path=None, rejection_criteria='clashes'):
42 | # NOTE: x and pocket_x must already be translated to COM
43 | # x :torch.Tensor -> [n_atoms, 3] coordiantes of a single molecule
44 | # h :list-> [n_atoms] atom types (eg. 'C', 'N') of a single molecule
45 | atomic_nums = [CROSSDOCK_CHARGES[a] for a in h]
46 | if rejection_criteria == 'qvina':
47 | try:
48 | mol_rec = reconstruct_from_generated(x.tolist(), atomic_nums)
49 | Chem.SanitizeMol(mol_rec)
50 |
51 | if not is_connected(mol_rec):
52 | m_frags = Chem.GetMolFrags(mol_rec, asMols=True, sanitizeFrags=False)
53 | mol_rec = max(m_frags, key=lambda x: x.GetNumAtoms())
54 |
55 | prot_pdbqt_file = prot_path[:-4] + '.pdbqt'
56 | out_sdf_file = 'mol.sdf'
57 | with Chem.SDWriter(out_sdf_file) as writer:
58 | writer.write(mol_rec)
59 | sdf_file = Path(out_sdf_file)
60 | if not os.path.exists('qvina-path'):
61 | os.mkdir('qvina-path')
62 | score_result = calculate_qvina2_score(prot_pdbqt_file, sdf_file, out_dir='qvina-path', return_rdmol=False, score_only=True)
63 | print('qvina score: ', score_result)
64 | files = os.listdir('qvina-path')
65 | for file in files:
66 | if file.endswith('.sdf') or file.endswith('.pdbqt'):
67 | os.remove(os.path.join('qvina-path', file))
68 |
69 | except:
70 | score_result = 100
71 |
72 | return score_result
73 |
74 | elif rejection_criteria == 'clashes':
75 | # pocket_x -> [n_atoms, 3] coordiantes of a single pocket
76 | # pocket_h -> [n_atoms] atom types (eg. 'C', 'N') of a single pocket
77 | # x -> [n_atoms, 3] coordiantes of a single molecule
78 | clashes, clashed_ids, clashed_pocket_ids, n_clashes = compute_number_of_clashes(pocket_x, x, pocket_h, h)
79 | return n_clashes
80 |
81 | def compute_lj(lig_x, lig_h, extension_mask, scaffold_mask, pocket_x, pocket_h, pocket_mask, prot_mol_lj_rm, all_H_coords, mol_mol_lj_rm=None):
82 | """ compute the LJ between protein and ligand
83 | lig_x: [B, N, 3]
84 | lig_h: [B, N, hf]
85 | """
86 |
87 | num_atoms = extension_mask.sum()
88 |
89 | # ------------- ligand - ligand LJ ----------
90 | mol_mask = (scaffold_mask.bool() | extension_mask.bool())
91 | N = mol_mask.sum()
92 |
93 | x_mol = lig_x[mol_mask] # [N_mol, 3]
94 | h_mol = lig_h[mol_mask] # [N_mol, hf]
95 |
96 | x = lig_x[extension_mask.bool()]
97 | h = lig_x[extension_mask.bool()]
98 |
99 | dists_mol = torch.cdist(x, x_mol, p=2) # [N_ext, N_mol]
100 |
101 | inds_mol = torch.argmax(h_mol, dim=1) # [N_mol]
102 | inds_ext = torch.argmax(h, dim=1) # [N_ext]
103 | rm_mol = mol_mol_lj_rm[inds_ext][:, inds_mol] # [N_ext, N_mol]
104 |
105 |
106 | dists_mol = torch.where(dists_mol==0.0, 1, dists_mol)
107 | rm_mol = torch.where(rm_mol==0.0, 1, rm_mol)
108 |
109 |
110 | dists_mol = torch.where(dists_mol < 0.5, 0.5, dists_mol) # clamp the distance to 0.1
111 | lj_mol = ((rm_mol / dists_mol) ** 12 - (rm_mol / dists_mol) ** 6) # [N_mol, N_mol]
112 |
113 | lj_lig_lig = lj_mol.sum() / num_atoms
114 |
115 | # --------------- compute the LJ between protein and ligand --------------
116 |
117 |
118 | pocket_x = pocket_x[pocket_mask.bool()] # [N_p, 3]
119 | pocket_h = pocket_h[pocket_mask.bool()][:, :4] # [N_p, hf]
120 | h_coords = all_H_coords # [N_p, 3]
121 |
122 | # --------------- compute the LJ between protein and ligand --------------
123 | dists = torch.cdist(x, pocket_x, p=2)
124 | inds_lig = torch.argmax(h, dim=1) # [N_l]
125 | inds_pocket = torch.argmax(pocket_h, dim=1).long() # [N_p]
126 |
127 | rm = prot_mol_lj_rm[inds_lig][:, inds_pocket] # [N_l, N_p]
128 | lj = ((rm / dists) ** 12 - (rm / dists) ** 6) # [N_l, N_p]
129 | lj[torch.isnan(lj)] = 0
130 |
131 | # ------------- compute the loss for h atoms ----------------
132 | dists_h = torch.cdist(x, h_coords, p=2)
133 | #dists_h = torch.where(dists_h<0.5, 0.5, dists_h)
134 | inds_H = torch.ones(len(h_coords), device=x.device).long() * 10 # index of H is 10 in the table
135 | rm_h = prot_mol_lj_rm[inds_lig][:, inds_H]
136 | lj_h = ((rm_h / dists_h) ** 12 - (rm_h / dists_h) ** 6) # [N_l, N_p]
137 |
138 | lj_h[torch.isnan(lj_h)] = 0 # remove nan values
139 |
140 | lj = lj.sum()
141 | lj_h = lj_h.sum()
142 |
143 | lj_prot_lig = (lj + lj_h) / num_atoms
144 | return lj_prot_lig, lj_lig_lig
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keiserlab/autofragdiff/84f0885cb12e6ac4abc7558870f8d304c78c8a38/src/__init__.py
--------------------------------------------------------------------------------
/src/anchor_gnn.py:
--------------------------------------------------------------------------------
1 | from src.egnn import GCL, GaussianSmearing
2 | import torch.nn as nn
3 | import torch
4 | from src.egnn import coord2diff
5 |
6 | class MaskedBCEWithLogitsLoss2(torch.nn.Module):
7 | def __init__(self):
8 | super(MaskedBCEWithLogitsLoss2, self).__init__()
9 | self.loss = torch.nn.BCEWithLogitsLoss(reduction='none')
10 |
11 | def forward(self, input, target, scaffold_mask, pocket_mask, is_first_frag_mask):
12 | # TODO:
13 | """
14 | if_first_frag_mask -> mask for the first fragment (if the fragment is the first the mask is 1)
15 | """
16 | masked_loss = self.loss(input, target)
17 | masked_loss_1 = masked_loss * (~is_first_frag_mask.bool()) # only for parts that are not the first fragment
18 | masked_loss_1 = masked_loss_1 * scaffold_mask.float() # only for the scaffold atoms
19 |
20 | masked_loss_2 = masked_loss * is_first_frag_mask.bool() # only for parts that are the first fragment
21 | masked_loss_2 = masked_loss_2 * pocket_mask.float() # only for the pocket atoms
22 |
23 | total_masked_loss = (masked_loss_1.sum() / scaffold_mask.sum().float()) + (masked_loss_2.sum() / pocket_mask.sum().float())
24 | return total_masked_loss
25 |
26 | class MaskedBCEWithLogitsLoss(torch.nn.Module):
27 | def __init__(self):
28 | super(MaskedBCEWithLogitsLoss, self).__init__()
29 | self.loss = torch.nn.BCEWithLogitsLoss(reduction='none')
30 |
31 | def forward(self, input, target, mask):
32 | masked_loss = self.loss(input, target)
33 | masked_loss = masked_loss * mask.float()
34 | return masked_loss.sum() / mask.sum().float()
35 |
36 | class AnchorGNNPocket(nn.Module):
37 | def __init__(self,
38 | lig_nf, # ligand node features
39 | pocket_nf, # pocket node features
40 | joint_nf, # joint number of features
41 | hidden_nf,
42 | out_node_nf,
43 | n_layers,
44 | normalization,
45 | attention=True,
46 | normalization_factor=100,
47 | aggregation_method='sum',
48 | dist_cutoff=7,
49 | gaussian_expansion=False,
50 | num_gaussians=16,
51 | edge_cutoff_ligand=None,
52 | edge_cutoff_pocket=4.5,
53 | edge_cutoff_interaction=4.5
54 | ):
55 |
56 | super(AnchorGNNPocket, self).__init__()
57 |
58 | #in_node_nf = in_node_nf + context_node_nf # adding the context pocket
59 | if gaussian_expansion:
60 | self.gauss_exp = GaussianSmearing(start=0., stop=7., num_gaussians=16)
61 | in_edge_nf = num_gaussians
62 | else:
63 | in_edge_nf = 1
64 |
65 | self.hidden_nf = hidden_nf
66 | self.out_node_nf = out_node_nf
67 | self.n_layers = n_layers
68 | self.normalization = normalization
69 | self.attention = attention
70 | self.dist_cutoff = dist_cutoff
71 | self.normalization_factor = normalization_factor
72 | self.gaussian_expansion = gaussian_expansion
73 | self.num_gaussians = num_gaussians
74 | self.joint_nf = joint_nf
75 | self.edge_cutoff_l = edge_cutoff_ligand
76 | self.edge_cutoff_p = edge_cutoff_pocket
77 | self.edge_cutoff_i = edge_cutoff_interaction
78 |
79 | self.mol_encoder = nn.Sequential(
80 | nn.Linear(lig_nf, joint_nf),
81 | nn.SiLU()
82 | )
83 |
84 | self.pocket_encoder = nn.Sequential(
85 | nn.Linear(pocket_nf, joint_nf),
86 | nn.SiLU()
87 | )
88 |
89 | self.embed_both = nn.Linear(joint_nf, self.hidden_nf)
90 |
91 | self.gcl1 = GCL(
92 | input_nf=self.hidden_nf,
93 | output_nf=self.hidden_nf,
94 | hidden_nf=self.hidden_nf,
95 | normalization_factor=normalization_factor,
96 | aggregation_method=aggregation_method,
97 | edges_in_d=in_edge_nf,
98 | activation=nn.ReLU(),
99 | attention=attention,
100 | normalization=normalization
101 | )
102 |
103 | layers = []
104 | layers.append(self.gcl1)
105 | for i in range(n_layers - 1):
106 | layer = GCL(
107 | input_nf=self.hidden_nf,
108 | output_nf=self.hidden_nf,
109 | hidden_nf=self.hidden_nf,
110 | normalization_factor=normalization_factor,
111 | aggregation_method='sum',
112 | edges_in_d=in_edge_nf,
113 | activation=nn.ReLU(),
114 | attention=attention,
115 | normalization=normalization
116 | )
117 | layers.append(layer)
118 |
119 | self.gcl_layers = nn.ModuleList(layers)
120 | self.embedding_out = nn.Linear(self.hidden_nf, self.out_node_nf)
121 | self.lin_out = nn.Linear(self.out_node_nf, 1)
122 | self.act = nn.ReLU()
123 | #self.bce_loss = MaskedBCEWithLogitsLoss()
124 |
125 | def forward(self, mol_x, mol_h, node_mask, pocket_x, pocket_h, pocket_mask):
126 | """
127 | input:
128 | mol_x: [B, Ns, 3] coordinates of scaffold
129 | mol_h: [B, Ns, nf] onehot of scaffold
130 | node_mask: [B, Ns] masking on the scaffold
131 | pocket_x: [B, Np] coordinates of pocket
132 | pocket_h: [B, NP, nf_p] onehot of pocket
133 | pocket_mask: [B, Np] masking on pocket atoms
134 | output:
135 | h_out: [B, Ns, 1] logits for the scaffold atoms
136 | """
137 | bs, n_lig_nodes = mol_x.shape[0], mol_x.shape[1]
138 | n_pocket_nodes = pocket_x.shape[1]
139 | node_mask = node_mask.squeeze()
140 |
141 | N = n_lig_nodes + n_pocket_nodes
142 | mol_x = mol_x[node_mask.bool()] # [N_l, 3]
143 | mol_h = mol_h[node_mask.bool()] # [N_l, nf]
144 |
145 | pocket_x = pocket_x[pocket_mask.bool()] # [N_p, 3]
146 | pocket_h = pocket_h[pocket_mask.bool()] # [N_p, nf]
147 |
148 | mol_h = self.mol_encoder(mol_h) # [N_l, joint_nf]
149 | pocket_h = self.pocket_encoder(pocket_h) # [N_p, joint_nf]
150 |
151 | h = torch.cat([mol_h, pocket_h], dim=0) # [N_l+N_p, joint_nf]
152 | x = torch.cat([mol_x, pocket_x], dim=0) # [N_l+N_p, 3]
153 |
154 | batch_mask_ligand = self.get_batch_mask(node_mask.bool(), device=x.device) # [N_l]
155 | batch_mask_pocket = self.get_batch_mask(pocket_mask.bool(), device=x.device) # [N_p]
156 |
157 | edges = self.get_edges_cutoff(batch_mask_ligand, batch_mask_pocket, mol_x, pocket_x) # [2, num_edges]
158 |
159 | h = self.embed_both(h) # [N_l+N_p, hidden_nf]
160 |
161 | distances, _ = coord2diff(x, edges)
162 | if self.gaussian_expansion:
163 | distances = self.gauss_exp(distances)
164 |
165 | for gcl in self.gcl_layers:
166 | h, _ = gcl(h, edges, edge_attr=distances, node_mask=None, edge_mask=None) # [N_l+N_p, hidden_nf]
167 |
168 | h_atoms = h[:len(batch_mask_ligand)] # [N_l, hidden_nf]
169 | h_atoms = self.act(self.embedding_out(h_atoms)) # [N_l, out_node_nf]
170 | h_out = self.lin_out(h_atoms) # [N_l, 1]
171 |
172 | # convert to batch
173 | num_atoms = node_mask.sum(dim=1).int() # [B]
174 | reshaped_h_out = torch.zeros(bs, n_lig_nodes, 1, dtype=h_out.dtype).to(h_out.device)
175 | positions = torch.zeros_like(batch_mask_ligand).to(h_out.device)
176 | for idx in range(bs):
177 | positions[batch_mask_ligand == idx] = torch.arange(num_atoms[idx]).to(x.device)
178 | reshaped_h_out[batch_mask_ligand, positions] = h_out # [B, n_lig_nodes, 1]
179 |
180 | return reshaped_h_out
181 |
182 | def get_edges_cutoff(self, batch_mask_ligand, batch_mask_pocket, x_ligand, x_pocket):
183 |
184 | adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :]
185 | adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :]
186 | adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :]
187 |
188 | if self.edge_cutoff_l is not None:
189 | adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l)
190 |
191 | if self.edge_cutoff_p is not None:
192 | adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p)
193 |
194 | if self.edge_cutoff_i is not None:
195 | adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i)
196 |
197 | adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1),
198 | torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0)
199 | edges = torch.stack(torch.where(adj), dim=0)
200 |
201 | return edges
202 |
203 | @staticmethod
204 | def get_batch_mask(mask, device):
205 | n_nodes = mask.float().sum(dim=1).int()
206 | batch_size = mask.shape[0]
207 | batch_mask = torch.cat([torch.ones(n_nodes[i]) * i for i in range(batch_size)]).long().to(device)
208 | return batch_mask
--------------------------------------------------------------------------------
/src/conv_layer.py:
--------------------------------------------------------------------------------
1 |
2 | # Following DiffHopp implementation of GVP https://github.com/jostorge/diffusion-hopping/tree/main
3 |
4 | from abc import ABC
5 | from functools import partial
6 | from typing import Optional, Tuple, Union
7 |
8 | import torch
9 | from torch import nn as nn
10 | from torch.nn import functional as F
11 | from torch_geometric.nn import MessagePassing
12 |
13 | from src.dropout import GVPDropout
14 | from src.gvp import GVP, s_V
15 | from src.layer_norm import GVPLayerNorm
16 |
17 | from abc import ABC
18 | from functools import partial
19 | from typing import Optional, Tuple, Union
20 |
21 | import torch
22 | from torch import nn as nn
23 | from torch.nn import functional as F
24 | from torch_geometric.nn import MessagePassing
25 |
26 | class GVPMessagePassing(MessagePassing, ABC):
27 | def __init__(
28 | self,
29 | in_dims: Tuple[int, int],
30 | out_dims: Tuple[int, int],
31 | edge_dims: Tuple[int, int],
32 | hidden_dims: Optional[Tuple[int, int]] = None,
33 | activations=(F.relu, torch.sigmoid),
34 | vector_gate: bool = False,
35 | attention: bool = True,
36 | aggr: str = "add",
37 | normalization_factor: float = 1.0,
38 | ):
39 | super().__init__(aggr)
40 | if hidden_dims is None:
41 | hidden_dims = out_dims
42 |
43 | in_scalar, in_vector = in_dims
44 | hidden_scalar, hidden_vector = hidden_dims
45 |
46 | edge_scalar, edge_vector = edge_dims
47 |
48 | self.out_scalar, self.out_vector = out_dims
49 | self.in_vector = in_vector
50 | self.hidden_scalar = hidden_scalar
51 | self.hidden_vector = hidden_vector
52 | self.normalization_factor = normalization_factor
53 |
54 | GVP_ = partial(GVP, activations=activations, vector_gate=vector_gate)
55 | self.edge_gvps = nn.Sequential(
56 | GVP_(
57 | (2 * in_scalar + edge_scalar, 2 * in_vector + edge_vector),
58 | hidden_dims,
59 | ),
60 | GVP_(hidden_dims, hidden_dims),
61 | GVP_(hidden_dims, out_dims, activations=(None, None)),
62 | )
63 |
64 | self.attention = attention
65 | if attention:
66 | self.attention_gvp = GVP_(
67 | out_dims,
68 | (1, 0),
69 | activations=(torch.sigmoid, None),
70 | )
71 |
72 | def forward(self, x: s_V, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> s_V:
73 | s, V = x
74 | v_dim = V.shape[-1]
75 | V = torch.flatten(V, start_dim=-2, end_dim=-1)
76 | return self.propagate(edge_index, s=s, V=V, edge_attr=edge_attr, v_dim=v_dim)
77 |
78 | def message(self, s_i, s_j, V_i, V_j, edge_attr, v_dim):
79 | V_i = V_i.view(*V_i.shape[:-1], self.in_vector, v_dim)
80 | V_j = V_j.view(*V_j.shape[:-1], self.in_vector, v_dim)
81 | edge_scalar, edge_vector = edge_attr
82 |
83 | s = torch.cat([s_i, s_j, edge_scalar], dim=-1)
84 | V = torch.cat([V_i, V_j, edge_vector], dim=-2)
85 | s, V = self.edge_gvps((s, V))
86 |
87 | if self.attention:
88 | att = self.attention_gvp((s, V))
89 | s, V = att * s, att[..., None] * V
90 | return self._combine(s, V)
91 |
92 | def update(self, aggr_out: torch.Tensor) -> s_V:
93 | s_aggr, V_aggr = self._split(aggr_out, self.out_scalar, self.out_vector)
94 | if self.aggr == "add" or self.aggr == "sum":
95 | s_aggr = s_aggr / self.normalization_factor
96 | V_aggr = V_aggr / self.normalization_factor
97 | return s_aggr, V_aggr
98 |
99 | @staticmethod
100 | def _combine(s, V) -> torch.Tensor:
101 | V = torch.flatten(V, start_dim=-2, end_dim=-1)
102 | return torch.cat([s, V], dim=-1)
103 |
104 | @staticmethod
105 | def _split(s_V: torch.Tensor, scalar: int, vector: int) -> s_V:
106 | s = s_V[..., :scalar]
107 | V = s_V[..., scalar:]
108 | V = V.view(*V.shape[:-1], vector, -1)
109 | return s, V
110 |
111 | def reset_parameters(self):
112 | for gvp in self.edge_gvps:
113 | gvp.reset_parameters()
114 | if self.attention:
115 | self.attention_gvp.reset_parameters()
116 |
117 | class GVPConvLayer(GVPMessagePassing, ABC):
118 | def __init__(
119 | self,
120 | node_dims: Tuple[int, int],
121 | edge_dims: Tuple[int, int],
122 | drop_rate: float = 0.0,
123 | activations=(F.relu, torch.sigmoid),
124 | vector_gate: bool = False,
125 | residual: bool = True,
126 | attention: bool = True,
127 | aggr: str = "add",
128 | normalization_factor: float = 1.0,
129 | ):
130 | super().__init__(
131 | node_dims,
132 | node_dims,
133 | edge_dims,
134 | hidden_dims=node_dims,
135 | activations=activations,
136 | vector_gate=vector_gate,
137 | attention=attention,
138 | aggr=aggr,
139 | normalization_factor=normalization_factor,
140 | )
141 | self.residual = residual
142 | self.drop_rate = drop_rate
143 | GVP_ = partial(GVP, activations=activations, vector_gate=vector_gate)
144 | self.norm = nn.ModuleList([GVPLayerNorm(node_dims) for _ in range(2)])
145 | self.dropout = nn.ModuleList([GVPDropout(drop_rate) for _ in range(2)])
146 |
147 | self.ff_func = nn.Sequential(
148 | GVP_(node_dims, node_dims),
149 | GVP_(node_dims, node_dims, activations=(None, None)),
150 | )
151 | self.residual = residual
152 |
153 | def forward(
154 | self,
155 | x: Union[s_V, torch.Tensor],
156 | edge_index: torch.Tensor,
157 | edge_attr: torch.Tensor,
158 | ) -> s_V:
159 |
160 | s, V = super().forward(x, edge_index, edge_attr)
161 | if self.residual:
162 | s, V = self.dropout[0]((s, V))
163 | s, V = x[0] + s, x[1] + V
164 | s, V = self.norm[0]((s, V))
165 |
166 | x = (s, V)
167 | s, V = self.ff_func(x)
168 |
169 | if self.residual:
170 | s, V = self.dropout[1]((s, V))
171 | s, V = s + x[0], V + x[1]
172 | s, V = self.norm[1]((s, V))
173 |
174 | return s, V
--------------------------------------------------------------------------------
/src/dropout.py:
--------------------------------------------------------------------------------
1 | # Following diffhopp implementation of GVP https://github.com/jostorge/diffusion-hopping/tree/main
2 |
3 | from typing import Union, Tuple
4 |
5 | import torch
6 | from torch import nn as nn
7 |
8 | s_V = Tuple[torch.Tensor, torch.Tensor]
9 |
10 | class GVPDropout(nn.Module):
11 | def __init__(self, p: float=0.5) -> None:
12 | super().__init__()
13 | self.dropout_features = nn.Dropout(p)
14 | self.dropout_vector = nn.Dropout1d(p)
15 |
16 | def forward(self, x: Union[torch.Tensor, s_V]) -> Union[torch.Tensor, s_V]:
17 | if isinstance(x, torch.Tensor):
18 | return self.dropout_features(x)
19 |
20 | s, V = x
21 | s = self.dropout_features(s)
22 | V = self.dropout_vector(V)
23 | return s, V
--------------------------------------------------------------------------------
/src/dynamics_gvp.py:
--------------------------------------------------------------------------------
1 | # Following DiffHopp implementation of GVP: https://github.com/jostorge/diffusion-hopping/tree/main
2 |
3 | import torch.nn as nn
4 | import torch
5 | import numpy as np
6 | from src.gvp_model import GVPNetwork
7 |
8 | class DynamicsWithPockets(nn.Module):
9 | def __init__(
10 | self, n_dims, lig_nf, pocket_nf, context_node_nf=3, joint_nf=32, hidden_nf=128, activation=nn.SiLU(),
11 | n_layers=4, attention=False, condition_time=True, tanh=False, normalization_factor=100, model='gvp',
12 | centering=False, edge_cutoff=7, edge_cutoff_interaction=4.5, edge_cutoff_pocket=4.5, edge_cutoff_ligand=None
13 | ):
14 | super().__init__()
15 |
16 | self.edge_cutoff_l = edge_cutoff_ligand
17 | self.edge_cutoff_p = edge_cutoff_pocket
18 | self.edge_cutoff_i = edge_cutoff_interaction
19 |
20 | self.atom_encoder = nn.Sequential(
21 | nn.Linear(lig_nf, joint_nf),
22 | )
23 |
24 | self.pocket_encoder = nn.Sequential(
25 | nn.Linear(pocket_nf, joint_nf),
26 | )
27 |
28 | self.atom_decoder = nn.Sequential(
29 | nn.Linear(joint_nf, lig_nf),
30 | )
31 |
32 | if condition_time:
33 | dynamics_node_nf = joint_nf + 1
34 | else:
35 | print('Warning: dynamics moddel is _not_ conditioned on time')
36 | dynamics_node_nf = joint_nf
37 |
38 | self.dynamics = GVPNetwork(
39 | in_dims=(dynamics_node_nf + context_node_nf, 0), # (scalar_features, vector_features)
40 | out_dims=(joint_nf, 1),
41 | hidden_dims=(hidden_nf, hidden_nf//2),
42 | vector_gate=True,
43 | num_layers=n_layers,
44 | attention=attention,
45 | normalization_factor=normalization_factor,
46 | ) # other parameters are default
47 |
48 | self.n_dims = n_dims
49 | self.condition_time = condition_time
50 | self.centering = centering
51 | self.context_node_nf = context_node_nf
52 | self.edge_cutoff = edge_cutoff
53 | self.model = model
54 |
55 | def forward(self, t, xh, pocket_xh, extension_mask, scaffold_mask, anchors, pocket_anchors, pocket_mask):
56 | """
57 | input:
58 | t: timestep: [B]
59 | xh: ligand atoms (noised) [B, N_l, h_l+3]
60 | pocket_xh: pocket atoms (no noised added) [B, N_p, h_p + 3]
61 | extension_masks: mask on fragment extension atoms [B, N]
62 | scaffold_masks: mask on scaffold atoms [B, N]
63 | anchor_masks: mask on anchor atoms [B, N]
64 | pocket_masks: masking on all the pocket atoms [B, N_p]
65 | output:
66 | (x_out,h_out) for ligand
67 | """
68 | bs, n_lig_nodes = xh.shape[0], xh.shape[1]
69 | n_pocket_nodes = pocket_xh.shape[1]
70 |
71 | N = n_lig_nodes + n_pocket_nodes
72 |
73 | node_mask = (scaffold_mask.bool() | extension_mask.bool()) # [B, N_l]
74 | xh = xh[node_mask] # [N_l, h_l+3]
75 | pocket_xh = pocket_xh[pocket_mask.bool()] # [N_p, h_p+3]
76 |
77 | x_atoms = xh[:, :self.n_dims].clone() # [N_l,3]
78 | h_atoms = xh[:, self.n_dims:].clone() # [N_l,nf]
79 |
80 | x_pocket = pocket_xh[:, :self.n_dims].clone() # [N_p, 3]
81 | h_pocket = pocket_xh[:, self.n_dims:].clone() # [N_p, hp]
82 |
83 | h_atoms = self.atom_encoder(h_atoms) # [N_l, joint_nf]
84 | h_pocket = self.pocket_encoder(h_pocket) # [N_p, joint_nf]
85 |
86 | x = torch.cat((x_atoms, x_pocket), dim=0) # [N_l+N_p, 3]
87 | h = torch.cat((h_atoms, h_pocket), dim=0) # [N_l+N_p, joint_nf]
88 |
89 | batch_mask_ligand = self.get_batch_mask(node_mask, device=x.device) # [N_l]
90 | batch_mask_pocket = self.get_batch_mask(pocket_mask, device=x.device) # [N_p]
91 | mask = torch.cat([batch_mask_ligand, batch_mask_pocket], dim=0) # [N_l+N_p]
92 |
93 | new_anchor_mask = torch.cat([anchors[node_mask], pocket_anchors[pocket_mask.bool()]], dim=0).unsqueeze(-1)
94 | new_scaffold_msak = torch.cat([scaffold_mask[node_mask], torch.zeros_like(batch_mask_pocket, device=xh.device)], dim=0).unsqueeze(-1)
95 | new_pocket_mask = torch.cat([torch.zeros_like(batch_mask_ligand, device=xh.device), torch.ones_like(batch_mask_pocket)], dim=0).unsqueeze(-1)
96 |
97 | h = torch.cat([h, new_anchor_mask, new_scaffold_msak, new_pocket_mask], dim=1) # [N_l+N_p, joint_nf+3]
98 |
99 | if self.condition_time:
100 | if np.prod(t.size()) == 1:
101 | # t is the same for all elements in batch.
102 | h_time = torch.empty_like(h[:, 0:1]).fill_(t.item())
103 | else:
104 | # t is different over the batch dimension.
105 | h_time = t[mask]
106 | h = torch.cat([h, h_time], dim=1)
107 |
108 | edges = self.get_edges_cutoff(batch_mask_ligand, batch_mask_pocket, x_atoms, x_pocket) # [2, num_edges]
109 | assert torch.all(mask[edges[0]] == mask[edges[1]])
110 |
111 | # --------------- apply the GVP dynamics ----------
112 | h_final, pos_out = self.dynamics(h, x, edges) # [N_l+N_p, joint_nf], [N_l+N_p, 3]
113 | pos_out = pos_out.reshape(-1,3) # [N_l+N_p, 3]
114 |
115 | # decode atoms
116 | h_final_atoms = self.atom_decoder(h_final[:len(batch_mask_ligand)]) # [N_l, h_l]
117 |
118 | vel_ligand = pos_out[:len(batch_mask_ligand)] # [N_l, 3]
119 | vel_h_ligand = torch.cat([vel_ligand, h_final_atoms], dim=1) # [N_l, h_l+3]
120 |
121 | # convert to batch
122 | num_atoms = node_mask.sum(dim=1).int() # [B]
123 | reshaped_vel_h = torch.zeros(bs, n_lig_nodes, vel_h_ligand.shape[-1]).to(xh.device)
124 | positions = torch.zeros_like(batch_mask_ligand).to(xh.device)
125 | for idx in range(bs):
126 | positions[batch_mask_ligand == idx] = torch.arange(num_atoms[idx]).to(xh.device)
127 | reshaped_vel_h[batch_mask_ligand, positions] = vel_h_ligand
128 |
129 | return reshaped_vel_h # [B, N_l, h_l+3]
130 |
131 | @staticmethod
132 | def get_dist_edges(x, node_mask, batch_mask):
133 | node_mask = node_mask.squeeze().bool()
134 | batch_adj = (batch_mask[:, None] == batch_mask[None, :])
135 | nodes_adj = (node_mask[:, None] & node_mask[None, :])
136 | dists_adj = (torch.cdist(x, x) <= 7)
137 | rm_self_loops = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device)
138 | adj = batch_adj & nodes_adj & dists_adj & rm_self_loops
139 | edges = torch.stack(torch.where(adj))
140 | return edges
141 |
142 | def get_edges_cutoff(self, batch_mask_ligand, batch_mask_pocket, x_ligand, x_pocket):
143 |
144 | adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :]
145 | adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :]
146 | adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :]
147 |
148 | if self.edge_cutoff_l is not None:
149 | adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l)
150 |
151 | if self.edge_cutoff_p is not None:
152 | adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p)
153 |
154 | if self.edge_cutoff_i is not None:
155 | adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i)
156 |
157 | adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1),
158 | torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0)
159 | edges = torch.stack(torch.where(adj), dim=0)
160 | return edges
161 |
162 | @staticmethod
163 | def get_batch_mask(mask, device):
164 | n_nodes = mask.float().sum(dim=1).int()
165 | batch_size = mask.shape[0]
166 | batch_mask = torch.cat([torch.ones(n_nodes[i]) * i for i in range(batch_size)]).long().to(device)
167 | return batch_mask
--------------------------------------------------------------------------------
/src/extension_size.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | from torch.distributions.categorical import Categorical
6 |
7 | class DistributionNodes:
8 | def __init__(self, histogram):
9 |
10 | self.n_nodes = []
11 | prob = []
12 | self.keys = {}
13 | for i, nodes in enumerate(histogram):
14 | self.n_nodes.append(nodes)
15 | self.keys[nodes] = i
16 | prob.append(histogram[nodes])
17 | self.n_nodes = torch.tensor(self.n_nodes)
18 | prob = np.array(prob)
19 | prob = prob / np.sum(prob)
20 |
21 | self.prob = torch.from_numpy(prob).float()
22 | self.m = Categorical(torch.tensor(prob))
23 |
24 | def sample(self, n_samples=1):
25 | idx = self.m.sample((n_samples,))
26 | return self.n_nodes[idx]
27 |
28 | def log_prob(self, batch_n_nodes):
29 | assert len(batch_n_nodes.size()) == 1
30 |
31 | idcs = [self.keys[i.item()] for i in batch_n_nodes]
32 | idcs = torch.tensor(idcs).to(batch_n_nodes.device)
33 |
34 | log_p = torch.log(self.prob + 1e-30)
35 | log_p = log_p.to(batch_n_nodes.device)
36 |
37 | log_probs = log_p[idcs]
38 |
39 | return log_probs
40 |
41 |
42 |
--------------------------------------------------------------------------------
/src/fragment_size_gnn.py:
--------------------------------------------------------------------------------
1 | from src.egnn import GCL, GaussianSmearing
2 | import torch.nn as nn
3 | import torch
4 | from src.egnn import coord2diff
5 | from torch_scatter import scatter_mean
6 |
7 | class FragSizeGNN(nn.Module):
8 | def __init__(self,
9 | lig_nf,
10 | pocket_nf,
11 | joint_nf,
12 | hidden_nf,
13 | out_node_nf, # number of classes (fragment sizes)
14 | n_layers,
15 | normalization=True,
16 | attention=True,
17 | normalization_factor=100,
18 | aggregation_method='sum',
19 | edge_cutoff_ligand=None,
20 | edge_cutoff_pocket=5,
21 | edge_cutoff_interaction=5,
22 | dataset_type='CrossDock',
23 | gaussian_expansion=True,
24 | num_gaussians=16):
25 | super(FragSizeGNN, self).__init__()
26 |
27 | self.dataset_type = dataset_type
28 | if self.dataset_type == 'CrossDock':
29 | context_node_nf = 3 # mask on the pocket atoms and anchor points
30 |
31 | if gaussian_expansion:
32 | self.gauss_exp = GaussianSmearing(start=0., stop=5., num_gaussians=num_gaussians)
33 | in_edge_nf = num_gaussians
34 |
35 | self.hidden_nf = hidden_nf
36 | self.out_node_nf = out_node_nf
37 | self.n_layers = n_layers
38 | self.normalization = normalization
39 | self.attention = attention
40 | self.normalization_factor = normalization_factor
41 | self.gaussian_expansion = gaussian_expansion
42 | self.edge_cutoff_l = edge_cutoff_ligand
43 | self.edge_cutoff_p = edge_cutoff_pocket
44 | self.edge_cutoff_i = edge_cutoff_interaction
45 |
46 | self.mol_encoder = nn.Sequential(
47 | nn.Linear(lig_nf, joint_nf),
48 | )
49 |
50 | self.pocket_encoder = nn.Sequential(
51 | nn.Linear(pocket_nf, joint_nf),
52 | )
53 |
54 | self.embed_both = nn.Linear(joint_nf+context_node_nf, hidden_nf) # concatenate the context features to joint space
55 |
56 | self.gcl1 = GCL(
57 | input_nf=self.hidden_nf,
58 | output_nf=self.hidden_nf,
59 | hidden_nf=self.hidden_nf,
60 | normalization_factor=normalization_factor,
61 | aggregation_method=aggregation_method,
62 | edges_in_d=in_edge_nf,
63 | activation=nn.ReLU(),
64 | attention=attention,
65 | normalization=normalization
66 | )
67 |
68 | layers = []
69 | for i in range(n_layers - 1):
70 | layer = GCL(
71 | input_nf=self.hidden_nf,
72 | output_nf=self.hidden_nf,
73 | hidden_nf=self.hidden_nf,
74 | normalization_factor=normalization_factor,
75 | aggregation_method=aggregation_method,
76 | edges_in_d=in_edge_nf,
77 | activation=nn.ReLU(),
78 | attention=attention,
79 | normalization=normalization
80 | )
81 | layers.append(layer)
82 |
83 | self.gcl_layers = nn.ModuleList(layers)
84 | self.embedding_out = nn.Linear(self.hidden_nf, self.out_node_nf)
85 | self.act = nn.ReLU()
86 |
87 | self.edge_cache = {}
88 | #self.lin_out = nn.Linear(self.out_node_nf, 1)
89 |
90 | def forward(self, mol_x, mol_h, node_mask, pocket_x, pocket_h, pocket_mask, anchors, pocket_anchors):
91 | """
92 | mol_x: [B, N, 3] positions of scaffold atoms
93 | mol_h: [B, N, nf] onehot of scaffold atoms
94 | node_mask: [B, N] only for scaffold-based
95 | pocket_x: [B, N, 3] positions of pocket atoms
96 | pocket_h: [B, N, nf] onehot of pocket atoms
97 | anchors: [B, N, 3] positions of anchor points
98 | pocket_anchors: [B, N, 3] positions of anchor points
99 | """
100 | bs, n_nodes_lig = mol_x.shape[0], mol_x.shape[1]
101 | n_nodes_pocket = pocket_x.shape[1]
102 | node_mask = node_mask.squeeze()
103 |
104 | N = n_nodes_lig + n_nodes_pocket
105 | mol_x = mol_x[node_mask.bool()] # [N_l, 3]
106 | mol_h = mol_h[node_mask.bool()] # [N_l, nf]
107 |
108 | pocket_x = pocket_x[pocket_mask.bool()] # [N_p, 3]
109 | pocket_h = pocket_h[pocket_mask.bool()] # [N_p, nf]
110 |
111 | mol_h = self.mol_encoder(mol_h) # [N_l, joint_nf]
112 | pocket_h = self.pocket_encoder(pocket_h)
113 |
114 | h = torch.cat([mol_h, pocket_h], dim=0) # [N, joint_nf]
115 |
116 | batch_mask_ligand = self.get_batch_mask(node_mask, device=mol_x.device) # [N_l]
117 | batch_mask_pocket = self.get_batch_mask(pocket_mask, device=mol_x.device) # [N_p]
118 | new_anchor_mask = torch.cat([anchors[node_mask.bool()], pocket_anchors[pocket_mask.bool()]], dim=0).unsqueeze(-1)
119 | new_scaffold_mask = torch.cat([torch.ones_like(batch_mask_ligand, device=mol_x.device), torch.zeros_like(batch_mask_pocket)], dim=0).unsqueeze(-1)
120 | new_pocket_mask = torch.cat([torch.zeros_like(batch_mask_ligand), torch.ones_like(batch_mask_pocket)], dim=0).unsqueeze(-1)
121 |
122 | h = torch.cat([h, new_anchor_mask, new_scaffold_mask, new_pocket_mask], dim=1) # [N, joint_nf+2]
123 | x = torch.cat([mol_x, pocket_x], dim=0) # [N, 3]
124 |
125 | mask = torch.cat([batch_mask_ligand, batch_mask_pocket], dim=0) # [N]
126 | device = mol_x.device
127 |
128 | h = self.embed_both(h)
129 | edges = self.get_edges_cutoff(batch_mask_ligand, batch_mask_pocket, mol_x, pocket_x) # [2, E]
130 |
131 | # selected only edges based on a 7A distance (all protein and scaffold atoms considered)
132 | distances, _ = coord2diff(x, edges) # TODO: consider adding more edge info such as the type of bond
133 | if self.gaussian_expansion:
134 | distances = self.gauss_exp(distances)
135 |
136 | for gcl in self.gcl_layers:
137 | h, _ = gcl(h, edges, edge_attr=distances, node_mask=None, edge_mask=None)
138 |
139 | h_final = self.act(self.embedding_out(h)) # [N, out_node_nf]
140 |
141 | # convert to batch
142 | #out = scatter_mean(h_final, mask, dim=0, dim_size=bs) # [B, out_node_nf]
143 | num_atoms = node_mask.sum(dim=1).int() + pocket_mask.sum(dim=1).int()
144 | reshaped_out = torch.zeros(bs, N, h_final.shape[-1], dtype=h.dtype, device=h.device)
145 | positions = torch.zeros_like(mask).to(h.device)
146 | for idx in range(bs):
147 | positions[mask == idx] = torch.arange(num_atoms[idx], device=h.device)
148 | reshaped_out[mask, positions] = h_final
149 | return reshaped_out # [B, N, out_node_nf]
150 |
151 | def get_edges_cutoff(self, batch_mask_ligand, batch_mask_pocket, x_ligand, x_pocket):
152 |
153 | adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :]
154 | adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :]
155 | adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :]
156 |
157 | if self.edge_cutoff_l is not None:
158 | adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l)
159 |
160 | if self.edge_cutoff_p is not None:
161 | adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p)
162 |
163 | if self.edge_cutoff_i is not None:
164 | adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i)
165 |
166 | adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1),
167 | torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0)
168 | edges = torch.stack(torch.where(adj), dim=0)
169 | return edges
170 |
171 | @staticmethod
172 | def get_batch_mask(mask, device):
173 | n_nodes = mask.float().sum(dim=1).int()
174 | batch_size = mask.shape[0]
175 | batch_mask = torch.cat([torch.ones(n_nodes[i]) * i for i in range(batch_size)]).long().to(device)
176 | return batch_mask
--------------------------------------------------------------------------------
/src/gvp.py:
--------------------------------------------------------------------------------
1 | # GVP implementation from DiffHopp https://github.com/jostorge/diffusion-hopping/tree/main
2 |
3 | import math
4 | from typing import Tuple, Union
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | s_V = Tuple[torch.Tensor, torch.Tensor]
11 |
12 |
13 | # Relevant papers:
14 | # Learning from Protein Structure with Geometric Vector Perceptrons,
15 | # Equivariant Graph Neural Networks for 3D Macromolecular Structure,
16 | class GVP(nn.Module):
17 | def __init__(
18 | self,
19 | in_dims: Tuple[int, int],
20 | out_dims: Tuple[int, int],
21 | activations=(F.relu, torch.sigmoid),
22 | vector_gate: bool = False,
23 | eps: float = 1e-4,
24 | ) -> None:
25 | super().__init__()
26 | in_scalar, in_vector = in_dims
27 | out_scalar, out_vector = out_dims
28 | self.sigma, self.sigma_plus = activations
29 |
30 | if self.sigma is None:
31 | self.sigma = nn.Identity()
32 | if self.sigma_plus is None:
33 | self.sigma_plus = nn.Identity()
34 |
35 | self.h = max(in_vector, out_vector)
36 | self.W_h = nn.Parameter(torch.empty((self.h, in_vector)))
37 | self.W_mu = nn.Parameter(torch.empty((out_vector, self.h)))
38 |
39 | self.W_m = nn.Linear(self.h + in_scalar, out_scalar)
40 | self.v = in_vector
41 | self.mu = out_vector
42 | self.n = in_scalar
43 | self.m = out_scalar
44 | self.vector_gate = vector_gate
45 |
46 | if vector_gate:
47 | self.sigma_g = nn.Sigmoid()
48 | self.W_g = nn.Linear(out_scalar, out_vector)
49 |
50 | self.eps = eps
51 | self.reset_parameters()
52 |
53 | def reset_parameters(self):
54 | torch.nn.init.kaiming_uniform_(self.W_h, a=math.sqrt(5))
55 | torch.nn.init.kaiming_uniform_(self.W_mu, a=math.sqrt(5))
56 | self.W_m.reset_parameters()
57 | if self.vector_gate:
58 | self.W_g.reset_parameters()
59 |
60 | def forward(self, x: Union[torch.Tensor, s_V]) -> Union[torch.Tensor, s_V]:
61 | """Geometric vector perceptron"""
62 | s, V = (
63 | x if self.v > 0 else (x, torch.empty((x.shape[0], 0, 3), device=x.device))
64 | )
65 |
66 | assert (
67 | s.shape[-1] == self.n
68 | ), f"{s.shape[-1]} != {self.n} Scalar dimension mismatch"
69 | assert (
70 | V.shape[-2] == self.v
71 | ), f" {V.shape[-2]} != {self.v} Vector dimension mismatch"
72 | assert V.shape[0] == s.shape[0], "Batch size mismatch"
73 |
74 | V_h = self.W_h @ V
75 | V_mu = self.W_mu @ V_h
76 | s_h = torch.clip(torch.norm(V_h, dim=-1), min=self.eps)
77 | s_hn = torch.cat([s, s_h], dim=-1)
78 | s_m = self.W_m(s_hn)
79 | s_dash = self.sigma(s_m)
80 | if self.vector_gate:
81 | V_dash = self.sigma_g(self.W_g(self.sigma_plus(s_m)))[..., None] * V_mu
82 | else:
83 | v_mu = torch.clip(torch.norm(V_mu, dim=-1, keepdim=True), min=self.eps)
84 | V_dash = self.sigma_plus(v_mu) * V_mu
85 | return (s_dash, V_dash) if self.mu > 0 else s_dash
--------------------------------------------------------------------------------
/src/gvp_model.py:
--------------------------------------------------------------------------------
1 | # GVP implementation from DiffHopp https://github.com/jostorge/diffusion-hopping/tree/main
2 | from typing import Tuple, Union, Optional
3 |
4 | import torch
5 | from torch import nn as nn
6 | from torch.nn import functional as F
7 |
8 | from src.conv_layer import GVPConvLayer
9 | from src.gvp import GVP, s_V
10 | from src.layer_norm import GVPLayerNorm
11 |
12 | class GVPNetwork(nn.Module):
13 | def __init__(
14 | self,
15 | in_dims: Tuple[int, int],
16 | out_dims: Tuple[int, int],
17 | hidden_dims: Tuple[int, int],
18 | num_layers: int,
19 | attention: bool = False,
20 | normalization_factor: float=100.0,
21 | aggr: str = "add",
22 | activations=(F.silu, None),
23 | vector_gate: bool = True,
24 | eps=1e-4
25 | ) -> None:
26 | super().__init__()
27 | edge_dims = (1,1)
28 |
29 | self.eps = eps
30 | self.embedding_in = nn.Sequential(
31 | GVPLayerNorm(in_dims),
32 | GVP(
33 | in_dims,
34 | hidden_dims,
35 | activations=(None,None),
36 | vector_gate=vector_gate
37 | ),
38 | )
39 | self.embedding_out = nn.Sequential(
40 | GVPLayerNorm(hidden_dims),
41 | GVP(
42 | hidden_dims,
43 | out_dims,
44 | activations=activations,
45 | vector_gate=vector_gate
46 | ),
47 | )
48 | self.edge_embedding = nn.Sequential(
49 | GVPLayerNorm(edge_dims),
50 | GVP(
51 | edge_dims,
52 | (hidden_dims[0],1),
53 | activations=(None, None),
54 | vector_gate=vector_gate
55 | )
56 | )
57 |
58 | self.layers = nn.ModuleList(
59 | [
60 | GVPConvLayer(
61 | hidden_dims,
62 | (hidden_dims[0], 1),
63 | activations=activations,
64 | vector_gate=vector_gate,
65 | residual=True,
66 | attention=attention,
67 | aggr=aggr,
68 | normalization_factor=normalization_factor,
69 | )
70 | for _ in range(num_layers)
71 | ]
72 | )
73 |
74 | def get_edge_attr(self, edge_index, pos) -> s_V:
75 | V = pos[edge_index[0]] - pos[edge_index[1]] # [n_edges, 3]
76 | s = torch.linalg.norm(V, dim=-1, keepdim=True) # [n_edges, 1]
77 | V = (V / torch.clip(s, min=self.eps))[..., None, :] # [n_edges, 1, 3]
78 | return s, V
79 |
80 | def forward(self, h, pos, edge_index) -> s_V:
81 | edge_attr = self.get_edge_attr(edge_index, pos)
82 | edge_attr = self.edge_embedding(edge_attr)
83 |
84 | h = self.embedding_in(h)
85 | for layer in self.layers:
86 | h = layer(h, edge_index, edge_attr)
87 |
88 | return self.embedding_out(h)
--------------------------------------------------------------------------------
/src/layer_norm.py:
--------------------------------------------------------------------------------
1 | # GVP implementation from DiffHopp https://github.com/jostorge/diffusion-hopping/tree/main
2 |
3 | import math
4 | from typing import Tuple, Optional, Union
5 |
6 | import torch
7 | from torch import nn as nn
8 | s_V = Tuple[torch.Tensor, torch.Tensor]
9 |
10 | class GVPLayerNorm(nn.Module):
11 | def __init__(self, dims: Tuple[int, int], eps: float=0.00001) ->None:
12 | super().__init__()
13 | self.eps = math.sqrt(eps)
14 | self.scalar_size, self.vector_size = dims
15 | self.feature_layer_norm = nn.LayerNorm(self.scalar_size, eps=eps)
16 |
17 | def forward(self, x:Union[torch.Tensor, s_V]) -> Union[torch.Tensor, s_V]:
18 | if self.vector_size == 0:
19 | return self.feature_layer_norm(x)
20 |
21 | s, V = x
22 | s = self.feature_layer_norm(s)
23 | norm = torch.clip(
24 | torch.linalg.vector_norm(V, dim=(-1,-2), keepdim=True)
25 | / math.sqrt(self.vector_size),
26 | min=self.eps
27 | )
28 |
29 | V = V / norm
30 | return s, V
--------------------------------------------------------------------------------
/src/lightning_anchor_gnn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import pytorch_lightning as pl
4 |
5 | from torch.nn.functional import sigmoid
6 | from src.datasets import HierCrossDockDataset, get_dataloader, collate_pocket_aux
7 | from src.anchor_gnn import AnchorGNNPocket
8 |
9 | from typing import Dict, List, Optional
10 | from tqdm import tqdm
11 | import os
12 | import torch.nn as nn
13 |
14 | def get_activation(activation):
15 | if activation == 'silu':
16 | return torch.nn.SiLU()
17 | else:
18 | raise Exception('activation fn not found. add it here')
19 |
20 | class MaskedBCEWithLogitsLoss(torch.nn.Module):
21 | """ masks the pocket atoms for anchor prediction loss calculation """
22 | def __init__(self):
23 | super(MaskedBCEWithLogitsLoss, self).__init__()
24 | self.loss = torch.nn.BCEWithLogitsLoss(reduction='none')
25 |
26 | def forward(self, input, target, mask=None, return_mean=False):
27 | masked_loss = self.loss(input, target)
28 |
29 | if mask is not None:
30 | masked_loss = masked_loss * mask.float()
31 | if return_mean:
32 | if mask is not None:
33 | return masked_loss.sum() / mask.sum().float()
34 | else:
35 | return masked_loss.mean()
36 | else:
37 | return masked_loss
38 |
39 | class AnchorGNN_pl(pl.LightningModule):
40 | train_dataset = None
41 | val_dataset = None
42 | starting_epoch = None
43 | metrics: Dict[str, List[float]] = {}
44 |
45 | def __init__(
46 | self,
47 | lig_node_nf,
48 | pocket_node_nf,
49 | joint_nf,
50 | n_dims,
51 | hidden_nf,
52 | activation,
53 | tanh,
54 | n_layers,
55 | attention,
56 | norm_constant,
57 | data_path,
58 | train_data_prefix,
59 | val_data_prefix,
60 | batch_size,
61 | lr,
62 | test_epochs,
63 | dataset_type,
64 | normalization_factor,
65 | gaussian_expansion=False,
66 | normalization=None,
67 | include_charges=False,
68 | samples_dir=None,
69 | train_dataframe_path='paths_train.csv',
70 | val_dataframe_path='paths_val.csv',
71 | num_workers=0,
72 | ):
73 |
74 | super(AnchorGNN_pl, self).__init__()
75 | self.save_hyperparameters()
76 | self.data_path = data_path
77 | self.train_data_prefix = train_data_prefix
78 | self.val_data_prefix = val_data_prefix
79 | self.batch_size = batch_size
80 | self.lr = lr
81 | self.test_epochs = test_epochs
82 | self.samples_dir = samples_dir
83 | self.n_dims = n_dims
84 | self.num_classes = lig_node_nf - include_charges
85 | self.include_charges = include_charges
86 | self.train_dataframe_path = train_dataframe_path
87 | self.val_dataframe_path = val_dataframe_path
88 | self.num_workers = num_workers
89 | self.n_layers = n_layers
90 | self.attention = attention
91 | self.normalization_factor = normalization_factor
92 |
93 | self.joint_nf = joint_nf
94 | self.lig_node_nf = lig_node_nf
95 | self.pocket_node_nf = pocket_node_nf
96 |
97 | self.norm_constant = norm_constant
98 | self.tanh = tanh
99 | self.dataset_type = dataset_type
100 | self.gaussian_expansion = gaussian_expansion
101 | #self.bce_loss = MaskedBCEWithLogitsLoss()
102 |
103 | if self.dataset_type == 'GEOM':
104 | self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')
105 | elif self.dataset_type == 'CrossDock':
106 | self.bce_loss = MaskedBCEWithLogitsLoss()
107 |
108 | if type(activation) is str:
109 | activation = get_activation(activation)
110 |
111 | self.anchor_predictor = AnchorGNNPocket(
112 | lig_nf=lig_node_nf,
113 | pocket_nf=pocket_node_nf,
114 | joint_nf=joint_nf,
115 | hidden_nf=hidden_nf,
116 | out_node_nf=hidden_nf,
117 | n_layers=4,
118 | normalization_factor=normalization_factor,
119 | normalization=normalization,
120 | attention=True,
121 | aggregation_method='sum',
122 | dist_cutoff=7,
123 | gaussian_expansion=gaussian_expansion,
124 | edge_cutoff_ligand=None,
125 | edge_cutoff_pocket=4.5,
126 | edge_cutoff_interaction=4.5
127 | )
128 |
129 | def setup(self, stage: Optional[str]=None):
130 | if stage == 'fit':
131 | self.train_dataset = HierCrossDockDataset(
132 | data_path=self.data_path,
133 | prefix=self.train_data_prefix,
134 | device=self.device,
135 | dataframe_path=self.train_dataframe_path
136 | )
137 | print('loaded train data')
138 | self.val_dataset = HierCrossDockDataset(
139 | data_path=self.data_path,
140 | prefix=self.val_data_prefix,
141 | device=self.device,
142 | dataframe_path=self.val_dataframe_path
143 | )
144 | print('loaded validation data')
145 |
146 | elif stage == 'val':
147 | self.val_dataset = HierCrossDockDataset(
148 | data_path=self.data_path,
149 | prefix=self.val_data_prefix,
150 | device=self.device,
151 | dataframe_path=self.val_dataframe_path
152 | )
153 | else:
154 | raise NotImplementedError
155 |
156 | def train_dataloader(self):
157 | return get_dataloader(self.train_dataset, self.batch_size, num_workers=self.num_workers, collate_fn=collate_pocket_aux, shuffle=True)
158 |
159 | def val_dataloader(self):
160 | return get_dataloader(self.val_dataset, self.batch_size, num_workers=self.num_workers, collate_fn=collate_pocket_aux)
161 |
162 | def test_dataloader(self):
163 | return get_dataloader(self.test_dataset, self.batch_size, num_workers=self.num_workers, collate_fn=collate_pocket_aux)
164 |
165 | def forward(self, data, training):
166 |
167 | scaff_x = data['position_aux'].to(self.device) # [B, Ns, 3]
168 | scaff_h = data['onehot_aux'].to(self.device) # [B, Ns, nf]
169 | scaffold_masks = data['scaffold_masks_aux'].to(self.device) # [B, Ns]
170 | pocket_masks = data['pocket_mask_aux'].to(self.device) # [B, Np]
171 | scaffold_anchors = data['anchors_aux'].to(self.device) # [B,Ns]
172 | pocket_x = data['pocket_coords_aux'].to(self.device)
173 | pocket_h = data['pocket_onehot_aux'].to(self.device)
174 |
175 | B, N = scaff_x.shape[0], scaff_x.shape[1]
176 |
177 | B = scaff_x.shape[0]
178 | N_s = scaff_x.shape[1]
179 | N_p = pocket_x.shape[1]
180 | N = N_s+N_p
181 |
182 | anchor_out = self.anchor_predictor.forward(mol_x=scaff_x, # [B, Ns, 3]
183 | mol_h=scaff_h, # [B, Ns, nf]
184 | pocket_x=pocket_x, # [B, Np, 3]
185 | pocket_h=pocket_h, # [B, Np, hp]
186 | node_mask=scaffold_masks, # [B, Np] # mask on both pocket and scaffold
187 | pocket_mask=pocket_masks,
188 | ) # [B, Np] masks only on pocket atoms)
189 |
190 | anchor_loss = self.bce_loss(anchor_out.view(B*N_s, 1), scaffold_anchors.view(B*N_s, 1), scaffold_masks.view(B*N_s, 1), return_mean=True)
191 | #anchor_loss = anchor_loss[not_first_frag_mask].mean()
192 | return anchor_out, anchor_loss
193 |
194 | def training_step(self, data, *args):
195 | _, loss = self.forward(data, training=True)
196 | training_metrics = {
197 | 'loss': loss
198 | }
199 | for metric_name, metric in training_metrics.items():
200 | self.metrics.setdefault(f'{metric_name}/train', []).append(metric)
201 | self.log(f'{metric_name}/train', metric, on_step=True, on_epoch=True, batch_size=self.batch_size, prog_bar=True)
202 | self.metrics.clear()
203 | return training_metrics
204 |
205 | def validation_step(self, data, *args):
206 | _, loss = self.forward(data, training=False)
207 | validation_metrics = {
208 | 'loss': loss
209 | }
210 | return validation_metrics
211 |
212 | def training_epoch_end(self, training_step_outputs):
213 | for metric in training_step_outputs[0].keys():
214 | avg_metric = self.aggregate_metric(training_step_outputs, metric)
215 | self.metrics.setdefault(f'{metric}/train', []).append(avg_metric)
216 | self.log(f'{metric}/train', avg_metric, prog_bar=True)
217 |
218 | self.metrics.clear() # free up memory
219 |
220 | def validation_epoch_end(self, validation_step_outputs):
221 | for metric in validation_step_outputs[0].keys():
222 | avg_metric = self.aggregate_metric(validation_step_outputs, metric)
223 | self.metrics.setdefault(f'{metric}/val', []).append(avg_metric)
224 | self.log(f'{metric}/val', avg_metric, prog_bar=True)
225 |
226 | self.metrics.clear()
227 |
228 | def configure_optimizers(self):
229 | return torch.optim.AdamW(self.anchor_predictor.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12)
230 |
231 | @staticmethod
232 | def aggregate_metric(step_outputs, metric):
233 | return torch.tensor([out[metric] for out in step_outputs]).mean()
234 |
--------------------------------------------------------------------------------
/src/noise.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import math
4 | import numpy as np
5 |
6 | def clip_noise_schedule(alphas2, clip_value=0.001):
7 | """
8 | For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during
9 | sampling.
10 | """
11 | alphas2 = np.concatenate([np.ones(1), alphas2], axis=0)
12 |
13 | alphas_step = (alphas2[1:] / alphas2[:-1])
14 |
15 | alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.)
16 | alphas2 = np.cumprod(alphas_step, axis=0)
17 | return alphas2
18 |
19 | def polynomial_schedule(timesteps: int, s=1e-4, power=3.):
20 | """
21 | A noise schedule based on a simple polynomial equation: 1 - x^power.
22 | """
23 | steps = timesteps + 1
24 | x = np.linspace(0, steps, steps)
25 | alphas2 = (1 - np.power(x / steps, power)) ** 2
26 |
27 | alphas2 = clip_noise_schedule(alphas2, clip_value=0.001)
28 |
29 | precision = 1 - 2 * s
30 |
31 | alphas2 = precision * alphas2 + s
32 |
33 | return alphas2
34 |
35 |
36 | def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1):
37 | """
38 | cosine schedule
39 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
40 | """
41 | steps = timesteps + 2
42 | x = np.linspace(0, steps, steps)
43 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
44 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
45 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
46 | betas = np.clip(betas, a_min=0, a_max=0.999)
47 | alphas = 1. - betas
48 | alphas_cumprod = np.cumprod(alphas, axis=0)
49 |
50 | if raise_to_power != 1:
51 | alphas_cumprod = np.power(alphas_cumprod, raise_to_power)
52 |
53 | return alphas_cumprod
54 |
55 |
56 | class PositiveLinear(torch.nn.Module):
57 | """Linear layer with weights forced to be positive."""
58 |
59 | def __init__(self, in_features: int, out_features: int, bias: bool = True,
60 | weight_init_offset: int = -2):
61 | super(PositiveLinear, self).__init__()
62 | self.in_features = in_features
63 | self.out_features = out_features
64 | self.weight = torch.nn.Parameter(
65 | torch.empty((out_features, in_features)))
66 | if bias:
67 | self.bias = torch.nn.Parameter(torch.empty(out_features))
68 | else:
69 | self.register_parameter('bias', None)
70 | self.weight_init_offset = weight_init_offset
71 | self.reset_parameters()
72 |
73 | def reset_parameters(self) -> None:
74 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
75 |
76 | with torch.no_grad():
77 | self.weight.add_(self.weight_init_offset)
78 |
79 | if self.bias is not None:
80 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
81 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
82 | torch.nn.init.uniform_(self.bias, -bound, bound)
83 |
84 | def forward(self, x):
85 | positive_weight = F.softplus(self.weight)
86 | return F.linear(x, positive_weight, self.bias)
87 |
88 |
89 | class PredefinedNoiseSchedule(torch.nn.Module):
90 | """
91 | Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules.
92 | """
93 |
94 | def __init__(self, noise_schedule, timesteps, precision):
95 | super(PredefinedNoiseSchedule, self).__init__()
96 | self.timesteps = timesteps
97 |
98 | if noise_schedule == 'cosine':
99 | alphas2 = cosine_beta_schedule(timesteps)
100 | elif 'polynomial' in noise_schedule:
101 | splits = noise_schedule.split('_')
102 | assert len(splits) == 2
103 | power = float(splits[1])
104 | alphas2 = polynomial_schedule(timesteps, s=precision, power=power)
105 | else:
106 | raise ValueError(noise_schedule)
107 |
108 | # print('alphas2', alphas2)
109 |
110 | sigmas2 = 1 - alphas2
111 |
112 | log_alphas2 = np.log(alphas2)
113 | log_sigmas2 = np.log(sigmas2)
114 |
115 | log_alphas2_to_sigmas2 = log_alphas2 - log_sigmas2
116 |
117 | # print('gamma', -log_alphas2_to_sigmas2)
118 |
119 | self.gamma = torch.nn.Parameter(
120 | torch.from_numpy(-log_alphas2_to_sigmas2).float(),
121 | requires_grad=False)
122 |
123 | def forward(self, t):
124 | t_int = torch.round(t * self.timesteps).long()
125 | return self.gamma[t_int]
126 |
127 | class GammaNetwork(torch.nn.Module):
128 | """The gamma network models a monotonic increasing function. Construction as in the VDM paper."""
129 |
130 | def __init__(self):
131 | super().__init__()
132 |
133 | self.l1 = PositiveLinear(1, 1)
134 | self.l2 = PositiveLinear(1, 1024)
135 | self.l3 = PositiveLinear(1024, 1)
136 |
137 | self.gamma_0 = torch.nn.Parameter(torch.tensor([-5.]))
138 | self.gamma_1 = torch.nn.Parameter(torch.tensor([10.]))
139 | self.show_schedule()
140 |
141 | def show_schedule(self, num_steps=50):
142 | t = torch.linspace(0, 1, num_steps).view(num_steps, 1)
143 | gamma = self.forward(t)
144 | print('Gamma schedule:')
145 | print(gamma.detach().cpu().numpy().reshape(num_steps))
146 |
147 | def gamma_tilde(self, t):
148 | l1_t = self.l1(t)
149 | return l1_t + self.l3(torch.sigmoid(self.l2(l1_t)))
150 |
151 | def forward(self, t):
152 | zeros, ones = torch.zeros_like(t), torch.ones_like(t)
153 | # Not super efficient.
154 | gamma_tilde_0 = self.gamma_tilde(zeros)
155 | gamma_tilde_1 = self.gamma_tilde(ones)
156 | gamma_tilde_t = self.gamma_tilde(t)
157 |
158 | # Normalize to [0, 1]
159 | normalized_gamma = (gamma_tilde_t - gamma_tilde_0) / (
160 | gamma_tilde_1 - gamma_tilde_0)
161 |
162 | # Rescale to [gamma_0, gamma_1]
163 | gamma = self.gamma_0 + (self.gamma_1 - self.gamma_0) * normalized_gamma
164 |
165 | return gamma
--------------------------------------------------------------------------------
/train_anchor_predictor.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | from src.lightning_anchor_gnn import AnchorGNN_pl
5 | from src.utils import disable_rdkit_logging, Logger
6 | from pytorch_lightning import Trainer, callbacks, loggers
7 | from pytorch_lightning.loggers import TensorBoardLogger
8 |
9 | def find_last_checkpoint(checkpoints_dir):
10 | epoch2fname = [
11 | (int(fname.split('=')[1].split('.')[0]), fname)
12 | for fname in os.listdir(checkpoints_dir)
13 | if fname.endswith('.ckpt')
14 | ]
15 | latest_fname = max(epoch2fname, key=lambda t: t[0])[1]
16 | return os.path.join(checkpoints_dir, latest_fname)
17 |
18 | def main(args):
19 | run_name = args.exp_name
20 | experiment = run_name if args.resume is None else args.resume
21 | checkpoints_dir = os.path.join(args.checkpoints, experiment)
22 | os.makedirs(os.path.join(args.logs, 'general_logs', experiment), exist_ok=True)
23 | sys.stdout = Logger(logpath=os.path.join(args.logs, "general_logs", experiment, f'log.log'), syspart=sys.stdout)
24 | sys.stderr = Logger(logpath=os.path.join(args.logs, "general_logs", experiment, f'log.log'), syspart=sys.stderr)
25 |
26 | os.makedirs(checkpoints_dir, exist_ok=True)
27 | os.makedirs(args.logs, exist_ok=True)
28 | samples_dir = os.path.join(args.logs, 'samples', experiment)
29 |
30 | TB_Logger = TensorBoardLogger('tb_logs', name=experiment)
31 | wandb_logger = loggers.WandbLogger(
32 | save_dir=args.logs,
33 | project='diffusion-anchor-pred',
34 | name=experiment,
35 | id=experiment,
36 | resume='must' if args.resume is not None else 'allow'
37 | )
38 |
39 | if args.gaussian_expansion is not None:
40 | gaussian_expansion = True
41 | else:
42 | gaussian_expansion = False
43 |
44 | if args.use_guidance:
45 | use_guidance = True
46 | else:
47 | use_guidance = False
48 |
49 | if args.guidance_feature == 'QED' or args.guidance_feature == 'SA':
50 | guidance_classes = 6
51 | elif args.guidance_feature == 'Vina':
52 | guidance_classes = 6
53 | else:
54 | raise ValueError
55 |
56 | # ---------------------------------------------------------
57 | lig_nf = 10 # atom types
58 | pocket_nf = 25 # node features (4) + AA type (20) + BB (1)
59 | #context_node_nf = 3 # context is (anchors + scaffold_masks + pocket_masks )
60 | joint_nf = 32
61 |
62 | anchor_predictor = AnchorGNN_pl(
63 | lig_node_nf=lig_nf,
64 | pocket_node_nf=pocket_nf,
65 | joint_nf=joint_nf, # TODO: change this?
66 | n_dims=3,
67 | hidden_nf=args.nf,
68 | activation=args.activation,
69 | tanh=args.tanh,
70 | n_layers=args.n_layers,
71 | attention=args.attention,
72 | norm_constant=args.norm_constant,
73 | data_path=args.data,
74 | train_data_prefix=args.train_data_prefix,
75 | val_data_prefix=args.val_data_prefix,
76 | batch_size=args.batch_size,
77 | lr=args.lr,
78 | test_epochs=args.test_epochs,
79 | normalization_factor=args.normalization_factor,
80 | normalization=args.normalization,
81 | include_charges=False,
82 | samples_dir=None,
83 | train_dataframe_path='paths_train.csv',
84 | val_dataframe_path='paths_val.csv',
85 | num_workers=0,
86 | dataset_type=args.dataset_type,
87 | use_guidance=use_guidance,
88 | guidance_classes=guidance_classes,
89 | guidance_feature=args.guidance_feature,
90 | gaussian_expansion=gaussian_expansion)
91 |
92 | checkpoint_callback = callbacks.ModelCheckpoint(
93 | dirpath=checkpoints_dir,
94 | filename=experiment+'_{epoch:02}',
95 | monitor='loss/val',
96 | save_top_k=10
97 | )
98 |
99 | trainer = Trainer(
100 | max_epochs=args.n_epochs,
101 | logger=wandb_logger,
102 | callbacks=checkpoint_callback,
103 | accelerator='gpu',
104 | devices=[0,1],
105 | num_sanity_val_steps=0,
106 | enable_progress_bar=True,
107 | strategy='ddp',
108 | precision=16
109 | )
110 |
111 | if args.resume is None:
112 | last_checkpoint = None
113 | else:
114 | last_checkpoint = find_last_checkpoint(checkpoints_dir)
115 | print(f'Training will be resumed from the last checkpoint {last_checkpoint}')
116 | print('Start training')
117 | trainer.fit(model=anchor_predictor, ckpt_path=last_checkpoint)
118 |
119 | if __name__ == '__main__':
120 | p = argparse.ArgumentParser(description='anchor_predictor')
121 | p.add_argument('--data', action='store', type=str, default="")
122 | p.add_argument('--train-dataframe-path', action='store', type=str, default='paths_train.csv')
123 | p.add_argument('--valid-dataframe-path', action='store', type=str, default='paths_val.csv')
124 | p.add_argument('--train_data_prefix', action='store', type=str, default='train_data')
125 | p.add_argument('--val_data_prefix', action='store', type=str, default='val_data')
126 | p.add_argument('--checkpoints', action='store', type=str, default='checkpoints')
127 | p.add_argument('--logs', action='store', type=str, default='logs')
128 | p.add_argument('--device', action='store', type=str, default='cuda:1')
129 | p.add_argument('--trainer_params', type=dict, help='parameters with keywords of the lightning trainer')
130 | p.add_argument('--log_iterations', action='store', type=str, default=20)
131 | p.add_argument('--exp_name', type=str, default='test_1')
132 |
133 | p.add_argument('--n_epochs', type=int, default=400)
134 | p.add_argument('--batch_size', type=int, default=16)
135 | p.add_argument('--lr', type=float, default=5e-4)
136 |
137 | p.add_argument('--activation', type=str, default='silu', help='activation function')
138 | p.add_argument('--n_layers', type=int, default=4, help='number of layers')
139 | p.add_argument('--inv_sublayers', type=int, default=2, help='number of layers')
140 | p.add_argument('--nf', type=int, default=128, help='number of layers')
141 | p.add_argument('--tanh', type=eval, default=False, help='use tanh in the coord_mlp')
142 | p.add_argument('--attention', type=eval, default=False, help='use attention in the EGNN')
143 | p.add_argument('--norm_constant', type=float, default=100, help='diff/(|diff| + norm_constant)')
144 |
145 | p.add_argument('--resume', type=str, default=None, help='')
146 | p.add_argument('--start_epoch', type=int, default=0, help='')
147 | p.add_argument('--ema_decay', type=float, default=0.999, help='Amount of EMA decay, 0 means off. A reasonable value is 0.999.')
148 | p.add_argument('--test_epochs', type=int, default=100)
149 | p.add_argument('--aggregation_method', type=str, default='sum',help='"sum" or "mean"')
150 | p.add_argument('--normalization', type=str, default='batch_norm', help='batch_norm')
151 | p.add_argument('--normalization_factor', type=float, default=100, help="Normalize the sum aggregation of EGNN")
152 | p.add_argument('--dataset-type', type=str, default='GEOM', help='dataset-type can be GEOM or CrossDock for now')
153 |
154 | p.add_argument('--gaussian-expansion', action='store_true', default=False, help='whether to use gaussian expansion of distances')
155 | p.add_argument('--use-guidance', action='store_true', default=False, help='whether to train anchor-predictor for a specific guidance feature')
156 | p.add_argument('--guidance-feature', type=str, default='QED', help='guidance feature for adding to anchor predictor')
157 | args = p.parse_args()
158 | main(args=args)
--------------------------------------------------------------------------------
/utils/sample_frag_size.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | def sample_fragment_size(new_score, bin_edges, distributions):
5 | # Find which bin the new score belongs to
6 | bin_idx = np.digitize(new_score, bin_edges)
7 | # Get the probability distribution for the bin
8 | probabilities = distributions.loc[bin_idx].values
9 | discrete_values = distributions.columns.values
10 | # Sample a discrete number from the distribution
11 | return np.random.choice(discrete_values, p=probabilities)
12 |
13 | bounds = [4.1, 8.1, 12.1, 16.1]
14 | fragsize_prob = np.array([[7.11770964e-01, 1.53752812e-01, 1.06553153e-01, 1.95619411e-02,
15 | 2.82028835e-03, 5.54084154e-03, 0.00000000e+00, 0.00000000e+00,
16 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
17 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
18 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
19 | [1.20153410e-01, 4.57824749e-02, 3.31473162e-02, 2.52469754e-02,
20 | 1.16709580e-01, 4.48883000e-01, 1.22096176e-01, 4.48883000e-02,
21 | 2.69329800e-02, 1.61597880e-02, 0.00000000e+00, 0.00000000e+00,
22 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
23 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
24 | [7.74508034e-02, 3.24180141e-02, 1.07572679e-02, 7.19736364e-03,
25 | 1.01711551e-01, 2.79706766e-01, 1.27139439e-01, 7.62836633e-02,
26 | 7.62836633e-02, 1.14425495e-01, 5.08557756e-02, 2.54278878e-02,
27 | 1.27139439e-02, 3.81418317e-03, 2.54278878e-03, 1.27139439e-03,
28 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
29 | [4.83709778e-02, 1.47227584e-02, 3.28787372e-03, 4.42784727e-03,
30 | 6.65155568e-02, 2.08995151e-01, 1.13997355e-01, 1.13997355e-01,
31 | 1.04497576e-01, 1.42496694e-01, 6.64984572e-02, 4.74988980e-02,
32 | 2.84993388e-02, 1.89995592e-02, 9.49977961e-03, 4.74988980e-03,
33 | 2.84993388e-03, 9.49977961e-05, 0.00000000e+00, 0.00000000e+00],
34 | [1.77003281e-02, 2.15758183e-02, 8.85185643e-03, 3.94403048e-03,
35 | 5.92476134e-02, 1.69235378e-01, 1.26926533e-01, 1.10002995e-01,
36 | 8.46176888e-02, 1.52311840e-01, 1.18464764e-01, 5.92323822e-02,
37 | 3.38470755e-02, 1.69235378e-02, 8.46176888e-03, 5.07706133e-03,
38 | 3.38470755e-03, 1.69235378e-04, 8.46176888e-06, 1.69235378e-05]])
39 |
40 | fragsize_prob_df = pd.DataFrame(fragsize_prob, index=[0,1,2,3,4], columns=np.arange(1,21))
--------------------------------------------------------------------------------
/utils/visuals.py:
--------------------------------------------------------------------------------
1 | import py3Dmol
2 | from rdkit.Chem import AllChem
3 | from rdkit import Chem
4 | from rdkit.Geometry import Point3D
5 | from openbabel import openbabel
6 | import numpy as np
7 | from openbabel import openbabel
8 | import tempfile
9 |
10 |
11 | atom_dict = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9}
12 | idx2atom = {0:'C', 1:'N', 2:'O', 3:'S', 4:'B', 5:'Br', 6:'Cl', 7:'P', 8:'I', 9:'F'}
13 |
14 | def write_xyz_file(coords, atom_types, filename):
15 | out = f"{len(coords)}\n\n"
16 | assert len(coords) == len(atom_types)
17 | for i in range(len(coords)):
18 | out += f"{atom_types[i]} {coords[i, 0]:.3f} {coords[i, 1]:.3f} {coords[i, 2]:.3f}\n"
19 | with open(filename, 'w') as f:
20 | f.write(out)
21 |
22 | def visualize_molecules_grid(mols, grid_size=(3, 3), spacing=5.0, spin=True):
23 | viewer = py3Dmol.view(width=900, height=900)
24 |
25 | for i, mol in enumerate(mols):
26 | try:
27 | Chem.SanitizeMol(mol)
28 | except:
29 | print('couldnt sanitize')
30 | #AllChem.EmbedMolecule(mol) # Generate 3D coordinates
31 | #AllChem.MMFFOptimizeMolecule(mol, maxIters=500) # Optimize the geometry using MMFF94 force field
32 |
33 | # Calculate the grid position
34 | grid_x = i % grid_size[0]
35 | grid_y = i // grid_size[0]
36 |
37 | # Translate the molecule according to its position in the grid
38 | conf = mol.GetConformer()
39 | translation_vector = Point3D((grid_x * spacing) + (spacing / 2), (grid_y * spacing) + (spacing / 2), 0.0)
40 | for atom_idx in range(mol.GetNumAtoms()):
41 | atom_position = conf.GetAtomPosition(atom_idx)
42 | atom_position += translation_vector
43 | conf.SetAtomPosition(atom_idx, atom_position)
44 |
45 | mb = Chem.MolToMolBlock(mol)
46 | viewer.addModel(mb, 'sdf')
47 |
48 | #if spin:
49 | # viewer.spin({'x': 0, 'y': 1, 'z': 0})
50 |
51 | # Draw separating lines
52 | for i in range(grid_size[0] - 1):
53 | x = (i + 1) * spacing
54 | viewer.addLine({'start': {'x': x, 'y': 0, 'z': 0},
55 | 'end': {'x': x, 'y': grid_size[1] * spacing, 'z': 0},
56 | 'color': 'gray'})
57 | for i in range(grid_size[1] - 1):
58 | y = (i + 1) * spacing
59 | viewer.addLine({'start': {'x': 0, 'y': y, 'z': 0},
60 | 'end': {'x': grid_size[0] * spacing, 'y': y, 'z': 0},
61 | 'color': 'gray'})
62 |
63 | #viewer.spin({'x': 0, 'y': 1, 'z': 0}, origin=(grid_size[0] * spacing / 2, grid_size[1] * spacing / 2, 0))
64 | viewer.setStyle({}, {'stick': {'colorscheme': ['silverCarbon', 'redOxygen', 'blueNitrogen'], 'radius': 0.15, 'opacity': 1},
65 | 'sphere': {'colorscheme': ['silverCarbon', 'redOxygen', 'blueNitrogen'], 'radius': 0.35, 'opacity': 1}})
66 | viewer.zoomTo()
67 | viewer.show()
68 |
69 | def get_pocket_mol(pocket_coords, pocket_onehot):
70 | with tempfile.NamedTemporaryFile() as tmp:
71 | tmp_file = tmp.name
72 |
73 | atom_inds= pocket_onehot.argmax(1)
74 | atom_types = [idx2atom[x] for x in atom_inds]
75 | # write xyz file
76 | write_xyz_file(pocket_coords, atom_types, tmp_file)
77 |
78 | obConversion = openbabel.OBConversion()
79 | obConversion.SetInAndOutFormats('xyz', 'sdf')
80 | ob_mol = openbabel.OBMol()
81 | obConversion.ReadFile(ob_mol, tmp_file)
82 |
83 | obConversion.WriteFile(ob_mol, tmp_file)
84 | pocket_mol = Chem.SDMolSupplier(tmp_file, sanitize=False)[0]
85 |
86 | return pocket_mol
87 |
88 | def visualize_3d_pocket_molecule(pocket_mol, mol=None, spin=False, optimize_coords=False, sphere_positions1=None, sphere_positions2=None, rotate=None):
89 | viewer = py3Dmol.view()
90 |
91 | pocket_mol = Chem.RemoveHs(pocket_mol)
92 | pocket_mb = Chem.MolToMolBlock(pocket_mol)
93 | viewer.addModel(pocket_mb, 'sdf')
94 | viewer.setStyle({'model': -1}, {"sphere": {'color': 'grey', 'opacity': 0.8, 'radius':0.9}})
95 | #viewer.setStyle({'model': 0}, {'stick': {'colorscheme': ['whiteCarbon', 'redOxygen', 'blueNitrogen'], 'radius': 0.2, 'opacity': 1},
96 | # 'sphere': {'colorscheme': ['whiteCarbon', 'redOxygen', 'blueNitrogen'], 'radius': 0.3, 'opacity': 1}})
97 |
98 | viewer.zoomTo()
99 | #viewer.setStyle({'model': 0}, {'cartoon': {'color': 'spectrum'}}) # Updated style for cartoon representation
100 | #viewer.addSurface(py3Dmol.SAS, {'opacity': 0.9, 'radius': 0.5})
101 |
102 | if mol is not None:
103 | try:
104 | Chem.SanitizeMol(mol)
105 | except:
106 | print('Problem with the molecule')
107 | return
108 |
109 | mol = Chem.RemoveHs(mol)
110 | mol_mb = Chem.MolToMolBlock(mol)
111 | viewer.addModel(mol_mb, 'sdf')
112 | viewer.setStyle({'model': 1}, {'stick': {'colorscheme': 'cyanCarbon', 'radius': 0.15, 'opacity': 1},
113 | 'sphere': {'colorscheme': 'cyanCarbon', 'radius': 0.35, 'opacity': 1}})
114 |
115 | if sphere_positions1 is not None:
116 | for pos in sphere_positions1:
117 | sphere_spec = {'center': {'x': float(pos[0]), 'y': float(pos[1]), 'z': float(pos[2])}, 'radius': 1, 'color': 'green', 'opacity': 0.75}
118 | viewer.addSphere(sphere_spec)
119 |
120 | if sphere_positions2 is not None:
121 | for pos in sphere_positions2:
122 | sphere_spec = {'center': {'x': float(pos[0]), 'y': float(pos[1]), 'z': float(pos[2])}, 'radius': 0.3, 'color': 'yellow', 'opacity': 0.75}
123 | viewer.addSphere(sphere_spec)
124 |
125 |
126 | if spin:
127 | viewer.spin({'x': 0, 'y': 1, 'z': 0})
128 |
129 | if rotate:
130 | viewer.rotate(rotate,'y',1);
131 | return viewer
132 |
--------------------------------------------------------------------------------