├── .gitignore
├── CHANGES.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── CONTRIBUTORS.md
├── LICENSE.md
├── README.md
├── config
├── __init__.py
└── moad_partitions.py
├── configurations
├── README.md
├── final.json
├── layer_type_sweep
│ ├── lig_simple.json
│ ├── lig_simple_h.json
│ ├── lig_single.json
│ ├── lig_single_h.json
│ ├── rec_meta.json
│ ├── rec_meta_mix.json
│ ├── rec_simple_h.json
│ ├── rec_single.json
│ └── rec_single_h.json
└── voxelation_sweep
│ ├── t0.json
│ ├── t1.json
│ ├── t10.json
│ ├── t11.json
│ ├── t12.json
│ ├── t13.json
│ ├── t14.json
│ ├── t15.json
│ ├── t16.json
│ ├── t17.json
│ ├── t2.json
│ ├── t3.json
│ ├── t4.json
│ ├── t5.json
│ ├── t6.json
│ ├── t7.json
│ ├── t8.json
│ └── t9.json
├── data
└── README.md
├── deepfrag.py
├── leadopt
├── __init__.py
├── data_util.py
├── grid_util.py
├── infer.py
├── metrics.py
├── model_conf.py
├── models
│ ├── __init__.py
│ ├── backport.py
│ └── voxel.py
└── util.py
├── requirements.txt
├── scripts
├── README.md
├── README_MOAD.md
├── make_fingerprints.py
├── merge_moad.py
├── moad_training_splits.py
├── moad_util.py
├── process_moad.py
└── split_moad.py
├── test_installation.sh
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/*.pyc
2 | **/__pycache__
3 | .ipynb_checkpoints
4 |
5 | data/**
6 | !data/README.md
7 |
8 | .DS_Store
9 |
10 | .store/
11 | dist/**
12 | build/**
13 | .vscode
14 |
--------------------------------------------------------------------------------
/CHANGES.md:
--------------------------------------------------------------------------------
1 | Changes
2 | =======
3 |
4 | 1.0.4
5 | -----
6 |
7 | * Updated packages in `requirements.txt`
8 | * Fingerprint now cast as float instead of np.float
9 | * Minor updates to `README.md`
10 | * Fixed an error that prevented DeepFrag from loading PDB files without the .pdb
11 | extension. (Affects only recent versions of prody?)
12 | * Added `test_installation.sh` to make it easy to verify that DeepFrag is
13 | installed correctly. Downloads sample data (PDB ID 1XDN) and runs DeepFrag.
14 |
15 | 1.0.3
16 | -----
17 |
18 | * CLI parameters `--cx`, `--cy`, `--cz`, `--rx`, `--ry`, and `--rz` can now be
19 | floats (not just integers). We recommend specifying the exact atomic
20 | coordinates of the connection and removal points.
21 | * Fixed a bug that caused the `--full` parameter to throw an error when
22 | performing fragment addition (but not fragment replacement) using the CLI
23 | implementation.
24 | * Minor updates to the documentation.
25 |
26 | 1.0.2
27 | -----
28 |
29 | * Added a CLI implementation of the program. See `README.md` for details.
30 | * Added a version number and citation to the program output.
31 |
32 | 1.0.1
33 | -----
34 |
35 | * Removed open-babel dependency.
36 | * Added option (`cpu_gridify`) to improve use on CPUs when no GPU is
37 | available.
38 | * Updated `data/README.md` with new location of data files.
39 | * Fixed a config import.
40 |
41 | 1.0
42 | ---
43 |
44 | Original version.
45 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
6 |
7 | We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
8 |
9 | ## Our Standards
10 |
11 | Examples of behavior that contributes to a positive environment for our community include:
12 |
13 | * Demonstrating empathy and kindness toward other people
14 | * Being respectful of differing opinions, viewpoints, and experiences
15 | * Giving and gracefully accepting constructive feedback
16 | * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
17 | * Focusing on what is best not just for us as individuals, but for the overall community
18 |
19 | Examples of unacceptable behavior include:
20 |
21 | * The use of sexualized language or imagery, and sexual attention or
22 | advances of any kind
23 | * Trolling, insulting or derogatory comments, and personal or political attacks
24 | * Public or private harassment
25 | * Publishing others' private information, such as a physical or email
26 | address, without their explicit permission
27 | * Other conduct which could reasonably be considered inappropriate in a
28 | professional setting
29 |
30 | ## Enforcement Responsibilities
31 |
32 | Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.
33 |
34 | Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate.
35 |
36 | ## Scope
37 |
38 | This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
39 |
40 | ## Enforcement
41 |
42 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at [http://durrantlab.com/contact/](http://durrantlab.com/contact/). All complaints will be reviewed and investigated promptly and fairly.
43 |
44 | All community leaders are obligated to respect the privacy and security of the reporter of any incident.
45 |
46 | ## Enforcement Guidelines
47 |
48 | Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
49 |
50 | ### 1. Correction
51 |
52 | **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.
53 |
54 | **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.
55 |
56 | ### 2. Warning
57 |
58 | **Community Impact**: A violation through a single incident or series of actions.
59 |
60 | **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.
61 |
62 | ### 3. Temporary Ban
63 |
64 | **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior.
65 |
66 | **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
67 |
68 | ### 4. Permanent Ban
69 |
70 | **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
71 |
72 | **Consequence**: A permanent ban from any sort of public interaction within the project community.
73 |
74 | ## Attribution
75 |
76 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0,
77 | available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
78 |
79 | Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
80 |
81 | [homepage]: https://www.contributor-covenant.org
82 |
83 | For answers to common questions about this code of conduct, see the FAQ at
84 | https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations.
85 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Introduction
4 |
5 | Thank you for your interest in contributing! All types of contributions are
6 | encouraged and valued. Please make sure to read the relevant section before
7 | making your contribution. We at the [Durrant Lab](http://durrantlab.com) look
8 | forward to your help!
9 |
10 | ## Reporting a bug
11 |
12 | If you're unable to find an open issue addressing the bug, feel free to [open
13 | a new one](https://docs.gitlab.com/ee/user/project/issues/). Be sure to
14 | include a **title and clear description**, as much relevant information as
15 | possible (e.g., the program, platform, or operating-system version numbers),
16 | and a **code sample** or **test case** demonstrating the expected behavior
17 | that is not occurring.
18 |
19 | If you or the maintainers don't respond to an issue for 30 days, the issue may
20 | be closed. If you want to come back to it, reply (once, please), and we'll
21 | reopen the existing issue. Please avoid filing new issues as extensions of one
22 | you already made.
23 |
24 | ## Project setup to make source-code changes on your computer
25 |
26 | This project uses `git` to manage contributions, so start by [reading up on
27 | how to fork a `git`
28 | repository](https://docs.gitlab.com/ee/user/project/repository/forking_workflow.html#creating-a-fork)
29 | if you've never done it before.
30 |
31 | Forking will place a copy of the code on your own computer, where you can
32 | modify it to correct bugs or add features.
33 |
34 | ## Integrating your changes back into the main codebase
35 |
36 | Follow these steps to "push" your changes to the main online repository so
37 | others can benefit from them:
38 |
39 | * Create a [new merge
40 | request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html)
41 | with your changes.
42 |
43 | * Ensure the description clearly describes the problem and solution. Include
44 | the relevant issue number if applicable.
45 |
46 | * Before submitting, please read this CONTRIBUTING.md file to know more about
47 | coding conventions and benchmarks.
48 |
49 | ## Coding conventions
50 |
51 | Be sure to adequately document your code with comments so others can
52 | understand your changes. All classes and functions should have associated doc
53 | strings, formatted as appropriate given the programming language. Here are
54 | some examples:
55 |
56 | ```python
57 | """
58 | This file does important calculations. It is a Python file with nice doc strings.
59 | """
60 |
61 | class ImportantCalcs(object):
62 | """
63 | An important class where important things happen.
64 | """
65 |
66 | def __init__(self, vars=None, receptor_file=None,
67 | file_conversion_class_object=None, test_boot=True):
68 | """
69 | Required to initialize any conversion.
70 |
71 | Inputs:
72 | :param dict vars: Dictionary of user variables
73 | :param str receptor_file: the path for the receptor file
74 | :param obj file_conversion_class_object: object that is used to convert
75 | files from pdb to pdbqt
76 | :param bool test_boot: used to initialize class without objects for
77 | testing purpose
78 | """
79 |
80 | pass
81 | ```
82 |
83 | ```typescript
84 | /**
85 | * Sets the curStarePt variable externally. A useful, well-documented
86 | * TypeScript function.
87 | * @param {number[]} pt The x, y coordinates of the point as a list of
88 | * numbers.
89 | * @returns void
90 | */
91 | export function setCurStarePt(pt: any): void {
92 | curStarePt.copyFrom(pt);
93 | }
94 | ```
95 |
96 | If writing Python code, be sure to use the [Black
97 | formatter](https://black.readthedocs.io/en/stable/) before submitting a merge
98 | request. If writing code in JavaScript or TypeScript, please use the [Prettier
99 | formatter](https://marketplace.visualstudio.com/items?itemName=esbenp.prettier-vscode).
100 |
101 | ## Fixing whitespace, formatting code, or making a purely cosmetic patch
102 |
103 | Changes that are cosmetic in nature and do not add anything substantial to the
104 | stability, functionality, or testability of the program are unlikely to be
105 | accepted.
106 |
107 | ## Asking questions about the program
108 |
109 | Ask any question about how to use the program on the appropriate [Durrant Lab
110 | forum](http://durrantlab.com/forums/).
111 |
112 | ## Acknowledgements
113 |
114 | This document was inspired by:
115 |
116 | * [Ruby on Rails CONTRIBUTING.md
117 | file](https://raw.githubusercontent.com/rails/rails/master/CONTRIBUTING.md)
118 | (MIT License).
119 | * [weallcontribute](https://github.com/WeAllJS/weallcontribute/blob/latest/CONTRIBUTING.md)
120 | (Public Domain License).
121 |
--------------------------------------------------------------------------------
/CONTRIBUTORS.md:
--------------------------------------------------------------------------------
1 | Contributors
2 | ============
3 |
4 | * Harrison Green
5 | * Jacob Durrant
6 | * David Koes
7 | * Vandan Revanur
8 | * Martin Salinas
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Apache License
2 | ==============
3 |
4 | _Version 2.0, January 2004_
5 | _<>_
6 |
7 | ### Terms and Conditions for use, reproduction, and distribution
8 |
9 | #### 1. Definitions
10 |
11 | "License" shall mean the terms and conditions for use, reproduction, and
12 | distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by the
15 | copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all other
18 | entities that control, are controlled by, or are under common control with
19 | that entity. For the purposes of this definition, "control" means **(i)** the
20 | power, direct or indirect, to cause the direction or management of such
21 | entity, whether by contract or otherwise, or **(ii)** ownership of fifty
22 | percent (50%) or more of the outstanding shares, or **(iii)** beneficial
23 | ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity exercising
26 | permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation source, and
30 | configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical transformation or
33 | translation of a Source form, including but not limited to compiled object
34 | code, generated documentation, and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or Object form,
37 | made available under the License, as indicated by a copyright notice that is
38 | included in or attached to the work (an example is provided in the Appendix
39 | below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object form, that
42 | is based on (or derived from) the Work and for which the editorial revisions,
43 | annotations, elaborations, or other modifications represent, as a whole, an
44 | original work of authorship. For the purposes of this License, Derivative
45 | Works shall not include works that remain separable from, or merely link (or
46 | bind by name) to the interfaces of, the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including the original
49 | version of the Work and any modifications or additions to that Work or
50 | Derivative Works thereof, that is intentionally submitted to Licensor for
51 | inclusion in the Work by the copyright owner or by an individual or Legal
52 | Entity authorized to submit on behalf of the copyright owner. For the purposes
53 | of this definition, "submitted" means any form of electronic, verbal, or
54 | written communication sent to the Licensor or its representatives, including
55 | but not limited to communication on electronic mailing lists, source code
56 | control systems, and issue tracking systems that are managed by, or on behalf
57 | of, the Licensor for the purpose of discussing and improving the Work, but
58 | excluding communication that is conspicuously marked or otherwise designated
59 | in writing by the copyright owner as "Not a Contribution."
60 |
61 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf
62 | of whom a Contribution has been received by Licensor and subsequently
63 | incorporated within the Work.
64 |
65 | #### 2. Grant of Copyright License
66 |
67 | Subject to the terms and conditions of this License, each Contributor hereby
68 | grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
69 | irrevocable copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the Work and
71 | such Derivative Works in Source or Object form.
72 |
73 | #### 3. Grant of Patent License
74 |
75 | Subject to the terms and conditions of this License, each Contributor hereby
76 | grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
77 | irrevocable (except as stated in this section) patent license to make, have
78 | made, use, offer to sell, sell, import, and otherwise transfer the Work, where
79 | such license applies only to those patent claims licensable by such
80 | Contributor that are necessarily infringed by their Contribution(s) alone or
81 | by combination of their Contribution(s) with the Work to which such
82 | Contribution(s) was submitted. If You institute patent litigation against any
83 | entity (including a cross-claim or counterclaim in a lawsuit) alleging that
84 | the Work or a Contribution incorporated within the Work constitutes direct or
85 | contributory patent infringement, then any patent licenses granted to You
86 | under this License for that Work shall terminate as of the date such
87 | litigation is filed.
88 |
89 | #### 4. Redistribution
90 |
91 | You may reproduce and distribute copies of the Work or Derivative Works
92 | thereof in any medium, with or without modifications, and in Source or Object
93 | form, provided that You meet the following conditions:
94 |
95 | * **(a)** You must give any other recipients of the Work or Derivative Works a
96 | copy of this License; and
97 | * **(b)** You must cause any modified files to carry prominent notices stating
98 | that You changed the files; and
99 | * **(c)** You must retain, in the Source form of any Derivative Works that You
100 | distribute, all copyright, patent, trademark, and attribution notices from
101 | the Source form of the Work, excluding those notices that do not pertain to
102 | any part of the Derivative Works; and
103 | * **(d)** If the Work includes a "NOTICE" text file as part of its
104 | distribution, then any Derivative Works that You distribute must include a
105 | readable copy of the attribution notices contained within such NOTICE file,
106 | excluding those notices that do not pertain to any part of the Derivative
107 | Works, in at least one of the following places: within a NOTICE text file
108 | distributed as part of the Derivative Works; within the Source form or
109 | documentation, if provided along with the Derivative Works; or, within a
110 | display generated by the Derivative Works, if and wherever such third-party
111 | notices normally appear. The contents of the NOTICE file are for
112 | informational purposes only and do not modify the License. You may add Your
113 | own attribution notices within Derivative Works that You distribute,
114 | alongside or as an addendum to the NOTICE text from the Work, provided that
115 | such additional attribution notices cannot be construed as modifying the
116 | License.
117 |
118 | You may add Your own copyright statement to Your modifications and may provide
119 | additional or different license terms and conditions for use, reproduction, or
120 | distribution of Your modifications, or for any such Derivative Works as a
121 | whole, provided Your use, reproduction, and distribution of the Work otherwise
122 | complies with the conditions stated in this License.
123 |
124 | #### 5. Submission of Contributions
125 |
126 | Unless You explicitly state otherwise, any Contribution intentionally
127 | submitted for inclusion in the Work by You to the Licensor shall be under the
128 | terms and conditions of this License, without any additional terms or
129 | conditions. Notwithstanding the above, nothing herein shall supersede or
130 | modify the terms of any separate license agreement you may have executed with
131 | Licensor regarding such Contributions.
132 |
133 | #### 6. Trademarks
134 |
135 | This License does not grant permission to use the trade names, trademarks,
136 | service marks, or product names of the Licensor, except as required for
137 | reasonable and customary use in describing the origin of the Work and
138 | reproducing the content of the NOTICE file.
139 |
140 | #### 7. Disclaimer of Warranty
141 |
142 | Unless required by applicable law or agreed to in writing, Licensor provides
143 | the Work (and each Contributor provides its Contributions) on an "AS IS"
144 | BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
145 | implied, including, without limitation, any warranties or conditions of TITLE,
146 | NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You
147 | are solely responsible for determining the appropriateness of using or
148 | redistributing the Work and assume any risks associated with Your exercise of
149 | permissions under this License.
150 |
151 | #### 8. Limitation of Liability
152 |
153 | In no event and under no legal theory, whether in tort (including negligence),
154 | contract, or otherwise, unless required by applicable law (such as deliberate
155 | and grossly negligent acts) or agreed to in writing, shall any Contributor be
156 | liable to You for damages, including any direct, indirect, special,
157 | incidental, or consequential damages of any character arising as a result of
158 | this License or out of the use or inability to use the Work (including but not
159 | limited to damages for loss of goodwill, work stoppage, computer failure or
160 | malfunction, or any and all other commercial damages or losses), even if such
161 | Contributor has been advised of the possibility of such damages.
162 |
163 | #### 9. Accepting Warranty or Additional Liability
164 |
165 | While redistributing the Work or Derivative Works thereof, You may choose to
166 | offer, and charge a fee for, acceptance of support, warranty, indemnity, or
167 | other liability obligations and/or rights consistent with this License.
168 | However, in accepting such obligations, You may act only on Your own behalf
169 | and on Your sole responsibility, not on behalf of any other Contributor, and
170 | only if You agree to indemnify, defend, and hold each Contributor harmless for
171 | any liability incurred by, or claims asserted against, such Contributor by
172 | reason of your accepting any such warranty or additional liability.
173 |
174 | _END OF TERMS AND CONDITIONS_
175 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepFrag
2 |
3 | DeepFrag is a machine learning model for fragment-based lead optimization. In
4 | this repository, you will find code to train the model and code to run
5 | inference using a pre-trained model.
6 |
7 | ## Citation
8 |
9 | If you use DeepFrag in your research, please cite as:
10 |
11 | Green, H., Koes, D. R., & Durrant, J. D. (2021). DeepFrag: a deep
12 | convolutional neural network for fragment-based lead optimization. Chemical
13 | Science.
14 |
15 | ```tex
16 | @article{green2021deepfrag,
17 | title={DeepFrag: a deep convolutional neural network for fragment-based lead optimization},
18 | author={Green, Harrison and Koes, David Ryan and Durrant, Jacob D},
19 | journal={Chemical Science},
20 | year={2021},
21 | publisher={Royal Society of Chemistry}
22 | }
23 | ```
24 |
25 | ## Usage
26 |
27 | There are three ways to use DeepFrag:
28 |
29 | 1. **DeepFrag Browser App**: We have released a free, open-source browser app
30 | for DeepFrag that requires no setup and does not transmit any structures to
31 | a remote server.
32 | - View the online version at
33 | [durrantlab.pitt.edu/deepfrag](https://durrantlab.pitt.edu/deepfrag/)
34 | - See the code at
35 | [git.durrantlab.pitt.edu/jdurrant/deepfrag-app](https://git.durrantlab.pitt.edu/jdurrant/deepfrag-app)
36 | 2. **DeepFrag CLI**: In this repository we have included a `deepfrag.py`
37 | script that can perform common prediction tasks using the API.
38 | - See the `DeepFrag CLI` section below
39 | 3. **DeepFrag API**: For custom tasks or fine-grained control over
40 | predictions, you can invoke the DeepFrag API directly and interface with
41 | the raw data structures and the PyTorch model. We have created an example
42 | Google Colab (Jupyter notebook) that demonstrates how to perform manual
43 | predictions.
44 | - See the interactive
45 | [Colab](https://colab.research.google.com/drive/1If8rWQ9aVKJyJwfaOql56mA2llqC0iur).
46 |
47 | ## DeepFrag CLI
48 |
49 | The DeepFrag CLI is invoked by running `python3 deepfrag.py` in this
50 | repository. The CLI requires a pre-trained model and the fragment library to
51 | run. You will be prompted to download both when you first run the CLI and
52 | these will be saved in the `./.store` directory.
53 |
54 | ### Structure (specify exactly one)
55 |
56 | The input structures are specified using either a manual receptor and ligand
57 | pdb or by specifying a pdb id and the ligand residue number.
58 |
59 | - `--receptor --ligand `
60 | - `--pdb --resnum `
61 |
62 | ### Connection Point (specify exactly one)
63 |
64 | DeepFrag will predict new fragments that connect to the _connection point_ via
65 | a single bond. You must specify the connection point atom using one of the
66 | following:
67 |
68 | - `--cname `: Specify the connection point by atom name (e.g. `C3`,
69 | `N5`, `O2`, ...).
70 | - `--cx --cy --cz `: Specify the connection point by atomic
71 | coordinate. DeepFrag will find the closest atom to this point.
72 |
73 | ### Fragment Removal (optional) (specify exactly one)
74 |
75 | If you are using DeepFrag for fragment _replacement_, you must first remove
76 | the original fragment from the ligand structure. You can either do this by
77 | hand, e.g. editing the PDB, or DeepFrag can do this for you by specifying
78 | _which_ fragment should be removed.
79 |
80 | _Note: predicting fragments in place of hydrogen atoms (e.g. protons) does not
81 | require any fragment removal since hydrogen atoms are ignored by the model._
82 |
83 | To remove a fragment, you specify a second atom that is contained in the
84 | fragment. Like the connection point, you can either use the atom name or the
85 | atom coordinate.
86 |
87 | - `--rname `: Specify the connection point by atom name (e.g. `C3`,
88 | `N5`, `O2`, ...).
89 | - `--rx --ry --rz `: Specify the connection point by atomic
90 | coordinate. DeepFrag will find the closest atom to this point.
91 |
92 | ### Output (optional)
93 |
94 | By default, DeepFrag will print a list of fragment predictions to stdout
95 | similar to the [Browser App](https://durrantlab.pitt.edu/deepfrag/).
96 |
97 | - `--out `: Save predictions in CSV format to `out.csv`. Each line
98 | contains the fragment rank, score and SMILES string.
99 |
100 | ### Miscellaneous (optional)
101 |
102 | - `--full`: Generate SMILES strings with the full ligand structure instead of
103 | just the fragment. (__IMPORTANT NOTE__: Bond orders are not assigned to the
104 | parent portion of the full ligand structure. These must be added manually.)
105 | - `--cpu/--gpu`: DeepFrag will attempt to infer if a Cuda GPU is available and
106 | fallback to the CPU if it is not. You can set either the `--cpu` or `--gpu`
107 | flag to explicitly specify the target device.
108 | - `--num_grids `: Number of grid rotations to use. Using more will take
109 | longer but produce a more stable prediction. (Default: 4)
110 | - `--top_k `: Number of predictions to print in stdout. Use -1 to display
111 | all. (Default: 25)
112 |
113 | ## Reproduce Results
114 |
115 | You can use the DeepFrag CLI to reproduce the highlighted results from the
116 | main manuscript:
117 |
118 | ### 1. Fragment replacement
119 |
120 | To replace fragments, specify the connection point (`cname` or `cx/cy/cz`) and
121 | specify a second atom that is contained in the fragment (`rname` or
122 | `rx/ry/rz`).
123 |
124 | ```bash
125 | # Fig. 3: (2XP9) H. sapiens peptidyl-prolyl cis-trans isomerase NIMA-interacting 1 (HsPin1p)
126 |
127 | # Carboxylate A
128 | $ python3 deepfrag.py --pdb 2xp9 --resnum 1165 --cname C10 --rname C12
129 |
130 | # Phenyl B
131 | $ python3 deepfrag.py --pdb 2xp9 --resnum 1165 --cname C1 --rname C2
132 |
133 | # Phenyl C
134 | $ python3 deepfrag.py --pdb 2xp9 --resnum 1165 --cname C18 --rname C19
135 | ```
136 |
137 | ```bash
138 | # Fig. 4A: (6QZ8) Protein myeloid cell leukemia1 (Mcl-1)
139 |
140 | # Carboxylate group interacting with R263
141 | $ python3 deepfrag.py --pdb 6qz8 --resnum 401 --cname C12 --rname C14
142 |
143 | # Ethyl group
144 | $ python3 deepfrag.py --pdb 6qz8 --resnum 401 --cname C6 --rname C10
145 |
146 | # Methyl group
147 | $ python3 deepfrag.py --pdb 6qz8 --resnum 401 --cname C25 --rname C30
148 |
149 | # Chlorine atom
150 | $ python3 deepfrag.py --pdb 6qz8 --resnum 401 --cname C28 --rname CL
151 | ```
152 |
153 | ```bash
154 | # Fig. 4B: (1X38) Family GH3 b-D-glucan glucohydrolase (barley)
155 |
156 | # Hydroxyl group interacting with R158 and D285
157 | $ python3 deepfrag.py --pdb 1x38 --resnum 1001 --cname C2B --rname O2B
158 |
159 | # Phenyl group interacting with W286 and W434
160 | $ python3 deepfrag.py --pdb 1x38 --resnum 1001 --cname C7B --rname C1
161 | ```
162 |
163 | ```bash
164 | # Fig. 4C: (4FOW) NanB sialidase (Streptococcus pneumoniae)
165 |
166 | # Amino group
167 | $ python3 deepfrag.py --pdb 4fow --resnum 701 --cname CAE --rname NAA
168 | ```
169 |
170 | ### 2. Fragment addition
171 |
172 | For fragment addition, you only need to specify the atom connection point
173 | (`cname` or `cx/cy/cz`). In this case, DeepFrag will implicitly replace a
174 | valent hydrogen.
175 |
176 | ```bash
177 | # Fig. 5: Ligands targeting the SARS-CoV-2 main protease (MPro)
178 |
179 | # 5A: (5RGH) Extension on Z1619978933
180 | $ python3 deepfrag.py --pdb 5rgh --resnum 404 --cname C09
181 |
182 | # 5B: (5R81) Extension on Z1367324110
183 | $ python3 deepfrag.py --pdb 5r81 --resnum 1001 --cname C07
184 | ```
185 |
186 | ## Overview
187 |
188 | - `config`: fixed configuration information (e.g., TRAIN/VAL/TEST partitions)
189 | - `configurations`: benchmark model configurations (see
190 | [`configurations/README.md`](configurations/README.md))
191 | - `data`: training/inference data (see [`data/README.md`](data/README.md))
192 | - `leadopt`: main module code
193 | - `models`: pytorch architecture definitions
194 | - `data_util.py`: utility code for reading packed fragment/fingerprint data
195 | files
196 | - `grid_util.py`: GPU-accelerated grid generation code
197 | - `metrics.py`: pytorch implementations of several metrics
198 | - `model_conf.py`: contains code to configure and train models
199 | - `util.py`: utility code for rdkit/openbabel processing
200 | - `scripts`: data processing scripts (see
201 | [`scripts/README.md`](scripts/README.md))
202 | - `train.py`: CLI interface to launch training runs
203 |
204 | ## Dependencies
205 |
206 | You can build a virtualenv with the requirements:
207 |
208 | ```sh
209 | $ python3 -m venv leadopt_env
210 | $ source ./leadopt_env/bin/activate
211 | $ pip install -r requirements.txt
212 | $ pip install prody
213 | $ pip install torch==2.1.2+cu118 torchvision==0.16.2+cu118 --index-url https://download.pytorch.org/whl/cu118
214 | $ sudo apt install nvidia-cuda-toolkit
215 | ```
216 |
217 | Regarding the nvidia-cuda-toolkit, you may wish to ensure that the toolkit
218 | version matches cuda installed on your machine. You can check the version of
219 | cuda by running the following commands:
220 |
221 | ```sh
222 | $ nvcc --version
223 | $ nvidia-smi
224 | ```
225 |
226 | Note: We used `Cuda 10.1` for training.
227 |
228 | ## Training
229 |
230 | To train a model, you can use the `train.py` utility script. You can specify
231 | model parameters as command line arguments or load parameters from a
232 | configuration args.json file.
233 |
234 | ```bash
235 | python train.py \
236 | --save_path=/path/to/model \
237 | --wandb_project=my_project \
238 | {model_type} \
239 | --model_arg1=x \
240 | --model_arg2=y \
241 | ...
242 | ```
243 |
244 | or
245 |
246 | ```bash
247 | python train.py \
248 | --save_path=/path/to/model \
249 | --wandb_project=my_project \
250 | --configuration=./configurations/args.json
251 | ```
252 |
253 | `save_path` is a directory to save the best model. The directory will be
254 | created if it doesn't exist. If this is not provided, the model will not be
255 | saved.
256 |
257 | `wandb_project` is an optional wandb project name. If provided, the run will
258 | be logged to wandb.
259 |
260 | See below for available models and model-specific parameters:
261 |
262 | ## Leadopt Models
263 |
264 | In this repository, trainable models are subclasses of
265 | `model_conf.LeadoptModel`. This class encapsulates model configuration
266 | arguments and pytorch models and enables saving and loading multi-component
267 | models.
268 |
269 | ```py
270 | from leadopt.model_conf import LeadoptModel, MODELS
271 |
272 | model = MODELS['voxel']({args...})
273 | model.train(save_path='./mymodel')
274 |
275 | ...
276 |
277 | model2 = LeadoptModel.load('./mymodel')
278 | ```
279 |
280 | Internally, model arguments are configured by setting up an `argparse` parser
281 | and passing around a `dict` of configuration parameters in `self._args`.
282 |
283 | ### VoxelNet
284 |
285 | ```text
286 | --no_partitions If set, disable the use of TRAIN/VAL partitions during
287 | training.
288 | -f FRAGMENTS, --fragments FRAGMENTS
289 | Path to fragments file.
290 | -fp FINGERPRINTS, --fingerprints FINGERPRINTS
291 | Path to fingerprints file.
292 | -lr LEARNING_RATE, --learning_rate LEARNING_RATE
293 | --num_epochs NUM_EPOCHS
294 | Number of epochs to train for.
295 | --test_steps TEST_STEPS
296 | Number of evaluation steps per epoch.
297 | -b BATCH_SIZE, --batch_size BATCH_SIZE
298 | --grid_width GRID_WIDTH
299 | --grid_res GRID_RES
300 | --fdist_min FDIST_MIN
301 | Ignore fragments closer to the receptor than this
302 | distance (Angstroms).
303 | --fdist_max FDIST_MAX
304 | Ignore fragments further from the receptor than this
305 | distance (Angstroms).
306 | --fmass_min FMASS_MIN
307 | Ignore fragments smaller than this mass (Daltons).
308 | --fmass_max FMASS_MAX
309 | Ignore fragments larger than this mass (Daltons).
310 | --ignore_receptor
311 | --ignore_parent
312 | -rec_typer {single,single_h,simple,simple_h,desc,desc_h}
313 | -lig_typer {single,single_h,simple,simple_h,desc,desc_h}
314 | -rec_channels REC_CHANNELS
315 | -lig_channels LIG_CHANNELS
316 | --in_channels IN_CHANNELS
317 | --output_size OUTPUT_SIZE
318 | --pad
319 | --blocks BLOCKS [BLOCKS ...]
320 | --fc FC [FC ...]
321 | --use_all_labels
322 | --dist_fn {mse,bce,cos,tanimoto}
323 | --loss {direct,support_v1}
324 | ```
325 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
--------------------------------------------------------------------------------
/configurations/README.md:
--------------------------------------------------------------------------------
1 | This folder contains benchmark model configurations referenced in the paper.
2 |
3 | Overview:
4 | - `layer_type_sweep/*`: experimenting with different parent/receptor typing schemes
5 | - `voxelation_sweep/*`: experimenting with different voxelation types and atomic influence radii
6 | - `final.json`: final production model
7 |
8 | You can train new models using these configurations with the `train.py` script:
9 |
10 | ```sh
11 | python train.py \
12 | --save_path=/path/to/model \
13 | --wandb_project=my_project \
14 | --configuration=./configurations/model.json
15 | ```
16 |
17 | Note: these configuration files assume the working directory is the `leadopt` base directory and that the data directory is accessible at `./data`.
18 |
--------------------------------------------------------------------------------
/configurations/final.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "voxelnet",
3 | "no_partitions": false,
4 | "fragments": "./data/moad.h5",
5 | "fingerprints": "./data/rdk10_moad.h5",
6 | "learning_rate": 0.001,
7 | "num_epochs": 50,
8 | "test_steps": 400,
9 | "batch_size": 16,
10 | "grid_width": 24,
11 | "grid_res": 0.75,
12 | "fdist_min": null,
13 | "fdist_max": 4,
14 | "fmass_min": null,
15 | "fmass_max": 150,
16 | "ignore_receptor": false,
17 | "ignore_parent": false,
18 | "output_size": 2048,
19 | "pad": false,
20 | "blocks": [
21 | 64,
22 | 64
23 | ],
24 | "fc": [
25 | 512
26 | ],
27 | "use_all_labels": true,
28 | "dist_fn": "cos",
29 | "loss": "direct",
30 | "point_radius": 1,
31 | "point_type": 0,
32 | "rec_typer": "simple",
33 | "acc_type": 0,
34 | "lig_typer": "simple"
35 | }
36 |
--------------------------------------------------------------------------------
/configurations/layer_type_sweep/lig_simple.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "voxelnet",
3 | "no_partitions": false,
4 | "fragments": "./data/moad.h5",
5 | "fingerprints": "./data/rdk10_moad.h5",
6 | "learning_rate": 0.001,
7 | "num_epochs": 15,
8 | "test_steps": 400,
9 | "batch_size": 16,
10 | "grid_width": 24,
11 | "grid_res": 0.75,
12 | "fdist_min": null,
13 | "fdist_max": 4,
14 | "fmass_min": null,
15 | "fmass_max": 150,
16 | "ignore_receptor": false,
17 | "ignore_parent": false,
18 | "output_size": 2048,
19 | "pad": false,
20 | "blocks": [64, 64],
21 | "fc": [512],
22 | "use_all_labels": true,
23 | "dist_fn": "cos",
24 | "loss": "direct",
25 | "point_radius": 1,
26 | "point_type": 3,
27 | "rec_typer": "simple",
28 | "acc_type": 0,
29 | "lig_typer": "simple"
30 | }
31 |
--------------------------------------------------------------------------------
/configurations/layer_type_sweep/lig_simple_h.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "voxelnet",
3 | "no_partitions": false,
4 | "fragments": "./data/moad.h5",
5 | "fingerprints": "./data/rdk10_moad.h5",
6 | "learning_rate": 0.001,
7 | "num_epochs": 15,
8 | "test_steps": 400,
9 | "batch_size": 16,
10 | "grid_width": 24,
11 | "grid_res": 0.75,
12 | "fdist_min": null,
13 | "fdist_max": 4,
14 | "fmass_min": null,
15 | "fmass_max": 150,
16 | "ignore_receptor": false,
17 | "ignore_parent": false,
18 | "output_size": 2048,
19 | "pad": false,
20 | "blocks": [64, 64],
21 | "fc": [512],
22 | "use_all_labels": true,
23 | "dist_fn": "cos",
24 | "loss": "direct",
25 | "point_radius": 1,
26 | "point_type": 3,
27 | "rec_typer": "simple",
28 | "acc_type": 0,
29 | "lig_typer": "simple_h"
30 | }
31 |
--------------------------------------------------------------------------------
/configurations/layer_type_sweep/lig_single.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "voxelnet",
3 | "no_partitions": false,
4 | "fragments": "./data/moad.h5",
5 | "fingerprints": "./data/rdk10_moad.h5",
6 | "learning_rate": 0.001,
7 | "num_epochs": 15,
8 | "test_steps": 400,
9 | "batch_size": 16,
10 | "grid_width": 24,
11 | "grid_res": 0.75,
12 | "fdist_min": null,
13 | "fdist_max": 4,
14 | "fmass_min": null,
15 | "fmass_max": 150,
16 | "ignore_receptor": false,
17 | "ignore_parent": false,
18 | "output_size": 2048,
19 | "pad": false,
20 | "blocks": [64, 64],
21 | "fc": [512],
22 | "use_all_labels": true,
23 | "dist_fn": "cos",
24 | "loss": "direct",
25 | "point_radius": 1,
26 | "point_type": 3,
27 | "rec_typer": "simple",
28 | "acc_type": 0,
29 | "lig_typer": "single"
30 | }
31 |
--------------------------------------------------------------------------------
/configurations/layer_type_sweep/lig_single_h.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "voxelnet",
3 | "no_partitions": false,
4 | "fragments": "./data/moad.h5",
5 | "fingerprints": "./data/rdk10_moad.h5",
6 | "learning_rate": 0.001,
7 | "num_epochs": 15,
8 | "test_steps": 400,
9 | "batch_size": 16,
10 | "grid_width": 24,
11 | "grid_res": 0.75,
12 | "fdist_min": null,
13 | "fdist_max": 4,
14 | "fmass_min": null,
15 | "fmass_max": 150,
16 | "ignore_receptor": false,
17 | "ignore_parent": false,
18 | "output_size": 2048,
19 | "pad": false,
20 | "blocks": [64, 64],
21 | "fc": [512],
22 | "use_all_labels": true,
23 | "dist_fn": "cos",
24 | "loss": "direct",
25 | "point_radius": 1,
26 | "point_type": 3,
27 | "rec_typer": "simple",
28 | "acc_type": 0,
29 | "lig_typer": "single_h"
30 | }
31 |
--------------------------------------------------------------------------------
/configurations/layer_type_sweep/rec_meta.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "voxelnet",
3 | "no_partitions": false,
4 | "fragments": "./data/moad.h5",
5 | "fingerprints": "./data/rdk10_moad.h5",
6 | "learning_rate": 0.001,
7 | "num_epochs": 15,
8 | "test_steps": 400,
9 | "batch_size": 16,
10 | "grid_width": 24,
11 | "grid_res": 0.75,
12 | "fdist_min": null,
13 | "fdist_max": 4,
14 | "fmass_min": null,
15 | "fmass_max": 150,
16 | "ignore_receptor": false,
17 | "ignore_parent": false,
18 | "output_size": 2048,
19 | "pad": false,
20 | "blocks": [64, 64],
21 | "fc": [512],
22 | "use_all_labels": true,
23 | "dist_fn": "cos",
24 | "loss": "direct",
25 | "point_radius": 1,
26 | "point_type": 3,
27 | "rec_typer": "meta",
28 | "acc_type": 0,
29 | "lig_typer": "simple"
30 | }
31 |
--------------------------------------------------------------------------------
/configurations/layer_type_sweep/rec_meta_mix.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "voxelnet",
3 | "no_partitions": false,
4 | "fragments": "./data/moad.h5",
5 | "fingerprints": "./data/rdk10_moad.h5",
6 | "learning_rate": 0.001,
7 | "num_epochs": 15,
8 | "test_steps": 400,
9 | "batch_size": 16,
10 | "grid_width": 24,
11 | "grid_res": 0.75,
12 | "fdist_min": null,
13 | "fdist_max": 4,
14 | "fmass_min": null,
15 | "fmass_max": 150,
16 | "ignore_receptor": false,
17 | "ignore_parent": false,
18 | "output_size": 2048,
19 | "pad": false,
20 | "blocks": [64, 64],
21 | "fc": [512],
22 | "use_all_labels": true,
23 | "dist_fn": "cos",
24 | "loss": "direct",
25 | "point_radius": 1,
26 | "point_type": 3,
27 | "rec_typer": "meta_mix",
28 | "acc_type": 0,
29 | "lig_typer": "simple"
30 | }
31 |
--------------------------------------------------------------------------------
/configurations/layer_type_sweep/rec_simple_h.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "voxelnet",
3 | "no_partitions": false,
4 | "fragments": "./data/moad.h5",
5 | "fingerprints": "./data/rdk10_moad.h5",
6 | "learning_rate": 0.001,
7 | "num_epochs": 15,
8 | "test_steps": 400,
9 | "batch_size": 16,
10 | "grid_width": 24,
11 | "grid_res": 0.75,
12 | "fdist_min": null,
13 | "fdist_max": 4,
14 | "fmass_min": null,
15 | "fmass_max": 150,
16 | "ignore_receptor": false,
17 | "ignore_parent": false,
18 | "output_size": 2048,
19 | "pad": false,
20 | "blocks": [64, 64],
21 | "fc": [512],
22 | "use_all_labels": true,
23 | "dist_fn": "cos",
24 | "loss": "direct",
25 | "point_radius": 1,
26 | "point_type": 3,
27 | "rec_typer": "simple_h",
28 | "acc_type": 0,
29 | "lig_typer": "simple"
30 | }
31 |
--------------------------------------------------------------------------------
/configurations/layer_type_sweep/rec_single.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "voxelnet",
3 | "no_partitions": false,
4 | "fragments": "./data/moad.h5",
5 | "fingerprints": "./data/rdk10_moad.h5",
6 | "learning_rate": 0.001,
7 | "num_epochs": 15,
8 | "test_steps": 400,
9 | "batch_size": 16,
10 | "grid_width": 24,
11 | "grid_res": 0.75,
12 | "fdist_min": null,
13 | "fdist_max": 4,
14 | "fmass_min": null,
15 | "fmass_max": 150,
16 | "ignore_receptor": false,
17 | "ignore_parent": false,
18 | "output_size": 2048,
19 | "pad": false,
20 | "blocks": [64, 64],
21 | "fc": [512],
22 | "use_all_labels": true,
23 | "dist_fn": "cos",
24 | "loss": "direct",
25 | "point_radius": 1,
26 | "point_type": 3,
27 | "rec_typer": "single",
28 | "acc_type": 0,
29 | "lig_typer": "simple"
30 | }
31 |
--------------------------------------------------------------------------------
/configurations/layer_type_sweep/rec_single_h.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "voxelnet",
3 | "no_partitions": false,
4 | "fragments": "./data/moad.h5",
5 | "fingerprints": "./data/rdk10_moad.h5",
6 | "learning_rate": 0.001,
7 | "num_epochs": 15,
8 | "test_steps": 400,
9 | "batch_size": 16,
10 | "grid_width": 24,
11 | "grid_res": 0.75,
12 | "fdist_min": null,
13 | "fdist_max": 4,
14 | "fmass_min": null,
15 | "fmass_max": 150,
16 | "ignore_receptor": false,
17 | "ignore_parent": false,
18 | "output_size": 2048,
19 | "pad": false,
20 | "blocks": [64, 64],
21 | "fc": [512],
22 | "use_all_labels": true,
23 | "dist_fn": "cos",
24 | "loss": "direct",
25 | "point_radius": 1,
26 | "point_type": 3,
27 | "rec_typer": "single_h",
28 | "acc_type": 0,
29 | "lig_typer": "simple"
30 | }
31 |
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t0.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1, "point_type": 0, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t1.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1, "point_type": 1, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t10.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1.75, "point_type": 4, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t11.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1.75, "point_type": 5, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t12.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 2.5, "point_type": 0, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t13.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 2.5, "point_type": 1, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t14.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 2.5, "point_type": 2, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t15.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 2.5, "point_type": 3, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t16.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 2.5, "point_type": 4, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t17.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 2.5, "point_type": 5, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t2.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1, "point_type": 2, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t3.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1, "point_type": 3, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t4.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1, "point_type": 4, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t5.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1, "point_type": 5, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t6.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1.75, "point_type": 0, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t7.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1.75, "point_type": 1, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t8.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1.75, "point_type": 2, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/configurations/voxelation_sweep/t9.json:
--------------------------------------------------------------------------------
1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1.75, "point_type": 3, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"}
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data for Training and Inference
2 |
3 | This folder contains data used during training and inference.
4 |
5 | Model configuration files in `/configurations` expect the data files to be in
6 | this directory. You can either copy them directly here or use symlinks.
7 |
8 | You can download the data here: http://durrantlab.com/apps/deepfrag/files/
9 |
10 |
11 |
12 | Overview:
13 |
14 | - `moad.h5` (7 GB): processed MOAD data loaded by `data_util.FragmentDataset`
15 | - `rdk10_moad` (384 MB): RDK-10 fingerprints for MOAD data loaded by
16 | `data_util.FingerprintDataset` (generated with
17 | `scripts/make_fingerprints.py`)
18 |
--------------------------------------------------------------------------------
/deepfrag.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import functools
3 | import os
4 | import pathlib
5 | import shutil
6 | import time
7 | from typing import Tuple
8 | import zipfile
9 |
10 | import requests
11 | from tqdm.auto import tqdm
12 | import h5py
13 | import numpy as np
14 | import rdkit.Chem.AllChem as Chem
15 | import torch
16 | import prody
17 |
18 | from leadopt.model_conf import LeadoptModel, REC_TYPER, LIG_TYPER, DIST_FN
19 | from leadopt import util, grid_util
20 |
21 |
22 | USER_DIR = './.store'
23 | PDB_CACHE = 'pdb_cache'
24 |
25 | MODEL_DOWNLOAD = 'https://durrantlab.pitt.edu/apps/deepfrag/files/final_model_v2.zip'
26 | FINGERPRINTS_DOWNLOAD = 'https://durrantlab.pitt.edu/apps/deepfrag/files/fingerprints.h5'
27 |
28 | RCSB_DOWNLOAD = 'https://files.rcsb.org/download/%s.pdb1'
29 |
30 | VERSION = "1.0.4"
31 |
32 | def download_remote(url, path, compression=None):
33 | r = requests.get(url, stream=True, allow_redirects=True)
34 | if r.status_code != 200:
35 | r.raise_for_status()
36 | print(f'Can\'t access {url}')
37 |
38 | file_size = int(r.headers.get('Content-Length', 0))
39 |
40 | r.raw.read = functools.partial(r.raw.read, decode_content=True)
41 | with tqdm.wrapattr(r.raw, 'read', total=file_size, desc='Downloading') as r_raw:
42 | with path.open('wb') as f:
43 | shutil.copyfileobj(r_raw, f)
44 |
45 | if compression is not None:
46 | shutil.move(str(path), str(path) + '.tmp')
47 | shutil.unpack_archive(str(path) + '.tmp', str(path), format=compression)
48 |
49 |
50 | def get_deepfrag_user_dir() -> pathlib.Path:
51 | user_dir = pathlib.Path(os.path.realpath(__file__)).parent / USER_DIR
52 | os.makedirs(str(user_dir), exist_ok=True)
53 | return user_dir
54 |
55 |
56 | def get_model_path():
57 | return get_deepfrag_user_dir() / 'model'
58 |
59 |
60 | def get_fingerprints_path():
61 | return get_deepfrag_user_dir() / 'fingerprints.h5'
62 |
63 |
64 | def ensure_cli_data():
65 | model_path = get_model_path()
66 | fingerprints_path = get_fingerprints_path()
67 |
68 | if not os.path.exists(str(model_path)):
69 | r = input('Pre-trained DeepFrag model not found, download it now? (5.8 MB) [Y/n]: ')
70 | if r.lower() == 'n':
71 | print('Exiting...')
72 | exit(-1)
73 |
74 | print(f'Saving to {model_path}...')
75 | download_remote(MODEL_DOWNLOAD, model_path, compression='zip')
76 |
77 | if not os.path.exists(str(fingerprints_path)):
78 | r = input('Fingerprint library not found, download it now? (11 MB) [Y/n]: ')
79 | if r.lower() == 'n':
80 | print('Exiting...')
81 | exit(-1)
82 |
83 | print(f'Saving to {fingerprints_path}...')
84 | download_remote(FINGERPRINTS_DOWNLOAD, fingerprints_path, compression=None)
85 |
86 |
87 | def download_pdb(pdb_id, path):
88 | download_remote(RCSB_DOWNLOAD % pdb_id, path, compression=None)
89 |
90 |
91 | def load_pdb(pdb_id, resnum):
92 | pdb_id = pdb_id.upper()
93 | assert all([x in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' for x in pdb_id])
94 |
95 | # Check pdb cache
96 | pdb_dir = get_deepfrag_user_dir() / PDB_CACHE / pdb_id
97 |
98 | complex_path = pdb_dir / 'complex.pdb'
99 | rec_path = pdb_dir / 'receptor.pdb'
100 | lig_path = pdb_dir / 'ligand.pdb'
101 |
102 | os.makedirs(str(pdb_dir), exist_ok=True)
103 |
104 | if not os.path.exists(complex_path):
105 | download_pdb(pdb_id, complex_path)
106 |
107 | with open(str(complex_path), 'r') as f:
108 | m = prody.parsePDBStream(f)
109 | rec = m.select('not (nucleic or hetatm) and not water')
110 | lig = m.select('resnum %d' % resnum)
111 |
112 | if lig is None:
113 | print('[!] Error could not find ligand with resnum: %d' % resnum)
114 | exit(-1)
115 |
116 | prody.writePDB(str(rec_path), rec)
117 | prody.writePDB(str(lig_path), lig)
118 |
119 | return (str(rec_path), str(lig_path))
120 |
121 |
122 | def get_structure_paths(args) -> Tuple[str, str]:
123 | """Get structure paths specified by the command line args.
124 | Returns (rec_path, lig_path)
125 | """
126 | if args.receptor is not None and args.ligand is not None:
127 | return (args.receptor, args.ligand)
128 | elif args.pdb is not None and args.resnum is not None:
129 | return load_pdb(args.pdb, args.resnum)
130 | else:
131 | raise NotImplementedError()
132 |
133 |
134 | def preprocess_ligand_without_removal_point(lig, conn):
135 | """
136 | Mark the atom at conn as a connecting atom. Useful when adding a fragment.
137 | """
138 |
139 | lig_pos = lig.GetConformer().GetPositions()
140 | lig_atm_conn_dist = np.sum((lig_pos - conn) ** 2, axis=1)
141 |
142 | # Get index of min
143 | min_idx = int(np.argmin(lig_atm_conn_dist))
144 |
145 | # Get atom at that position
146 | lig_atm_conn = lig.GetAtomWithIdx(min_idx)
147 |
148 | # Add a dummy atom to the ligand, connected to lig_atm_conn
149 | dummy_atom = Chem.MolFromSmiles("*")
150 | merged = Chem.RWMol(Chem.CombineMols(lig, dummy_atom))
151 |
152 | idx_of_dummy_in_merged = int([a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0][0])
153 | bond = merged.AddBond(min_idx, idx_of_dummy_in_merged, Chem.rdchem.BondType.SINGLE)
154 |
155 | return merged
156 |
157 |
158 | def preprocess_ligand_with_removal_point(lig, conn, rvec):
159 | """
160 | Remove the fragment from lig connected via the atom at conn and containing
161 | the atom at rvec. Useful when replacing a fragment.
162 | """
163 | # Generate all fragments.
164 | frags = util.generate_fragments(lig)
165 |
166 | for parent, frag in frags:
167 | # Get the index of the dummy (connection) atom on the fragment.
168 | cidx = [a for a in frag.GetAtoms() if a.GetAtomicNum() == 0][0].GetIdx()
169 |
170 | # Get the coordinates of the associated atom (the dummy atom's
171 | # neighbor).
172 | vec = frag.GetConformer().GetAtomPosition(cidx)
173 | c_vec = np.array([vec.x, vec.y, vec.z])
174 |
175 | # Check connection point.
176 | if np.linalg.norm(c_vec - conn) < 1e-3:
177 | # Check removal point.
178 | frag_pos = frag.GetConformer().GetPositions()
179 | min_dist = np.min(np.sum((frag_pos - rvec) ** 2, axis=1))
180 |
181 | if min_dist < 1e-3:
182 | # You have found the parent/fragment split that correctly
183 | # exposes the user-specified connection-point atom.
184 |
185 | # Found fragment.
186 | print('[*] Removing fragment with %d atoms (%s)' % (
187 | frag_pos.shape[0] - 1, Chem.MolToSmiles(frag, False)))
188 |
189 | return parent
190 |
191 | print('[!] Could not find a suitable fragment to remove.')
192 | exit(-1)
193 |
194 |
195 | def lookup_atom_name(lig_path, name):
196 | """Try to look up an atom by name. Returns the coordinate of the atom if
197 | found."""
198 | with open(lig_path, 'r') as f:
199 | p = prody.parsePDBStream(f)
200 | p = p.select(f'name {name}')
201 | if p is None:
202 | print(f'[!] Error: no atom with name "{name}" in ligand')
203 | exit(-1)
204 | elif len(p) > 1:
205 | print(f'[!] Error: multiple atoms with name "{name}" in ligand')
206 | exit(-1)
207 | return p.getCoords()[0]
208 |
209 |
210 | def get_structures(args):
211 | rec_path, lig_path = get_structure_paths(args)
212 |
213 | print(f'[*] Loading receptor: {rec_path} ... ', end='')
214 | rec_coords, rec_types = util.load_receptor_ob(rec_path)
215 | print('done.')
216 |
217 | print(f'[*] Loading ligand: {lig_path} ... ', end='')
218 | lig = Chem.MolFromPDBFile(lig_path)
219 | print('done.')
220 |
221 | conn = None
222 | if args.cx is not None and args.cy is not None and args.cz is not None:
223 | conn = np.array([float(args.cx), float(args.cy), float(args.cz)])
224 | elif args.cname is not None:
225 | conn = lookup_atom_name(lig_path, args.cname)
226 | else:
227 | raise NotImplementedError()
228 |
229 | rvec = None
230 | if args.rx is not None and args.ry is not None and args.rz is not None:
231 | rvec = np.array([float(args.rx), float(args.ry), float(args.rz)])
232 | elif args.rname is not None:
233 | rvec = lookup_atom_name(lig_path, args.rname)
234 | else:
235 | pass
236 |
237 | if rvec is not None:
238 | # Fragment repalcement (rvec specified)
239 | lig = preprocess_ligand_with_removal_point(lig, conn, rvec)
240 | else:
241 | # Only fragment addition
242 | lig = preprocess_ligand_without_removal_point(lig, conn)
243 |
244 | parent_coords = util.get_coords(lig)
245 | parent_types = np.array(util.get_types(lig)).reshape((-1,1))
246 |
247 | return (rec_coords, rec_types, parent_coords, parent_types, conn, lig)
248 |
249 |
250 | def get_model(args, device):
251 | """Load a pre-trained DeepFrag model."""
252 | print('[*] Loading model ... ', end='')
253 | model = LeadoptModel.load(str(get_model_path() / 'final_model'), device=('cuda' if device == 'gpu' else device))
254 | print('done.')
255 | return model
256 |
257 |
258 | def get_fingerprints(args):
259 | """Load the fingerprint library.
260 | Returns (smiles, fingerprints).
261 | """
262 | f_smiles = None
263 | f_fingerprints = None
264 | print('[*] Loading fingerprint library ... ', end='')
265 | with h5py.File(str(get_fingerprints_path()), 'r') as f:
266 | f_smiles = f['smiles'][()]
267 | f_fingerprints = f['fingerprints'][()].astype(float)
268 | print('done.')
269 |
270 | return (f_smiles, f_fingerprints)
271 |
272 |
273 | def get_target_device(args) -> str:
274 | """Infer the target device or use the argument overrides."""
275 | device = 'gpu' if torch.cuda.device_count() > 0 else 'cpu'
276 |
277 | if args.cpu:
278 | if device == 'gpu':
279 | print('[*] Warning: GPU is available but running on CPU due to --cpu flag')
280 | device = 'cpu'
281 | elif args.gpu:
282 | if device == 'cpu':
283 | print('[*] Error: No CUDA-enabled GPU was found. Exiting due to --gpu flag. You can run on the CPU instead with the --cpu flag.')
284 | exit(-1)
285 | device = 'gpu'
286 |
287 | print('[*] Running on device: %s' % device)
288 |
289 | return device
290 |
291 |
292 | def generate_grids(args, model_args, rec_coords, rec_types, parent_coords, parent_types, conn, device):
293 | start = time.time()
294 |
295 | print('[*] Generating grids ... ', end='', flush=True)
296 | batch = grid_util.get_raw_batch(
297 | rec_coords, rec_types, parent_coords, parent_types,
298 | rec_typer=REC_TYPER[model_args['rec_typer']],
299 | lig_typer=LIG_TYPER[model_args['lig_typer']],
300 | conn=conn,
301 | num_samples=args.num_grids,
302 | width=model_args['grid_width'],
303 | res=model_args['grid_res'],
304 | point_radius=model_args['point_radius'],
305 | point_type=model_args['point_type'],
306 | acc_type=model_args['acc_type'],
307 | cpu=(device == 'cpu')
308 | )
309 | print('done.')
310 | end = time.time()
311 | print(f'[*] Generated grids in {end-start:.3f} seconds.')
312 |
313 | return batch
314 |
315 |
316 | def get_predictions(model, batch, f_smiles, f_fingerprints):
317 | start = time.time()
318 | pred = model.predict(torch.tensor(batch).float()).cpu().numpy()
319 | end = time.time()
320 | print(f'[*] Generated prediction in {end-start} seconds.')
321 |
322 | avg_fp = np.mean(pred, axis=0)
323 | dist_fn = DIST_FN[model._args['dist_fn']]
324 |
325 | # The distance functions are implemented in pytorch so we need to convert our
326 | # numpy arrays to a torch Tensor.
327 | dist = 1 - dist_fn(
328 | torch.tensor(avg_fp).unsqueeze(0),
329 | torch.tensor(f_fingerprints))
330 |
331 | # Pair smiles strings and distances.
332 | dist = list(dist.numpy())
333 | scores = list(zip(f_smiles, dist))
334 | scores = sorted(scores, key=lambda x:x[1], reverse=True)
335 | scores = [(a.decode('ascii'), b) for a,b in scores]
336 |
337 | return scores
338 |
339 |
340 | def gen_output(args, scores):
341 | if args.out is None:
342 | # Write results to stdout.
343 | print('%4s %8s %s' % ('#', 'Score', 'SMILES'))
344 | for i in range(len(scores)):
345 | smi, score = scores[i]
346 | print('%4d %8f %s' % (i+1, score, smi))
347 | else:
348 | # Write csv output.
349 | csv = 'Rank,SMILES,Score\n'
350 | for i in range(len(scores)):
351 | smi, score = scores[i]
352 | csv += '%d,%s,%f\n' % (
353 | i+1, smi, score
354 | )
355 |
356 | open(args.out, 'w').write(csv)
357 | print('[*] Wrote output to %s' % args.out)
358 |
359 |
360 | def fuse(lig, frag):
361 | # Combine the ligand and fragment, though this does not form a bond between
362 | # the two.
363 | merged = Chem.RWMol(Chem.CombineMols(lig, frag))
364 |
365 | conn_atoms = [a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0]
366 | neighbors = [merged.GetAtomWithIdx(x).GetNeighbors()[0].GetIdx() for x in conn_atoms]
367 |
368 | bond = merged.AddBond(neighbors[0], neighbors[1], Chem.rdchem.BondType.SINGLE)
369 |
370 | merged.RemoveAtom([a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0][0])
371 | merged.RemoveAtom([a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0][0])
372 |
373 | Chem.SanitizeMol(merged)
374 |
375 | return merged
376 |
377 |
378 | def fuse_fragments(lig, conn, scores):
379 | # Note: lig is rdkit.Chem.rdchem.Mol; scores is a list of (smiles, score)
380 | # tuples.
381 | new_sc = []
382 | for smi, score in scores:
383 | try:
384 | frag = Chem.MolFromSmiles(smi)
385 | fused = fuse(Chem.Mol(lig), frag)
386 | new_sc.append((Chem.MolToSmiles(fused, False), score))
387 | except:
388 | print('[*] Error: couldn\'t process mol.')
389 | new_sc.append(('', score))
390 |
391 | return new_sc
392 |
393 |
394 | def run(args):
395 | device = get_target_device(args)
396 |
397 | model = get_model(args, device)
398 | f_smiles, f_fingerprints = get_fingerprints(args)
399 |
400 | rec_coords, rec_types, parent_coords, parent_types, conn, lig = get_structures(args)
401 |
402 | batch = generate_grids(args, model._args, rec_coords, rec_types,
403 | parent_coords, parent_types, conn, device)
404 |
405 | scores = get_predictions(model, batch, f_smiles, f_fingerprints)
406 |
407 | if args.top_k != -1:
408 | scores = scores[:args.top_k]
409 |
410 | if args.full:
411 | scores = fuse_fragments(lig, conn, scores)
412 |
413 | gen_output(args, scores)
414 |
415 |
416 | def main():
417 | global VERSION
418 |
419 | print("\nDeepFrag " + VERSION)
420 | print("\nIf you use DeepFrag in your research, please cite:\n")
421 | print("Green, H., Koes, D. R., & Durrant, J. D. (2021). DeepFrag: a deep convolutional")
422 | print("neural network for fragment-based lead optimization. Chemical Science.\n")
423 |
424 | ensure_cli_data()
425 |
426 | parser = argparse.ArgumentParser()
427 |
428 | # Structure
429 | parser.add_argument('--receptor', help='Path to receptor structure.')
430 | parser.add_argument('--ligand', help='Path to ligand structure.')
431 | parser.add_argument('--pdb', help='PDB ID to download.')
432 | parser.add_argument('--resnum', type=int, help='Residue number of ligand.')
433 |
434 | # Connection point
435 | parser.add_argument('--cx', type=float, help='Connection point x coordinate.')
436 | parser.add_argument('--cy', type=float, help='Connection point y coordinate.')
437 | parser.add_argument('--cz', type=float, help='Connection point z coordinate.')
438 | parser.add_argument('--cname', type=str, help='Connection point atom name.')
439 |
440 | # Removal point
441 | parser.add_argument('--rx', type=float, help='Removal point x coordinate.')
442 | parser.add_argument('--ry', type=float, help='Removal point y coordinate.')
443 | parser.add_argument('--rz', type=float, help='Removal point z coordinate.')
444 | parser.add_argument('--rname', type=str, help='Removal point atom name.')
445 |
446 | # Misc
447 | parser.add_argument('--full', action='store_true', default=False,
448 | help='Print the full (fused) ligand structure.')
449 | parser.add_argument('--num_grids', type=int, default=4,
450 | help='Number of grid rotations.')
451 | parser.add_argument('--top_k', type=int, default=25,
452 | help='Number of results to show. Set to -1 to show all.')
453 | parser.add_argument('--out', type=str,
454 | help='Path to output CSV file.')
455 | parser.add_argument('--cpu', action='store_true', default=False,
456 | help='Use the CPU for grid generation and predictions.')
457 | parser.add_argument('--gpu', action='store_true', default=False,
458 | help='Use a (CUDA-capable) GPU for grid generation and predictions.')
459 |
460 | args = parser.parse_args()
461 |
462 | groupings = [
463 | ([('receptor', 'ligand'), ('pdb', 'resnum')], True),
464 | ([('cx', 'cy', 'cz'), ('cname',)], True),
465 | ([('rx', 'ry', 'rz'), ('rname',)], False),
466 | ([('cpu',), ('gpu',)], False)
467 | ]
468 |
469 | for grp, req in groupings:
470 | partial = []
471 | complete = 0
472 |
473 | for subset in grp:
474 | res = [not (getattr(args, name) in [None, False]) for name in subset]
475 | partial.append(any(res) and not all(res))
476 | complete += int(all(res))
477 |
478 | if any(partial) or complete > 1 or (complete != 1 and req):
479 | # Invalid arg combination.
480 | print('Invalid arguments, must specify exactly one of the following combinations:')
481 | for subset in grp:
482 | print('\t%s' % ', '.join(['--' + x for x in subset]))
483 | exit(-1)
484 |
485 | run(args)
486 |
487 |
488 | if __name__=='__main__':
489 | main()
490 |
--------------------------------------------------------------------------------
/leadopt/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
--------------------------------------------------------------------------------
/leadopt/data_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | """
17 | Contains utility code for reading packed data files.
18 | """
19 | import os
20 |
21 | import torch
22 | from torch.utils.data import DataLoader, Dataset
23 | import numpy as np
24 | import h5py
25 | import tqdm
26 |
27 | # Atom typing
28 | #
29 | # Atom typing is the process of figuring out which layer each atom should be
30 | # written to. For ease of testing, the packed data file contains a lot of
31 | # potentially useful atomic information which can be distilled during the
32 | # data loading process.
33 | #
34 | # Atom typing is implemented by map functions of the type:
35 | # (atom descriptor) -> (layer index)
36 | #
37 | # If the layer index is -1, the atom is ignored.
38 |
39 |
40 | class AtomTyper(object):
41 | def __init__(self, fn, num_layers):
42 | """Initialize an atom typer.
43 |
44 | Args:
45 | fn: a function of type:
46 | (atomic_num, aro, hdon, hacc, pcharge) -> (mask)
47 | num_layers: number of output layers (<=32)
48 | """
49 | self._fn = fn
50 | self._num_layers = num_layers
51 |
52 | def size(self):
53 | return self._num_layers
54 |
55 | def apply(self, *args):
56 | return self._fn(*args)
57 |
58 |
59 | class CondAtomTyper(AtomTyper):
60 | def __init__(self, cond_func):
61 | assert len(cond_func) <= 16
62 | def _fn(*args):
63 | v = 0
64 | for k in range(len(cond_func)):
65 | if cond_func[k](*args):
66 | v |= 1 << k
67 | return v
68 | super(CondAtomTyper, self).__init__(_fn, len(cond_func))
69 |
70 |
71 | REC_TYPER = {
72 | # 1 channel, no hydrogen
73 | 'single': CondAtomTyper([
74 | lambda num, aro, hdon, hacc, pcharge: num not in [0,1]
75 | ]),
76 |
77 | # 1 channel, including hydrogen
78 | 'single_h': CondAtomTyper([
79 | lambda num, aro, hdon, hacc, pcharge: num != 0
80 | ]),
81 |
82 | # (C,N,O,S,*)
83 | 'simple': CondAtomTyper([
84 | lambda num, aro, hdon, hacc, pcharge: num == 6,
85 | lambda num, aro, hdon, hacc, pcharge: num == 7,
86 | lambda num, aro, hdon, hacc, pcharge: num == 8,
87 | lambda num, aro, hdon, hacc, pcharge: num == 16,
88 | lambda num, aro, hdon, hacc, pcharge: num not in [0,1,6,7,8,16],
89 | ]),
90 |
91 | # (H,C,N,O,S,*)
92 | 'simple_h': CondAtomTyper([
93 | lambda num, aro, hdon, hacc, pcharge: num == 1,
94 | lambda num, aro, hdon, hacc, pcharge: num == 6,
95 | lambda num, aro, hdon, hacc, pcharge: num == 7,
96 | lambda num, aro, hdon, hacc, pcharge: num == 8,
97 | lambda num, aro, hdon, hacc, pcharge: num == 16,
98 | lambda num, aro, hdon, hacc, pcharge: num not in [0,1,6,7,8,16],
99 | ]),
100 |
101 | # (aro, hdon, hacc, positive, negative, occ)
102 | 'meta': CondAtomTyper([
103 | lambda num, aro, hdon, hacc, pcharge: bool(aro), # aromatic
104 | lambda num, aro, hdon, hacc, pcharge: bool(hdon), # hydrogen donor
105 | lambda num, aro, hdon, hacc, pcharge: bool(hacc), # hydrogen acceptor
106 | lambda num, aro, hdon, hacc, pcharge: pcharge >= 128, # partial positive
107 | lambda num, aro, hdon, hacc, pcharge: pcharge < 128, # partial negative
108 | lambda num, aro, hdon, hacc, pcharge: num != 0, # occupancy
109 | ]),
110 |
111 | # (aro, hdon, hacc, positive, negative, occ)
112 | 'meta_mix': CondAtomTyper([
113 | lambda num, aro, hdon, hacc, pcharge: bool(aro), # aromatic
114 | lambda num, aro, hdon, hacc, pcharge: bool(hdon), # hydrogen donor
115 | lambda num, aro, hdon, hacc, pcharge: bool(hacc), # hydrogen acceptor
116 | lambda num, aro, hdon, hacc, pcharge: pcharge >= 128, # partial positive
117 | lambda num, aro, hdon, hacc, pcharge: pcharge < 128, # partial negative
118 | lambda num, aro, hdon, hacc, pcharge: num != 0, # occupancy
119 | lambda num, aro, hdon, hacc, pcharge: num == 1, # hydrogen
120 | lambda num, aro, hdon, hacc, pcharge: num == 6, # carbon
121 | lambda num, aro, hdon, hacc, pcharge: num == 7, # nitrogen
122 | lambda num, aro, hdon, hacc, pcharge: num == 8, # oxygen
123 | lambda num, aro, hdon, hacc, pcharge: num == 16, # sulfur
124 | ])
125 | }
126 |
127 | LIG_TYPER = {
128 | # 1 channel, no hydrogen
129 | 'single': CondAtomTyper([
130 | lambda num: num not in [0,1]
131 | ]),
132 |
133 | # 1 channel, including hydrogen
134 | 'single_h': CondAtomTyper([
135 | lambda num: num != 0
136 | ]),
137 |
138 | 'simple': CondAtomTyper([
139 | lambda num: num == 6, # carbon
140 | lambda num: num == 7, # nitrogen
141 | lambda num: num == 8, # oxygen
142 | lambda num: num not in [0,1,6,7,8] # extra
143 | ]),
144 |
145 | 'simple_h': CondAtomTyper([
146 | lambda num: num == 1, # hydrogen
147 | lambda num: num == 6, # carbon
148 | lambda num: num == 7, # nitrogen
149 | lambda num: num == 8, # oxygen
150 | lambda num: num not in [0,1,6,7,8] # extra
151 | ])
152 | }
153 |
154 |
155 | class FragmentDataset(Dataset):
156 | """Utility class to work with the packed fragments.h5 format."""
157 |
158 | def __init__(self, fragment_file, rec_typer=REC_TYPER['simple'],
159 | lig_typer=LIG_TYPER['simple'], filter_rec=None, filter_smi=None,
160 | fdist_min=None, fdist_max=None, fmass_min=None, fmass_max=None,
161 | verbose=False, lazy_loading=True):
162 | """Initializes the fragment dataset.
163 |
164 | Args:
165 | fragment_file: path to fragments.h5
166 | rec_typer: AtomTyper for receptor
167 | lig_typer: AtomTyper for ligand
168 | filter_rec: list of receptor ids to use (or None to use all)
169 | skip_remap: if True, don't prepare atom type information
170 |
171 | (filtering options):
172 | fdist_min: minimum fragment distance
173 | fdist_max: maximum fragment distance
174 | fmass_min: minimum fragment mass (Da)
175 | fmass_max: maximum fragment mass (Da)
176 | """
177 | self._rec_typer = rec_typer
178 | self._lig_typer = lig_typer
179 |
180 | self.verbose = verbose
181 | self._lazy_loading = lazy_loading
182 |
183 | self.rec = self._load_rec(fragment_file, rec_typer)
184 | self.frag = self._load_fragments(fragment_file, lig_typer)
185 |
186 | self.valid_idx = self._get_valid_examples(
187 | filter_rec, filter_smi, fdist_min, fdist_max, fmass_min, fmass_max, verbose)
188 |
189 | def _load_rec(self, fragment_file, rec_typer):
190 | """Loads receptor information."""
191 | f = h5py.File(fragment_file, 'r')
192 |
193 | rec_coords = f['rec_coords'][()]
194 | rec_types = f['rec_types'][()]
195 | rec_lookup = f['rec_lookup'][()]
196 |
197 | r = range(len(rec_types))
198 | if self.verbose:
199 | r = tqdm.tqdm(r, desc='Remap receptor atoms')
200 |
201 | rec_remapped = np.zeros(len(rec_types), dtype=np.uint16)
202 | if not self._lazy_loading:
203 | for i in r:
204 | rec_remapped[i] = rec_typer.apply(*rec_types[i])
205 |
206 | rec_loaded = np.zeros(len(rec_lookup)).astype(np.bool)
207 |
208 | # create rec mapping
209 | rec_mapping = {}
210 | for i in range(len(rec_lookup)):
211 | rec_mapping[rec_lookup[i][0].decode('ascii')] = i
212 |
213 | rec = {
214 | 'rec_coords': rec_coords,
215 | 'rec_types': rec_types,
216 | 'rec_remapped': rec_remapped,
217 | 'rec_lookup': rec_lookup,
218 | 'rec_mapping': rec_mapping,
219 | 'rec_loaded': rec_loaded
220 | }
221 |
222 | f.close()
223 |
224 | return rec
225 |
226 | def _load_fragments(self, fragment_file, lig_typer):
227 | """Loads fragment information."""
228 | f = h5py.File(fragment_file, 'r')
229 |
230 | frag_data = f['frag_data'][()]
231 | frag_lookup = f['frag_lookup'][()]
232 | frag_smiles = f['frag_smiles'][()]
233 | frag_mass = f['frag_mass'][()]
234 | frag_dist = f['frag_dist'][()]
235 |
236 | frag_lig_smi = None
237 | frag_lig_idx = None
238 | if 'frag_lig_smi' in f.keys():
239 | frag_lig_smi = f['frag_lig_smi'][()]
240 | frag_lig_idx = f['frag_lig_idx'][()]
241 |
242 | # unpack frag data into separate structures
243 | frag_coords = frag_data[:,:3].astype(np.float32)
244 | frag_types = frag_data[:,3].astype(np.uint8)
245 |
246 | frag_remapped = np.zeros(len(frag_types), dtype=np.uint16)
247 | if not self._lazy_loading:
248 | for i in range(len(frag_types)):
249 | frag_remapped[i] = lig_typer.apply(frag_types[i])
250 |
251 | frag_loaded = np.zeros(len(frag_lookup)).astype(np.bool)
252 |
253 | # find and save connection point
254 | r = range(len(frag_lookup))
255 | if self.verbose:
256 | r = tqdm.tqdm(r, desc='Frag connection point')
257 |
258 | frag_conn = np.zeros((len(frag_lookup), 3))
259 | for i in r:
260 | _,f_start,f_end,_,_ = frag_lookup[i]
261 | fdat = frag_data[f_start:f_end]
262 |
263 | found = False
264 | for j in range(len(fdat)):
265 | if fdat[j][3] == 0:
266 | frag_conn[i,:] = tuple(fdat[j])[:3]
267 | found = True
268 | break
269 |
270 | assert found, "missing fragment connection point at %d" % i
271 |
272 | frag = {
273 | 'frag_coords': frag_coords, # d_idx -> (x,y,z)
274 | 'frag_types': frag_types, # d_idx -> (type)
275 | 'frag_remapped': frag_remapped, # d_idx -> (layer)
276 | 'frag_lookup': frag_lookup, # f_idx -> (rec_id, fstart, fend, pstart, pend)
277 | 'frag_conn': frag_conn, # f_idx -> (x,y,z)
278 | 'frag_smiles': frag_smiles, # f_idx -> smiles
279 | 'frag_mass': frag_mass, # f_idx -> mass
280 | 'frag_dist': frag_dist, # f_idx -> dist
281 | 'frag_lig_smi': frag_lig_smi,
282 | 'frag_lig_idx': frag_lig_idx,
283 | 'frag_loaded': frag_loaded
284 | }
285 |
286 | f.close()
287 |
288 | return frag
289 |
290 | def _get_valid_examples(self, filter_rec, filter_smi, fdist_min, fdist_max, fmass_min,
291 | fmass_max, verbose):
292 | """Returns an array of valid fragment indexes.
293 |
294 | "Valid" in this context means the fragment belongs to a receptor in
295 | filter_rec and the fragment abides by the optional mass/distance
296 | constraints.
297 | """
298 | # keep track of valid examples
299 | valid_mask = np.ones(self.frag['frag_lookup'].shape[0]).astype(np.bool)
300 |
301 | num_frags = self.frag['frag_lookup'].shape[0]
302 |
303 | # filter by receptor id
304 | if filter_rec is not None:
305 | valid_rec = np.zeros(num_frags, dtype=np.bool)
306 |
307 | r = range(num_frags)
308 | if verbose:
309 | r = tqdm.tqdm(r, desc='filter rec')
310 |
311 | for i in r:
312 | rec = self.frag['frag_lookup'][i][0].decode('ascii')
313 | if rec in filter_rec:
314 | valid_rec[i] = 1
315 | valid_mask *= valid_rec
316 |
317 | # filter by ligand smiles string
318 | if filter_smi is not None:
319 | valid_lig = np.zeros(num_frags, dtype=np.bool)
320 |
321 | r = range(num_frags)
322 | if verbose:
323 | r = tqdm.tqdm(r, desc='filter lig')
324 |
325 | for i in r:
326 | smi = self.frag['frag_lig_smi'][self.frag['frag_lig_idx'][i]]
327 | smi = smi.decode('ascii')
328 | if smi in filter_smi:
329 | valid_lig[i] = 1
330 |
331 | valid_mask *= valid_lig
332 |
333 | # filter by fragment distance
334 | if fdist_min is not None:
335 | valid_mask[self.frag['frag_dist'] < fdist_min] = 0
336 |
337 | if fdist_max is not None:
338 | valid_mask[self.frag['frag_dist'] > fdist_max] = 0
339 |
340 | # filter by fragment mass
341 | if fmass_min is not None:
342 | valid_mask[self.frag['frag_mass'] < fmass_min] = 0
343 |
344 | if fmass_max is not None:
345 | valid_mask[self.frag['frag_mass'] > fmass_max] = 0
346 |
347 | # convert to a list of indexes
348 | valid_idx = np.where(valid_mask)[0]
349 |
350 | return valid_idx
351 |
352 | def __len__(self):
353 | """Returns the number of valid fragment examples."""
354 | return self.valid_idx.shape[0]
355 |
356 | def __getitem__(self, idx):
357 | """Returns the Nth example.
358 |
359 | Returns a dict with:
360 | f_coords: fragment coordinates (Fx3)
361 | f_types: fragment layers (Fx1)
362 | p_coords: parent coordinates (Px3)
363 | p_types: parent layers (Px1)
364 | r_coords: receptor coordinates (Rx3)
365 | r_types: receptor layers (Rx1)
366 | conn: fragment connection point in the parent molecule (x,y,z)
367 | smiles: fragment smiles string
368 | """
369 | # convert to fragment index
370 | frag_idx = self.valid_idx[idx]
371 | return self.get_raw(frag_idx)
372 |
373 | def get_raw(self, frag_idx):
374 | # lookup fragment
375 | rec_id, f_start, f_end, p_start, p_end = self.frag['frag_lookup'][frag_idx]
376 | smiles = self.frag['frag_smiles'][frag_idx].decode('ascii')
377 | conn = self.frag['frag_conn'][frag_idx]
378 |
379 | # lookup receptor
380 | rec_idx = self.rec['rec_mapping'][rec_id.decode('ascii')]
381 | _, r_start, r_end = self.rec['rec_lookup'][rec_idx]
382 |
383 | # fetch data
384 | # f_coords = self.frag['frag_coords'][f_start:f_end]
385 | # f_types = self.frag['frag_types'][f_start:f_end]
386 | p_coords = self.frag['frag_coords'][p_start:p_end]
387 | r_coords = self.rec['rec_coords'][r_start:r_end]
388 |
389 | if self._lazy_loading and self.frag['frag_loaded'][frag_idx] == 0:
390 | frag_types = self.frag['frag_types']
391 | frag_remapped = self.frag['frag_remapped']
392 |
393 | # load parent
394 | for i in range(p_start, p_end):
395 | frag_remapped[i] = self._lig_typer.apply(frag_types[i])
396 |
397 | self.frag['frag_loaded'][frag_idx] = 1
398 |
399 | if self._lazy_loading and self.rec['rec_loaded'][rec_idx] == 0:
400 | rec_types = self.rec['rec_types']
401 | rec_remapped = self.rec['rec_remapped']
402 |
403 | # load receptor
404 | for i in range(r_start, r_end):
405 | rec_remapped[i] = self._rec_typer.apply(*rec_types[i])
406 |
407 | self.rec['rec_loaded'][rec_idx] = 1
408 |
409 | p_mask = self.frag['frag_remapped'][p_start:p_end]
410 | r_mask = self.rec['rec_remapped'][r_start:r_end]
411 |
412 | return {
413 | # 'f_coords': f_coords,
414 | # 'f_types': f_types,
415 | 'p_coords': p_coords,
416 | 'p_types': p_mask,
417 | 'r_coords': r_coords,
418 | 'r_types': r_mask,
419 | 'conn': conn,
420 | 'smiles': smiles
421 | }
422 |
423 | def get_valid_smiles(self):
424 | """Returns a list of all valid smiles fragments."""
425 | valid_smiles = set()
426 |
427 | for idx in self.valid_idx:
428 | smiles = self.frag['frag_smiles'][idx].decode('ascii')
429 | valid_smiles.add(smiles)
430 |
431 | return list(valid_smiles)
432 |
433 | def lig_layers(self):
434 | return self._lig_typer.size()
435 |
436 | def rec_layers(self):
437 | return self._rec_typer.size()
438 |
439 |
440 | class SharedFragmentDataset(object):
441 | def __init__(self, dat, filter_rec=None, filter_smi=None, fdist_min=None,
442 | fdist_max=None, fmass_min=None, fmass_max=None):
443 |
444 | self._dat = dat
445 |
446 | self.valid_idx = self._dat._get_valid_examples(
447 | filter_rec, filter_smi, fdist_min, fdist_max, fmass_min, fmass_max, verbose=True)
448 |
449 | def __len__(self):
450 | return self.valid_idx.shape[0]
451 |
452 | def __getitem__(self, idx):
453 | frag_idx = self.valid_idx[idx]
454 | return self._dat.get_raw(frag_idx)
455 |
456 | def get_valid_smiles(self):
457 | """Returns a list of all valid smiles fragments."""
458 | valid_smiles = set()
459 |
460 | for idx in self.valid_idx:
461 | smiles = self._dat.frag['frag_smiles'][idx].decode('ascii')
462 | valid_smiles.add(smiles)
463 |
464 | return list(valid_smiles)
465 |
466 | def lig_layers(self):
467 | return self._dat.lig_layers()
468 |
469 | def rec_layers(self):
470 | return self._dat.rec_layers()
471 |
472 |
473 | class FingerprintDataset(Dataset):
474 |
475 | def __init__(self, fingerprint_file):
476 | """Initializes a fingerprint dataset.
477 |
478 | Args:
479 | fingerprint_file: path to a fingerprint .h5 file
480 | """
481 | self.fingerprints = self._load_fingerprints(fingerprint_file)
482 |
483 | def _load_fingerprints(self, fingerprint_file):
484 | """Loads fingerprint information."""
485 | f = h5py.File(fingerprint_file, 'r')
486 |
487 | fingerprint_data = f['fingerprints'][()]
488 | fingerprint_smiles = f['smiles'][()]
489 |
490 | # create smiles->idx mapping
491 | fingerprint_mapping = {}
492 | for i in range(len(fingerprint_smiles)):
493 | sm = fingerprint_smiles[i].decode('ascii')
494 | fingerprint_mapping[sm] = i
495 |
496 | fingerprints = {
497 | 'fingerprint_data': fingerprint_data,
498 | 'fingerprint_mapping': fingerprint_mapping,
499 | 'fingerprint_smiles': fingerprint_smiles,
500 | }
501 |
502 | f.close()
503 |
504 | return fingerprints
505 |
506 | def for_smiles(self, smiles):
507 | """Return a Tensor of fingerprints for a list of smiles.
508 |
509 | Args:
510 | smiles: size N list of smiles strings (as str not bytes)
511 | """
512 | fp = np.zeros((len(smiles), self.fingerprints['fingerprint_data'].shape[1]))
513 |
514 | for i in range(len(smiles)):
515 | fp_idx = self.fingerprints['fingerprint_mapping'][smiles[i]]
516 | fp[i] = self.fingerprints['fingerprint_data'][fp_idx]
517 |
518 | return torch.Tensor(fp)
519 |
--------------------------------------------------------------------------------
/leadopt/grid_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | """
17 | Contains code for gpu-accelerated grid generation.
18 | """
19 |
20 | import math
21 | import ctypes
22 |
23 | import torch
24 | import numba
25 | import numba.cuda
26 | import numpy as np
27 |
28 |
29 | GPU_DIM = 8
30 |
31 |
32 | class POINT_TYPE(object):
33 | EXP = 0 # simple exponential sphere fill
34 | SPHERE = 1 # fixed sphere fill
35 | CUBE = 2 # fixed cube fill
36 | GAUSSIAN = 3 # continous piecewise expenential fill
37 | LJ = 4
38 | DISCRETE = 5
39 |
40 | class ACC_TYPE(object):
41 | SUM = 0
42 | MAX = 1
43 |
44 |
45 | @numba.cuda.jit
46 | def gpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset,
47 | batch_idx, width, res, center, rot,
48 | point_radius, point_type, acc_type
49 | ):
50 | """Adds atoms to the grid in a GPU kernel.
51 |
52 | This kernel converts atom coordinate information to 3d voxel information.
53 | Each GPU thread is responsible for one specific grid point. This function
54 | receives a list of atomic coordinates and atom layers and simply iterates
55 | over the list to find nearby atoms and add their effect.
56 |
57 | Voxel information is stored in a 5D tensor of type: BxTxNxNxN where:
58 | B = batch size
59 | T = number of atom types (receptor + ligand)
60 | N = grid width (in gridpoints)
61 |
62 | Each invocation of this function will write information to a specific batch
63 | index specified by batch_idx. Additionally, the layer_offset parameter can
64 | be set to specify a fixed offset to add to each atom_layer item.
65 |
66 | How it works:
67 | 1. Each GPU thread controls a single gridpoint. This gridpoint coordinate
68 | is translated to a "real world" coordinate by applying rotation and
69 | translation vectors.
70 | 2. Each thread iterates over the list of atoms and checks for atoms within
71 | a threshold to add to the grid.
72 |
73 | Args:
74 | grid: DeviceNDArray tensor where grid information is stored
75 | atom_num: number of atoms
76 | atom_coords: array containing (x,y,z) atom coordinates
77 | atom_mask: uint32 array of size atom_num containing a destination
78 | layer bitmask (i.e. if bit k is set, write atom to index k)
79 | layer_offset: a fixed ofset added to each atom layer index
80 | batch_idx: index specifiying which batch to write information to
81 | width: number of grid points in each dimension
82 | res: distance between neighboring grid points in angstroms
83 | (1 == gridpoint every angstrom)
84 | (0.5 == gridpoint every half angstrom, e.g. tighter grid)
85 | center: (x,y,z) coordinate of grid center
86 | rot: (x,y,z,y) rotation quaternion
87 | """
88 | x,y,z = numba.cuda.grid(3)
89 |
90 | # center around origin
91 | tx = x - (width/2)
92 | ty = y - (width/2)
93 | tz = z - (width/2)
94 |
95 | # scale by resolution
96 | tx = tx * res
97 | ty = ty * res
98 | tz = tz * res
99 |
100 | # apply rotation vector
101 | aw = rot[0]
102 | ax = rot[1]
103 | ay = rot[2]
104 | az = rot[3]
105 |
106 | bw = 0
107 | bx = tx
108 | by = ty
109 | bz = tz
110 |
111 | # multiply by rotation vector
112 | cw = (aw * bw) - (ax * bx) - (ay * by) - (az * bz)
113 | cx = (aw * bx) + (ax * bw) + (ay * bz) - (az * by)
114 | cy = (aw * by) + (ay * bw) + (az * bx) - (ax * bz)
115 | cz = (aw * bz) + (az * bw) + (ax * by) - (ay * bx)
116 |
117 | # multiply by conjugate
118 | # dw = (cw * aw) - (cx * (-ax)) - (cy * (-ay)) - (cz * (-az))
119 | dx = (cw * (-ax)) + (cx * aw) + (cy * (-az)) - (cz * (-ay))
120 | dy = (cw * (-ay)) + (cy * aw) + (cz * (-ax)) - (cx * (-az))
121 | dz = (cw * (-az)) + (cz * aw) + (cx * (-ay)) - (cy * (-ax))
122 |
123 | # apply translation vector
124 | tx = dx + center[0]
125 | ty = dy + center[1]
126 | tz = dz + center[2]
127 |
128 | i = 0
129 | while i < atom_num:
130 | # fetch atom
131 | fx, fy, fz = atom_coords[i]
132 | mask = atom_mask[i]
133 | i += 1
134 |
135 | # invisible atoms
136 | if mask == 0:
137 | continue
138 |
139 | # point radius squared
140 | r = point_radius
141 | r2 = point_radius * point_radius
142 |
143 | # quick cube bounds check
144 | if abs(fx-tx) > r2 or abs(fy-ty) > r2 or abs(fz-tz) > r2:
145 | continue
146 |
147 | # value to add to this gridpoint
148 | val = 0
149 |
150 | if point_type == 0: # POINT_TYPE.EXP
151 | # exponential sphere fill
152 | # compute squared distance to atom
153 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
154 | if d2 > r2:
155 | continue
156 |
157 | # compute effect
158 | val = math.exp((-2 * d2) / r2)
159 | elif point_type == 1: # POINT_TYPE.SPHERE
160 | # solid sphere fill
161 | # compute squared distance to atom
162 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
163 | if d2 > r2:
164 | continue
165 |
166 | val = 1
167 | elif point_type == 2: # POINT_TYPE.CUBE
168 | # solid cube fill
169 | val = 1
170 | elif point_type == 3: # POINT_TYPE.GAUSSIAN
171 | # (Ragoza, 2016)
172 | #
173 | # piecewise gaussian sphere fill
174 | # compute squared distance to atom
175 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
176 | d = math.sqrt(d2)
177 |
178 | if d > r * 1.5:
179 | continue
180 | elif d > r:
181 | val = math.exp(-2.0) * ( (4*d2/r2) - (12*d/r) + 9 )
182 | else:
183 | val = math.exp((-2 * d2) / r2)
184 | elif point_type == 4: # POINT_TYPE.LJ
185 | # (Jimenez, 2017) - DeepSite
186 | #
187 | # LJ potential
188 | # compute squared distance to atom
189 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
190 | d = math.sqrt(d2)
191 |
192 | if d > r * 1.5:
193 | continue
194 | else:
195 | val = 1 - math.exp(-((r/d)**12))
196 | elif point_type == 5: # POINT_TYPE.DISCRETE
197 | # nearest-gridpoint
198 | # L1 distance
199 | if abs(fx-tx) < (res/2) and abs(fy-ty) < (res/2) and abs(fz-tz) < (res/2):
200 | val = 1
201 |
202 | # add value to layers
203 | for k in range(32):
204 | if (mask >> k) & 1:
205 | idx = (batch_idx, layer_offset+k, x, y, z)
206 | if acc_type == 0: # ACC_TYPE.SUM
207 | numba.cuda.atomic.add(grid, idx, val)
208 | elif acc_type == 1: # ACC_TYPE.MAX
209 | numba.cuda.atomic.max(grid, idx, val)
210 |
211 |
212 | @numba.jit(nopython=True)
213 | def cpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset,
214 | batch_idx, width, res, center, rot,
215 | point_radius, point_type, acc_type
216 | ):
217 | """Adds atoms to the grid in a GPU kernel.
218 |
219 | This kernel converts atom coordinate information to 3d voxel information.
220 | Each GPU thread is responsible for one specific grid point. This function
221 | receives a list of atomic coordinates and atom layers and simply iterates
222 | over the list to find nearby atoms and add their effect.
223 |
224 | Voxel information is stored in a 5D tensor of type: BxTxNxNxN where:
225 | B = batch size
226 | T = number of atom types (receptor + ligand)
227 | N = grid width (in gridpoints)
228 |
229 | Each invocation of this function will write information to a specific batch
230 | index specified by batch_idx. Additionally, the layer_offset parameter can
231 | be set to specify a fixed offset to add to each atom_layer item.
232 |
233 | How it works:
234 | 1. Each GPU thread controls a single gridpoint. This gridpoint coordinate
235 | is translated to a "real world" coordinate by applying rotation and
236 | translation vectors.
237 | 2. Each thread iterates over the list of atoms and checks for atoms within
238 | a threshold to add to the grid.
239 |
240 | Args:
241 | grid: DeviceNDArray tensor where grid information is stored
242 | atom_num: number of atoms
243 | atom_coords: array containing (x,y,z) atom coordinates
244 | atom_mask: uint32 array of size atom_num containing a destination
245 | layer bitmask (i.e. if bit k is set, write atom to index k)
246 | layer_offset: a fixed ofset added to each atom layer index
247 | batch_idx: index specifiying which batch to write information to
248 | width: number of grid points in each dimension
249 | res: distance between neighboring grid points in angstroms
250 | (1 == gridpoint every angstrom)
251 | (0.5 == gridpoint every half angstrom, e.g. tighter grid)
252 | center: (x,y,z) coordinate of grid center
253 | rot: (x,y,z,y) rotation quaternion
254 | """
255 | # x,y,z = numba.cuda.grid(3)
256 | for x in range(width):
257 | for y in range(width):
258 | for z in range(width):
259 |
260 | # center around origin
261 | tx = x - (width/2)
262 | ty = y - (width/2)
263 | tz = z - (width/2)
264 |
265 | # scale by resolution
266 | tx = tx * res
267 | ty = ty * res
268 | tz = tz * res
269 |
270 | # apply rotation vector
271 | aw = rot[0]
272 | ax = rot[1]
273 | ay = rot[2]
274 | az = rot[3]
275 |
276 | bw = 0
277 | bx = tx
278 | by = ty
279 | bz = tz
280 |
281 | # multiply by rotation vector
282 | cw = (aw * bw) - (ax * bx) - (ay * by) - (az * bz)
283 | cx = (aw * bx) + (ax * bw) + (ay * bz) - (az * by)
284 | cy = (aw * by) + (ay * bw) + (az * bx) - (ax * bz)
285 | cz = (aw * bz) + (az * bw) + (ax * by) - (ay * bx)
286 |
287 | # multiply by conjugate
288 | # dw = (cw * aw) - (cx * (-ax)) - (cy * (-ay)) - (cz * (-az))
289 | dx = (cw * (-ax)) + (cx * aw) + (cy * (-az)) - (cz * (-ay))
290 | dy = (cw * (-ay)) + (cy * aw) + (cz * (-ax)) - (cx * (-az))
291 | dz = (cw * (-az)) + (cz * aw) + (cx * (-ay)) - (cy * (-ax))
292 |
293 | # apply translation vector
294 | tx = dx + center[0]
295 | ty = dy + center[1]
296 | tz = dz + center[2]
297 |
298 | i = 0
299 | while i < atom_num:
300 | # fetch atom
301 | fx, fy, fz = atom_coords[i]
302 | mask = atom_mask[i]
303 | i += 1
304 |
305 | # invisible atoms
306 | if mask == 0:
307 | continue
308 |
309 | # point radius squared
310 | r = point_radius
311 | r2 = point_radius * point_radius
312 |
313 | # quick cube bounds check
314 | if abs(fx-tx) > r2 or abs(fy-ty) > r2 or abs(fz-tz) > r2:
315 | continue
316 |
317 | # value to add to this gridpoint
318 | val = 0
319 |
320 | if point_type == 0: # POINT_TYPE.EXP
321 | # exponential sphere fill
322 | # compute squared distance to atom
323 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
324 | if d2 > r2:
325 | continue
326 |
327 | # compute effect
328 | val = math.exp((-2 * d2) / r2)
329 | elif point_type == 1: # POINT_TYPE.SPHERE
330 | # solid sphere fill
331 | # compute squared distance to atom
332 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
333 | if d2 > r2:
334 | continue
335 |
336 | val = 1
337 | elif point_type == 2: # POINT_TYPE.CUBE
338 | # solid cube fill
339 | val = 1
340 | elif point_type == 3: # POINT_TYPE.GAUSSIAN
341 | # (Ragoza, 2016)
342 | #
343 | # piecewise gaussian sphere fill
344 | # compute squared distance to atom
345 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
346 | d = math.sqrt(d2)
347 |
348 | if d > r * 1.5:
349 | continue
350 | elif d > r:
351 | val = math.exp(-2.0) * ( (4*d2/r2) - (12*d/r) + 9 )
352 | else:
353 | val = math.exp((-2 * d2) / r2)
354 | elif point_type == 4: # POINT_TYPE.LJ
355 | # (Jimenez, 2017) - DeepSite
356 | #
357 | # LJ potential
358 | # compute squared distance to atom
359 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
360 | d = math.sqrt(d2)
361 |
362 | if d > r * 1.5:
363 | continue
364 | else:
365 | val = 1 - math.exp(-((r/d)**12))
366 | elif point_type == 5: # POINT_TYPE.DISCRETE
367 | # nearest-gridpoint
368 | # L1 distance
369 | if abs(fx-tx) < (res/2) and abs(fy-ty) < (res/2) and abs(fz-tz) < (res/2):
370 | val = 1
371 |
372 | # add value to layers
373 | for k in range(32):
374 | if (mask >> k) & 1:
375 | idx = (batch_idx, layer_offset+k, x, y, z)
376 | if acc_type == 0: # ACC_TYPE.SUM
377 | grid[idx] += val
378 | elif acc_type == 1: # ACC_TYPE.MAX
379 | grid[idx] = max(grid[idx], val)
380 |
381 |
382 | def mol_gridify(grid, atom_coords, atom_mask, layer_offset, batch_idx,
383 | width, res, center, rot, point_radius, point_type, acc_type,
384 | cpu=False):
385 | """Wrapper around gpu_gridify.
386 |
387 | (See gpu_gridify() for details)
388 | """
389 | if cpu:
390 | cpu_gridify(
391 | grid, len(atom_coords), atom_coords, atom_mask, layer_offset,
392 | batch_idx, width, res, center, rot, point_radius, point_type, acc_type
393 | )
394 | else:
395 | dw = ((width - 1) // GPU_DIM) + 1
396 | gpu_gridify[(dw,dw,dw), (GPU_DIM,GPU_DIM,GPU_DIM)](
397 | grid, len(atom_coords), atom_coords, atom_mask, layer_offset,
398 | batch_idx, width, res, center, rot, point_radius, point_type, acc_type
399 | )
400 |
401 |
402 | def make_tensor(shape):
403 | """Creates a pytorch tensor and numba array with shared GPU memory backing.
404 |
405 | Args:
406 | shape: the shape of the array
407 |
408 | Returns:
409 | (torch_arr, cuda_arr)
410 | """
411 | # get cuda context
412 | ctx = numba.cuda.cudadrv.driver.driver.get_active_context()
413 |
414 | # setup tensor on gpu
415 | t = torch.zeros(size=shape, dtype=torch.float32).cuda()
416 |
417 | memory = numba.cuda.cudadrv.driver.MemoryPointer(ctx, ctypes.c_ulong(t.data_ptr()), t.numel() * 4)
418 | cuda_arr = numba.cuda.cudadrv.devicearray.DeviceNDArray(
419 | t.size(),
420 | [i*4 for i in t.stride()],
421 | np.dtype('float32'),
422 | gpu_data=memory,
423 | stream=torch.cuda.current_stream().cuda_stream
424 | )
425 |
426 | return (t, cuda_arr)
427 |
428 |
429 | def rand_rot():
430 | """Returns a random uniform quaternion rotation."""
431 | q = np.random.normal(size=4) # sample quaternion from normal distribution
432 | q = q / np.sqrt(np.sum(q**2)) # normalize
433 | return q
434 |
435 |
436 | def get_batch(data, batch_size=16, batch_set=None, width=48, res=0.5,
437 | ignore_receptor=False, ignore_parent=False, fixed_rot=None,
438 | point_radius=2, point_type=POINT_TYPE.EXP,
439 | acc_type=ACC_TYPE.SUM):
440 | """Builds a batch grid from a FragmentDataset.
441 |
442 | Args:
443 | data: a FragmentDataset object
444 | rec_channels: number of receptor channels
445 | parent_channels: number of parent channels
446 | batch_size: size of the batch
447 | batch_set: if not None, specify a list of data indexes to use for each
448 | item in the batch
449 | width: grid width
450 | res: grid resolution
451 | ignore_receptor: if True, ignore receptor atoms
452 | ignore_parent: if True, ignore parent atoms
453 |
454 | Returns: (torch_grid, batch_set)
455 | torch_grid: pytorch Tensor with voxel information
456 | examples: list of examples used
457 | """
458 | assert (not (ignore_receptor and ignore_parent)), "Can't ignore parent and receptor!"
459 |
460 | batch_size = int(batch_size)
461 | width = int(width)
462 |
463 | rec_channels = data.rec_layers()
464 | lig_channels = data.lig_layers()
465 |
466 | dim = 0
467 | if not ignore_receptor:
468 | dim += rec_channels
469 | if not ignore_parent:
470 | dim += lig_channels
471 |
472 | # create a tensor with shared memory on the gpu
473 | torch_grid, cuda_grid = make_tensor((batch_size, dim, width, width, width))
474 |
475 | if batch_set is None:
476 | batch_set = np.random.choice(len(data), size=batch_size, replace=False)
477 |
478 | examples = [data[idx] for idx in batch_set]
479 |
480 | for i in range(len(examples)):
481 | example = examples[i]
482 | rot = fixed_rot
483 | if rot is None:
484 | rot = rand_rot()
485 |
486 | if ignore_receptor:
487 | mol_gridify(
488 | cuda_grid,
489 | example['p_coords'],
490 | example['p_types'],
491 | layer_offset=0,
492 | batch_idx=i,
493 | width=width,
494 | res=res,
495 | center=example['conn'],
496 | rot=rot,
497 | point_radius=point_radius,
498 | point_type=point_type,
499 | acc_type=acc_type
500 | )
501 | elif ignore_parent:
502 | mol_gridify(
503 | cuda_grid,
504 | example['r_coords'],
505 | example['r_types'],
506 | layer_offset=0,
507 | batch_idx=i,
508 | width=width,
509 | res=res,
510 | center=example['conn'],
511 | rot=rot,
512 | point_radius=point_radius,
513 | point_type=point_type,
514 | acc_type=acc_type
515 | )
516 | else:
517 | mol_gridify(
518 | cuda_grid,
519 | example['p_coords'],
520 | example['p_types'],
521 | layer_offset=0,
522 | batch_idx=i,
523 | width=width,
524 | res=res,
525 | center=example['conn'],
526 | rot=rot,
527 | point_radius=point_radius,
528 | point_type=point_type,
529 | acc_type=acc_type
530 | )
531 | mol_gridify(
532 | cuda_grid,
533 | example['r_coords'],
534 | example['r_types'],
535 | layer_offset=lig_channels,
536 | batch_idx=i,
537 | width=width,
538 | res=res,
539 | center=example['conn'],
540 | rot=rot,
541 | point_radius=point_radius,
542 | point_type=point_type,
543 | acc_type=acc_type
544 | )
545 |
546 | return torch_grid, examples
547 |
548 |
549 | def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer,
550 | conn, num_samples=32, width=24, res=1, fixed_rot=None,
551 | point_radius=1.5, point_type=0, acc_type=0, cpu=False):
552 | """Sample a raw batch with provided atom coordinates.
553 |
554 | Args:
555 | r_coords: receptor coordinates
556 | r_types: receptor types (layers)
557 | p_coords: parent coordinates
558 | p_types: parent types (layers)
559 | conn: (x,y,z) connection point
560 | num_samples: number of rotations to sample
561 | width: grid width
562 | res: grid resolution
563 | fixed_rot: None or a fixed 4-element rotation vector
564 | point_radius: atom radius in Angstroms
565 | point_type: shape of the atom densities
566 | acc_type: atom density accumulation type
567 | cpu: if True, generate batches with cpu_gridify
568 | """
569 | B = num_samples
570 | T = rec_typer.size() + lig_typer.size()
571 | N = width
572 |
573 | if cpu:
574 | t = np.zeros((B,T,N,N,N))
575 | torch_grid = t
576 | cuda_grid = t
577 | else:
578 | torch_grid, cuda_grid = make_tensor((B,T,N,N,N))
579 |
580 | r_mask = np.zeros(len(r_types), dtype=np.uint32)
581 | p_mask = np.zeros(len(p_types), dtype=np.uint32)
582 |
583 | for i in range(len(r_types)):
584 | r_mask[i] = rec_typer.apply(*r_types[i])
585 |
586 | for i in range(len(p_types)):
587 | p_mask[i] = lig_typer.apply(*p_types[i])
588 |
589 | for i in range(num_samples):
590 | rot = fixed_rot
591 | if rot is None:
592 | rot = rand_rot()
593 |
594 | mol_gridify(
595 | cuda_grid,
596 | p_coords,
597 | p_mask,
598 | layer_offset=0,
599 | batch_idx=i,
600 | width=width,
601 | res=res,
602 | center=conn,
603 | rot=rot,
604 | point_radius=point_radius,
605 | point_type=point_type,
606 | acc_type=acc_type,
607 | cpu=cpu
608 | )
609 |
610 | mol_gridify(
611 | cuda_grid,
612 | r_coords,
613 | r_mask,
614 | layer_offset=lig_typer.size(),
615 | batch_idx=i,
616 | width=width,
617 | res=res,
618 | center=conn,
619 | rot=rot,
620 | point_radius=point_radius,
621 | point_type=point_type,
622 | acc_type=acc_type,
623 | cpu=cpu
624 | )
625 |
626 | return torch_grid
627 |
--------------------------------------------------------------------------------
/leadopt/infer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | import os
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 |
22 | import rdkit.Chem.AllChem as Chem
23 |
24 | import numpy as np
25 | import h5py
26 | import tqdm
27 |
28 | from leadopt.grid_util import get_raw_batch
29 | import leadopt.util as util
30 |
31 |
32 | def get_nearest_fp(fingerprints, fp, k=10):
33 | '''
34 | Return the top-k closest rows in fingerprints
35 |
36 | Returns [(idx1, dist1), (idx2, dist2), ...]
37 | '''
38 | def mse(a,b):
39 | return np.sum((a-b)**2, axis=1)
40 |
41 | d = mse(fingerprints, fp)
42 | arr = [(i,d[i]) for i in range(len(d))]
43 | arr = sorted(arr, key=lambda x: x[1])
44 |
45 | return arr[:k]
46 |
47 |
48 | def infer_all(model, fingerprints, smiles, rec_path, lig_path, num_samples=16, k=25):
49 | '''
50 |
51 | '''
52 | # load ligand and receptor
53 | lig, frags = util.load_ligand(lig_path)
54 | rec = util.load_receptor(rec_path)
55 |
56 |
57 | # compute shared receptor coords and layers
58 | rec_coords, rec_layers = util.mol_to_points(rec)
59 | # [
60 | # (parent_sm, orig_frag_sm, conn, [
61 | # (new_frag_sm, merged_sm, score),
62 | # ...
63 | # ])
64 | # ]
65 | res = []
66 |
67 | for parent, frag in frags:
68 | # compute parent coords and layers
69 | parent_coords, parent_layers = util.mol_to_points(parent)
70 |
71 | # find connection point
72 | conn = util.get_connection_point(frag)
73 |
74 | # generate batch
75 | grid = get_raw_batch(rec_coords, rec_layers, parent_coords, parent_layers, conn)
76 |
77 | # infer
78 | fp = model(grid).detach().cpu().numpy()
79 | fp_mean = np.mean(fp, axis=0)
80 |
81 | # find closest fingerprints
82 | top = get_nearest_fp(fingerprints, fp_mean, k=k)
83 |
84 | # convert to (frag_smiles, merged_smiles, score) tuples
85 | top_smiles = [(smiles[x[0]].decode('ascii'), util.merge_smiles(Chem.MolToSmiles(parent), smiles[x[0]].decode('ascii')), x[1]) for x in top]
86 |
87 | res.append(
88 | (Chem.MolToSmiles(parent, isomericSmiles=False), Chem.MolToSmiles(frag, isomericSmiles=False), tuple(conn), top_smiles)
89 | )
90 |
91 | return res
92 |
--------------------------------------------------------------------------------
/leadopt/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 |
20 |
21 | def mse(yp, yt):
22 | """Mean squared error loss."""
23 | return torch.sum((yp - yt) ** 2, axis=1)
24 |
25 |
26 | def bce(yp, yt):
27 | """Binary cross entropy loss."""
28 | return torch.sum(F.binary_cross_entropy(yp, yt, reduction='none'), axis=1)
29 |
30 |
31 | def tanimoto(yp, yt):
32 | """Tanimoto distance metric."""
33 | intersect = torch.sum(yt * torch.round(yp), axis=1)
34 | union = torch.sum(torch.clamp(yt + torch.round(yp), 0, 1), axis=1)
35 | return 1 - (intersect / union)
36 |
37 |
38 | _cos = nn.CosineSimilarity(dim=1, eps=1e-6)
39 | def cos(yp, yt):
40 | """Cosine distance as a loss (inverted)."""
41 | return 1 - _cos(yp,yt)
42 |
43 |
44 | def broadcast_fn(fn, yp, yt):
45 | """Broadcast a distance function."""
46 | yp_b, yt_b = torch.broadcast_tensors(yp, yt)
47 | return fn(yp_b, yt_b)
48 |
49 |
50 | def average_position(fingerprints, fn, norm=True):
51 | """Returns the average ranking of the correct fragment relative to all
52 | possible fragments.
53 |
54 | Args:
55 | fingerprints: NxF tensor of fingerprint data
56 | fn: distance function to compare fingerprints
57 | norm: if True, normalize position in range (0,1)
58 | """
59 | def _average_position(yp, yt):
60 | # distance to correct fragment
61 | p_dist = broadcast_fn(fn, yp, yt.detach())
62 |
63 | c = torch.empty(yp.shape[0])
64 | for i in range(yp.shape[0]):
65 | # compute distance to all other fragments
66 | dist = broadcast_fn(fn, yp[i].unsqueeze(0), fingerprints)
67 |
68 | # number of fragment that are closer or equal
69 | count = torch.sum((dist <= p_dist[i]).to(torch.float))
70 | c[i] = count
71 |
72 | score = torch.mean(c)
73 | return score
74 |
75 | return _average_position
76 |
77 |
78 | def average_support(fingerprints, fn):
79 | """
80 | """
81 | def _average_support(yp, yt):
82 | # correct distance
83 | p_dist = broadcast_fn(fn, yp, yt)
84 |
85 | c = torch.empty(yp.shape[0])
86 | for i in range(yp.shape[0]):
87 | # compute distance to all other fragments
88 | dist = broadcast_fn(fn, yp[i].unsqueeze(0), fingerprints)
89 |
90 | # shift distance so bad examples are positive
91 | dist -= p_dist[i]
92 | dist *= -1
93 |
94 | dist_n = torch.sigmoid(dist)
95 |
96 | c[i] = torch.mean(dist_n)
97 |
98 | score = torch.mean(c)
99 |
100 | return score
101 |
102 | return _average_support
103 |
104 |
105 | def inside_support(fingerprints, fn):
106 | """
107 | """
108 | def _inside_support(yp, yt):
109 | # correct distance
110 | p_dist = broadcast_fn(fn, yp, yt)
111 |
112 | c = torch.empty(yp.shape[0])
113 | for i in range(yp.shape[0]):
114 | # compute distance to all other fragments
115 | dist = broadcast_fn(fn, yp[i].unsqueeze(0), fingerprints)
116 |
117 | # shift distance so bad examples are positive
118 | dist -= p_dist[i]
119 | dist *= -1
120 |
121 | # ignore labels that are further away
122 | dist[dist < 0] = 0
123 |
124 | dist_n = torch.sigmoid(dist)
125 |
126 | c[i] = torch.mean(dist_n)
127 |
128 | score = torch.mean(c)
129 |
130 | return score
131 |
132 | return _inside_support
133 |
134 |
135 | def top_k_acc(fingerprints, fn, k, pre=''):
136 | """Top-k accuracy metric.
137 |
138 | Returns a dict containing top-k accuracies:
139 | {
140 | {pre}acc_{k1}: acc_k1,
141 | {pre}acc_{k2}: acc_k2,
142 | }
143 |
144 | Args:
145 | fingerprints: NxF tensor of fingerprints
146 | fn: distance function to compare fingerprints
147 | k: List[int] containing K-positions to evaluate (e.g. [1,5,10])
148 | pre: optional prefix on the metric name
149 | """
150 |
151 | def _top_k_acc(yp, yt):
152 | # correct distance
153 | p_dist = broadcast_fn(fn, yp.detach(), yt.detach())
154 |
155 | c = torch.empty(yp.shape[0], len(k))
156 | for i in range(yp.shape[0]):
157 | # compute distance to all other fragments
158 | dist = broadcast_fn(fn, yp[i].unsqueeze(0).detach(), fingerprints)
159 |
160 | # number of fragment that are closer or equal
161 | count = torch.sum((dist < p_dist[i]).to(torch.float))
162 |
163 | for j in range(len(k)):
164 | c[i,j] = int(count < k[j])
165 |
166 | score = torch.mean(c, 0)
167 | m = {'%sacc_%d' % (pre, h): v.item() for h,v in zip(k,score)}
168 |
169 | return m
170 |
171 | return _top_k_acc
172 |
--------------------------------------------------------------------------------
/leadopt/model_conf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | import os
17 | import json
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | import tqdm
23 | import numpy as np
24 |
25 | try:
26 | import wandb
27 | except:
28 | pass
29 |
30 | from leadopt.models.voxel import VoxelFingerprintNet
31 | from leadopt.data_util import FragmentDataset, SharedFragmentDataset, FingerprintDataset, LIG_TYPER,\
32 | REC_TYPER
33 | from leadopt.grid_util import get_batch
34 | from leadopt.metrics import mse, bce, tanimoto, cos, top_k_acc,\
35 | average_support, inside_support
36 |
37 | from config import moad_partitions
38 |
39 |
40 | def get_bios(p):
41 | part = []
42 | for n in range(20):
43 | part += [x.lower() + '_bio%d' % n for x in p]
44 | return part
45 |
46 |
47 | def _do_mean(fn):
48 | def g(yp, yt):
49 | return torch.mean(fn(yp,yt))
50 | return g
51 |
52 |
53 | def _direct_loss(fingerprints, fn):
54 | return _do_mean(fn)
55 |
56 |
57 | DIST_FN = {
58 | 'mse': mse,
59 | 'bce': bce,
60 | 'cos': cos,
61 | 'tanimoto': tanimoto
62 | }
63 |
64 |
65 | LOSS_TYPE = {
66 | # minimize distance to target fingeprint
67 | 'direct': _direct_loss,
68 |
69 | # minimize distance to target and maximize distance to all other
70 | 'support_v1': average_support,
71 |
72 | # support, limited to closer points
73 | 'support_v2': average_support,
74 | }
75 |
76 |
77 | class RunLog(object):
78 | def __init__(self, args, models, wandb_project=None):
79 | """Initialize a run logger.
80 |
81 | Args:
82 | args: command line training arguments
83 | models: {name: model} mapping
84 | wandb_project: a project to initialize wandb or None
85 | """
86 | self._use_wandb = wandb_project != None
87 | if self._use_wandb:
88 | wandb.init(
89 | project=wandb_project,
90 | config=args
91 | )
92 |
93 | for m in models:
94 | wandb.watch(models[m])
95 |
96 | def log(self, x):
97 | if self._use_wandb:
98 | wandb.log(x)
99 |
100 |
101 | class MetricTracker(object):
102 | def __init__(self, name, metric_fns):
103 | self._name = name
104 | self._metric_fns = metric_fns
105 | self._metrics = {}
106 |
107 | def evaluate(self, yp, yt):
108 | for m in self._metric_fns:
109 | self.update(m, self._metric_fns[m](yp, yt))
110 |
111 | def update(self, name, metric):
112 | if type(metric) is dict:
113 | for subname in metric:
114 | fullname = '%s_%s' % (self._name, subname)
115 | if not fullname in self._metrics:
116 | self._metrics[fullname] = 0
117 | self._metrics[fullname] += metric[subname]
118 | else:
119 | fullname = '%s_%s' % (self._name, name)
120 | if not fullname in self._metrics:
121 | self._metrics[fullname] = 0
122 | self._metrics[fullname] += metric
123 |
124 | def normalize(self, size):
125 | for m in self._metrics:
126 | self._metrics[m] /= size
127 |
128 | def clear(self):
129 | self._metrics = {}
130 |
131 | def get(self, name):
132 | fullname = '%s_%s' % (self._name, name)
133 | return self._metrics[fullname]
134 |
135 | def get_all(self):
136 | return self._metrics
137 |
138 |
139 | class LeadoptModel(object):
140 | """Abstract LeadoptModel base class."""
141 |
142 | @staticmethod
143 | def setup_base_parser(parser):
144 | """Configures base parser arguments."""
145 | parser.add_argument('--wandb_project', default=None, help='''
146 | Set this argument to track run in wandb.
147 | ''')
148 |
149 | @staticmethod
150 | def setup_parser(sub):
151 | """Adds arguments to a subparser.
152 |
153 | Args:
154 | sub: an argparse subparser
155 | """
156 | raise NotImplementedError()
157 |
158 | @staticmethod
159 | def get_defaults():
160 | return {}
161 |
162 | @classmethod
163 | def load(cls, path, device='cuda'):
164 | """Load model configuration saved with save().
165 |
166 | Call LeadoptModel.load to infer model type.
167 | Or call subclass.load to load a specific model type.
168 |
169 | Args:
170 | path: full path to saved model
171 | """
172 | args_path = os.path.join(path, 'args.json')
173 | args = json.loads(open(args_path, 'r').read())
174 |
175 | model_type = MODELS[args['version']] if cls is LeadoptModel else cls
176 |
177 | default_args = model_type.get_defaults()
178 | for k in default_args:
179 | if not k in args:
180 | args[k] = default_args[k]
181 |
182 | instance = model_type(args, device=device, with_log=False)
183 | for name in instance._models:
184 | model_path = os.path.join(path, '%s.pt' % name)
185 | instance._models[name].load_state_dict(torch.load(model_path, map_location=torch.device(device)))
186 |
187 | return instance
188 |
189 | def __init__(self, args, device='cuda', with_log=True):
190 | self._args = args
191 | self._device = torch.device(device)
192 | self._models = self.init_models()
193 | if with_log:
194 | wandb_project = None
195 | if 'wandb_project' in self._args:
196 | wandb_project = self._args['wandb_project']
197 |
198 | self._log = RunLog(self._args, self._models, wandb_project)
199 |
200 | def save(self, path):
201 | """Save model configuration to a path.
202 |
203 | Args:
204 | path: path to an existing directory to save models
205 | """
206 | os.makedirs(path, exist_ok=True)
207 |
208 | args_path = os.path.join(path, 'args.json')
209 | open(args_path, 'w').write(json.dumps(self._args))
210 |
211 | for name in self._models:
212 | model_path = os.path.join(path, '%s.pt' % name)
213 | torch.save(self._models[name].state_dict(), model_path)
214 |
215 | def init_models(self):
216 | """Initializes any pytorch models.
217 |
218 | Returns a dict of name->model mapping:
219 | {'model_1': m1, 'model_2': m2, ...}
220 | """
221 | return {}
222 |
223 | def train(self, save_path=None):
224 | """Train the models."""
225 | raise NotImplementedError()
226 |
227 |
228 | class VoxelNet(LeadoptModel):
229 | @staticmethod
230 | def setup_parser(sub):
231 | # testing
232 | sub.add_argument('--no_partitions', action='store_true', default=False, help='''
233 | If set, disable the use of TRAIN/VAL partitions during training.
234 | ''')
235 |
236 | # dataset
237 | sub.add_argument('-f', '--fragments', required=True, help='''
238 | Path to fragments file.
239 | ''')
240 | sub.add_argument('-fp', '--fingerprints', required=True, help='''
241 | Path to fingerprints file.
242 | ''')
243 |
244 | # training parameters
245 | sub.add_argument('-lr', '--learning_rate', type=float, default=1e-4)
246 | sub.add_argument('--num_epochs', type=int, default=50, help='''
247 | Number of epochs to train for.
248 | ''')
249 | sub.add_argument('--test_steps', type=int, default=500, help='''
250 | Number of evaluation steps per epoch.
251 | ''')
252 | sub.add_argument('-b', '--batch_size', default=32, type=int)
253 |
254 | # grid generation
255 | sub.add_argument('--grid_width', type=int, default=24)
256 | sub.add_argument('--grid_res', type=float, default=1)
257 |
258 | # fragment filtering
259 | sub.add_argument('--fdist_min', type=float, help='''
260 | Ignore fragments closer to the receptor than this distance (Angstroms).
261 | ''')
262 | sub.add_argument('--fdist_max', type=float, help='''
263 | Ignore fragments further from the receptor than this distance (Angstroms).
264 | ''')
265 | sub.add_argument('--fmass_min', type=float, help='''
266 | Ignore fragments smaller than this mass (Daltons).
267 | ''')
268 | sub.add_argument('--fmass_max', type=float, help='''
269 | Ignore fragments larger than this mass (Daltons).
270 | ''')
271 |
272 | # receptor/parent options
273 | sub.add_argument('--ignore_receptor', action='store_true', default=False)
274 | sub.add_argument('--ignore_parent', action='store_true', default=False)
275 | sub.add_argument('-rec_typer', required=True, choices=[k for k in REC_TYPER])
276 | sub.add_argument('-lig_typer', required=True, choices=[k for k in LIG_TYPER])
277 | # sub.add_argument('-rec_channels', required=True, type=int)
278 | # sub.add_argument('-lig_channels', required=True, type=int)
279 |
280 | # model parameters
281 | # sub.add_argument('--in_channels', type=int, default=18)
282 | sub.add_argument('--output_size', type=int, default=2048)
283 | sub.add_argument('--pad', default=False, action='store_true')
284 | sub.add_argument('--blocks', nargs='+', type=int, default=[32,64])
285 | sub.add_argument('--fc', nargs='+', type=int, default=[2048])
286 | sub.add_argument('--use_all_labels', default=False, action='store_true')
287 | sub.add_argument('--dist_fn', default='mse', choices=[k for k in DIST_FN])
288 | sub.add_argument('--loss', default='direct', choices=[k for k in LOSS_TYPE])
289 |
290 | @staticmethod
291 | def get_defaults():
292 | return {
293 | 'point_radius': 1,
294 | 'point_type': 0,
295 | 'acc_type': 0
296 | }
297 |
298 | def init_models(self):
299 | in_channels = 0
300 | if not self._args['ignore_receptor']:
301 | in_channels += REC_TYPER[self._args['rec_typer']].size()
302 | if not self._args['ignore_parent']:
303 | in_channels += LIG_TYPER[self._args['lig_typer']].size()
304 |
305 | voxel = VoxelFingerprintNet(
306 | in_channels=in_channels,
307 | output_size=self._args['output_size'],
308 | blocks=self._args['blocks'],
309 | fc=self._args['fc'],
310 | pad=self._args['pad']
311 | ).to(self._device)
312 | return {'voxel': voxel}
313 |
314 | def load_data(self):
315 | print('[*] Loading data...', flush=True)
316 | dat = FragmentDataset(
317 | self._args['fragments'],
318 | rec_typer=REC_TYPER[self._args['rec_typer']],
319 | lig_typer=LIG_TYPER[self._args['lig_typer']],
320 | verbose=True
321 | )
322 |
323 | train_dat = SharedFragmentDataset(
324 | dat,
325 | filter_rec=set(get_bios(moad_partitions.TRAIN)),
326 | filter_smi=set(moad_partitions.TRAIN_SMI),
327 | fdist_min=self._args['fdist_min'],
328 | fdist_max=self._args['fdist_max'],
329 | fmass_min=self._args['fmass_min'],
330 | fmass_max=self._args['fmass_max'],
331 | )
332 |
333 | val_dat = SharedFragmentDataset(
334 | dat,
335 | filter_rec=set(get_bios(moad_partitions.VAL)),
336 | filter_smi=set(moad_partitions.VAL_SMI),
337 | fdist_min=self._args['fdist_min'],
338 | fdist_max=self._args['fdist_max'],
339 | fmass_min=self._args['fmass_min'],
340 | fmass_max=self._args['fmass_max'],
341 | )
342 |
343 | return train_dat, val_dat
344 |
345 | def train(self, save_path=None, custom_steps=None, checkpoint_callback=None, data=None):
346 |
347 | if data is None:
348 | data = self.load_data()
349 |
350 | train_dat, val_dat = data
351 |
352 | fingerprints = FingerprintDataset(self._args['fingerprints'])
353 |
354 | train_smiles = train_dat.get_valid_smiles()
355 | val_smiles = val_dat.get_valid_smiles()
356 | all_smiles = list(set(train_smiles) | set(val_smiles))
357 |
358 | train_fingerprints = fingerprints.for_smiles(train_smiles).cuda()
359 | val_fingerprints = fingerprints.for_smiles(val_smiles).cuda()
360 | all_fingerprints = fingerprints.for_smiles(all_smiles).cuda()
361 |
362 | # fingerprint metrics
363 | print('[*] Train smiles: %d' % len(train_smiles))
364 | print('[*] Val smiles: %d' % len(val_smiles))
365 | print('[*] All smiles: %d' % len(all_smiles))
366 |
367 | print('[*] Train smiles: %d' % train_fingerprints.shape[0])
368 | print('[*] Val smiles: %d' % val_fingerprints.shape[0])
369 | print('[*] All smiles: %d' % all_fingerprints.shape[0])
370 |
371 | # memory optimization, drop some unnecessary columns
372 | train_dat._dat.frag['frag_mass'] = None
373 | train_dat._dat.frag['frag_dist'] = None
374 | train_dat._dat.frag['frag_lig_smi'] = None
375 | train_dat._dat.frag['frag_lig_idx'] = None
376 |
377 | print('[*] Training...', flush=True)
378 | opt = torch.optim.Adam(
379 | self._models['voxel'].parameters(), lr=self._args['learning_rate'])
380 | steps_per_epoch = len(train_dat) // self._args['batch_size']
381 | steps_per_epoch = custom_steps if custom_steps is not None else steps_per_epoch
382 |
383 | # configure metrics
384 | dist_fn = DIST_FN[self._args['dist_fn']]
385 |
386 | loss_fingerprints = train_fingerprints
387 | if self._args['use_all_labels']:
388 | loss_fingerprints = all_fingerprints
389 |
390 | loss_fn = LOSS_TYPE[self._args['loss']](loss_fingerprints, dist_fn)
391 |
392 | train_metrics = MetricTracker('train', {
393 | 'all': top_k_acc(all_fingerprints, dist_fn, [1,8,64], pre='all')
394 | })
395 | val_metrics = MetricTracker('val', {
396 | 'all': top_k_acc(all_fingerprints, dist_fn, [1,8,64], pre='all'),
397 | # 'val': top_k_acc(val_fingerprints, dist_fn, [1,5,10,50,100], pre='val'),
398 | })
399 |
400 | best_loss = None
401 |
402 | for epoch in range(self._args['num_epochs']):
403 | self._models['voxel'].train()
404 | train_pbar = tqdm.tqdm(
405 | range(steps_per_epoch),
406 | desc='Train (epoch %d)' % epoch
407 | )
408 | for step in train_pbar:
409 | torch_grid, examples = get_batch(
410 | train_dat,
411 | batch_size=self._args['batch_size'],
412 | batch_set=None,
413 | width=self._args['grid_width'],
414 | res=self._args['grid_res'],
415 | ignore_receptor=self._args['ignore_receptor'],
416 | ignore_parent=self._args['ignore_parent'],
417 | point_radius=self._args['point_radius'],
418 | point_type=self._args['point_type'],
419 | acc_type=self._args['acc_type']
420 | )
421 |
422 | smiles = [example['smiles'] for example in examples]
423 | correct_fp = torch.Tensor(
424 | fingerprints.for_smiles(smiles)).cuda()
425 |
426 | predicted_fp = self._models['voxel'](torch_grid)
427 |
428 | loss = loss_fn(predicted_fp, correct_fp)
429 |
430 | opt.zero_grad()
431 | loss.backward()
432 | opt.step()
433 |
434 | train_metrics.update('loss', loss)
435 | train_metrics.evaluate(predicted_fp, correct_fp)
436 | self._log.log(train_metrics.get_all())
437 | train_metrics.clear()
438 |
439 | self._models['voxel'].eval()
440 |
441 | val_pbar = tqdm.tqdm(
442 | range(self._args['test_steps']),
443 | desc='Val %d' % epoch
444 | )
445 | with torch.no_grad():
446 | for step in val_pbar:
447 | torch_grid, examples = get_batch(
448 | val_dat,
449 | batch_size=self._args['batch_size'],
450 | batch_set=None,
451 | width=self._args['grid_width'],
452 | res=self._args['grid_res'],
453 | ignore_receptor=self._args['ignore_receptor'],
454 | ignore_parent=self._args['ignore_parent'],
455 | point_radius=self._args['point_radius'],
456 | point_type=self._args['point_type'],
457 | acc_type=self._args['acc_type']
458 | )
459 |
460 | smiles = [example['smiles'] for example in examples]
461 | correct_fp = torch.Tensor(
462 | fingerprints.for_smiles(smiles)).cuda()
463 |
464 | predicted_fp = self._models['voxel'](torch_grid)
465 |
466 | loss = loss_fn(predicted_fp, correct_fp)
467 |
468 | val_metrics.update('loss', loss)
469 | val_metrics.evaluate(predicted_fp, correct_fp)
470 |
471 | if checkpoint_callback:
472 | checkpoint_callback(self, epoch)
473 |
474 | val_metrics.normalize(self._args['test_steps'])
475 | self._log.log(val_metrics.get_all())
476 |
477 | val_loss = val_metrics.get('loss')
478 | if best_loss is None or val_loss < best_loss:
479 | # save new best model
480 | best_loss = val_loss
481 | print('[*] New best loss: %f' % best_loss, flush=True)
482 | if save_path:
483 | self.save(save_path)
484 |
485 | val_metrics.clear()
486 |
487 | def run_test(self, save_path, use_val=False):
488 | # load test dataset
489 | test_dat = FragmentDataset(
490 | self._args['fragments'],
491 | rec_typer=REC_TYPER[self._args['rec_typer']],
492 | lig_typer=LIG_TYPER[self._args['lig_typer']],
493 | # filter_rec=partitions.TEST,
494 | filter_rec=set(get_bios(moad_partitions.VAL if use_val else moad_partitions.TEST)),
495 | filter_smi=set(moad_partitions.VAL_SMI if use_val else moad_partitions.TEST_SMI),
496 | fdist_min=self._args['fdist_min'],
497 | fdist_max=self._args['fdist_max'],
498 | fmass_min=self._args['fmass_min'],
499 | fmass_max=self._args['fmass_max'],
500 | verbose=True
501 | )
502 |
503 | fingerprints = FingerprintDataset(self._args['fingerprints'])
504 |
505 | self._models['voxel'].eval()
506 |
507 | predicted_fp = np.zeros((
508 | len(test_dat),
509 | self._args['samples_per_example'],
510 | self._args['output_size']))
511 |
512 | smiles = [test_dat[i]['smiles'] for i in range(len(test_dat))]
513 | correct_fp = fingerprints.for_smiles(smiles).numpy()
514 |
515 | # (example_idx, sample_idx)
516 | queries = []
517 | for i in range(len(test_dat)):
518 | queries += [(i,x) for x in range(self._args['samples_per_example'])]
519 |
520 | # run inference
521 | pbar = tqdm.tqdm(
522 | range(0, len(queries), self._args['batch_size']), desc='Inference')
523 | for i in pbar:
524 | batch = queries[i:i+self._args['batch_size']]
525 |
526 | torch_grid, examples = get_batch(
527 | test_dat,
528 | batch_size=self._args['batch_size'],
529 | batch_set=[x[0] for x in batch],
530 | width=self._args['grid_width'],
531 | res=self._args['grid_res'],
532 | ignore_receptor=self._args['ignore_receptor'],
533 | ignore_parent=self._args['ignore_parent'],
534 | point_radius=self._args['point_radius'],
535 | point_type=self._args['point_type'],
536 | acc_type=self._args['acc_type']
537 | )
538 |
539 | predicted = self._models['voxel'](torch_grid)
540 |
541 | for j in range(len(batch)):
542 | example_idx, sample_idx = batch[j]
543 | predicted_fp[example_idx][sample_idx] = predicted[j].detach().cpu().numpy()
544 |
545 | if use_val:
546 | np.save(os.path.join(save_path, 'val_predicted_fp.npy'), predicted_fp)
547 | np.save(os.path.join(save_path, 'val_correct_fp.npy'), correct_fp)
548 | else:
549 | np.save(os.path.join(save_path, 'predicted_fp.npy'), predicted_fp)
550 | np.save(os.path.join(save_path, 'correct_fp.npy'), correct_fp)
551 |
552 | print('done.')
553 |
554 | def predict(self, batch):
555 | with torch.no_grad():
556 | pred = self._models['voxel'](batch)
557 | return pred
558 |
559 |
560 | MODELS = {
561 | 'voxelnet': VoxelNet
562 | }
563 |
--------------------------------------------------------------------------------
/leadopt/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
--------------------------------------------------------------------------------
/leadopt/models/backport.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | from torch.nn import Module
17 | from typing import Tuple, Union
18 | from torch import Tensor
19 |
20 | '''
21 | On some older versions of Cuda, we may need to use an early version of PyTorch
22 | that doesn't have the Flatten layer builtin.
23 | '''
24 |
25 | class Flatten(Module):
26 | __constants__ = ['start_dim', 'end_dim']
27 | start_dim: int
28 | end_dim: int
29 |
30 | def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
31 | super(Flatten, self).__init__()
32 | self.start_dim = start_dim
33 | self.end_dim = end_dim
34 |
35 | def forward(self, input: Tensor) -> Tensor:
36 | return input.flatten(self.start_dim, self.end_dim)
37 |
38 | def extra_repr(self) -> str:
39 | return 'start_dim={}, end_dim={}'.format(
40 | self.start_dim, self.end_dim
41 | )
42 |
--------------------------------------------------------------------------------
/leadopt/models/voxel.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 |
20 | Flatten = None
21 | try:
22 | Flatten = nn.Flatten
23 | except:
24 | from . import backport
25 | Flatten = backport.Flatten
26 |
27 | class VoxelFingerprintNet(nn.Module):
28 | def __init__(self, in_channels, output_size, blocks=[32,64], fc=[2048], pad=True):
29 | super(VoxelFingerprintNet, self).__init__()
30 |
31 | blocks = list(blocks)
32 | fc = list(fc)
33 |
34 | self.blocks = nn.ModuleList()
35 | prev = in_channels
36 | for i in range(len(blocks)):
37 | b = blocks[i]
38 | parts = []
39 | parts += [
40 | nn.BatchNorm3d(prev),
41 | nn.Conv3d(prev, b, (3,3,3), padding=(1 if pad else 0)),
42 | nn.ReLU(),
43 | nn.Conv3d(b, b, (3,3,3), padding=(1 if pad else 0)),
44 | nn.ReLU(),
45 | nn.Conv3d(b, b, (3,3,3), padding=(1 if pad else 0)),
46 | nn.ReLU()
47 | ]
48 | if i != len(blocks)-1:
49 | parts += [
50 | nn.MaxPool3d((2,2,2))
51 | ]
52 |
53 | self.blocks.append(nn.Sequential(*parts))
54 | prev = b
55 |
56 | self.reduce = nn.Sequential(
57 | nn.AdaptiveAvgPool3d((1,1,1)),
58 | Flatten(),
59 | )
60 |
61 | pred = []
62 | prev = blocks[-1]
63 | for f in fc + [output_size]:
64 | pred += [
65 | nn.ReLU(),
66 | nn.Dropout(),
67 | nn.Linear(prev, f)
68 | ]
69 | prev = f
70 |
71 | self.pred = nn.Sequential(*pred)
72 | self.norm = nn.Sigmoid()
73 |
74 | def forward(self, x):
75 | for b in self.blocks:
76 | x = b(x)
77 | x = self.reduce(x)
78 | x = self.pred(x)
79 | x = self.norm(x)
80 |
81 | return x
82 |
--------------------------------------------------------------------------------
/leadopt/util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | '''
17 | rdkit/openbabel utility scripts
18 | '''
19 | import numpy as np
20 |
21 |
22 | # try:
23 | # import pybel
24 | # except:
25 | # from openbabel import pybel
26 |
27 | from rdkit import Chem
28 |
29 |
30 | def get_coords(mol):
31 | """Returns an array of atom coordinates from an rdkit mol."""
32 | conf = mol.GetConformer()
33 | coords = np.array([conf.GetAtomPosition(i) for i in range(conf.GetNumAtoms())])
34 | return coords
35 |
36 |
37 | def get_types(mol):
38 | """Returns an array of atomic numbers from an rdkit mol."""
39 | return [mol.GetAtomWithIdx(i).GetAtomicNum() for i in range(mol.GetNumAtoms())]
40 |
41 |
42 | def combine_all(frags):
43 | """Combines a list of rdkit mols."""
44 | if len(frags) == 0:
45 | return None
46 |
47 | c = frags[0]
48 | for f in frags[1:]:
49 | c = Chem.CombineMols(c,f)
50 |
51 | return c
52 |
53 |
54 | def generate_fragments(mol, max_heavy_atoms=0, only_single_bonds=True):
55 | """Takes an rdkit molecule and returns a list of (parent, fragment) tuples.
56 |
57 | Args:
58 | mol: The molecule to fragment.
59 | max_heavy_atoms: The maximum number of heavy atoms to include
60 | in generated fragments.
61 | nly_single_bonds: If set to true, this method will only return
62 | fragments generated by breaking single bonds.
63 |
64 | Returns:
65 | A list of (parent, fragment) tuples where mol is larger than fragment.
66 | """
67 | # list of (parent, fragment) tuples
68 | splits = []
69 |
70 | # if we have multiple ligands already, split into pieces and then iterate
71 | ligands = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
72 |
73 | for i in range(len(ligands)):
74 | lig = ligands[i]
75 | other = list(ligands[:i] + ligands[i+1:])
76 |
77 | # iterate over bonds
78 | for i in range(lig.GetNumBonds()):
79 | # (optional) filter base on single bonds
80 | if only_single_bonds and lig.GetBondWithIdx(i).GetBondType() != Chem.rdchem.BondType.SINGLE:
81 | continue
82 |
83 | # split the molecule
84 | split_mol = Chem.rdmolops.FragmentOnBonds(lig, [i])
85 |
86 | # obtain fragments
87 | fragments = Chem.GetMolFrags(split_mol, asMols=True, sanitizeFrags=False)
88 |
89 | # skip if this did not break the molecule into two pieces
90 | if len(fragments) != 2:
91 | continue
92 |
93 | # otherwise make sure the first fragment is larger
94 | if fragments[0].GetNumAtoms() < fragments[1].GetNumAtoms():
95 | fragments = fragments[::-1]
96 |
97 | # make sure the fragment has at least one heavy atom
98 | if fragments[1].GetNumHeavyAtoms() == 0:
99 | continue
100 |
101 | # (optional) filter based on number of heavy atoms in the fragment
102 | if max_heavy_atoms > 0 and fragments[1].GetNumHeavyAtoms() > max_heavy_atoms:
103 | continue
104 |
105 | # if we have other ligands present, merge them with the parent
106 | parent = fragments[0]
107 |
108 | if len(other) > 0:
109 | parent = combine_all([parent] + other)
110 |
111 | # add this pair
112 | splits.append((parent, fragments[1]))
113 |
114 | return splits
115 |
116 |
117 | def load_ligand(sdf):
118 | """Loads a ligand from an sdf file and fragments it.
119 |
120 | Args:
121 | sdf: Path to sdf file containing a ligand.
122 | """
123 | lig = next(Chem.SDMolSupplier(sdf, sanitize=False))
124 | frags = generate_fragments(lig)
125 |
126 | return lig, frags
127 |
128 |
129 | def load_ligands_pdb(pdb):
130 | """Load multiple ligands from a pdb file.
131 |
132 | Args:
133 | pdb: Path to pdb file containing a ligand.
134 | """
135 | lig_mult = Chem.MolFromPDBFile(pdb)
136 | ligands = Chem.GetMolFrags(lig_mult, asMols=True, sanitizeFrags=True)
137 |
138 | return ligands
139 |
140 |
141 | def remove_water(m):
142 | """Removes water molecules from an rdkit mol."""
143 | parts = Chem.GetMolFrags(m, asMols=True, sanitizeFrags=False)
144 | valid = [k for k in parts if not Chem.MolToSmiles(k, allHsExplicit=True) == '[OH2]']
145 |
146 | assert len(valid) > 0, 'error: molecule contains only water'
147 |
148 | merged = valid[0]
149 | for part in valid[1:]:
150 | merged = Chem.CombineMols(merged, part)
151 |
152 | return merged
153 |
154 |
155 | def load_receptor(rec_path):
156 | """Loads a receptor from a pdb file and retrieves atomic information.
157 |
158 | Args:
159 | rec_path: Path to a pdb file.
160 | """
161 | rec = Chem.MolFromPDBFile(rec_path, sanitize=False)
162 | rec = remove_water(rec)
163 |
164 | return rec
165 |
166 |
167 | # def load_receptor_ob(rec_path):
168 | # rec = next(pybel.readfile('pdb', rec_path))
169 | # valid = [r for r in rec.residues if r.name != 'HOH']
170 |
171 | # # map partial charge into byte range
172 | # def conv_charge(x):
173 | # x = max(x,-0.5)
174 | # x = min(x,0.5)
175 | # x += 0.5
176 | # x *= 255
177 | # x = int(x)
178 | # return x
179 |
180 | # coords = []
181 | # types = []
182 | # for v in valid:
183 | # coords += [k.coords for k in v.atoms]
184 | # types += [(
185 | # k.atomicnum,
186 | # int(k.OBAtom.IsAromatic()),
187 | # int(k.OBAtom.IsHbondDonor()),
188 | # int(k.OBAtom.IsHbondAcceptor()),
189 | # conv_charge(k.OBAtom.GetPartialCharge())
190 | # ) for k in v.atoms]
191 |
192 | # return np.array(coords), np.array(types)
193 |
194 |
195 | def load_receptor_ob(rec_path):
196 | rec = load_receptor(rec_path)
197 |
198 | coords = get_coords(rec)
199 | types = np.array(get_types(rec))
200 | types = np.concatenate([
201 | types.reshape(-1,1),
202 | np.zeros((len(types), 4))
203 | ], 1)
204 |
205 | return coords, types
206 |
207 |
208 | def get_connection_point(frag):
209 | '''return the coordinates of the dummy atom as a numpy array [x,y,z]'''
210 | dummy_idx = get_types(frag).index(0)
211 | coords = get_coords(frag)[dummy_idx]
212 |
213 | return coords
214 |
215 |
216 | def frag_dist_to_receptor(rec, frag):
217 | '''compute the minimum distance between the fragment connection point any receptor atom'''
218 | rec_coords = rec.GetConformer().GetPositions()
219 | conn = get_connection_point(frag)
220 |
221 | dist = np.sum((rec_coords - conn) ** 2, axis=1)
222 | min_dist = np.sqrt(np.min(dist))
223 |
224 | return min_dist
225 |
226 |
227 | def frag_dist_to_receptor_raw(coords, frag):
228 | '''compute the minimum distance between the fragment connection point any receptor atom'''
229 | rec_coords = np.array(coords)
230 | conn = get_connection_point(frag)
231 |
232 | dist = np.sum((rec_coords - conn) ** 2, axis=1)
233 | min_dist = np.sqrt(np.min(dist))
234 |
235 | return min_dist
236 |
237 |
238 | def mol_array(mol):
239 | '''convert an rdkit mol to an array of coordinates and atom types'''
240 | coords = get_coords(mol)
241 | types = np.array(get_types(mol)).reshape(-1,1)
242 |
243 | arr = np.concatenate([coords, types], axis=1)
244 |
245 | return arr
246 |
247 |
248 | def desc_mol_array(mol, atom_fn):
249 | '''user-defined atomic mapping function'''
250 | coords = get_coords(mol)
251 | atoms = list(mol.GetAtoms())
252 | types = np.array([atom_fn(x) for x in atoms]).reshape(-1,1)
253 |
254 | arr = np.concatenate([coords, types], axis=1)
255 |
256 | return arr
257 |
258 |
259 | def desc_mol_array_ob(atoms, atom_fn):
260 | coords = np.array([k[0] for k in atoms])
261 | types = np.array([atom_fn(k[1]) for k in atoms]).reshape(-1,1)
262 |
263 | # arr = np.concatenate([coords, types], axis=1)
264 |
265 | return coords, types
266 |
267 |
268 | def mol_to_points(mol, atom_types=[6,7,8,9,15,16,17,35,53]):
269 | '''convert an rdkit mol to an array of coordinates and layers'''
270 | coords = get_coords(mol)
271 |
272 | types = get_types(mol)
273 | layers = np.array([(atom_types.index(k) if k in atom_types else -1) for k in types])
274 |
275 | # filter by existing layer
276 | coords = coords[layers != -1]
277 | layers = layers[layers != -1].reshape(-1,1)
278 |
279 | return coords, layers
280 |
281 |
282 | def merge_smiles(sma, smb):
283 | '''merge two smile frament strings by combining at the dummy connection point'''
284 | a = Chem.MolFromSmiles(sma, sanitize=False)
285 | b = Chem.MolFromSmiles(smb, sanitize=False)
286 |
287 | # merge molecules
288 | c = Chem.CombineMols(a,b)
289 |
290 | # find dummy atoms
291 | da,db = np.where(np.array([k.GetAtomicNum() for k in c.GetAtoms()]) == 0)[0]
292 |
293 | # find neighbors to connect
294 | na = c.GetAtomWithIdx(int(da)).GetNeighbors()[0].GetIdx()
295 | nb = c.GetAtomWithIdx(int(db)).GetNeighbors()[0].GetIdx()
296 |
297 | e = Chem.EditableMol(c)
298 | for d in sorted([da,db])[::-1]:
299 | e.RemoveAtom(int(d))
300 |
301 | # adjust atom indexes
302 | na -= int(da < na) + int(db < na)
303 | nb -= int(da < nb) + int(db < nb)
304 |
305 | e.AddBond(na,nb,Chem.rdchem.BondType.SINGLE)
306 |
307 | r = e.GetMol()
308 |
309 | sm = Chem.MolToSmiles(Chem.RemoveHs(r, sanitize=False), isomericSmiles=False)
310 |
311 | return sm
312 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | certifi==2020.6.20
2 | chardet==3.0.4
3 | click==7.1.2
4 | configparser==5.0.0
5 | docker-pycreds==0.4.0
6 | future==0.18.2
7 | gitdb==4.0.5
8 | GitPython==3.1.7
9 | gql==0.2.0
10 | graphql-core==1.1
11 | h5py==3.10.0
12 | idna==2.10
13 | llvmlite==0.41.1
14 | numba==0.58.1
15 | numpy==1.24.4
16 | nvidia-ml-py3==7.352.0
17 | pathtools==0.1.2
18 | Pillow==7.2.0
19 | promise==2.3
20 | psutil==5.7.2
21 | python-dateutil==2.8.1
22 | PyYAML==5.3.1
23 | requests==2.24.0
24 | sentry-sdk==0.16.3
25 | shortuuid==1.0.1
26 | six==1.15.0
27 | smmap==3.0.4
28 | subprocess32==3.5.4
29 | tqdm==4.48.2
30 | urllib3==1.25.10
31 | wandb==0.9.4
32 | watchdog==0.10.3
33 | rdkit==2023.09.2
34 |
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Data processing scripts
3 |
4 | ## `make_fingerprints.py`
5 |
6 | Utility script to generate fingerprints for a set of smiles strings. By precomputing the fingerprints for all the fragments in our dataset, we can speed up training.
7 |
8 | To use, pass in the `moad.h5` file and specify the fingerprint type and output path.
9 |
10 | Supported fingerprints:
11 | - `rdk`: RDKFingerprint (2048 bits)
12 | - `rdk10`: RDKFingerprint (path size 10) (2048 bits)
13 | - `morgan`: Mogan fingerprint (r=2) (2048 bits)
14 | - `gobbi2d`: Gobbi 2d pharmophocore fingerprint (folded to 2048 bits)
15 |
16 | Usage:
17 |
18 | ```
19 | usage: make_fingerprints.py [-h] -f FRAGMENTS -fp {rdk,rdk10,morgan,gobbi2d}
20 | [-o OUTPUT]
21 |
22 | optional arguments:
23 | -h, --help show this help message and exit
24 | -f FRAGMENTS, --fragments FRAGMENTS
25 | Path to fragemnts.h5 containing "frag_smiles" array
26 | -fp {rdk,rdk10,morgan,gobbi2d}, --fingerprint {rdk,rdk10,morgan,gobbi2d}
27 | Which fingerprint type to generate
28 | -o OUTPUT, --output OUTPUT
29 | Output file path (.h5)
30 | ```
31 |
32 | ## MOAD Dataset
33 |
34 | For instructions on working with MOAD data, see [`README_MOAD.md`](./README_MOAD.md).
35 |
--------------------------------------------------------------------------------
/scripts/README_MOAD.md:
--------------------------------------------------------------------------------
1 |
2 | # MOAD Data
3 |
4 | This readme describes how to process raw MOAD data for use in training the DeepFrag model. Note that we already provide fully processed datasets for
5 | training so this section mostly serves as a description of how to process future versions of the MOAD dataset or as an example for researchers looking
6 | to accomplish similar things.
7 |
8 | See [`data/README.md`](../data/README.md) for instructions on how to download the packed `moad.h5` data.
9 |
10 | See [`config/moad_partitions.py`](../config/moad_partitions.py) for a pre-computed MOAD TRAIN/VAL/TEST split (generated with seed 7).
11 |
12 | # 1. Download MOAD datasets
13 |
14 | Go to https://bindingmoad.org/Home/download and download `every_part_a.zip` and `every_part_b.zip` and "Binding data" (`every.csv`).
15 |
16 | This readme assumes these files are stored in `$MOAD_DIR`.
17 |
18 | # 2. Unpack MOAD datasets
19 |
20 | Run:
21 |
22 | ```sh
23 | $ unzip every_part_a.zip
24 | ...
25 | $ unzip every_part_b.zip
26 | ...
27 | ```
28 |
29 | # 3. Process MOAD pdb files
30 |
31 | The MOAD dataset contains ligand/receptor structures combined in a single pdb file (named with a `.bio` extension). In this step, we will separate the receptor and each ligand into individual files.
32 |
33 | Run:
34 |
35 | ```sh
36 | $ cd $MOAD_DIR && mkdir split
37 | $ python3 scripts/split_moad.py \
38 | -d $MOAD_DIR/BindingMOAD_2020 \
39 | -c $MOAD_DIR/every.csv \
40 | -o $MOAD_DIR/split \
41 | -n
42 | ```
43 |
44 | # 4. Generate packed data files.
45 |
46 | For training purposes, we pack all of the relevant information into an h5 file so we can load it entirely in memory during training.
47 |
48 | This step will produce several similar `.h5` files that can be combined later.
49 |
50 | Run:
51 |
52 | ```sh
53 | $ cd $MOAD_DIR && mkdir packed
54 | $ python3 scripts/process_moad.py \
55 | -d $MOAD_DIR/split \
56 | -c $MOAD_DIR/every.csv \
57 | -o $MOAD_DIR/packed/moad.h5 \
58 | -n \
59 | -s
60 | ```
61 |
62 | # 5. Merge packed data files.
63 |
64 | ```sh
65 | $ python3 scripts/merge_moad.py \
66 | -i $MOAD_DIR/packed \
67 | -o moad.h5
68 | ```
69 |
70 | # 6. Generate MOAD Training splits
71 |
72 | ```sh
73 | $ python3 scripts/moad_training_splits.py \
74 | -c $MOAD_DIR/every.csv \
75 | -s 7 \
76 | -o moad_partitions.py
77 | ```
78 |
--------------------------------------------------------------------------------
/scripts/make_fingerprints.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | '''
17 | Utility script to generate fingerprints for smiles strings
18 | '''
19 | import argparse
20 |
21 | import h5py
22 | import numpy as np
23 | import tqdm
24 |
25 | from rdkit.Chem import rdMolDescriptors
26 | import rdkit.Chem.AllChem as Chem
27 | from rdkit.Chem.Pharm2D import Generate, Gobbi_Pharm2D
28 |
29 |
30 | def fold_to(bv, size=2048):
31 | '''fold a SparseBitVec to a certain length'''
32 | fp = np.zeros(size)
33 |
34 | for b in list(bv.GetOnBits()):
35 | fp[b % size] = 1
36 |
37 | return fp
38 |
39 |
40 | def rdkfingerprint(m):
41 | '''rdkfingerprint as 2048-len bit array'''
42 | fp = Chem.rdmolops.RDKFingerprint(m)
43 | n_fp = list(map(int, list(fp.ToBitString())))
44 | return n_fp
45 |
46 |
47 | def rdkfingerprint10(m):
48 | '''rdkfingerprint as 2048-len bit array (maxPath=10)'''
49 | fp = Chem.rdmolops.RDKFingerprint(m, maxPath=10)
50 | n_fp = list(map(int, list(fp.ToBitString())))
51 | return n_fp
52 |
53 |
54 | def morganfingerprint(m):
55 | '''morgan fingerprint as 2048-len bit array'''
56 | m.UpdatePropertyCache(strict=False)
57 | Chem.rdmolops.FastFindRings(m)
58 | fp = rdMolDescriptors.GetMorganFingerprintAsBitVect(m, 2)
59 | n_fp = list(map(int, list(fp.ToBitString())))
60 | return n_fp
61 |
62 |
63 | def gobbi2d(m):
64 | '''gobbi 2d pharmacophore as 2048-len bit array'''
65 | m.UpdatePropertyCache(strict=False)
66 | Chem.rdmolops.FastFindRings(m)
67 | bv = Generate.Gen2DFingerprint(m, Gobbi_Pharm2D.factory)
68 | n_fp = fold_to(bv, size=2048)
69 | return n_fp
70 |
71 |
72 | FINGERPRINTS = {
73 | 'rdk': (rdkfingerprint, 2048),
74 | 'rdk10': (rdkfingerprint10, 2048),
75 | 'morgan': (morganfingerprint, 2048),
76 | 'gobbi2d': (gobbi2d, 2048),
77 | }
78 |
79 |
80 | def process(fragments_path, fp_func, fp_size, out_path):
81 | # open fragments file
82 | f = h5py.File(fragments_path, 'r')
83 | smiles = f['frag_smiles'][()]
84 | f.close()
85 |
86 | # deduplicate smiles strings
87 | all_smiles = list(set(smiles))
88 | n_smiles = np.array(all_smiles)
89 |
90 | n_fingerprints = np.zeros((len(all_smiles), fp_size))
91 |
92 | for i in tqdm.tqdm(range(len(all_smiles))):
93 | # generate fingerprint
94 | m = Chem.MolFromSmiles(all_smiles[i].decode('ascii'), sanitize=False)
95 | n_fingerprints[i] = fp_func(m)
96 |
97 | # save
98 | with h5py.File(out_path, 'w') as f:
99 | f['fingerprints'] = n_fingerprints
100 | f['smiles'] = n_smiles
101 |
102 | print('Done!')
103 |
104 |
105 | def main():
106 | parser = argparse.ArgumentParser()
107 |
108 | parser.add_argument('-f', '--fragments', required=True, help='Path to fragemnts.h5 containing "frag_smiles" array')
109 | parser.add_argument('-fp', '--fingerprint', required=True, choices=[k for k in FINGERPRINTS], help='Which fingerprint type to generate')
110 | parser.add_argument('-o', '--output', default='fingerprints.h5', help='Output file path (.h5)')
111 |
112 | args = parser.parse_args()
113 |
114 | fn, size = FINGERPRINTS[args.fingerprint]
115 | process(args.fragments, fn, size, out_path=args.output)
116 |
117 |
118 | if __name__=='__main__':
119 | main()
120 |
--------------------------------------------------------------------------------
/scripts/merge_moad.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | import argparse
17 | import os
18 |
19 | import h5py
20 | import numpy as np
21 |
22 |
23 | def unpack(path):
24 | f = h5py.File(path, 'r')
25 |
26 | dat = {}
27 | for k in f.keys():
28 | dat[k] = f[k][()]
29 |
30 | f.close()
31 |
32 | return dat
33 |
34 | def append(dat, other):
35 | frag_coord_off = dat['frag_data'].shape[0]
36 | frag_lig_smi_off = dat['frag_lig_smi'].shape[0]
37 | rec_coord_off = dat['rec_coords'].shape[0]
38 |
39 | # Update fragment coords.
40 | other['frag_lookup']['f1'] += frag_coord_off
41 | other['frag_lookup']['f2'] += frag_coord_off
42 | other['frag_lookup']['f3'] += frag_coord_off
43 | other['frag_lookup']['f4'] += frag_coord_off
44 |
45 | # Update receptor coords.
46 | other['rec_lookup']['f1'] += rec_coord_off
47 | other['rec_lookup']['f2'] += rec_coord_off
48 |
49 | # Update ligand index.
50 | other['frag_lig_idx'] += frag_lig_smi_off
51 |
52 | # Concatenate everything.
53 | for k in dat:
54 | dat[k] = np.concatenate((dat[k], other[k]), axis=0)
55 |
56 | def cat_all(paths):
57 | dat = unpack(paths[0])
58 |
59 | for i in range(1, len(paths)):
60 | append(dat, unpack(paths[i]))
61 |
62 | return dat
63 |
64 | def main():
65 | parser = argparse.ArgumentParser()
66 |
67 | parser.add_argument('-i', '--input', required=True, help='Path to folder containing intermediate .h5 fragments')
68 | parser.add_argument('-o', '--output', default='moad.h5', help='Output file path (.h5)')
69 |
70 | args = parser.parse_args()
71 |
72 | inp = args.input
73 | paths = [x for x in os.listdir(inp) if x.endswith('.h5')]
74 | paths = [os.path.join(inp, x) for x in paths]
75 |
76 | print('Merging:')
77 | for k in paths:
78 | print('- %s' % k)
79 |
80 | full = cat_all(paths)
81 |
82 | f = h5py.File(args.output, 'w')
83 | for k in dat:
84 | f[k] = dat[k]
85 | f.close()
86 |
87 | print('Done!')
88 |
89 | if __name__=='__main__':
90 | main()
91 |
--------------------------------------------------------------------------------
/scripts/moad_training_splits.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | import argparse
17 | import os
18 |
19 | import numpy as np
20 |
21 | from moad_util import parse_moad
22 |
23 |
24 | def all_smi(families):
25 | smi = []
26 | for f in families:
27 | for t in f.targets:
28 | for lig in t.ligands:
29 | smi.append(lig[1])
30 | return set(smi)
31 |
32 | def do_split(families, pa=0.6):
33 | fa = []
34 | fb = []
35 |
36 | for f in families:
37 | if np.random.rand() < pa:
38 | fa.append(f)
39 | else:
40 | fb.append(f)
41 |
42 | return (fa, fb)
43 |
44 | def split_smi(smi):
45 | l = list(smi)
46 | sz = len(l)
47 |
48 | np.random.shuffle(l)
49 |
50 | return (set(l[:sz//2]), set(l[sz//2:]))
51 |
52 | def split_smi3(smi):
53 | l = list(smi)
54 | sz = len(l)
55 |
56 | np.random.shuffle(l)
57 |
58 | v = sz//3
59 | return (set(l[:v]), set(l[v:v*2]), set(l[v*2:]))
60 |
61 | def get_ids(fam):
62 | ids = []
63 | for f in fam:
64 | for t in f.targets:
65 | ids.append(t.pdb_id)
66 | return ids
67 |
68 | def gen_split(csv, output, seed=7):
69 | np.random.seed(seed)
70 |
71 | moad_families, moad_targets = parse_moad(csv)
72 |
73 | train, other = do_split(moad_families, pa=0.6)
74 | val, test = do_split(other, pa=0.5)
75 |
76 | train_sum = np.sum([len(x.targets) for x in train])
77 | val_sum = np.sum([len(x.targets) for x in val])
78 | test_sum = np.sum([len(x.targets) for x in test])
79 |
80 | print('[Targets] (Train: %d) (Val: %d) (Test %d)' % (train_sum, val_sum, test_sum))
81 |
82 | train_smi = all_smi(train)
83 | val_smi = all_smi(val)
84 | test_smi = all_smi(test)
85 |
86 | train_smi_uniq = train_smi - (val_smi | test_smi)
87 | val_smi_uniq = val_smi - (train_smi | test_smi)
88 | test_smi_uniq = test_smi - (val_smi | train_smi)
89 |
90 | print('[Unique ligands] (Train: %d) (Val: %d) (Test %d)' % (
91 | len(train_smi_uniq), len(val_smi_uniq), len(test_smi_uniq)))
92 |
93 | print('[Total unique ligands] %d' % len(train_smi | val_smi | test_smi))
94 |
95 | split_train_val = (train_smi & val_smi) - test_smi
96 | split_train_test = (train_smi & test_smi) - val_smi
97 | split_val_test = (val_smi & test_smi) - train_smi
98 |
99 | split_all = (train_smi & val_smi & test_smi)
100 |
101 | split_train_val_a, split_train_val_b = split_smi(split_train_val)
102 | split_train_test_a, split_train_test_b = split_smi(split_train_test)
103 | split_val_test_a, split_val_test_b = split_smi(split_val_test)
104 |
105 | split_all_train, split_all_val, split_all_test = split_smi3(split_all)
106 |
107 | train_full = (train_smi_uniq | split_train_val_a | split_train_test_a | split_all_train)
108 | val_full = (val_smi_uniq | split_train_val_b | split_val_test_a | split_all_val)
109 | test_full = (test_smi_uniq | split_train_test_b | split_val_test_b | split_all_test)
110 |
111 | print('[Full ligands] (Train: %d) (Val: %d) (Test %d)' % (
112 | len(train_full), len(val_full), len(test_full)))
113 |
114 | mixed = (train_full & val_full) | (val_full & test_full) | (train_full & test_full)
115 |
116 | train_ids = sorted(get_ids(train))
117 | val_ids = sorted(get_ids(val))
118 | test_ids = sorted(get_ids(test))
119 |
120 | train_s = sorted(train_full)
121 | val_s = sorted(val_full)
122 | test_s = sorted(test_full)
123 |
124 | # Format as a python file.
125 | out = ''
126 | out += 'TRAIN = ' + repr(train_ids).replace(' ','') + '\n'
127 | out += 'TRAIN_SMI = ' + repr(train_s).replace(' ','') + '\n'
128 | out += 'VAL = ' + repr(val_ids).replace(' ','') + '\n'
129 | out += 'VAL_SMI = ' + repr(val_s).replace(' ','') + '\n'
130 | out += 'TEST = ' + repr(test_ids).replace(' ','') + '\n'
131 | out += 'TEST_SMI = ' + repr(test_s).replace(' ','') + '\n'
132 |
133 | open(output, 'w').write(out)
134 |
135 | def main():
136 | parser = argparse.ArgumentParser()
137 |
138 | parser.add_argument('-c', '--csv', required=True, help='Path to every.csv file')
139 | parser.add_argument('-s', '--seed', required=False, default=7, type=int, help='Integer seed')
140 | parser.add_argument('-o', '--output', default='moad_partitions.py', help='Output file path (.py)')
141 |
142 | args = parser.parse_args()
143 |
144 | gen_split(args.csv, args.output, args.seed)
145 | print('Done!')
146 |
147 | if __name__=='__main__':
148 | main()
149 |
--------------------------------------------------------------------------------
/scripts/moad_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | # MOAD csv parsing utility
17 |
18 | class Family(object):
19 | def __init__(self):
20 | self.targets = []
21 |
22 | def __repr__(self):
23 | return 'F(%d)' % len(self.targets)
24 |
25 | class Protein(object):
26 | def __init__(self, pdb_id):
27 | # (chain, smi)
28 | self.pdb_id = pdb_id.upper()
29 | self.ligands = []
30 |
31 | def __repr__(self):
32 | return '%s(%d)' % (self.pdb_id, len(self.ligands))
33 |
34 | def parse_moad(csv):
35 | csv_dat = open(csv, 'r').read().strip().split('\n')
36 | csv_dat = [x.split(',') for x in csv_dat]
37 |
38 | families = []
39 |
40 | curr_f = None
41 | curr_t = None
42 |
43 | for line in csv_dat:
44 | if line[0] != '':
45 | # new class
46 | continue
47 | elif line[1] != '':
48 | # new family
49 | if curr_t != None:
50 | curr_f.targets.append(curr_t)
51 | if curr_f != None:
52 | families.append(curr_f)
53 | curr_f = Family()
54 | curr_t = Protein(line[2])
55 | elif line[2] != '':
56 | # new target
57 | if curr_t != None:
58 | curr_f.targets.append(curr_t)
59 | curr_t = Protein(line[2])
60 | elif line[3] != '':
61 | # new ligand
62 | if line[4] != 'valid':
63 | continue
64 | curr_t.ligands.append((line[3], line[9]))
65 |
66 | curr_f.targets.append(curr_t)
67 | families.append(curr_f)
68 |
69 | by_target = {}
70 | for f in families:
71 | for t in f.targets:
72 | by_target[t.pdb_id] = t
73 |
74 | return families, by_target
75 |
--------------------------------------------------------------------------------
/scripts/process_moad.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | '''
17 | Utility script to convert the MOAD dataset into a packed format
18 | '''
19 | import sys
20 | import argparse
21 | import os
22 | import re
23 | import multiprocessing
24 | import threading
25 |
26 | from moad_util import parse_moad
27 |
28 | import rdkit.Chem.AllChem as Chem
29 | from rdkit.Chem.Descriptors import ExactMolWt
30 | import molvs
31 | import h5py
32 | import tqdm
33 | import numpy as np
34 |
35 | # add leadopt to path
36 | sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir))
37 |
38 | from leadopt import util
39 |
40 |
41 | # Data Format
42 | #
43 | # Receptor data:
44 | # - rec_lookup: [id][start][end]
45 | # - rec_coords: [x][y][z]
46 | # - rec_types: [num][is_hacc][is_hdon][is_aro][pcharge]
47 | #
48 | # Fragment data:
49 | # - frag_lookup: [id][fstart][fend][pstart][pend]
50 | # - frag_lig_id: [lig_id]
51 | # - frag_coords: [x][y][z]
52 | # - frag_types: [num]
53 | # - frag_smiles: [smiles]
54 | # - frag_mass: [mass]
55 | # - frag_dist: [dist]
56 |
57 | LOAD_TIMEOUT = 60
58 |
59 |
60 | u = molvs.charge.Uncharger()
61 | t = molvs.tautomer.TautomerCanonicalizer()
62 |
63 |
64 | REPLACE = [
65 | ('C(=O)[O-]', 'C(=O)O'),
66 | ('N=[N+]=N', 'N=[N+]=[N-]'),
67 | ('[N+](=O)O', '[N+](=O)[O-]'),
68 | ('S(O)(O)O', '[S+2](O)(O)O'),
69 | ]
70 |
71 |
72 | def basic_replace(sm):
73 | m = Chem.MolFromSmiles(sm, False)
74 |
75 | for a,b in REPLACE:
76 | m = Chem.ReplaceSubstructs(
77 | m,
78 | Chem.MolFromSmiles(a, sanitize=False),
79 | Chem.MolFromSmiles(b, sanitize=False),
80 | replaceAll=True
81 | )[0]
82 |
83 | return Chem.MolToSmiles(m)
84 |
85 |
86 | def neutralize_smiles(sm):
87 | m = Chem.MolFromSmiles(sm)
88 | m = u.uncharge(m)
89 | m = t.canonicalize(m)
90 | sm = Chem.MolToSmiles(m)
91 | sm = basic_replace(sm)
92 |
93 | try:
94 | return molvs.standardize_smiles(sm)
95 | except:
96 | print(sm)
97 | return sm
98 |
99 |
100 | def load_example(base, rec_id, target):
101 |
102 | rec_path = os.path.join(base, '%s_rec.pdb' % rec_id)
103 |
104 | # Load receptor data.
105 | rec_coords, rec_types = util.load_receptor_ob(rec_path)
106 |
107 | # (frag_data, parent_data, smiles, mass, dist, lig_off)
108 | fragments = []
109 |
110 | # (smi)
111 | lig_smiles = []
112 |
113 |
114 | lig_off = 0
115 | for lig in target.ligands:
116 |
117 | lig_path = os.path.join(base, '%s_%s.pdb' % (rec_id, lig[0].replace(' ','_')))
118 | try:
119 | lig_mol = Chem.MolFromPDBFile(lig_path, True)
120 | except:
121 | continue
122 |
123 | if lig_mol is None:
124 | continue
125 |
126 | lig_smi = lig[1]
127 | lig_smiles.append(lig_smi)
128 |
129 | ref = Chem.MolFromSmiles(lig_smi)
130 | lig_fixed = Chem.AssignBondOrdersFromTemplate(ref, lig_mol)
131 |
132 | splits = util.generate_fragments(lig_fixed)
133 |
134 | for parent, frag in splits:
135 | frag_data = util.mol_array(frag)
136 | parent_data = util.mol_array(parent)
137 |
138 | frag_smi = Chem.MolToSmiles(
139 | frag,
140 | isomericSmiles=False,
141 | kekuleSmiles=False,
142 | canonical=True,
143 | allHsExplicit=False
144 | )
145 |
146 | frag_smi = neutralize_smiles(frag_smi)
147 |
148 | frag.UpdatePropertyCache(strict=False)
149 | mass = ExactMolWt(frag)
150 |
151 | dist = util.frag_dist_to_receptor_raw(rec_coords, frag)
152 |
153 | fragments.append((frag_data, parent_data, frag_smi, mass, dist, lig_off))
154 |
155 | lig_off += 1
156 |
157 | return (rec_coords, rec_types, fragments, lig_smiles)
158 |
159 | def do_thread(out, args):
160 | try:
161 | out[0] = load_example(*args)
162 | except:
163 | out[0] = None
164 |
165 | def multi_load(packed):
166 | out = [None]
167 |
168 | t = threading.Thread(target=do_thread, args=(out, packed))
169 | t.start()
170 | t.join(timeout=LOAD_TIMEOUT)
171 |
172 | if t.is_alive():
173 | print('timeout', packed[1])
174 |
175 | return (packed, out[0])
176 |
177 |
178 | def process(work, processed, moad_csv, out_path='moad.h5', num_cores=1):
179 | '''Process MOAD data and save to a packed format.
180 |
181 | Args:
182 | - out_path: where to save the .h5 packed data
183 | '''
184 | rec_lookup = [] # (id, start, end)
185 | rec_coords = [] # (x,y,z)
186 | rec_types = [] # (num, aro, hdon, hacc, pcharge)
187 |
188 | frag_lookup = [] # (id, f_start, f_end, p_start, p_end)
189 | frag_lig_idx = [] # (lig_idx)
190 | frag_lig_smi = [] # (lig_smi)
191 | frag_data = [] # (x,y,z,type)
192 | frag_smiles = [] # (frag_smi)
193 | frag_mass = [] # (mass)
194 | frag_dist = [] # (dist)
195 |
196 | # Data pointers.
197 | rec_i = 0
198 | frag_i = 0
199 |
200 | # Multiprocess.
201 | with multiprocessing.Pool(num_cores) as p:
202 | with tqdm.tqdm(total=len(work)) as pbar:
203 | for w, res in p.imap_unordered(multi_load, work):
204 | pbar.update()
205 |
206 | if res == None:
207 | print('[!] Failed: %s' % w[1])
208 | continue
209 |
210 | rcoords, rtypes, fragments, ex_lig_smiles = res
211 |
212 | if len(fragments) == 0:
213 | print('Empty', w[1])
214 | continue
215 |
216 | rec_id = w[1]
217 |
218 | # Add receptor info.
219 | rec_start = rec_i
220 | rec_end = rec_i + rcoords.shape[0]
221 | rec_i += rcoords.shape[0]
222 |
223 | rec_coords.append(rcoords)
224 | rec_types.append(rtypes)
225 | rec_lookup.append((rec_id.encode('ascii'), rec_start, rec_end))
226 |
227 | lig_idx = len(frag_lig_smi)
228 |
229 | # Add fragment info.
230 | for fdat, pdat, frag_smi, mass, dist, lig_off in fragments:
231 | frag_start = frag_i
232 | frag_end = frag_i + fdat.shape[0]
233 | frag_i += fdat.shape[0]
234 |
235 | parent_start = frag_i
236 | parent_end = frag_i + pdat.shape[0]
237 | frag_i += pdat.shape[0]
238 |
239 | frag_data.append(fdat)
240 | frag_data.append(pdat)
241 |
242 | frag_lookup.append((rec_id.encode('ascii'), frag_start, frag_end, parent_start, parent_end))
243 | frag_lig_idx.append(lig_idx+lig_off)
244 | frag_smiles.append(frag_smi)
245 | frag_mass.append(mass)
246 | frag_dist.append(dist)
247 |
248 | # Add ligand smiles.
249 | frag_lig_smi += ex_lig_smiles
250 |
251 | # Convert to numpy format.
252 | print('Convert numpy...', flush=True)
253 | n_rec_lookup = np.array(rec_lookup, dtype=' 3 and lig_resnum == '1':
41 | # assume peptide, take the whole chain
42 | sel = m.select('chain %s' % lig_chain)
43 | else:
44 | # assume small molecule, single residue
45 | sel = m.select('chain %s and resnum = %s' % (lig_chain, lig_resnum))
46 |
47 | ligands.append((lig[0], sel))
48 |
49 | return rec, ligands
50 |
51 |
52 | def do_proc(packed):
53 | out_dir, rec_name, path, target = packed
54 |
55 | try:
56 | rec, ligands = load_example(path, target)
57 |
58 | # Save receptor.
59 | prody.writePDB(os.path.join(out_dir, rec_name + '_rec.pdb'), rec)
60 |
61 | # Save ligands.
62 | for lig_name, lig_sel in ligands:
63 | if lig_sel is None:
64 | continue
65 | lig_name = lig_name.replace(' ', '_')
66 | prody.writePDB(os.path.join(out_dir, rec_name + '_' + lig_name + '.pdb'), lig_sel)
67 |
68 | except Exception as e:
69 | print('failed', path)
70 | print(e)
71 |
72 | return None
73 |
74 |
75 | def load_all(moad_dir, moad_csv, out_dir, num_cores=1):
76 | computed = []
77 |
78 | if os.path.exists(out_dir):
79 | names = os.listdir(out_dir)
80 | names = [x.split('_rec')[0] for x in names if '_rec' in x]
81 | computed = names
82 | else:
83 | os.mkdir(out_dir)
84 |
85 | # Load MOAD csv.
86 | moad_families, moad_targets = parse_moad(moad_csv)
87 |
88 | # Collect input files.
89 | files = []
90 | for fname in os.listdir(moad_dir):
91 | if fname.startswith('.'):
92 | continue
93 | files.append(os.path.join(moad_dir, fname))
94 |
95 | files = sorted(files)
96 | print('[*] Loading %d files...' % len(files))
97 |
98 | failed = []
99 | info = {}
100 |
101 | # (path, target)
102 | work = []
103 |
104 | for path in tqdm.tqdm(files):
105 | rec_name = os.path.split(path)[-1].replace('.','_')
106 | if rec_name in computed:
107 | # Skip computed.
108 | continue
109 |
110 | pdb_id = rec_name.split('_')[0].upper()
111 | target = moad_targets[pdb_id]
112 |
113 | work.append((out_dir, rec_name, path, target))
114 |
115 | print('[*] Starting...')
116 | with multiprocessing.Pool(num_cores) as p:
117 | with tqdm.tqdm(total=len(work)) as pbar:
118 | for r in p.imap_unordered(do_proc, work):
119 | pbar.update()
120 |
121 | print('[*] Done.')
122 |
123 |
124 | def run():
125 | parser = argparse.ArgumentParser()
126 |
127 | parser.add_argument('-d', '--dataset', required=True, help='Path to MOAD folder')
128 | parser.add_argument('-c', '--csv', required=True, help='Path to MOAD "every.csv"')
129 | parser.add_argument('-o', '--output', default='./processed', help='Output directory')
130 | parser.add_argument('-n', '--num_cores', default=1, type=int, help='Number of cores')
131 |
132 | args = parser.parse_args()
133 |
134 | load_all(args.dataset, args.csv, args.output, args.num_cores)
135 |
136 |
137 | if __name__=='__main__':
138 | run()
139 |
--------------------------------------------------------------------------------
/test_installation.sh:
--------------------------------------------------------------------------------
1 | rm -rf test_installation
2 | mkdir test_installation
3 | cd test_installation
4 | wget https://files.rcsb.org/view/1XDN.pdb
5 | cat 1XDN.pdb | grep -v ATP > receptor.pdb
6 | cat 1XDN.pdb | grep ATP > ligand.pdb
7 |
8 | # Remove terminal phosphate
9 | cat ligand.pdb | grep -v "O1G" | grep -v "PG" | grep -v "O2G" | grep -v "O3G" > ligand2.pdb
10 |
11 | cd ../
12 | python deepfrag.py --receptor test_installation/receptor.pdb --ligand test_installation/ligand2.pdb --cx 44.807 --cy 16.562 --cz 14.092
13 |
14 |
15 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Jacob Durrant
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | # use this file except in compliance with the License. You may obtain a copy
5 | # of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | # License for the specific language governing permissions and limitations
13 | # under the License.
14 |
15 |
16 | """
17 | Utility script to launch training jobs.
18 | """
19 |
20 | import argparse
21 | import json
22 |
23 | from leadopt.model_conf import MODELS
24 |
25 |
26 | def main():
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--save_path', help='location to save the model')
29 | parser.add_argument('--wandb_project', help='set this to log run to wandb')
30 | parser.add_argument('--configuration', help='path to a configuration args.json file')
31 |
32 | subparsers = parser.add_subparsers(dest='version')
33 |
34 | for m in MODELS:
35 | sub = subparsers.add_parser(m)
36 | MODELS[m].setup_parser(sub)
37 |
38 | args = parser.parse_args()
39 | args_dict = args.__dict__
40 |
41 | if args.configuration is None and args.version is None:
42 | parser.print_help()
43 | exit(0)
44 | elif args.configuration is not None:
45 | _args = {}
46 | try:
47 | _args = json.loads(open(args.configuration, 'r').read())
48 | except Exception as e:
49 | print('Error reading configuration file: %s' % args.configuration)
50 | print(e)
51 | exit(-1)
52 | args_dict.update(_args)
53 | elif args.version is not None:
54 | pass
55 | else:
56 | print('You can specify a model or configuration file but not both.')
57 | exit(-1)
58 |
59 | # Initialize model.
60 | model_type = args_dict['version']
61 | model = MODELS[model_type](args_dict)
62 |
63 | model.train(args.save_path)
64 |
65 |
66 | if __name__=='__main__':
67 | main()
68 |
--------------------------------------------------------------------------------