├── .gitignore ├── CHANGES.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── CONTRIBUTORS.md ├── LICENSE.md ├── README.md ├── config ├── __init__.py └── moad_partitions.py ├── configurations ├── README.md ├── final.json ├── layer_type_sweep │ ├── lig_simple.json │ ├── lig_simple_h.json │ ├── lig_single.json │ ├── lig_single_h.json │ ├── rec_meta.json │ ├── rec_meta_mix.json │ ├── rec_simple_h.json │ ├── rec_single.json │ └── rec_single_h.json └── voxelation_sweep │ ├── t0.json │ ├── t1.json │ ├── t10.json │ ├── t11.json │ ├── t12.json │ ├── t13.json │ ├── t14.json │ ├── t15.json │ ├── t16.json │ ├── t17.json │ ├── t2.json │ ├── t3.json │ ├── t4.json │ ├── t5.json │ ├── t6.json │ ├── t7.json │ ├── t8.json │ └── t9.json ├── data └── README.md ├── deepfrag.py ├── leadopt ├── __init__.py ├── data_util.py ├── grid_util.py ├── infer.py ├── metrics.py ├── model_conf.py ├── models │ ├── __init__.py │ ├── backport.py │ └── voxel.py └── util.py ├── requirements.txt ├── scripts ├── README.md ├── README_MOAD.md ├── make_fingerprints.py ├── merge_moad.py ├── moad_training_splits.py ├── moad_util.py ├── process_moad.py └── split_moad.py ├── test_installation.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.pyc 2 | **/__pycache__ 3 | .ipynb_checkpoints 4 | 5 | data/** 6 | !data/README.md 7 | 8 | .DS_Store 9 | 10 | .store/ 11 | dist/** 12 | build/** 13 | .vscode 14 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | Changes 2 | ======= 3 | 4 | 1.0.4 5 | ----- 6 | 7 | * Updated packages in `requirements.txt` 8 | * Fingerprint now cast as float instead of np.float 9 | * Minor updates to `README.md` 10 | * Fixed an error that prevented DeepFrag from loading PDB files without the .pdb 11 | extension. (Affects only recent versions of prody?) 12 | * Added `test_installation.sh` to make it easy to verify that DeepFrag is 13 | installed correctly. Downloads sample data (PDB ID 1XDN) and runs DeepFrag. 14 | 15 | 1.0.3 16 | ----- 17 | 18 | * CLI parameters `--cx`, `--cy`, `--cz`, `--rx`, `--ry`, and `--rz` can now be 19 | floats (not just integers). We recommend specifying the exact atomic 20 | coordinates of the connection and removal points. 21 | * Fixed a bug that caused the `--full` parameter to throw an error when 22 | performing fragment addition (but not fragment replacement) using the CLI 23 | implementation. 24 | * Minor updates to the documentation. 25 | 26 | 1.0.2 27 | ----- 28 | 29 | * Added a CLI implementation of the program. See `README.md` for details. 30 | * Added a version number and citation to the program output. 31 | 32 | 1.0.1 33 | ----- 34 | 35 | * Removed open-babel dependency. 36 | * Added option (`cpu_gridify`) to improve use on CPUs when no GPU is 37 | available. 38 | * Updated `data/README.md` with new location of data files. 39 | * Fixed a config import. 40 | 41 | 1.0 42 | --- 43 | 44 | Original version. 45 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. 8 | 9 | ## Our Standards 10 | 11 | Examples of behavior that contributes to a positive environment for our community include: 12 | 13 | * Demonstrating empathy and kindness toward other people 14 | * Being respectful of differing opinions, viewpoints, and experiences 15 | * Giving and gracefully accepting constructive feedback 16 | * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience 17 | * Focusing on what is best not just for us as individuals, but for the overall community 18 | 19 | Examples of unacceptable behavior include: 20 | 21 | * The use of sexualized language or imagery, and sexual attention or 22 | advances of any kind 23 | * Trolling, insulting or derogatory comments, and personal or political attacks 24 | * Public or private harassment 25 | * Publishing others' private information, such as a physical or email 26 | address, without their explicit permission 27 | * Other conduct which could reasonably be considered inappropriate in a 28 | professional setting 29 | 30 | ## Enforcement Responsibilities 31 | 32 | Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. 33 | 34 | Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. 35 | 36 | ## Scope 37 | 38 | This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. 39 | 40 | ## Enforcement 41 | 42 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at [http://durrantlab.com/contact/](http://durrantlab.com/contact/). All complaints will be reviewed and investigated promptly and fairly. 43 | 44 | All community leaders are obligated to respect the privacy and security of the reporter of any incident. 45 | 46 | ## Enforcement Guidelines 47 | 48 | Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: 49 | 50 | ### 1. Correction 51 | 52 | **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. 53 | 54 | **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. 55 | 56 | ### 2. Warning 57 | 58 | **Community Impact**: A violation through a single incident or series of actions. 59 | 60 | **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. 61 | 62 | ### 3. Temporary Ban 63 | 64 | **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. 65 | 66 | **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. 67 | 68 | ### 4. Permanent Ban 69 | 70 | **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. 71 | 72 | **Consequence**: A permanent ban from any sort of public interaction within the project community. 73 | 74 | ## Attribution 75 | 76 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, 77 | available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 78 | 79 | Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). 80 | 81 | [homepage]: https://www.contributor-covenant.org 82 | 83 | For answers to common questions about this code of conduct, see the FAQ at 84 | https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. 85 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Introduction 4 | 5 | Thank you for your interest in contributing! All types of contributions are 6 | encouraged and valued. Please make sure to read the relevant section before 7 | making your contribution. We at the [Durrant Lab](http://durrantlab.com) look 8 | forward to your help! 9 | 10 | ## Reporting a bug 11 | 12 | If you're unable to find an open issue addressing the bug, feel free to [open 13 | a new one](https://docs.gitlab.com/ee/user/project/issues/). Be sure to 14 | include a **title and clear description**, as much relevant information as 15 | possible (e.g., the program, platform, or operating-system version numbers), 16 | and a **code sample** or **test case** demonstrating the expected behavior 17 | that is not occurring. 18 | 19 | If you or the maintainers don't respond to an issue for 30 days, the issue may 20 | be closed. If you want to come back to it, reply (once, please), and we'll 21 | reopen the existing issue. Please avoid filing new issues as extensions of one 22 | you already made. 23 | 24 | ## Project setup to make source-code changes on your computer 25 | 26 | This project uses `git` to manage contributions, so start by [reading up on 27 | how to fork a `git` 28 | repository](https://docs.gitlab.com/ee/user/project/repository/forking_workflow.html#creating-a-fork) 29 | if you've never done it before. 30 | 31 | Forking will place a copy of the code on your own computer, where you can 32 | modify it to correct bugs or add features. 33 | 34 | ## Integrating your changes back into the main codebase 35 | 36 | Follow these steps to "push" your changes to the main online repository so 37 | others can benefit from them: 38 | 39 | * Create a [new merge 40 | request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html) 41 | with your changes. 42 | 43 | * Ensure the description clearly describes the problem and solution. Include 44 | the relevant issue number if applicable. 45 | 46 | * Before submitting, please read this CONTRIBUTING.md file to know more about 47 | coding conventions and benchmarks. 48 | 49 | ## Coding conventions 50 | 51 | Be sure to adequately document your code with comments so others can 52 | understand your changes. All classes and functions should have associated doc 53 | strings, formatted as appropriate given the programming language. Here are 54 | some examples: 55 | 56 | ```python 57 | """ 58 | This file does important calculations. It is a Python file with nice doc strings. 59 | """ 60 | 61 | class ImportantCalcs(object): 62 | """ 63 | An important class where important things happen. 64 | """ 65 | 66 | def __init__(self, vars=None, receptor_file=None, 67 | file_conversion_class_object=None, test_boot=True): 68 | """ 69 | Required to initialize any conversion. 70 | 71 | Inputs: 72 | :param dict vars: Dictionary of user variables 73 | :param str receptor_file: the path for the receptor file 74 | :param obj file_conversion_class_object: object that is used to convert 75 | files from pdb to pdbqt 76 | :param bool test_boot: used to initialize class without objects for 77 | testing purpose 78 | """ 79 | 80 | pass 81 | ``` 82 | 83 | ```typescript 84 | /** 85 | * Sets the curStarePt variable externally. A useful, well-documented 86 | * TypeScript function. 87 | * @param {number[]} pt The x, y coordinates of the point as a list of 88 | * numbers. 89 | * @returns void 90 | */ 91 | export function setCurStarePt(pt: any): void { 92 | curStarePt.copyFrom(pt); 93 | } 94 | ``` 95 | 96 | If writing Python code, be sure to use the [Black 97 | formatter](https://black.readthedocs.io/en/stable/) before submitting a merge 98 | request. If writing code in JavaScript or TypeScript, please use the [Prettier 99 | formatter](https://marketplace.visualstudio.com/items?itemName=esbenp.prettier-vscode). 100 | 101 | ## Fixing whitespace, formatting code, or making a purely cosmetic patch 102 | 103 | Changes that are cosmetic in nature and do not add anything substantial to the 104 | stability, functionality, or testability of the program are unlikely to be 105 | accepted. 106 | 107 | ## Asking questions about the program 108 | 109 | Ask any question about how to use the program on the appropriate [Durrant Lab 110 | forum](http://durrantlab.com/forums/). 111 | 112 | ## Acknowledgements 113 | 114 | This document was inspired by: 115 | 116 | * [Ruby on Rails CONTRIBUTING.md 117 | file](https://raw.githubusercontent.com/rails/rails/master/CONTRIBUTING.md) 118 | (MIT License). 119 | * [weallcontribute](https://github.com/WeAllJS/weallcontribute/blob/latest/CONTRIBUTING.md) 120 | (Public Domain License). 121 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | Contributors 2 | ============ 3 | 4 | * Harrison Green 5 | * Jacob Durrant 6 | * David Koes 7 | * Vandan Revanur 8 | * Martin Salinas -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | ============== 3 | 4 | _Version 2.0, January 2004_ 5 | _<>_ 6 | 7 | ### Terms and Conditions for use, reproduction, and distribution 8 | 9 | #### 1. Definitions 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, and 12 | distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by the 15 | copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all other 18 | entities that control, are controlled by, or are under common control with 19 | that entity. For the purposes of this definition, "control" means **(i)** the 20 | power, direct or indirect, to cause the direction or management of such 21 | entity, whether by contract or otherwise, or **(ii)** ownership of fifty 22 | percent (50%) or more of the outstanding shares, or **(iii)** beneficial 23 | ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity exercising 26 | permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation source, and 30 | configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical transformation or 33 | translation of a Source form, including but not limited to compiled object 34 | code, generated documentation, and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or Object form, 37 | made available under the License, as indicated by a copyright notice that is 38 | included in or attached to the work (an example is provided in the Appendix 39 | below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object form, that 42 | is based on (or derived from) the Work and for which the editorial revisions, 43 | annotations, elaborations, or other modifications represent, as a whole, an 44 | original work of authorship. For the purposes of this License, Derivative 45 | Works shall not include works that remain separable from, or merely link (or 46 | bind by name) to the interfaces of, the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including the original 49 | version of the Work and any modifications or additions to that Work or 50 | Derivative Works thereof, that is intentionally submitted to Licensor for 51 | inclusion in the Work by the copyright owner or by an individual or Legal 52 | Entity authorized to submit on behalf of the copyright owner. For the purposes 53 | of this definition, "submitted" means any form of electronic, verbal, or 54 | written communication sent to the Licensor or its representatives, including 55 | but not limited to communication on electronic mailing lists, source code 56 | control systems, and issue tracking systems that are managed by, or on behalf 57 | of, the Licensor for the purpose of discussing and improving the Work, but 58 | excluding communication that is conspicuously marked or otherwise designated 59 | in writing by the copyright owner as "Not a Contribution." 60 | 61 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf 62 | of whom a Contribution has been received by Licensor and subsequently 63 | incorporated within the Work. 64 | 65 | #### 2. Grant of Copyright License 66 | 67 | Subject to the terms and conditions of this License, each Contributor hereby 68 | grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, 69 | irrevocable copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the Work and 71 | such Derivative Works in Source or Object form. 72 | 73 | #### 3. Grant of Patent License 74 | 75 | Subject to the terms and conditions of this License, each Contributor hereby 76 | grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, 77 | irrevocable (except as stated in this section) patent license to make, have 78 | made, use, offer to sell, sell, import, and otherwise transfer the Work, where 79 | such license applies only to those patent claims licensable by such 80 | Contributor that are necessarily infringed by their Contribution(s) alone or 81 | by combination of their Contribution(s) with the Work to which such 82 | Contribution(s) was submitted. If You institute patent litigation against any 83 | entity (including a cross-claim or counterclaim in a lawsuit) alleging that 84 | the Work or a Contribution incorporated within the Work constitutes direct or 85 | contributory patent infringement, then any patent licenses granted to You 86 | under this License for that Work shall terminate as of the date such 87 | litigation is filed. 88 | 89 | #### 4. Redistribution 90 | 91 | You may reproduce and distribute copies of the Work or Derivative Works 92 | thereof in any medium, with or without modifications, and in Source or Object 93 | form, provided that You meet the following conditions: 94 | 95 | * **(a)** You must give any other recipients of the Work or Derivative Works a 96 | copy of this License; and 97 | * **(b)** You must cause any modified files to carry prominent notices stating 98 | that You changed the files; and 99 | * **(c)** You must retain, in the Source form of any Derivative Works that You 100 | distribute, all copyright, patent, trademark, and attribution notices from 101 | the Source form of the Work, excluding those notices that do not pertain to 102 | any part of the Derivative Works; and 103 | * **(d)** If the Work includes a "NOTICE" text file as part of its 104 | distribution, then any Derivative Works that You distribute must include a 105 | readable copy of the attribution notices contained within such NOTICE file, 106 | excluding those notices that do not pertain to any part of the Derivative 107 | Works, in at least one of the following places: within a NOTICE text file 108 | distributed as part of the Derivative Works; within the Source form or 109 | documentation, if provided along with the Derivative Works; or, within a 110 | display generated by the Derivative Works, if and wherever such third-party 111 | notices normally appear. The contents of the NOTICE file are for 112 | informational purposes only and do not modify the License. You may add Your 113 | own attribution notices within Derivative Works that You distribute, 114 | alongside or as an addendum to the NOTICE text from the Work, provided that 115 | such additional attribution notices cannot be construed as modifying the 116 | License. 117 | 118 | You may add Your own copyright statement to Your modifications and may provide 119 | additional or different license terms and conditions for use, reproduction, or 120 | distribution of Your modifications, or for any such Derivative Works as a 121 | whole, provided Your use, reproduction, and distribution of the Work otherwise 122 | complies with the conditions stated in this License. 123 | 124 | #### 5. Submission of Contributions 125 | 126 | Unless You explicitly state otherwise, any Contribution intentionally 127 | submitted for inclusion in the Work by You to the Licensor shall be under the 128 | terms and conditions of this License, without any additional terms or 129 | conditions. Notwithstanding the above, nothing herein shall supersede or 130 | modify the terms of any separate license agreement you may have executed with 131 | Licensor regarding such Contributions. 132 | 133 | #### 6. Trademarks 134 | 135 | This License does not grant permission to use the trade names, trademarks, 136 | service marks, or product names of the Licensor, except as required for 137 | reasonable and customary use in describing the origin of the Work and 138 | reproducing the content of the NOTICE file. 139 | 140 | #### 7. Disclaimer of Warranty 141 | 142 | Unless required by applicable law or agreed to in writing, Licensor provides 143 | the Work (and each Contributor provides its Contributions) on an "AS IS" 144 | BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 145 | implied, including, without limitation, any warranties or conditions of TITLE, 146 | NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You 147 | are solely responsible for determining the appropriateness of using or 148 | redistributing the Work and assume any risks associated with Your exercise of 149 | permissions under this License. 150 | 151 | #### 8. Limitation of Liability 152 | 153 | In no event and under no legal theory, whether in tort (including negligence), 154 | contract, or otherwise, unless required by applicable law (such as deliberate 155 | and grossly negligent acts) or agreed to in writing, shall any Contributor be 156 | liable to You for damages, including any direct, indirect, special, 157 | incidental, or consequential damages of any character arising as a result of 158 | this License or out of the use or inability to use the Work (including but not 159 | limited to damages for loss of goodwill, work stoppage, computer failure or 160 | malfunction, or any and all other commercial damages or losses), even if such 161 | Contributor has been advised of the possibility of such damages. 162 | 163 | #### 9. Accepting Warranty or Additional Liability 164 | 165 | While redistributing the Work or Derivative Works thereof, You may choose to 166 | offer, and charge a fee for, acceptance of support, warranty, indemnity, or 167 | other liability obligations and/or rights consistent with this License. 168 | However, in accepting such obligations, You may act only on Your own behalf 169 | and on Your sole responsibility, not on behalf of any other Contributor, and 170 | only if You agree to indemnify, defend, and hold each Contributor harmless for 171 | any liability incurred by, or claims asserted against, such Contributor by 172 | reason of your accepting any such warranty or additional liability. 173 | 174 | _END OF TERMS AND CONDITIONS_ 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepFrag 2 | 3 | DeepFrag is a machine learning model for fragment-based lead optimization. In 4 | this repository, you will find code to train the model and code to run 5 | inference using a pre-trained model. 6 | 7 | ## Citation 8 | 9 | If you use DeepFrag in your research, please cite as: 10 | 11 | Green, H., Koes, D. R., & Durrant, J. D. (2021). DeepFrag: a deep 12 | convolutional neural network for fragment-based lead optimization. Chemical 13 | Science. 14 | 15 | ```tex 16 | @article{green2021deepfrag, 17 | title={DeepFrag: a deep convolutional neural network for fragment-based lead optimization}, 18 | author={Green, Harrison and Koes, David Ryan and Durrant, Jacob D}, 19 | journal={Chemical Science}, 20 | year={2021}, 21 | publisher={Royal Society of Chemistry} 22 | } 23 | ``` 24 | 25 | ## Usage 26 | 27 | There are three ways to use DeepFrag: 28 | 29 | 1. **DeepFrag Browser App**: We have released a free, open-source browser app 30 | for DeepFrag that requires no setup and does not transmit any structures to 31 | a remote server. 32 | - View the online version at 33 | [durrantlab.pitt.edu/deepfrag](https://durrantlab.pitt.edu/deepfrag/) 34 | - See the code at 35 | [git.durrantlab.pitt.edu/jdurrant/deepfrag-app](https://git.durrantlab.pitt.edu/jdurrant/deepfrag-app) 36 | 2. **DeepFrag CLI**: In this repository we have included a `deepfrag.py` 37 | script that can perform common prediction tasks using the API. 38 | - See the `DeepFrag CLI` section below 39 | 3. **DeepFrag API**: For custom tasks or fine-grained control over 40 | predictions, you can invoke the DeepFrag API directly and interface with 41 | the raw data structures and the PyTorch model. We have created an example 42 | Google Colab (Jupyter notebook) that demonstrates how to perform manual 43 | predictions. 44 | - See the interactive 45 | [Colab](https://colab.research.google.com/drive/1If8rWQ9aVKJyJwfaOql56mA2llqC0iur). 46 | 47 | ## DeepFrag CLI 48 | 49 | The DeepFrag CLI is invoked by running `python3 deepfrag.py` in this 50 | repository. The CLI requires a pre-trained model and the fragment library to 51 | run. You will be prompted to download both when you first run the CLI and 52 | these will be saved in the `./.store` directory. 53 | 54 | ### Structure (specify exactly one) 55 | 56 | The input structures are specified using either a manual receptor and ligand 57 | pdb or by specifying a pdb id and the ligand residue number. 58 | 59 | - `--receptor --ligand ` 60 | - `--pdb --resnum ` 61 | 62 | ### Connection Point (specify exactly one) 63 | 64 | DeepFrag will predict new fragments that connect to the _connection point_ via 65 | a single bond. You must specify the connection point atom using one of the 66 | following: 67 | 68 | - `--cname `: Specify the connection point by atom name (e.g. `C3`, 69 | `N5`, `O2`, ...). 70 | - `--cx --cy --cz `: Specify the connection point by atomic 71 | coordinate. DeepFrag will find the closest atom to this point. 72 | 73 | ### Fragment Removal (optional) (specify exactly one) 74 | 75 | If you are using DeepFrag for fragment _replacement_, you must first remove 76 | the original fragment from the ligand structure. You can either do this by 77 | hand, e.g. editing the PDB, or DeepFrag can do this for you by specifying 78 | _which_ fragment should be removed. 79 | 80 | _Note: predicting fragments in place of hydrogen atoms (e.g. protons) does not 81 | require any fragment removal since hydrogen atoms are ignored by the model._ 82 | 83 | To remove a fragment, you specify a second atom that is contained in the 84 | fragment. Like the connection point, you can either use the atom name or the 85 | atom coordinate. 86 | 87 | - `--rname `: Specify the connection point by atom name (e.g. `C3`, 88 | `N5`, `O2`, ...). 89 | - `--rx --ry --rz `: Specify the connection point by atomic 90 | coordinate. DeepFrag will find the closest atom to this point. 91 | 92 | ### Output (optional) 93 | 94 | By default, DeepFrag will print a list of fragment predictions to stdout 95 | similar to the [Browser App](https://durrantlab.pitt.edu/deepfrag/). 96 | 97 | - `--out `: Save predictions in CSV format to `out.csv`. Each line 98 | contains the fragment rank, score and SMILES string. 99 | 100 | ### Miscellaneous (optional) 101 | 102 | - `--full`: Generate SMILES strings with the full ligand structure instead of 103 | just the fragment. (__IMPORTANT NOTE__: Bond orders are not assigned to the 104 | parent portion of the full ligand structure. These must be added manually.) 105 | - `--cpu/--gpu`: DeepFrag will attempt to infer if a Cuda GPU is available and 106 | fallback to the CPU if it is not. You can set either the `--cpu` or `--gpu` 107 | flag to explicitly specify the target device. 108 | - `--num_grids `: Number of grid rotations to use. Using more will take 109 | longer but produce a more stable prediction. (Default: 4) 110 | - `--top_k `: Number of predictions to print in stdout. Use -1 to display 111 | all. (Default: 25) 112 | 113 | ## Reproduce Results 114 | 115 | You can use the DeepFrag CLI to reproduce the highlighted results from the 116 | main manuscript: 117 | 118 | ### 1. Fragment replacement 119 | 120 | To replace fragments, specify the connection point (`cname` or `cx/cy/cz`) and 121 | specify a second atom that is contained in the fragment (`rname` or 122 | `rx/ry/rz`). 123 | 124 | ```bash 125 | # Fig. 3: (2XP9) H. sapiens peptidyl-prolyl cis-trans isomerase NIMA-interacting 1 (HsPin1p) 126 | 127 | # Carboxylate A 128 | $ python3 deepfrag.py --pdb 2xp9 --resnum 1165 --cname C10 --rname C12 129 | 130 | # Phenyl B 131 | $ python3 deepfrag.py --pdb 2xp9 --resnum 1165 --cname C1 --rname C2 132 | 133 | # Phenyl C 134 | $ python3 deepfrag.py --pdb 2xp9 --resnum 1165 --cname C18 --rname C19 135 | ``` 136 | 137 | ```bash 138 | # Fig. 4A: (6QZ8) Protein myeloid cell leukemia1 (Mcl-1) 139 | 140 | # Carboxylate group interacting with R263 141 | $ python3 deepfrag.py --pdb 6qz8 --resnum 401 --cname C12 --rname C14 142 | 143 | # Ethyl group 144 | $ python3 deepfrag.py --pdb 6qz8 --resnum 401 --cname C6 --rname C10 145 | 146 | # Methyl group 147 | $ python3 deepfrag.py --pdb 6qz8 --resnum 401 --cname C25 --rname C30 148 | 149 | # Chlorine atom 150 | $ python3 deepfrag.py --pdb 6qz8 --resnum 401 --cname C28 --rname CL 151 | ``` 152 | 153 | ```bash 154 | # Fig. 4B: (1X38) Family GH3 b-D-glucan glucohydrolase (barley) 155 | 156 | # Hydroxyl group interacting with R158 and D285 157 | $ python3 deepfrag.py --pdb 1x38 --resnum 1001 --cname C2B --rname O2B 158 | 159 | # Phenyl group interacting with W286 and W434 160 | $ python3 deepfrag.py --pdb 1x38 --resnum 1001 --cname C7B --rname C1 161 | ``` 162 | 163 | ```bash 164 | # Fig. 4C: (4FOW) NanB sialidase (Streptococcus pneumoniae) 165 | 166 | # Amino group 167 | $ python3 deepfrag.py --pdb 4fow --resnum 701 --cname CAE --rname NAA 168 | ``` 169 | 170 | ### 2. Fragment addition 171 | 172 | For fragment addition, you only need to specify the atom connection point 173 | (`cname` or `cx/cy/cz`). In this case, DeepFrag will implicitly replace a 174 | valent hydrogen. 175 | 176 | ```bash 177 | # Fig. 5: Ligands targeting the SARS-CoV-2 main protease (MPro) 178 | 179 | # 5A: (5RGH) Extension on Z1619978933 180 | $ python3 deepfrag.py --pdb 5rgh --resnum 404 --cname C09 181 | 182 | # 5B: (5R81) Extension on Z1367324110 183 | $ python3 deepfrag.py --pdb 5r81 --resnum 1001 --cname C07 184 | ``` 185 | 186 | ## Overview 187 | 188 | - `config`: fixed configuration information (e.g., TRAIN/VAL/TEST partitions) 189 | - `configurations`: benchmark model configurations (see 190 | [`configurations/README.md`](configurations/README.md)) 191 | - `data`: training/inference data (see [`data/README.md`](data/README.md)) 192 | - `leadopt`: main module code 193 | - `models`: pytorch architecture definitions 194 | - `data_util.py`: utility code for reading packed fragment/fingerprint data 195 | files 196 | - `grid_util.py`: GPU-accelerated grid generation code 197 | - `metrics.py`: pytorch implementations of several metrics 198 | - `model_conf.py`: contains code to configure and train models 199 | - `util.py`: utility code for rdkit/openbabel processing 200 | - `scripts`: data processing scripts (see 201 | [`scripts/README.md`](scripts/README.md)) 202 | - `train.py`: CLI interface to launch training runs 203 | 204 | ## Dependencies 205 | 206 | You can build a virtualenv with the requirements: 207 | 208 | ```sh 209 | $ python3 -m venv leadopt_env 210 | $ source ./leadopt_env/bin/activate 211 | $ pip install -r requirements.txt 212 | $ pip install prody 213 | $ pip install torch==2.1.2+cu118 torchvision==0.16.2+cu118 --index-url https://download.pytorch.org/whl/cu118 214 | $ sudo apt install nvidia-cuda-toolkit 215 | ``` 216 | 217 | Regarding the nvidia-cuda-toolkit, you may wish to ensure that the toolkit 218 | version matches cuda installed on your machine. You can check the version of 219 | cuda by running the following commands: 220 | 221 | ```sh 222 | $ nvcc --version 223 | $ nvidia-smi 224 | ``` 225 | 226 | Note: We used `Cuda 10.1` for training. 227 | 228 | ## Training 229 | 230 | To train a model, you can use the `train.py` utility script. You can specify 231 | model parameters as command line arguments or load parameters from a 232 | configuration args.json file. 233 | 234 | ```bash 235 | python train.py \ 236 | --save_path=/path/to/model \ 237 | --wandb_project=my_project \ 238 | {model_type} \ 239 | --model_arg1=x \ 240 | --model_arg2=y \ 241 | ... 242 | ``` 243 | 244 | or 245 | 246 | ```bash 247 | python train.py \ 248 | --save_path=/path/to/model \ 249 | --wandb_project=my_project \ 250 | --configuration=./configurations/args.json 251 | ``` 252 | 253 | `save_path` is a directory to save the best model. The directory will be 254 | created if it doesn't exist. If this is not provided, the model will not be 255 | saved. 256 | 257 | `wandb_project` is an optional wandb project name. If provided, the run will 258 | be logged to wandb. 259 | 260 | See below for available models and model-specific parameters: 261 | 262 | ## Leadopt Models 263 | 264 | In this repository, trainable models are subclasses of 265 | `model_conf.LeadoptModel`. This class encapsulates model configuration 266 | arguments and pytorch models and enables saving and loading multi-component 267 | models. 268 | 269 | ```py 270 | from leadopt.model_conf import LeadoptModel, MODELS 271 | 272 | model = MODELS['voxel']({args...}) 273 | model.train(save_path='./mymodel') 274 | 275 | ... 276 | 277 | model2 = LeadoptModel.load('./mymodel') 278 | ``` 279 | 280 | Internally, model arguments are configured by setting up an `argparse` parser 281 | and passing around a `dict` of configuration parameters in `self._args`. 282 | 283 | ### VoxelNet 284 | 285 | ```text 286 | --no_partitions If set, disable the use of TRAIN/VAL partitions during 287 | training. 288 | -f FRAGMENTS, --fragments FRAGMENTS 289 | Path to fragments file. 290 | -fp FINGERPRINTS, --fingerprints FINGERPRINTS 291 | Path to fingerprints file. 292 | -lr LEARNING_RATE, --learning_rate LEARNING_RATE 293 | --num_epochs NUM_EPOCHS 294 | Number of epochs to train for. 295 | --test_steps TEST_STEPS 296 | Number of evaluation steps per epoch. 297 | -b BATCH_SIZE, --batch_size BATCH_SIZE 298 | --grid_width GRID_WIDTH 299 | --grid_res GRID_RES 300 | --fdist_min FDIST_MIN 301 | Ignore fragments closer to the receptor than this 302 | distance (Angstroms). 303 | --fdist_max FDIST_MAX 304 | Ignore fragments further from the receptor than this 305 | distance (Angstroms). 306 | --fmass_min FMASS_MIN 307 | Ignore fragments smaller than this mass (Daltons). 308 | --fmass_max FMASS_MAX 309 | Ignore fragments larger than this mass (Daltons). 310 | --ignore_receptor 311 | --ignore_parent 312 | -rec_typer {single,single_h,simple,simple_h,desc,desc_h} 313 | -lig_typer {single,single_h,simple,simple_h,desc,desc_h} 314 | -rec_channels REC_CHANNELS 315 | -lig_channels LIG_CHANNELS 316 | --in_channels IN_CHANNELS 317 | --output_size OUTPUT_SIZE 318 | --pad 319 | --blocks BLOCKS [BLOCKS ...] 320 | --fc FC [FC ...] 321 | --use_all_labels 322 | --dist_fn {mse,bce,cos,tanimoto} 323 | --loss {direct,support_v1} 324 | ``` 325 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | -------------------------------------------------------------------------------- /configurations/README.md: -------------------------------------------------------------------------------- 1 | This folder contains benchmark model configurations referenced in the paper. 2 | 3 | Overview: 4 | - `layer_type_sweep/*`: experimenting with different parent/receptor typing schemes 5 | - `voxelation_sweep/*`: experimenting with different voxelation types and atomic influence radii 6 | - `final.json`: final production model 7 | 8 | You can train new models using these configurations with the `train.py` script: 9 | 10 | ```sh 11 | python train.py \ 12 | --save_path=/path/to/model \ 13 | --wandb_project=my_project \ 14 | --configuration=./configurations/model.json 15 | ``` 16 | 17 | Note: these configuration files assume the working directory is the `leadopt` base directory and that the data directory is accessible at `./data`. 18 | -------------------------------------------------------------------------------- /configurations/final.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "voxelnet", 3 | "no_partitions": false, 4 | "fragments": "./data/moad.h5", 5 | "fingerprints": "./data/rdk10_moad.h5", 6 | "learning_rate": 0.001, 7 | "num_epochs": 50, 8 | "test_steps": 400, 9 | "batch_size": 16, 10 | "grid_width": 24, 11 | "grid_res": 0.75, 12 | "fdist_min": null, 13 | "fdist_max": 4, 14 | "fmass_min": null, 15 | "fmass_max": 150, 16 | "ignore_receptor": false, 17 | "ignore_parent": false, 18 | "output_size": 2048, 19 | "pad": false, 20 | "blocks": [ 21 | 64, 22 | 64 23 | ], 24 | "fc": [ 25 | 512 26 | ], 27 | "use_all_labels": true, 28 | "dist_fn": "cos", 29 | "loss": "direct", 30 | "point_radius": 1, 31 | "point_type": 0, 32 | "rec_typer": "simple", 33 | "acc_type": 0, 34 | "lig_typer": "simple" 35 | } 36 | -------------------------------------------------------------------------------- /configurations/layer_type_sweep/lig_simple.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "voxelnet", 3 | "no_partitions": false, 4 | "fragments": "./data/moad.h5", 5 | "fingerprints": "./data/rdk10_moad.h5", 6 | "learning_rate": 0.001, 7 | "num_epochs": 15, 8 | "test_steps": 400, 9 | "batch_size": 16, 10 | "grid_width": 24, 11 | "grid_res": 0.75, 12 | "fdist_min": null, 13 | "fdist_max": 4, 14 | "fmass_min": null, 15 | "fmass_max": 150, 16 | "ignore_receptor": false, 17 | "ignore_parent": false, 18 | "output_size": 2048, 19 | "pad": false, 20 | "blocks": [64, 64], 21 | "fc": [512], 22 | "use_all_labels": true, 23 | "dist_fn": "cos", 24 | "loss": "direct", 25 | "point_radius": 1, 26 | "point_type": 3, 27 | "rec_typer": "simple", 28 | "acc_type": 0, 29 | "lig_typer": "simple" 30 | } 31 | -------------------------------------------------------------------------------- /configurations/layer_type_sweep/lig_simple_h.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "voxelnet", 3 | "no_partitions": false, 4 | "fragments": "./data/moad.h5", 5 | "fingerprints": "./data/rdk10_moad.h5", 6 | "learning_rate": 0.001, 7 | "num_epochs": 15, 8 | "test_steps": 400, 9 | "batch_size": 16, 10 | "grid_width": 24, 11 | "grid_res": 0.75, 12 | "fdist_min": null, 13 | "fdist_max": 4, 14 | "fmass_min": null, 15 | "fmass_max": 150, 16 | "ignore_receptor": false, 17 | "ignore_parent": false, 18 | "output_size": 2048, 19 | "pad": false, 20 | "blocks": [64, 64], 21 | "fc": [512], 22 | "use_all_labels": true, 23 | "dist_fn": "cos", 24 | "loss": "direct", 25 | "point_radius": 1, 26 | "point_type": 3, 27 | "rec_typer": "simple", 28 | "acc_type": 0, 29 | "lig_typer": "simple_h" 30 | } 31 | -------------------------------------------------------------------------------- /configurations/layer_type_sweep/lig_single.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "voxelnet", 3 | "no_partitions": false, 4 | "fragments": "./data/moad.h5", 5 | "fingerprints": "./data/rdk10_moad.h5", 6 | "learning_rate": 0.001, 7 | "num_epochs": 15, 8 | "test_steps": 400, 9 | "batch_size": 16, 10 | "grid_width": 24, 11 | "grid_res": 0.75, 12 | "fdist_min": null, 13 | "fdist_max": 4, 14 | "fmass_min": null, 15 | "fmass_max": 150, 16 | "ignore_receptor": false, 17 | "ignore_parent": false, 18 | "output_size": 2048, 19 | "pad": false, 20 | "blocks": [64, 64], 21 | "fc": [512], 22 | "use_all_labels": true, 23 | "dist_fn": "cos", 24 | "loss": "direct", 25 | "point_radius": 1, 26 | "point_type": 3, 27 | "rec_typer": "simple", 28 | "acc_type": 0, 29 | "lig_typer": "single" 30 | } 31 | -------------------------------------------------------------------------------- /configurations/layer_type_sweep/lig_single_h.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "voxelnet", 3 | "no_partitions": false, 4 | "fragments": "./data/moad.h5", 5 | "fingerprints": "./data/rdk10_moad.h5", 6 | "learning_rate": 0.001, 7 | "num_epochs": 15, 8 | "test_steps": 400, 9 | "batch_size": 16, 10 | "grid_width": 24, 11 | "grid_res": 0.75, 12 | "fdist_min": null, 13 | "fdist_max": 4, 14 | "fmass_min": null, 15 | "fmass_max": 150, 16 | "ignore_receptor": false, 17 | "ignore_parent": false, 18 | "output_size": 2048, 19 | "pad": false, 20 | "blocks": [64, 64], 21 | "fc": [512], 22 | "use_all_labels": true, 23 | "dist_fn": "cos", 24 | "loss": "direct", 25 | "point_radius": 1, 26 | "point_type": 3, 27 | "rec_typer": "simple", 28 | "acc_type": 0, 29 | "lig_typer": "single_h" 30 | } 31 | -------------------------------------------------------------------------------- /configurations/layer_type_sweep/rec_meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "voxelnet", 3 | "no_partitions": false, 4 | "fragments": "./data/moad.h5", 5 | "fingerprints": "./data/rdk10_moad.h5", 6 | "learning_rate": 0.001, 7 | "num_epochs": 15, 8 | "test_steps": 400, 9 | "batch_size": 16, 10 | "grid_width": 24, 11 | "grid_res": 0.75, 12 | "fdist_min": null, 13 | "fdist_max": 4, 14 | "fmass_min": null, 15 | "fmass_max": 150, 16 | "ignore_receptor": false, 17 | "ignore_parent": false, 18 | "output_size": 2048, 19 | "pad": false, 20 | "blocks": [64, 64], 21 | "fc": [512], 22 | "use_all_labels": true, 23 | "dist_fn": "cos", 24 | "loss": "direct", 25 | "point_radius": 1, 26 | "point_type": 3, 27 | "rec_typer": "meta", 28 | "acc_type": 0, 29 | "lig_typer": "simple" 30 | } 31 | -------------------------------------------------------------------------------- /configurations/layer_type_sweep/rec_meta_mix.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "voxelnet", 3 | "no_partitions": false, 4 | "fragments": "./data/moad.h5", 5 | "fingerprints": "./data/rdk10_moad.h5", 6 | "learning_rate": 0.001, 7 | "num_epochs": 15, 8 | "test_steps": 400, 9 | "batch_size": 16, 10 | "grid_width": 24, 11 | "grid_res": 0.75, 12 | "fdist_min": null, 13 | "fdist_max": 4, 14 | "fmass_min": null, 15 | "fmass_max": 150, 16 | "ignore_receptor": false, 17 | "ignore_parent": false, 18 | "output_size": 2048, 19 | "pad": false, 20 | "blocks": [64, 64], 21 | "fc": [512], 22 | "use_all_labels": true, 23 | "dist_fn": "cos", 24 | "loss": "direct", 25 | "point_radius": 1, 26 | "point_type": 3, 27 | "rec_typer": "meta_mix", 28 | "acc_type": 0, 29 | "lig_typer": "simple" 30 | } 31 | -------------------------------------------------------------------------------- /configurations/layer_type_sweep/rec_simple_h.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "voxelnet", 3 | "no_partitions": false, 4 | "fragments": "./data/moad.h5", 5 | "fingerprints": "./data/rdk10_moad.h5", 6 | "learning_rate": 0.001, 7 | "num_epochs": 15, 8 | "test_steps": 400, 9 | "batch_size": 16, 10 | "grid_width": 24, 11 | "grid_res": 0.75, 12 | "fdist_min": null, 13 | "fdist_max": 4, 14 | "fmass_min": null, 15 | "fmass_max": 150, 16 | "ignore_receptor": false, 17 | "ignore_parent": false, 18 | "output_size": 2048, 19 | "pad": false, 20 | "blocks": [64, 64], 21 | "fc": [512], 22 | "use_all_labels": true, 23 | "dist_fn": "cos", 24 | "loss": "direct", 25 | "point_radius": 1, 26 | "point_type": 3, 27 | "rec_typer": "simple_h", 28 | "acc_type": 0, 29 | "lig_typer": "simple" 30 | } 31 | -------------------------------------------------------------------------------- /configurations/layer_type_sweep/rec_single.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "voxelnet", 3 | "no_partitions": false, 4 | "fragments": "./data/moad.h5", 5 | "fingerprints": "./data/rdk10_moad.h5", 6 | "learning_rate": 0.001, 7 | "num_epochs": 15, 8 | "test_steps": 400, 9 | "batch_size": 16, 10 | "grid_width": 24, 11 | "grid_res": 0.75, 12 | "fdist_min": null, 13 | "fdist_max": 4, 14 | "fmass_min": null, 15 | "fmass_max": 150, 16 | "ignore_receptor": false, 17 | "ignore_parent": false, 18 | "output_size": 2048, 19 | "pad": false, 20 | "blocks": [64, 64], 21 | "fc": [512], 22 | "use_all_labels": true, 23 | "dist_fn": "cos", 24 | "loss": "direct", 25 | "point_radius": 1, 26 | "point_type": 3, 27 | "rec_typer": "single", 28 | "acc_type": 0, 29 | "lig_typer": "simple" 30 | } 31 | -------------------------------------------------------------------------------- /configurations/layer_type_sweep/rec_single_h.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "voxelnet", 3 | "no_partitions": false, 4 | "fragments": "./data/moad.h5", 5 | "fingerprints": "./data/rdk10_moad.h5", 6 | "learning_rate": 0.001, 7 | "num_epochs": 15, 8 | "test_steps": 400, 9 | "batch_size": 16, 10 | "grid_width": 24, 11 | "grid_res": 0.75, 12 | "fdist_min": null, 13 | "fdist_max": 4, 14 | "fmass_min": null, 15 | "fmass_max": 150, 16 | "ignore_receptor": false, 17 | "ignore_parent": false, 18 | "output_size": 2048, 19 | "pad": false, 20 | "blocks": [64, 64], 21 | "fc": [512], 22 | "use_all_labels": true, 23 | "dist_fn": "cos", 24 | "loss": "direct", 25 | "point_radius": 1, 26 | "point_type": 3, 27 | "rec_typer": "single_h", 28 | "acc_type": 0, 29 | "lig_typer": "simple" 30 | } 31 | -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t0.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1, "point_type": 0, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t1.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1, "point_type": 1, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t10.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1.75, "point_type": 4, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t11.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1.75, "point_type": 5, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t12.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 2.5, "point_type": 0, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t13.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 2.5, "point_type": 1, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t14.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 2.5, "point_type": 2, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t15.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 2.5, "point_type": 3, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t16.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 2.5, "point_type": 4, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t17.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 2.5, "point_type": 5, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t2.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1, "point_type": 2, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t3.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1, "point_type": 3, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t4.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1, "point_type": 4, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t5.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1, "point_type": 5, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t6.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1.75, "point_type": 0, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t7.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1.75, "point_type": 1, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t8.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1.75, "point_type": 2, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /configurations/voxelation_sweep/t9.json: -------------------------------------------------------------------------------- 1 | {"version": "voxelnet", "no_partitions": false, "fragments": "./data/moad.h5", "fingerprints": "./data/rdk10_moad.h5", "learning_rate": 0.001, "num_epochs": 15, "test_steps": 400, "batch_size": 16, "grid_width": 24, "grid_res": 0.75, "fdist_min": null, "fdist_max": 4, "fmass_min": null, "fmass_max": 150, "ignore_receptor": false, "ignore_parent": false, "output_size": 2048, "pad": false, "blocks": [64, 64], "fc": [512], "use_all_labels": true, "dist_fn": "cos", "loss": "direct", "point_radius": 1.75, "point_type": 3, "rec_typer": "simple", "acc_type": 0, "lig_typer": "simple"} -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data for Training and Inference 2 | 3 | This folder contains data used during training and inference. 4 | 5 | Model configuration files in `/configurations` expect the data files to be in 6 | this directory. You can either copy them directly here or use symlinks. 7 | 8 | You can download the data here: http://durrantlab.com/apps/deepfrag/files/ 9 | 10 | 11 | 12 | Overview: 13 | 14 | - `moad.h5` (7 GB): processed MOAD data loaded by `data_util.FragmentDataset` 15 | - `rdk10_moad` (384 MB): RDK-10 fingerprints for MOAD data loaded by 16 | `data_util.FingerprintDataset` (generated with 17 | `scripts/make_fingerprints.py`) 18 | -------------------------------------------------------------------------------- /deepfrag.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import os 4 | import pathlib 5 | import shutil 6 | import time 7 | from typing import Tuple 8 | import zipfile 9 | 10 | import requests 11 | from tqdm.auto import tqdm 12 | import h5py 13 | import numpy as np 14 | import rdkit.Chem.AllChem as Chem 15 | import torch 16 | import prody 17 | 18 | from leadopt.model_conf import LeadoptModel, REC_TYPER, LIG_TYPER, DIST_FN 19 | from leadopt import util, grid_util 20 | 21 | 22 | USER_DIR = './.store' 23 | PDB_CACHE = 'pdb_cache' 24 | 25 | MODEL_DOWNLOAD = 'https://durrantlab.pitt.edu/apps/deepfrag/files/final_model_v2.zip' 26 | FINGERPRINTS_DOWNLOAD = 'https://durrantlab.pitt.edu/apps/deepfrag/files/fingerprints.h5' 27 | 28 | RCSB_DOWNLOAD = 'https://files.rcsb.org/download/%s.pdb1' 29 | 30 | VERSION = "1.0.4" 31 | 32 | def download_remote(url, path, compression=None): 33 | r = requests.get(url, stream=True, allow_redirects=True) 34 | if r.status_code != 200: 35 | r.raise_for_status() 36 | print(f'Can\'t access {url}') 37 | 38 | file_size = int(r.headers.get('Content-Length', 0)) 39 | 40 | r.raw.read = functools.partial(r.raw.read, decode_content=True) 41 | with tqdm.wrapattr(r.raw, 'read', total=file_size, desc='Downloading') as r_raw: 42 | with path.open('wb') as f: 43 | shutil.copyfileobj(r_raw, f) 44 | 45 | if compression is not None: 46 | shutil.move(str(path), str(path) + '.tmp') 47 | shutil.unpack_archive(str(path) + '.tmp', str(path), format=compression) 48 | 49 | 50 | def get_deepfrag_user_dir() -> pathlib.Path: 51 | user_dir = pathlib.Path(os.path.realpath(__file__)).parent / USER_DIR 52 | os.makedirs(str(user_dir), exist_ok=True) 53 | return user_dir 54 | 55 | 56 | def get_model_path(): 57 | return get_deepfrag_user_dir() / 'model' 58 | 59 | 60 | def get_fingerprints_path(): 61 | return get_deepfrag_user_dir() / 'fingerprints.h5' 62 | 63 | 64 | def ensure_cli_data(): 65 | model_path = get_model_path() 66 | fingerprints_path = get_fingerprints_path() 67 | 68 | if not os.path.exists(str(model_path)): 69 | r = input('Pre-trained DeepFrag model not found, download it now? (5.8 MB) [Y/n]: ') 70 | if r.lower() == 'n': 71 | print('Exiting...') 72 | exit(-1) 73 | 74 | print(f'Saving to {model_path}...') 75 | download_remote(MODEL_DOWNLOAD, model_path, compression='zip') 76 | 77 | if not os.path.exists(str(fingerprints_path)): 78 | r = input('Fingerprint library not found, download it now? (11 MB) [Y/n]: ') 79 | if r.lower() == 'n': 80 | print('Exiting...') 81 | exit(-1) 82 | 83 | print(f'Saving to {fingerprints_path}...') 84 | download_remote(FINGERPRINTS_DOWNLOAD, fingerprints_path, compression=None) 85 | 86 | 87 | def download_pdb(pdb_id, path): 88 | download_remote(RCSB_DOWNLOAD % pdb_id, path, compression=None) 89 | 90 | 91 | def load_pdb(pdb_id, resnum): 92 | pdb_id = pdb_id.upper() 93 | assert all([x in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' for x in pdb_id]) 94 | 95 | # Check pdb cache 96 | pdb_dir = get_deepfrag_user_dir() / PDB_CACHE / pdb_id 97 | 98 | complex_path = pdb_dir / 'complex.pdb' 99 | rec_path = pdb_dir / 'receptor.pdb' 100 | lig_path = pdb_dir / 'ligand.pdb' 101 | 102 | os.makedirs(str(pdb_dir), exist_ok=True) 103 | 104 | if not os.path.exists(complex_path): 105 | download_pdb(pdb_id, complex_path) 106 | 107 | with open(str(complex_path), 'r') as f: 108 | m = prody.parsePDBStream(f) 109 | rec = m.select('not (nucleic or hetatm) and not water') 110 | lig = m.select('resnum %d' % resnum) 111 | 112 | if lig is None: 113 | print('[!] Error could not find ligand with resnum: %d' % resnum) 114 | exit(-1) 115 | 116 | prody.writePDB(str(rec_path), rec) 117 | prody.writePDB(str(lig_path), lig) 118 | 119 | return (str(rec_path), str(lig_path)) 120 | 121 | 122 | def get_structure_paths(args) -> Tuple[str, str]: 123 | """Get structure paths specified by the command line args. 124 | Returns (rec_path, lig_path) 125 | """ 126 | if args.receptor is not None and args.ligand is not None: 127 | return (args.receptor, args.ligand) 128 | elif args.pdb is not None and args.resnum is not None: 129 | return load_pdb(args.pdb, args.resnum) 130 | else: 131 | raise NotImplementedError() 132 | 133 | 134 | def preprocess_ligand_without_removal_point(lig, conn): 135 | """ 136 | Mark the atom at conn as a connecting atom. Useful when adding a fragment. 137 | """ 138 | 139 | lig_pos = lig.GetConformer().GetPositions() 140 | lig_atm_conn_dist = np.sum((lig_pos - conn) ** 2, axis=1) 141 | 142 | # Get index of min 143 | min_idx = int(np.argmin(lig_atm_conn_dist)) 144 | 145 | # Get atom at that position 146 | lig_atm_conn = lig.GetAtomWithIdx(min_idx) 147 | 148 | # Add a dummy atom to the ligand, connected to lig_atm_conn 149 | dummy_atom = Chem.MolFromSmiles("*") 150 | merged = Chem.RWMol(Chem.CombineMols(lig, dummy_atom)) 151 | 152 | idx_of_dummy_in_merged = int([a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0][0]) 153 | bond = merged.AddBond(min_idx, idx_of_dummy_in_merged, Chem.rdchem.BondType.SINGLE) 154 | 155 | return merged 156 | 157 | 158 | def preprocess_ligand_with_removal_point(lig, conn, rvec): 159 | """ 160 | Remove the fragment from lig connected via the atom at conn and containing 161 | the atom at rvec. Useful when replacing a fragment. 162 | """ 163 | # Generate all fragments. 164 | frags = util.generate_fragments(lig) 165 | 166 | for parent, frag in frags: 167 | # Get the index of the dummy (connection) atom on the fragment. 168 | cidx = [a for a in frag.GetAtoms() if a.GetAtomicNum() == 0][0].GetIdx() 169 | 170 | # Get the coordinates of the associated atom (the dummy atom's 171 | # neighbor). 172 | vec = frag.GetConformer().GetAtomPosition(cidx) 173 | c_vec = np.array([vec.x, vec.y, vec.z]) 174 | 175 | # Check connection point. 176 | if np.linalg.norm(c_vec - conn) < 1e-3: 177 | # Check removal point. 178 | frag_pos = frag.GetConformer().GetPositions() 179 | min_dist = np.min(np.sum((frag_pos - rvec) ** 2, axis=1)) 180 | 181 | if min_dist < 1e-3: 182 | # You have found the parent/fragment split that correctly 183 | # exposes the user-specified connection-point atom. 184 | 185 | # Found fragment. 186 | print('[*] Removing fragment with %d atoms (%s)' % ( 187 | frag_pos.shape[0] - 1, Chem.MolToSmiles(frag, False))) 188 | 189 | return parent 190 | 191 | print('[!] Could not find a suitable fragment to remove.') 192 | exit(-1) 193 | 194 | 195 | def lookup_atom_name(lig_path, name): 196 | """Try to look up an atom by name. Returns the coordinate of the atom if 197 | found.""" 198 | with open(lig_path, 'r') as f: 199 | p = prody.parsePDBStream(f) 200 | p = p.select(f'name {name}') 201 | if p is None: 202 | print(f'[!] Error: no atom with name "{name}" in ligand') 203 | exit(-1) 204 | elif len(p) > 1: 205 | print(f'[!] Error: multiple atoms with name "{name}" in ligand') 206 | exit(-1) 207 | return p.getCoords()[0] 208 | 209 | 210 | def get_structures(args): 211 | rec_path, lig_path = get_structure_paths(args) 212 | 213 | print(f'[*] Loading receptor: {rec_path} ... ', end='') 214 | rec_coords, rec_types = util.load_receptor_ob(rec_path) 215 | print('done.') 216 | 217 | print(f'[*] Loading ligand: {lig_path} ... ', end='') 218 | lig = Chem.MolFromPDBFile(lig_path) 219 | print('done.') 220 | 221 | conn = None 222 | if args.cx is not None and args.cy is not None and args.cz is not None: 223 | conn = np.array([float(args.cx), float(args.cy), float(args.cz)]) 224 | elif args.cname is not None: 225 | conn = lookup_atom_name(lig_path, args.cname) 226 | else: 227 | raise NotImplementedError() 228 | 229 | rvec = None 230 | if args.rx is not None and args.ry is not None and args.rz is not None: 231 | rvec = np.array([float(args.rx), float(args.ry), float(args.rz)]) 232 | elif args.rname is not None: 233 | rvec = lookup_atom_name(lig_path, args.rname) 234 | else: 235 | pass 236 | 237 | if rvec is not None: 238 | # Fragment repalcement (rvec specified) 239 | lig = preprocess_ligand_with_removal_point(lig, conn, rvec) 240 | else: 241 | # Only fragment addition 242 | lig = preprocess_ligand_without_removal_point(lig, conn) 243 | 244 | parent_coords = util.get_coords(lig) 245 | parent_types = np.array(util.get_types(lig)).reshape((-1,1)) 246 | 247 | return (rec_coords, rec_types, parent_coords, parent_types, conn, lig) 248 | 249 | 250 | def get_model(args, device): 251 | """Load a pre-trained DeepFrag model.""" 252 | print('[*] Loading model ... ', end='') 253 | model = LeadoptModel.load(str(get_model_path() / 'final_model'), device=('cuda' if device == 'gpu' else device)) 254 | print('done.') 255 | return model 256 | 257 | 258 | def get_fingerprints(args): 259 | """Load the fingerprint library. 260 | Returns (smiles, fingerprints). 261 | """ 262 | f_smiles = None 263 | f_fingerprints = None 264 | print('[*] Loading fingerprint library ... ', end='') 265 | with h5py.File(str(get_fingerprints_path()), 'r') as f: 266 | f_smiles = f['smiles'][()] 267 | f_fingerprints = f['fingerprints'][()].astype(float) 268 | print('done.') 269 | 270 | return (f_smiles, f_fingerprints) 271 | 272 | 273 | def get_target_device(args) -> str: 274 | """Infer the target device or use the argument overrides.""" 275 | device = 'gpu' if torch.cuda.device_count() > 0 else 'cpu' 276 | 277 | if args.cpu: 278 | if device == 'gpu': 279 | print('[*] Warning: GPU is available but running on CPU due to --cpu flag') 280 | device = 'cpu' 281 | elif args.gpu: 282 | if device == 'cpu': 283 | print('[*] Error: No CUDA-enabled GPU was found. Exiting due to --gpu flag. You can run on the CPU instead with the --cpu flag.') 284 | exit(-1) 285 | device = 'gpu' 286 | 287 | print('[*] Running on device: %s' % device) 288 | 289 | return device 290 | 291 | 292 | def generate_grids(args, model_args, rec_coords, rec_types, parent_coords, parent_types, conn, device): 293 | start = time.time() 294 | 295 | print('[*] Generating grids ... ', end='', flush=True) 296 | batch = grid_util.get_raw_batch( 297 | rec_coords, rec_types, parent_coords, parent_types, 298 | rec_typer=REC_TYPER[model_args['rec_typer']], 299 | lig_typer=LIG_TYPER[model_args['lig_typer']], 300 | conn=conn, 301 | num_samples=args.num_grids, 302 | width=model_args['grid_width'], 303 | res=model_args['grid_res'], 304 | point_radius=model_args['point_radius'], 305 | point_type=model_args['point_type'], 306 | acc_type=model_args['acc_type'], 307 | cpu=(device == 'cpu') 308 | ) 309 | print('done.') 310 | end = time.time() 311 | print(f'[*] Generated grids in {end-start:.3f} seconds.') 312 | 313 | return batch 314 | 315 | 316 | def get_predictions(model, batch, f_smiles, f_fingerprints): 317 | start = time.time() 318 | pred = model.predict(torch.tensor(batch).float()).cpu().numpy() 319 | end = time.time() 320 | print(f'[*] Generated prediction in {end-start} seconds.') 321 | 322 | avg_fp = np.mean(pred, axis=0) 323 | dist_fn = DIST_FN[model._args['dist_fn']] 324 | 325 | # The distance functions are implemented in pytorch so we need to convert our 326 | # numpy arrays to a torch Tensor. 327 | dist = 1 - dist_fn( 328 | torch.tensor(avg_fp).unsqueeze(0), 329 | torch.tensor(f_fingerprints)) 330 | 331 | # Pair smiles strings and distances. 332 | dist = list(dist.numpy()) 333 | scores = list(zip(f_smiles, dist)) 334 | scores = sorted(scores, key=lambda x:x[1], reverse=True) 335 | scores = [(a.decode('ascii'), b) for a,b in scores] 336 | 337 | return scores 338 | 339 | 340 | def gen_output(args, scores): 341 | if args.out is None: 342 | # Write results to stdout. 343 | print('%4s %8s %s' % ('#', 'Score', 'SMILES')) 344 | for i in range(len(scores)): 345 | smi, score = scores[i] 346 | print('%4d %8f %s' % (i+1, score, smi)) 347 | else: 348 | # Write csv output. 349 | csv = 'Rank,SMILES,Score\n' 350 | for i in range(len(scores)): 351 | smi, score = scores[i] 352 | csv += '%d,%s,%f\n' % ( 353 | i+1, smi, score 354 | ) 355 | 356 | open(args.out, 'w').write(csv) 357 | print('[*] Wrote output to %s' % args.out) 358 | 359 | 360 | def fuse(lig, frag): 361 | # Combine the ligand and fragment, though this does not form a bond between 362 | # the two. 363 | merged = Chem.RWMol(Chem.CombineMols(lig, frag)) 364 | 365 | conn_atoms = [a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0] 366 | neighbors = [merged.GetAtomWithIdx(x).GetNeighbors()[0].GetIdx() for x in conn_atoms] 367 | 368 | bond = merged.AddBond(neighbors[0], neighbors[1], Chem.rdchem.BondType.SINGLE) 369 | 370 | merged.RemoveAtom([a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0][0]) 371 | merged.RemoveAtom([a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0][0]) 372 | 373 | Chem.SanitizeMol(merged) 374 | 375 | return merged 376 | 377 | 378 | def fuse_fragments(lig, conn, scores): 379 | # Note: lig is rdkit.Chem.rdchem.Mol; scores is a list of (smiles, score) 380 | # tuples. 381 | new_sc = [] 382 | for smi, score in scores: 383 | try: 384 | frag = Chem.MolFromSmiles(smi) 385 | fused = fuse(Chem.Mol(lig), frag) 386 | new_sc.append((Chem.MolToSmiles(fused, False), score)) 387 | except: 388 | print('[*] Error: couldn\'t process mol.') 389 | new_sc.append(('', score)) 390 | 391 | return new_sc 392 | 393 | 394 | def run(args): 395 | device = get_target_device(args) 396 | 397 | model = get_model(args, device) 398 | f_smiles, f_fingerprints = get_fingerprints(args) 399 | 400 | rec_coords, rec_types, parent_coords, parent_types, conn, lig = get_structures(args) 401 | 402 | batch = generate_grids(args, model._args, rec_coords, rec_types, 403 | parent_coords, parent_types, conn, device) 404 | 405 | scores = get_predictions(model, batch, f_smiles, f_fingerprints) 406 | 407 | if args.top_k != -1: 408 | scores = scores[:args.top_k] 409 | 410 | if args.full: 411 | scores = fuse_fragments(lig, conn, scores) 412 | 413 | gen_output(args, scores) 414 | 415 | 416 | def main(): 417 | global VERSION 418 | 419 | print("\nDeepFrag " + VERSION) 420 | print("\nIf you use DeepFrag in your research, please cite:\n") 421 | print("Green, H., Koes, D. R., & Durrant, J. D. (2021). DeepFrag: a deep convolutional") 422 | print("neural network for fragment-based lead optimization. Chemical Science.\n") 423 | 424 | ensure_cli_data() 425 | 426 | parser = argparse.ArgumentParser() 427 | 428 | # Structure 429 | parser.add_argument('--receptor', help='Path to receptor structure.') 430 | parser.add_argument('--ligand', help='Path to ligand structure.') 431 | parser.add_argument('--pdb', help='PDB ID to download.') 432 | parser.add_argument('--resnum', type=int, help='Residue number of ligand.') 433 | 434 | # Connection point 435 | parser.add_argument('--cx', type=float, help='Connection point x coordinate.') 436 | parser.add_argument('--cy', type=float, help='Connection point y coordinate.') 437 | parser.add_argument('--cz', type=float, help='Connection point z coordinate.') 438 | parser.add_argument('--cname', type=str, help='Connection point atom name.') 439 | 440 | # Removal point 441 | parser.add_argument('--rx', type=float, help='Removal point x coordinate.') 442 | parser.add_argument('--ry', type=float, help='Removal point y coordinate.') 443 | parser.add_argument('--rz', type=float, help='Removal point z coordinate.') 444 | parser.add_argument('--rname', type=str, help='Removal point atom name.') 445 | 446 | # Misc 447 | parser.add_argument('--full', action='store_true', default=False, 448 | help='Print the full (fused) ligand structure.') 449 | parser.add_argument('--num_grids', type=int, default=4, 450 | help='Number of grid rotations.') 451 | parser.add_argument('--top_k', type=int, default=25, 452 | help='Number of results to show. Set to -1 to show all.') 453 | parser.add_argument('--out', type=str, 454 | help='Path to output CSV file.') 455 | parser.add_argument('--cpu', action='store_true', default=False, 456 | help='Use the CPU for grid generation and predictions.') 457 | parser.add_argument('--gpu', action='store_true', default=False, 458 | help='Use a (CUDA-capable) GPU for grid generation and predictions.') 459 | 460 | args = parser.parse_args() 461 | 462 | groupings = [ 463 | ([('receptor', 'ligand'), ('pdb', 'resnum')], True), 464 | ([('cx', 'cy', 'cz'), ('cname',)], True), 465 | ([('rx', 'ry', 'rz'), ('rname',)], False), 466 | ([('cpu',), ('gpu',)], False) 467 | ] 468 | 469 | for grp, req in groupings: 470 | partial = [] 471 | complete = 0 472 | 473 | for subset in grp: 474 | res = [not (getattr(args, name) in [None, False]) for name in subset] 475 | partial.append(any(res) and not all(res)) 476 | complete += int(all(res)) 477 | 478 | if any(partial) or complete > 1 or (complete != 1 and req): 479 | # Invalid arg combination. 480 | print('Invalid arguments, must specify exactly one of the following combinations:') 481 | for subset in grp: 482 | print('\t%s' % ', '.join(['--' + x for x in subset])) 483 | exit(-1) 484 | 485 | run(args) 486 | 487 | 488 | if __name__=='__main__': 489 | main() 490 | -------------------------------------------------------------------------------- /leadopt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | -------------------------------------------------------------------------------- /leadopt/data_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | """ 17 | Contains utility code for reading packed data files. 18 | """ 19 | import os 20 | 21 | import torch 22 | from torch.utils.data import DataLoader, Dataset 23 | import numpy as np 24 | import h5py 25 | import tqdm 26 | 27 | # Atom typing 28 | # 29 | # Atom typing is the process of figuring out which layer each atom should be 30 | # written to. For ease of testing, the packed data file contains a lot of 31 | # potentially useful atomic information which can be distilled during the 32 | # data loading process. 33 | # 34 | # Atom typing is implemented by map functions of the type: 35 | # (atom descriptor) -> (layer index) 36 | # 37 | # If the layer index is -1, the atom is ignored. 38 | 39 | 40 | class AtomTyper(object): 41 | def __init__(self, fn, num_layers): 42 | """Initialize an atom typer. 43 | 44 | Args: 45 | fn: a function of type: 46 | (atomic_num, aro, hdon, hacc, pcharge) -> (mask) 47 | num_layers: number of output layers (<=32) 48 | """ 49 | self._fn = fn 50 | self._num_layers = num_layers 51 | 52 | def size(self): 53 | return self._num_layers 54 | 55 | def apply(self, *args): 56 | return self._fn(*args) 57 | 58 | 59 | class CondAtomTyper(AtomTyper): 60 | def __init__(self, cond_func): 61 | assert len(cond_func) <= 16 62 | def _fn(*args): 63 | v = 0 64 | for k in range(len(cond_func)): 65 | if cond_func[k](*args): 66 | v |= 1 << k 67 | return v 68 | super(CondAtomTyper, self).__init__(_fn, len(cond_func)) 69 | 70 | 71 | REC_TYPER = { 72 | # 1 channel, no hydrogen 73 | 'single': CondAtomTyper([ 74 | lambda num, aro, hdon, hacc, pcharge: num not in [0,1] 75 | ]), 76 | 77 | # 1 channel, including hydrogen 78 | 'single_h': CondAtomTyper([ 79 | lambda num, aro, hdon, hacc, pcharge: num != 0 80 | ]), 81 | 82 | # (C,N,O,S,*) 83 | 'simple': CondAtomTyper([ 84 | lambda num, aro, hdon, hacc, pcharge: num == 6, 85 | lambda num, aro, hdon, hacc, pcharge: num == 7, 86 | lambda num, aro, hdon, hacc, pcharge: num == 8, 87 | lambda num, aro, hdon, hacc, pcharge: num == 16, 88 | lambda num, aro, hdon, hacc, pcharge: num not in [0,1,6,7,8,16], 89 | ]), 90 | 91 | # (H,C,N,O,S,*) 92 | 'simple_h': CondAtomTyper([ 93 | lambda num, aro, hdon, hacc, pcharge: num == 1, 94 | lambda num, aro, hdon, hacc, pcharge: num == 6, 95 | lambda num, aro, hdon, hacc, pcharge: num == 7, 96 | lambda num, aro, hdon, hacc, pcharge: num == 8, 97 | lambda num, aro, hdon, hacc, pcharge: num == 16, 98 | lambda num, aro, hdon, hacc, pcharge: num not in [0,1,6,7,8,16], 99 | ]), 100 | 101 | # (aro, hdon, hacc, positive, negative, occ) 102 | 'meta': CondAtomTyper([ 103 | lambda num, aro, hdon, hacc, pcharge: bool(aro), # aromatic 104 | lambda num, aro, hdon, hacc, pcharge: bool(hdon), # hydrogen donor 105 | lambda num, aro, hdon, hacc, pcharge: bool(hacc), # hydrogen acceptor 106 | lambda num, aro, hdon, hacc, pcharge: pcharge >= 128, # partial positive 107 | lambda num, aro, hdon, hacc, pcharge: pcharge < 128, # partial negative 108 | lambda num, aro, hdon, hacc, pcharge: num != 0, # occupancy 109 | ]), 110 | 111 | # (aro, hdon, hacc, positive, negative, occ) 112 | 'meta_mix': CondAtomTyper([ 113 | lambda num, aro, hdon, hacc, pcharge: bool(aro), # aromatic 114 | lambda num, aro, hdon, hacc, pcharge: bool(hdon), # hydrogen donor 115 | lambda num, aro, hdon, hacc, pcharge: bool(hacc), # hydrogen acceptor 116 | lambda num, aro, hdon, hacc, pcharge: pcharge >= 128, # partial positive 117 | lambda num, aro, hdon, hacc, pcharge: pcharge < 128, # partial negative 118 | lambda num, aro, hdon, hacc, pcharge: num != 0, # occupancy 119 | lambda num, aro, hdon, hacc, pcharge: num == 1, # hydrogen 120 | lambda num, aro, hdon, hacc, pcharge: num == 6, # carbon 121 | lambda num, aro, hdon, hacc, pcharge: num == 7, # nitrogen 122 | lambda num, aro, hdon, hacc, pcharge: num == 8, # oxygen 123 | lambda num, aro, hdon, hacc, pcharge: num == 16, # sulfur 124 | ]) 125 | } 126 | 127 | LIG_TYPER = { 128 | # 1 channel, no hydrogen 129 | 'single': CondAtomTyper([ 130 | lambda num: num not in [0,1] 131 | ]), 132 | 133 | # 1 channel, including hydrogen 134 | 'single_h': CondAtomTyper([ 135 | lambda num: num != 0 136 | ]), 137 | 138 | 'simple': CondAtomTyper([ 139 | lambda num: num == 6, # carbon 140 | lambda num: num == 7, # nitrogen 141 | lambda num: num == 8, # oxygen 142 | lambda num: num not in [0,1,6,7,8] # extra 143 | ]), 144 | 145 | 'simple_h': CondAtomTyper([ 146 | lambda num: num == 1, # hydrogen 147 | lambda num: num == 6, # carbon 148 | lambda num: num == 7, # nitrogen 149 | lambda num: num == 8, # oxygen 150 | lambda num: num not in [0,1,6,7,8] # extra 151 | ]) 152 | } 153 | 154 | 155 | class FragmentDataset(Dataset): 156 | """Utility class to work with the packed fragments.h5 format.""" 157 | 158 | def __init__(self, fragment_file, rec_typer=REC_TYPER['simple'], 159 | lig_typer=LIG_TYPER['simple'], filter_rec=None, filter_smi=None, 160 | fdist_min=None, fdist_max=None, fmass_min=None, fmass_max=None, 161 | verbose=False, lazy_loading=True): 162 | """Initializes the fragment dataset. 163 | 164 | Args: 165 | fragment_file: path to fragments.h5 166 | rec_typer: AtomTyper for receptor 167 | lig_typer: AtomTyper for ligand 168 | filter_rec: list of receptor ids to use (or None to use all) 169 | skip_remap: if True, don't prepare atom type information 170 | 171 | (filtering options): 172 | fdist_min: minimum fragment distance 173 | fdist_max: maximum fragment distance 174 | fmass_min: minimum fragment mass (Da) 175 | fmass_max: maximum fragment mass (Da) 176 | """ 177 | self._rec_typer = rec_typer 178 | self._lig_typer = lig_typer 179 | 180 | self.verbose = verbose 181 | self._lazy_loading = lazy_loading 182 | 183 | self.rec = self._load_rec(fragment_file, rec_typer) 184 | self.frag = self._load_fragments(fragment_file, lig_typer) 185 | 186 | self.valid_idx = self._get_valid_examples( 187 | filter_rec, filter_smi, fdist_min, fdist_max, fmass_min, fmass_max, verbose) 188 | 189 | def _load_rec(self, fragment_file, rec_typer): 190 | """Loads receptor information.""" 191 | f = h5py.File(fragment_file, 'r') 192 | 193 | rec_coords = f['rec_coords'][()] 194 | rec_types = f['rec_types'][()] 195 | rec_lookup = f['rec_lookup'][()] 196 | 197 | r = range(len(rec_types)) 198 | if self.verbose: 199 | r = tqdm.tqdm(r, desc='Remap receptor atoms') 200 | 201 | rec_remapped = np.zeros(len(rec_types), dtype=np.uint16) 202 | if not self._lazy_loading: 203 | for i in r: 204 | rec_remapped[i] = rec_typer.apply(*rec_types[i]) 205 | 206 | rec_loaded = np.zeros(len(rec_lookup)).astype(np.bool) 207 | 208 | # create rec mapping 209 | rec_mapping = {} 210 | for i in range(len(rec_lookup)): 211 | rec_mapping[rec_lookup[i][0].decode('ascii')] = i 212 | 213 | rec = { 214 | 'rec_coords': rec_coords, 215 | 'rec_types': rec_types, 216 | 'rec_remapped': rec_remapped, 217 | 'rec_lookup': rec_lookup, 218 | 'rec_mapping': rec_mapping, 219 | 'rec_loaded': rec_loaded 220 | } 221 | 222 | f.close() 223 | 224 | return rec 225 | 226 | def _load_fragments(self, fragment_file, lig_typer): 227 | """Loads fragment information.""" 228 | f = h5py.File(fragment_file, 'r') 229 | 230 | frag_data = f['frag_data'][()] 231 | frag_lookup = f['frag_lookup'][()] 232 | frag_smiles = f['frag_smiles'][()] 233 | frag_mass = f['frag_mass'][()] 234 | frag_dist = f['frag_dist'][()] 235 | 236 | frag_lig_smi = None 237 | frag_lig_idx = None 238 | if 'frag_lig_smi' in f.keys(): 239 | frag_lig_smi = f['frag_lig_smi'][()] 240 | frag_lig_idx = f['frag_lig_idx'][()] 241 | 242 | # unpack frag data into separate structures 243 | frag_coords = frag_data[:,:3].astype(np.float32) 244 | frag_types = frag_data[:,3].astype(np.uint8) 245 | 246 | frag_remapped = np.zeros(len(frag_types), dtype=np.uint16) 247 | if not self._lazy_loading: 248 | for i in range(len(frag_types)): 249 | frag_remapped[i] = lig_typer.apply(frag_types[i]) 250 | 251 | frag_loaded = np.zeros(len(frag_lookup)).astype(np.bool) 252 | 253 | # find and save connection point 254 | r = range(len(frag_lookup)) 255 | if self.verbose: 256 | r = tqdm.tqdm(r, desc='Frag connection point') 257 | 258 | frag_conn = np.zeros((len(frag_lookup), 3)) 259 | for i in r: 260 | _,f_start,f_end,_,_ = frag_lookup[i] 261 | fdat = frag_data[f_start:f_end] 262 | 263 | found = False 264 | for j in range(len(fdat)): 265 | if fdat[j][3] == 0: 266 | frag_conn[i,:] = tuple(fdat[j])[:3] 267 | found = True 268 | break 269 | 270 | assert found, "missing fragment connection point at %d" % i 271 | 272 | frag = { 273 | 'frag_coords': frag_coords, # d_idx -> (x,y,z) 274 | 'frag_types': frag_types, # d_idx -> (type) 275 | 'frag_remapped': frag_remapped, # d_idx -> (layer) 276 | 'frag_lookup': frag_lookup, # f_idx -> (rec_id, fstart, fend, pstart, pend) 277 | 'frag_conn': frag_conn, # f_idx -> (x,y,z) 278 | 'frag_smiles': frag_smiles, # f_idx -> smiles 279 | 'frag_mass': frag_mass, # f_idx -> mass 280 | 'frag_dist': frag_dist, # f_idx -> dist 281 | 'frag_lig_smi': frag_lig_smi, 282 | 'frag_lig_idx': frag_lig_idx, 283 | 'frag_loaded': frag_loaded 284 | } 285 | 286 | f.close() 287 | 288 | return frag 289 | 290 | def _get_valid_examples(self, filter_rec, filter_smi, fdist_min, fdist_max, fmass_min, 291 | fmass_max, verbose): 292 | """Returns an array of valid fragment indexes. 293 | 294 | "Valid" in this context means the fragment belongs to a receptor in 295 | filter_rec and the fragment abides by the optional mass/distance 296 | constraints. 297 | """ 298 | # keep track of valid examples 299 | valid_mask = np.ones(self.frag['frag_lookup'].shape[0]).astype(np.bool) 300 | 301 | num_frags = self.frag['frag_lookup'].shape[0] 302 | 303 | # filter by receptor id 304 | if filter_rec is not None: 305 | valid_rec = np.zeros(num_frags, dtype=np.bool) 306 | 307 | r = range(num_frags) 308 | if verbose: 309 | r = tqdm.tqdm(r, desc='filter rec') 310 | 311 | for i in r: 312 | rec = self.frag['frag_lookup'][i][0].decode('ascii') 313 | if rec in filter_rec: 314 | valid_rec[i] = 1 315 | valid_mask *= valid_rec 316 | 317 | # filter by ligand smiles string 318 | if filter_smi is not None: 319 | valid_lig = np.zeros(num_frags, dtype=np.bool) 320 | 321 | r = range(num_frags) 322 | if verbose: 323 | r = tqdm.tqdm(r, desc='filter lig') 324 | 325 | for i in r: 326 | smi = self.frag['frag_lig_smi'][self.frag['frag_lig_idx'][i]] 327 | smi = smi.decode('ascii') 328 | if smi in filter_smi: 329 | valid_lig[i] = 1 330 | 331 | valid_mask *= valid_lig 332 | 333 | # filter by fragment distance 334 | if fdist_min is not None: 335 | valid_mask[self.frag['frag_dist'] < fdist_min] = 0 336 | 337 | if fdist_max is not None: 338 | valid_mask[self.frag['frag_dist'] > fdist_max] = 0 339 | 340 | # filter by fragment mass 341 | if fmass_min is not None: 342 | valid_mask[self.frag['frag_mass'] < fmass_min] = 0 343 | 344 | if fmass_max is not None: 345 | valid_mask[self.frag['frag_mass'] > fmass_max] = 0 346 | 347 | # convert to a list of indexes 348 | valid_idx = np.where(valid_mask)[0] 349 | 350 | return valid_idx 351 | 352 | def __len__(self): 353 | """Returns the number of valid fragment examples.""" 354 | return self.valid_idx.shape[0] 355 | 356 | def __getitem__(self, idx): 357 | """Returns the Nth example. 358 | 359 | Returns a dict with: 360 | f_coords: fragment coordinates (Fx3) 361 | f_types: fragment layers (Fx1) 362 | p_coords: parent coordinates (Px3) 363 | p_types: parent layers (Px1) 364 | r_coords: receptor coordinates (Rx3) 365 | r_types: receptor layers (Rx1) 366 | conn: fragment connection point in the parent molecule (x,y,z) 367 | smiles: fragment smiles string 368 | """ 369 | # convert to fragment index 370 | frag_idx = self.valid_idx[idx] 371 | return self.get_raw(frag_idx) 372 | 373 | def get_raw(self, frag_idx): 374 | # lookup fragment 375 | rec_id, f_start, f_end, p_start, p_end = self.frag['frag_lookup'][frag_idx] 376 | smiles = self.frag['frag_smiles'][frag_idx].decode('ascii') 377 | conn = self.frag['frag_conn'][frag_idx] 378 | 379 | # lookup receptor 380 | rec_idx = self.rec['rec_mapping'][rec_id.decode('ascii')] 381 | _, r_start, r_end = self.rec['rec_lookup'][rec_idx] 382 | 383 | # fetch data 384 | # f_coords = self.frag['frag_coords'][f_start:f_end] 385 | # f_types = self.frag['frag_types'][f_start:f_end] 386 | p_coords = self.frag['frag_coords'][p_start:p_end] 387 | r_coords = self.rec['rec_coords'][r_start:r_end] 388 | 389 | if self._lazy_loading and self.frag['frag_loaded'][frag_idx] == 0: 390 | frag_types = self.frag['frag_types'] 391 | frag_remapped = self.frag['frag_remapped'] 392 | 393 | # load parent 394 | for i in range(p_start, p_end): 395 | frag_remapped[i] = self._lig_typer.apply(frag_types[i]) 396 | 397 | self.frag['frag_loaded'][frag_idx] = 1 398 | 399 | if self._lazy_loading and self.rec['rec_loaded'][rec_idx] == 0: 400 | rec_types = self.rec['rec_types'] 401 | rec_remapped = self.rec['rec_remapped'] 402 | 403 | # load receptor 404 | for i in range(r_start, r_end): 405 | rec_remapped[i] = self._rec_typer.apply(*rec_types[i]) 406 | 407 | self.rec['rec_loaded'][rec_idx] = 1 408 | 409 | p_mask = self.frag['frag_remapped'][p_start:p_end] 410 | r_mask = self.rec['rec_remapped'][r_start:r_end] 411 | 412 | return { 413 | # 'f_coords': f_coords, 414 | # 'f_types': f_types, 415 | 'p_coords': p_coords, 416 | 'p_types': p_mask, 417 | 'r_coords': r_coords, 418 | 'r_types': r_mask, 419 | 'conn': conn, 420 | 'smiles': smiles 421 | } 422 | 423 | def get_valid_smiles(self): 424 | """Returns a list of all valid smiles fragments.""" 425 | valid_smiles = set() 426 | 427 | for idx in self.valid_idx: 428 | smiles = self.frag['frag_smiles'][idx].decode('ascii') 429 | valid_smiles.add(smiles) 430 | 431 | return list(valid_smiles) 432 | 433 | def lig_layers(self): 434 | return self._lig_typer.size() 435 | 436 | def rec_layers(self): 437 | return self._rec_typer.size() 438 | 439 | 440 | class SharedFragmentDataset(object): 441 | def __init__(self, dat, filter_rec=None, filter_smi=None, fdist_min=None, 442 | fdist_max=None, fmass_min=None, fmass_max=None): 443 | 444 | self._dat = dat 445 | 446 | self.valid_idx = self._dat._get_valid_examples( 447 | filter_rec, filter_smi, fdist_min, fdist_max, fmass_min, fmass_max, verbose=True) 448 | 449 | def __len__(self): 450 | return self.valid_idx.shape[0] 451 | 452 | def __getitem__(self, idx): 453 | frag_idx = self.valid_idx[idx] 454 | return self._dat.get_raw(frag_idx) 455 | 456 | def get_valid_smiles(self): 457 | """Returns a list of all valid smiles fragments.""" 458 | valid_smiles = set() 459 | 460 | for idx in self.valid_idx: 461 | smiles = self._dat.frag['frag_smiles'][idx].decode('ascii') 462 | valid_smiles.add(smiles) 463 | 464 | return list(valid_smiles) 465 | 466 | def lig_layers(self): 467 | return self._dat.lig_layers() 468 | 469 | def rec_layers(self): 470 | return self._dat.rec_layers() 471 | 472 | 473 | class FingerprintDataset(Dataset): 474 | 475 | def __init__(self, fingerprint_file): 476 | """Initializes a fingerprint dataset. 477 | 478 | Args: 479 | fingerprint_file: path to a fingerprint .h5 file 480 | """ 481 | self.fingerprints = self._load_fingerprints(fingerprint_file) 482 | 483 | def _load_fingerprints(self, fingerprint_file): 484 | """Loads fingerprint information.""" 485 | f = h5py.File(fingerprint_file, 'r') 486 | 487 | fingerprint_data = f['fingerprints'][()] 488 | fingerprint_smiles = f['smiles'][()] 489 | 490 | # create smiles->idx mapping 491 | fingerprint_mapping = {} 492 | for i in range(len(fingerprint_smiles)): 493 | sm = fingerprint_smiles[i].decode('ascii') 494 | fingerprint_mapping[sm] = i 495 | 496 | fingerprints = { 497 | 'fingerprint_data': fingerprint_data, 498 | 'fingerprint_mapping': fingerprint_mapping, 499 | 'fingerprint_smiles': fingerprint_smiles, 500 | } 501 | 502 | f.close() 503 | 504 | return fingerprints 505 | 506 | def for_smiles(self, smiles): 507 | """Return a Tensor of fingerprints for a list of smiles. 508 | 509 | Args: 510 | smiles: size N list of smiles strings (as str not bytes) 511 | """ 512 | fp = np.zeros((len(smiles), self.fingerprints['fingerprint_data'].shape[1])) 513 | 514 | for i in range(len(smiles)): 515 | fp_idx = self.fingerprints['fingerprint_mapping'][smiles[i]] 516 | fp[i] = self.fingerprints['fingerprint_data'][fp_idx] 517 | 518 | return torch.Tensor(fp) 519 | -------------------------------------------------------------------------------- /leadopt/grid_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | """ 17 | Contains code for gpu-accelerated grid generation. 18 | """ 19 | 20 | import math 21 | import ctypes 22 | 23 | import torch 24 | import numba 25 | import numba.cuda 26 | import numpy as np 27 | 28 | 29 | GPU_DIM = 8 30 | 31 | 32 | class POINT_TYPE(object): 33 | EXP = 0 # simple exponential sphere fill 34 | SPHERE = 1 # fixed sphere fill 35 | CUBE = 2 # fixed cube fill 36 | GAUSSIAN = 3 # continous piecewise expenential fill 37 | LJ = 4 38 | DISCRETE = 5 39 | 40 | class ACC_TYPE(object): 41 | SUM = 0 42 | MAX = 1 43 | 44 | 45 | @numba.cuda.jit 46 | def gpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset, 47 | batch_idx, width, res, center, rot, 48 | point_radius, point_type, acc_type 49 | ): 50 | """Adds atoms to the grid in a GPU kernel. 51 | 52 | This kernel converts atom coordinate information to 3d voxel information. 53 | Each GPU thread is responsible for one specific grid point. This function 54 | receives a list of atomic coordinates and atom layers and simply iterates 55 | over the list to find nearby atoms and add their effect. 56 | 57 | Voxel information is stored in a 5D tensor of type: BxTxNxNxN where: 58 | B = batch size 59 | T = number of atom types (receptor + ligand) 60 | N = grid width (in gridpoints) 61 | 62 | Each invocation of this function will write information to a specific batch 63 | index specified by batch_idx. Additionally, the layer_offset parameter can 64 | be set to specify a fixed offset to add to each atom_layer item. 65 | 66 | How it works: 67 | 1. Each GPU thread controls a single gridpoint. This gridpoint coordinate 68 | is translated to a "real world" coordinate by applying rotation and 69 | translation vectors. 70 | 2. Each thread iterates over the list of atoms and checks for atoms within 71 | a threshold to add to the grid. 72 | 73 | Args: 74 | grid: DeviceNDArray tensor where grid information is stored 75 | atom_num: number of atoms 76 | atom_coords: array containing (x,y,z) atom coordinates 77 | atom_mask: uint32 array of size atom_num containing a destination 78 | layer bitmask (i.e. if bit k is set, write atom to index k) 79 | layer_offset: a fixed ofset added to each atom layer index 80 | batch_idx: index specifiying which batch to write information to 81 | width: number of grid points in each dimension 82 | res: distance between neighboring grid points in angstroms 83 | (1 == gridpoint every angstrom) 84 | (0.5 == gridpoint every half angstrom, e.g. tighter grid) 85 | center: (x,y,z) coordinate of grid center 86 | rot: (x,y,z,y) rotation quaternion 87 | """ 88 | x,y,z = numba.cuda.grid(3) 89 | 90 | # center around origin 91 | tx = x - (width/2) 92 | ty = y - (width/2) 93 | tz = z - (width/2) 94 | 95 | # scale by resolution 96 | tx = tx * res 97 | ty = ty * res 98 | tz = tz * res 99 | 100 | # apply rotation vector 101 | aw = rot[0] 102 | ax = rot[1] 103 | ay = rot[2] 104 | az = rot[3] 105 | 106 | bw = 0 107 | bx = tx 108 | by = ty 109 | bz = tz 110 | 111 | # multiply by rotation vector 112 | cw = (aw * bw) - (ax * bx) - (ay * by) - (az * bz) 113 | cx = (aw * bx) + (ax * bw) + (ay * bz) - (az * by) 114 | cy = (aw * by) + (ay * bw) + (az * bx) - (ax * bz) 115 | cz = (aw * bz) + (az * bw) + (ax * by) - (ay * bx) 116 | 117 | # multiply by conjugate 118 | # dw = (cw * aw) - (cx * (-ax)) - (cy * (-ay)) - (cz * (-az)) 119 | dx = (cw * (-ax)) + (cx * aw) + (cy * (-az)) - (cz * (-ay)) 120 | dy = (cw * (-ay)) + (cy * aw) + (cz * (-ax)) - (cx * (-az)) 121 | dz = (cw * (-az)) + (cz * aw) + (cx * (-ay)) - (cy * (-ax)) 122 | 123 | # apply translation vector 124 | tx = dx + center[0] 125 | ty = dy + center[1] 126 | tz = dz + center[2] 127 | 128 | i = 0 129 | while i < atom_num: 130 | # fetch atom 131 | fx, fy, fz = atom_coords[i] 132 | mask = atom_mask[i] 133 | i += 1 134 | 135 | # invisible atoms 136 | if mask == 0: 137 | continue 138 | 139 | # point radius squared 140 | r = point_radius 141 | r2 = point_radius * point_radius 142 | 143 | # quick cube bounds check 144 | if abs(fx-tx) > r2 or abs(fy-ty) > r2 or abs(fz-tz) > r2: 145 | continue 146 | 147 | # value to add to this gridpoint 148 | val = 0 149 | 150 | if point_type == 0: # POINT_TYPE.EXP 151 | # exponential sphere fill 152 | # compute squared distance to atom 153 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2 154 | if d2 > r2: 155 | continue 156 | 157 | # compute effect 158 | val = math.exp((-2 * d2) / r2) 159 | elif point_type == 1: # POINT_TYPE.SPHERE 160 | # solid sphere fill 161 | # compute squared distance to atom 162 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2 163 | if d2 > r2: 164 | continue 165 | 166 | val = 1 167 | elif point_type == 2: # POINT_TYPE.CUBE 168 | # solid cube fill 169 | val = 1 170 | elif point_type == 3: # POINT_TYPE.GAUSSIAN 171 | # (Ragoza, 2016) 172 | # 173 | # piecewise gaussian sphere fill 174 | # compute squared distance to atom 175 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2 176 | d = math.sqrt(d2) 177 | 178 | if d > r * 1.5: 179 | continue 180 | elif d > r: 181 | val = math.exp(-2.0) * ( (4*d2/r2) - (12*d/r) + 9 ) 182 | else: 183 | val = math.exp((-2 * d2) / r2) 184 | elif point_type == 4: # POINT_TYPE.LJ 185 | # (Jimenez, 2017) - DeepSite 186 | # 187 | # LJ potential 188 | # compute squared distance to atom 189 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2 190 | d = math.sqrt(d2) 191 | 192 | if d > r * 1.5: 193 | continue 194 | else: 195 | val = 1 - math.exp(-((r/d)**12)) 196 | elif point_type == 5: # POINT_TYPE.DISCRETE 197 | # nearest-gridpoint 198 | # L1 distance 199 | if abs(fx-tx) < (res/2) and abs(fy-ty) < (res/2) and abs(fz-tz) < (res/2): 200 | val = 1 201 | 202 | # add value to layers 203 | for k in range(32): 204 | if (mask >> k) & 1: 205 | idx = (batch_idx, layer_offset+k, x, y, z) 206 | if acc_type == 0: # ACC_TYPE.SUM 207 | numba.cuda.atomic.add(grid, idx, val) 208 | elif acc_type == 1: # ACC_TYPE.MAX 209 | numba.cuda.atomic.max(grid, idx, val) 210 | 211 | 212 | @numba.jit(nopython=True) 213 | def cpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset, 214 | batch_idx, width, res, center, rot, 215 | point_radius, point_type, acc_type 216 | ): 217 | """Adds atoms to the grid in a GPU kernel. 218 | 219 | This kernel converts atom coordinate information to 3d voxel information. 220 | Each GPU thread is responsible for one specific grid point. This function 221 | receives a list of atomic coordinates and atom layers and simply iterates 222 | over the list to find nearby atoms and add their effect. 223 | 224 | Voxel information is stored in a 5D tensor of type: BxTxNxNxN where: 225 | B = batch size 226 | T = number of atom types (receptor + ligand) 227 | N = grid width (in gridpoints) 228 | 229 | Each invocation of this function will write information to a specific batch 230 | index specified by batch_idx. Additionally, the layer_offset parameter can 231 | be set to specify a fixed offset to add to each atom_layer item. 232 | 233 | How it works: 234 | 1. Each GPU thread controls a single gridpoint. This gridpoint coordinate 235 | is translated to a "real world" coordinate by applying rotation and 236 | translation vectors. 237 | 2. Each thread iterates over the list of atoms and checks for atoms within 238 | a threshold to add to the grid. 239 | 240 | Args: 241 | grid: DeviceNDArray tensor where grid information is stored 242 | atom_num: number of atoms 243 | atom_coords: array containing (x,y,z) atom coordinates 244 | atom_mask: uint32 array of size atom_num containing a destination 245 | layer bitmask (i.e. if bit k is set, write atom to index k) 246 | layer_offset: a fixed ofset added to each atom layer index 247 | batch_idx: index specifiying which batch to write information to 248 | width: number of grid points in each dimension 249 | res: distance between neighboring grid points in angstroms 250 | (1 == gridpoint every angstrom) 251 | (0.5 == gridpoint every half angstrom, e.g. tighter grid) 252 | center: (x,y,z) coordinate of grid center 253 | rot: (x,y,z,y) rotation quaternion 254 | """ 255 | # x,y,z = numba.cuda.grid(3) 256 | for x in range(width): 257 | for y in range(width): 258 | for z in range(width): 259 | 260 | # center around origin 261 | tx = x - (width/2) 262 | ty = y - (width/2) 263 | tz = z - (width/2) 264 | 265 | # scale by resolution 266 | tx = tx * res 267 | ty = ty * res 268 | tz = tz * res 269 | 270 | # apply rotation vector 271 | aw = rot[0] 272 | ax = rot[1] 273 | ay = rot[2] 274 | az = rot[3] 275 | 276 | bw = 0 277 | bx = tx 278 | by = ty 279 | bz = tz 280 | 281 | # multiply by rotation vector 282 | cw = (aw * bw) - (ax * bx) - (ay * by) - (az * bz) 283 | cx = (aw * bx) + (ax * bw) + (ay * bz) - (az * by) 284 | cy = (aw * by) + (ay * bw) + (az * bx) - (ax * bz) 285 | cz = (aw * bz) + (az * bw) + (ax * by) - (ay * bx) 286 | 287 | # multiply by conjugate 288 | # dw = (cw * aw) - (cx * (-ax)) - (cy * (-ay)) - (cz * (-az)) 289 | dx = (cw * (-ax)) + (cx * aw) + (cy * (-az)) - (cz * (-ay)) 290 | dy = (cw * (-ay)) + (cy * aw) + (cz * (-ax)) - (cx * (-az)) 291 | dz = (cw * (-az)) + (cz * aw) + (cx * (-ay)) - (cy * (-ax)) 292 | 293 | # apply translation vector 294 | tx = dx + center[0] 295 | ty = dy + center[1] 296 | tz = dz + center[2] 297 | 298 | i = 0 299 | while i < atom_num: 300 | # fetch atom 301 | fx, fy, fz = atom_coords[i] 302 | mask = atom_mask[i] 303 | i += 1 304 | 305 | # invisible atoms 306 | if mask == 0: 307 | continue 308 | 309 | # point radius squared 310 | r = point_radius 311 | r2 = point_radius * point_radius 312 | 313 | # quick cube bounds check 314 | if abs(fx-tx) > r2 or abs(fy-ty) > r2 or abs(fz-tz) > r2: 315 | continue 316 | 317 | # value to add to this gridpoint 318 | val = 0 319 | 320 | if point_type == 0: # POINT_TYPE.EXP 321 | # exponential sphere fill 322 | # compute squared distance to atom 323 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2 324 | if d2 > r2: 325 | continue 326 | 327 | # compute effect 328 | val = math.exp((-2 * d2) / r2) 329 | elif point_type == 1: # POINT_TYPE.SPHERE 330 | # solid sphere fill 331 | # compute squared distance to atom 332 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2 333 | if d2 > r2: 334 | continue 335 | 336 | val = 1 337 | elif point_type == 2: # POINT_TYPE.CUBE 338 | # solid cube fill 339 | val = 1 340 | elif point_type == 3: # POINT_TYPE.GAUSSIAN 341 | # (Ragoza, 2016) 342 | # 343 | # piecewise gaussian sphere fill 344 | # compute squared distance to atom 345 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2 346 | d = math.sqrt(d2) 347 | 348 | if d > r * 1.5: 349 | continue 350 | elif d > r: 351 | val = math.exp(-2.0) * ( (4*d2/r2) - (12*d/r) + 9 ) 352 | else: 353 | val = math.exp((-2 * d2) / r2) 354 | elif point_type == 4: # POINT_TYPE.LJ 355 | # (Jimenez, 2017) - DeepSite 356 | # 357 | # LJ potential 358 | # compute squared distance to atom 359 | d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2 360 | d = math.sqrt(d2) 361 | 362 | if d > r * 1.5: 363 | continue 364 | else: 365 | val = 1 - math.exp(-((r/d)**12)) 366 | elif point_type == 5: # POINT_TYPE.DISCRETE 367 | # nearest-gridpoint 368 | # L1 distance 369 | if abs(fx-tx) < (res/2) and abs(fy-ty) < (res/2) and abs(fz-tz) < (res/2): 370 | val = 1 371 | 372 | # add value to layers 373 | for k in range(32): 374 | if (mask >> k) & 1: 375 | idx = (batch_idx, layer_offset+k, x, y, z) 376 | if acc_type == 0: # ACC_TYPE.SUM 377 | grid[idx] += val 378 | elif acc_type == 1: # ACC_TYPE.MAX 379 | grid[idx] = max(grid[idx], val) 380 | 381 | 382 | def mol_gridify(grid, atom_coords, atom_mask, layer_offset, batch_idx, 383 | width, res, center, rot, point_radius, point_type, acc_type, 384 | cpu=False): 385 | """Wrapper around gpu_gridify. 386 | 387 | (See gpu_gridify() for details) 388 | """ 389 | if cpu: 390 | cpu_gridify( 391 | grid, len(atom_coords), atom_coords, atom_mask, layer_offset, 392 | batch_idx, width, res, center, rot, point_radius, point_type, acc_type 393 | ) 394 | else: 395 | dw = ((width - 1) // GPU_DIM) + 1 396 | gpu_gridify[(dw,dw,dw), (GPU_DIM,GPU_DIM,GPU_DIM)]( 397 | grid, len(atom_coords), atom_coords, atom_mask, layer_offset, 398 | batch_idx, width, res, center, rot, point_radius, point_type, acc_type 399 | ) 400 | 401 | 402 | def make_tensor(shape): 403 | """Creates a pytorch tensor and numba array with shared GPU memory backing. 404 | 405 | Args: 406 | shape: the shape of the array 407 | 408 | Returns: 409 | (torch_arr, cuda_arr) 410 | """ 411 | # get cuda context 412 | ctx = numba.cuda.cudadrv.driver.driver.get_active_context() 413 | 414 | # setup tensor on gpu 415 | t = torch.zeros(size=shape, dtype=torch.float32).cuda() 416 | 417 | memory = numba.cuda.cudadrv.driver.MemoryPointer(ctx, ctypes.c_ulong(t.data_ptr()), t.numel() * 4) 418 | cuda_arr = numba.cuda.cudadrv.devicearray.DeviceNDArray( 419 | t.size(), 420 | [i*4 for i in t.stride()], 421 | np.dtype('float32'), 422 | gpu_data=memory, 423 | stream=torch.cuda.current_stream().cuda_stream 424 | ) 425 | 426 | return (t, cuda_arr) 427 | 428 | 429 | def rand_rot(): 430 | """Returns a random uniform quaternion rotation.""" 431 | q = np.random.normal(size=4) # sample quaternion from normal distribution 432 | q = q / np.sqrt(np.sum(q**2)) # normalize 433 | return q 434 | 435 | 436 | def get_batch(data, batch_size=16, batch_set=None, width=48, res=0.5, 437 | ignore_receptor=False, ignore_parent=False, fixed_rot=None, 438 | point_radius=2, point_type=POINT_TYPE.EXP, 439 | acc_type=ACC_TYPE.SUM): 440 | """Builds a batch grid from a FragmentDataset. 441 | 442 | Args: 443 | data: a FragmentDataset object 444 | rec_channels: number of receptor channels 445 | parent_channels: number of parent channels 446 | batch_size: size of the batch 447 | batch_set: if not None, specify a list of data indexes to use for each 448 | item in the batch 449 | width: grid width 450 | res: grid resolution 451 | ignore_receptor: if True, ignore receptor atoms 452 | ignore_parent: if True, ignore parent atoms 453 | 454 | Returns: (torch_grid, batch_set) 455 | torch_grid: pytorch Tensor with voxel information 456 | examples: list of examples used 457 | """ 458 | assert (not (ignore_receptor and ignore_parent)), "Can't ignore parent and receptor!" 459 | 460 | batch_size = int(batch_size) 461 | width = int(width) 462 | 463 | rec_channels = data.rec_layers() 464 | lig_channels = data.lig_layers() 465 | 466 | dim = 0 467 | if not ignore_receptor: 468 | dim += rec_channels 469 | if not ignore_parent: 470 | dim += lig_channels 471 | 472 | # create a tensor with shared memory on the gpu 473 | torch_grid, cuda_grid = make_tensor((batch_size, dim, width, width, width)) 474 | 475 | if batch_set is None: 476 | batch_set = np.random.choice(len(data), size=batch_size, replace=False) 477 | 478 | examples = [data[idx] for idx in batch_set] 479 | 480 | for i in range(len(examples)): 481 | example = examples[i] 482 | rot = fixed_rot 483 | if rot is None: 484 | rot = rand_rot() 485 | 486 | if ignore_receptor: 487 | mol_gridify( 488 | cuda_grid, 489 | example['p_coords'], 490 | example['p_types'], 491 | layer_offset=0, 492 | batch_idx=i, 493 | width=width, 494 | res=res, 495 | center=example['conn'], 496 | rot=rot, 497 | point_radius=point_radius, 498 | point_type=point_type, 499 | acc_type=acc_type 500 | ) 501 | elif ignore_parent: 502 | mol_gridify( 503 | cuda_grid, 504 | example['r_coords'], 505 | example['r_types'], 506 | layer_offset=0, 507 | batch_idx=i, 508 | width=width, 509 | res=res, 510 | center=example['conn'], 511 | rot=rot, 512 | point_radius=point_radius, 513 | point_type=point_type, 514 | acc_type=acc_type 515 | ) 516 | else: 517 | mol_gridify( 518 | cuda_grid, 519 | example['p_coords'], 520 | example['p_types'], 521 | layer_offset=0, 522 | batch_idx=i, 523 | width=width, 524 | res=res, 525 | center=example['conn'], 526 | rot=rot, 527 | point_radius=point_radius, 528 | point_type=point_type, 529 | acc_type=acc_type 530 | ) 531 | mol_gridify( 532 | cuda_grid, 533 | example['r_coords'], 534 | example['r_types'], 535 | layer_offset=lig_channels, 536 | batch_idx=i, 537 | width=width, 538 | res=res, 539 | center=example['conn'], 540 | rot=rot, 541 | point_radius=point_radius, 542 | point_type=point_type, 543 | acc_type=acc_type 544 | ) 545 | 546 | return torch_grid, examples 547 | 548 | 549 | def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer, 550 | conn, num_samples=32, width=24, res=1, fixed_rot=None, 551 | point_radius=1.5, point_type=0, acc_type=0, cpu=False): 552 | """Sample a raw batch with provided atom coordinates. 553 | 554 | Args: 555 | r_coords: receptor coordinates 556 | r_types: receptor types (layers) 557 | p_coords: parent coordinates 558 | p_types: parent types (layers) 559 | conn: (x,y,z) connection point 560 | num_samples: number of rotations to sample 561 | width: grid width 562 | res: grid resolution 563 | fixed_rot: None or a fixed 4-element rotation vector 564 | point_radius: atom radius in Angstroms 565 | point_type: shape of the atom densities 566 | acc_type: atom density accumulation type 567 | cpu: if True, generate batches with cpu_gridify 568 | """ 569 | B = num_samples 570 | T = rec_typer.size() + lig_typer.size() 571 | N = width 572 | 573 | if cpu: 574 | t = np.zeros((B,T,N,N,N)) 575 | torch_grid = t 576 | cuda_grid = t 577 | else: 578 | torch_grid, cuda_grid = make_tensor((B,T,N,N,N)) 579 | 580 | r_mask = np.zeros(len(r_types), dtype=np.uint32) 581 | p_mask = np.zeros(len(p_types), dtype=np.uint32) 582 | 583 | for i in range(len(r_types)): 584 | r_mask[i] = rec_typer.apply(*r_types[i]) 585 | 586 | for i in range(len(p_types)): 587 | p_mask[i] = lig_typer.apply(*p_types[i]) 588 | 589 | for i in range(num_samples): 590 | rot = fixed_rot 591 | if rot is None: 592 | rot = rand_rot() 593 | 594 | mol_gridify( 595 | cuda_grid, 596 | p_coords, 597 | p_mask, 598 | layer_offset=0, 599 | batch_idx=i, 600 | width=width, 601 | res=res, 602 | center=conn, 603 | rot=rot, 604 | point_radius=point_radius, 605 | point_type=point_type, 606 | acc_type=acc_type, 607 | cpu=cpu 608 | ) 609 | 610 | mol_gridify( 611 | cuda_grid, 612 | r_coords, 613 | r_mask, 614 | layer_offset=lig_typer.size(), 615 | batch_idx=i, 616 | width=width, 617 | res=res, 618 | center=conn, 619 | rot=rot, 620 | point_radius=point_radius, 621 | point_type=point_type, 622 | acc_type=acc_type, 623 | cpu=cpu 624 | ) 625 | 626 | return torch_grid 627 | -------------------------------------------------------------------------------- /leadopt/infer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | import os 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | import rdkit.Chem.AllChem as Chem 23 | 24 | import numpy as np 25 | import h5py 26 | import tqdm 27 | 28 | from leadopt.grid_util import get_raw_batch 29 | import leadopt.util as util 30 | 31 | 32 | def get_nearest_fp(fingerprints, fp, k=10): 33 | ''' 34 | Return the top-k closest rows in fingerprints 35 | 36 | Returns [(idx1, dist1), (idx2, dist2), ...] 37 | ''' 38 | def mse(a,b): 39 | return np.sum((a-b)**2, axis=1) 40 | 41 | d = mse(fingerprints, fp) 42 | arr = [(i,d[i]) for i in range(len(d))] 43 | arr = sorted(arr, key=lambda x: x[1]) 44 | 45 | return arr[:k] 46 | 47 | 48 | def infer_all(model, fingerprints, smiles, rec_path, lig_path, num_samples=16, k=25): 49 | ''' 50 | 51 | ''' 52 | # load ligand and receptor 53 | lig, frags = util.load_ligand(lig_path) 54 | rec = util.load_receptor(rec_path) 55 | 56 | 57 | # compute shared receptor coords and layers 58 | rec_coords, rec_layers = util.mol_to_points(rec) 59 | # [ 60 | # (parent_sm, orig_frag_sm, conn, [ 61 | # (new_frag_sm, merged_sm, score), 62 | # ... 63 | # ]) 64 | # ] 65 | res = [] 66 | 67 | for parent, frag in frags: 68 | # compute parent coords and layers 69 | parent_coords, parent_layers = util.mol_to_points(parent) 70 | 71 | # find connection point 72 | conn = util.get_connection_point(frag) 73 | 74 | # generate batch 75 | grid = get_raw_batch(rec_coords, rec_layers, parent_coords, parent_layers, conn) 76 | 77 | # infer 78 | fp = model(grid).detach().cpu().numpy() 79 | fp_mean = np.mean(fp, axis=0) 80 | 81 | # find closest fingerprints 82 | top = get_nearest_fp(fingerprints, fp_mean, k=k) 83 | 84 | # convert to (frag_smiles, merged_smiles, score) tuples 85 | top_smiles = [(smiles[x[0]].decode('ascii'), util.merge_smiles(Chem.MolToSmiles(parent), smiles[x[0]].decode('ascii')), x[1]) for x in top] 86 | 87 | res.append( 88 | (Chem.MolToSmiles(parent, isomericSmiles=False), Chem.MolToSmiles(frag, isomericSmiles=False), tuple(conn), top_smiles) 89 | ) 90 | 91 | return res 92 | -------------------------------------------------------------------------------- /leadopt/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | 21 | def mse(yp, yt): 22 | """Mean squared error loss.""" 23 | return torch.sum((yp - yt) ** 2, axis=1) 24 | 25 | 26 | def bce(yp, yt): 27 | """Binary cross entropy loss.""" 28 | return torch.sum(F.binary_cross_entropy(yp, yt, reduction='none'), axis=1) 29 | 30 | 31 | def tanimoto(yp, yt): 32 | """Tanimoto distance metric.""" 33 | intersect = torch.sum(yt * torch.round(yp), axis=1) 34 | union = torch.sum(torch.clamp(yt + torch.round(yp), 0, 1), axis=1) 35 | return 1 - (intersect / union) 36 | 37 | 38 | _cos = nn.CosineSimilarity(dim=1, eps=1e-6) 39 | def cos(yp, yt): 40 | """Cosine distance as a loss (inverted).""" 41 | return 1 - _cos(yp,yt) 42 | 43 | 44 | def broadcast_fn(fn, yp, yt): 45 | """Broadcast a distance function.""" 46 | yp_b, yt_b = torch.broadcast_tensors(yp, yt) 47 | return fn(yp_b, yt_b) 48 | 49 | 50 | def average_position(fingerprints, fn, norm=True): 51 | """Returns the average ranking of the correct fragment relative to all 52 | possible fragments. 53 | 54 | Args: 55 | fingerprints: NxF tensor of fingerprint data 56 | fn: distance function to compare fingerprints 57 | norm: if True, normalize position in range (0,1) 58 | """ 59 | def _average_position(yp, yt): 60 | # distance to correct fragment 61 | p_dist = broadcast_fn(fn, yp, yt.detach()) 62 | 63 | c = torch.empty(yp.shape[0]) 64 | for i in range(yp.shape[0]): 65 | # compute distance to all other fragments 66 | dist = broadcast_fn(fn, yp[i].unsqueeze(0), fingerprints) 67 | 68 | # number of fragment that are closer or equal 69 | count = torch.sum((dist <= p_dist[i]).to(torch.float)) 70 | c[i] = count 71 | 72 | score = torch.mean(c) 73 | return score 74 | 75 | return _average_position 76 | 77 | 78 | def average_support(fingerprints, fn): 79 | """ 80 | """ 81 | def _average_support(yp, yt): 82 | # correct distance 83 | p_dist = broadcast_fn(fn, yp, yt) 84 | 85 | c = torch.empty(yp.shape[0]) 86 | for i in range(yp.shape[0]): 87 | # compute distance to all other fragments 88 | dist = broadcast_fn(fn, yp[i].unsqueeze(0), fingerprints) 89 | 90 | # shift distance so bad examples are positive 91 | dist -= p_dist[i] 92 | dist *= -1 93 | 94 | dist_n = torch.sigmoid(dist) 95 | 96 | c[i] = torch.mean(dist_n) 97 | 98 | score = torch.mean(c) 99 | 100 | return score 101 | 102 | return _average_support 103 | 104 | 105 | def inside_support(fingerprints, fn): 106 | """ 107 | """ 108 | def _inside_support(yp, yt): 109 | # correct distance 110 | p_dist = broadcast_fn(fn, yp, yt) 111 | 112 | c = torch.empty(yp.shape[0]) 113 | for i in range(yp.shape[0]): 114 | # compute distance to all other fragments 115 | dist = broadcast_fn(fn, yp[i].unsqueeze(0), fingerprints) 116 | 117 | # shift distance so bad examples are positive 118 | dist -= p_dist[i] 119 | dist *= -1 120 | 121 | # ignore labels that are further away 122 | dist[dist < 0] = 0 123 | 124 | dist_n = torch.sigmoid(dist) 125 | 126 | c[i] = torch.mean(dist_n) 127 | 128 | score = torch.mean(c) 129 | 130 | return score 131 | 132 | return _inside_support 133 | 134 | 135 | def top_k_acc(fingerprints, fn, k, pre=''): 136 | """Top-k accuracy metric. 137 | 138 | Returns a dict containing top-k accuracies: 139 | { 140 | {pre}acc_{k1}: acc_k1, 141 | {pre}acc_{k2}: acc_k2, 142 | } 143 | 144 | Args: 145 | fingerprints: NxF tensor of fingerprints 146 | fn: distance function to compare fingerprints 147 | k: List[int] containing K-positions to evaluate (e.g. [1,5,10]) 148 | pre: optional prefix on the metric name 149 | """ 150 | 151 | def _top_k_acc(yp, yt): 152 | # correct distance 153 | p_dist = broadcast_fn(fn, yp.detach(), yt.detach()) 154 | 155 | c = torch.empty(yp.shape[0], len(k)) 156 | for i in range(yp.shape[0]): 157 | # compute distance to all other fragments 158 | dist = broadcast_fn(fn, yp[i].unsqueeze(0).detach(), fingerprints) 159 | 160 | # number of fragment that are closer or equal 161 | count = torch.sum((dist < p_dist[i]).to(torch.float)) 162 | 163 | for j in range(len(k)): 164 | c[i,j] = int(count < k[j]) 165 | 166 | score = torch.mean(c, 0) 167 | m = {'%sacc_%d' % (pre, h): v.item() for h,v in zip(k,score)} 168 | 169 | return m 170 | 171 | return _top_k_acc 172 | -------------------------------------------------------------------------------- /leadopt/model_conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | import os 17 | import json 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import tqdm 23 | import numpy as np 24 | 25 | try: 26 | import wandb 27 | except: 28 | pass 29 | 30 | from leadopt.models.voxel import VoxelFingerprintNet 31 | from leadopt.data_util import FragmentDataset, SharedFragmentDataset, FingerprintDataset, LIG_TYPER,\ 32 | REC_TYPER 33 | from leadopt.grid_util import get_batch 34 | from leadopt.metrics import mse, bce, tanimoto, cos, top_k_acc,\ 35 | average_support, inside_support 36 | 37 | from config import moad_partitions 38 | 39 | 40 | def get_bios(p): 41 | part = [] 42 | for n in range(20): 43 | part += [x.lower() + '_bio%d' % n for x in p] 44 | return part 45 | 46 | 47 | def _do_mean(fn): 48 | def g(yp, yt): 49 | return torch.mean(fn(yp,yt)) 50 | return g 51 | 52 | 53 | def _direct_loss(fingerprints, fn): 54 | return _do_mean(fn) 55 | 56 | 57 | DIST_FN = { 58 | 'mse': mse, 59 | 'bce': bce, 60 | 'cos': cos, 61 | 'tanimoto': tanimoto 62 | } 63 | 64 | 65 | LOSS_TYPE = { 66 | # minimize distance to target fingeprint 67 | 'direct': _direct_loss, 68 | 69 | # minimize distance to target and maximize distance to all other 70 | 'support_v1': average_support, 71 | 72 | # support, limited to closer points 73 | 'support_v2': average_support, 74 | } 75 | 76 | 77 | class RunLog(object): 78 | def __init__(self, args, models, wandb_project=None): 79 | """Initialize a run logger. 80 | 81 | Args: 82 | args: command line training arguments 83 | models: {name: model} mapping 84 | wandb_project: a project to initialize wandb or None 85 | """ 86 | self._use_wandb = wandb_project != None 87 | if self._use_wandb: 88 | wandb.init( 89 | project=wandb_project, 90 | config=args 91 | ) 92 | 93 | for m in models: 94 | wandb.watch(models[m]) 95 | 96 | def log(self, x): 97 | if self._use_wandb: 98 | wandb.log(x) 99 | 100 | 101 | class MetricTracker(object): 102 | def __init__(self, name, metric_fns): 103 | self._name = name 104 | self._metric_fns = metric_fns 105 | self._metrics = {} 106 | 107 | def evaluate(self, yp, yt): 108 | for m in self._metric_fns: 109 | self.update(m, self._metric_fns[m](yp, yt)) 110 | 111 | def update(self, name, metric): 112 | if type(metric) is dict: 113 | for subname in metric: 114 | fullname = '%s_%s' % (self._name, subname) 115 | if not fullname in self._metrics: 116 | self._metrics[fullname] = 0 117 | self._metrics[fullname] += metric[subname] 118 | else: 119 | fullname = '%s_%s' % (self._name, name) 120 | if not fullname in self._metrics: 121 | self._metrics[fullname] = 0 122 | self._metrics[fullname] += metric 123 | 124 | def normalize(self, size): 125 | for m in self._metrics: 126 | self._metrics[m] /= size 127 | 128 | def clear(self): 129 | self._metrics = {} 130 | 131 | def get(self, name): 132 | fullname = '%s_%s' % (self._name, name) 133 | return self._metrics[fullname] 134 | 135 | def get_all(self): 136 | return self._metrics 137 | 138 | 139 | class LeadoptModel(object): 140 | """Abstract LeadoptModel base class.""" 141 | 142 | @staticmethod 143 | def setup_base_parser(parser): 144 | """Configures base parser arguments.""" 145 | parser.add_argument('--wandb_project', default=None, help=''' 146 | Set this argument to track run in wandb. 147 | ''') 148 | 149 | @staticmethod 150 | def setup_parser(sub): 151 | """Adds arguments to a subparser. 152 | 153 | Args: 154 | sub: an argparse subparser 155 | """ 156 | raise NotImplementedError() 157 | 158 | @staticmethod 159 | def get_defaults(): 160 | return {} 161 | 162 | @classmethod 163 | def load(cls, path, device='cuda'): 164 | """Load model configuration saved with save(). 165 | 166 | Call LeadoptModel.load to infer model type. 167 | Or call subclass.load to load a specific model type. 168 | 169 | Args: 170 | path: full path to saved model 171 | """ 172 | args_path = os.path.join(path, 'args.json') 173 | args = json.loads(open(args_path, 'r').read()) 174 | 175 | model_type = MODELS[args['version']] if cls is LeadoptModel else cls 176 | 177 | default_args = model_type.get_defaults() 178 | for k in default_args: 179 | if not k in args: 180 | args[k] = default_args[k] 181 | 182 | instance = model_type(args, device=device, with_log=False) 183 | for name in instance._models: 184 | model_path = os.path.join(path, '%s.pt' % name) 185 | instance._models[name].load_state_dict(torch.load(model_path, map_location=torch.device(device))) 186 | 187 | return instance 188 | 189 | def __init__(self, args, device='cuda', with_log=True): 190 | self._args = args 191 | self._device = torch.device(device) 192 | self._models = self.init_models() 193 | if with_log: 194 | wandb_project = None 195 | if 'wandb_project' in self._args: 196 | wandb_project = self._args['wandb_project'] 197 | 198 | self._log = RunLog(self._args, self._models, wandb_project) 199 | 200 | def save(self, path): 201 | """Save model configuration to a path. 202 | 203 | Args: 204 | path: path to an existing directory to save models 205 | """ 206 | os.makedirs(path, exist_ok=True) 207 | 208 | args_path = os.path.join(path, 'args.json') 209 | open(args_path, 'w').write(json.dumps(self._args)) 210 | 211 | for name in self._models: 212 | model_path = os.path.join(path, '%s.pt' % name) 213 | torch.save(self._models[name].state_dict(), model_path) 214 | 215 | def init_models(self): 216 | """Initializes any pytorch models. 217 | 218 | Returns a dict of name->model mapping: 219 | {'model_1': m1, 'model_2': m2, ...} 220 | """ 221 | return {} 222 | 223 | def train(self, save_path=None): 224 | """Train the models.""" 225 | raise NotImplementedError() 226 | 227 | 228 | class VoxelNet(LeadoptModel): 229 | @staticmethod 230 | def setup_parser(sub): 231 | # testing 232 | sub.add_argument('--no_partitions', action='store_true', default=False, help=''' 233 | If set, disable the use of TRAIN/VAL partitions during training. 234 | ''') 235 | 236 | # dataset 237 | sub.add_argument('-f', '--fragments', required=True, help=''' 238 | Path to fragments file. 239 | ''') 240 | sub.add_argument('-fp', '--fingerprints', required=True, help=''' 241 | Path to fingerprints file. 242 | ''') 243 | 244 | # training parameters 245 | sub.add_argument('-lr', '--learning_rate', type=float, default=1e-4) 246 | sub.add_argument('--num_epochs', type=int, default=50, help=''' 247 | Number of epochs to train for. 248 | ''') 249 | sub.add_argument('--test_steps', type=int, default=500, help=''' 250 | Number of evaluation steps per epoch. 251 | ''') 252 | sub.add_argument('-b', '--batch_size', default=32, type=int) 253 | 254 | # grid generation 255 | sub.add_argument('--grid_width', type=int, default=24) 256 | sub.add_argument('--grid_res', type=float, default=1) 257 | 258 | # fragment filtering 259 | sub.add_argument('--fdist_min', type=float, help=''' 260 | Ignore fragments closer to the receptor than this distance (Angstroms). 261 | ''') 262 | sub.add_argument('--fdist_max', type=float, help=''' 263 | Ignore fragments further from the receptor than this distance (Angstroms). 264 | ''') 265 | sub.add_argument('--fmass_min', type=float, help=''' 266 | Ignore fragments smaller than this mass (Daltons). 267 | ''') 268 | sub.add_argument('--fmass_max', type=float, help=''' 269 | Ignore fragments larger than this mass (Daltons). 270 | ''') 271 | 272 | # receptor/parent options 273 | sub.add_argument('--ignore_receptor', action='store_true', default=False) 274 | sub.add_argument('--ignore_parent', action='store_true', default=False) 275 | sub.add_argument('-rec_typer', required=True, choices=[k for k in REC_TYPER]) 276 | sub.add_argument('-lig_typer', required=True, choices=[k for k in LIG_TYPER]) 277 | # sub.add_argument('-rec_channels', required=True, type=int) 278 | # sub.add_argument('-lig_channels', required=True, type=int) 279 | 280 | # model parameters 281 | # sub.add_argument('--in_channels', type=int, default=18) 282 | sub.add_argument('--output_size', type=int, default=2048) 283 | sub.add_argument('--pad', default=False, action='store_true') 284 | sub.add_argument('--blocks', nargs='+', type=int, default=[32,64]) 285 | sub.add_argument('--fc', nargs='+', type=int, default=[2048]) 286 | sub.add_argument('--use_all_labels', default=False, action='store_true') 287 | sub.add_argument('--dist_fn', default='mse', choices=[k for k in DIST_FN]) 288 | sub.add_argument('--loss', default='direct', choices=[k for k in LOSS_TYPE]) 289 | 290 | @staticmethod 291 | def get_defaults(): 292 | return { 293 | 'point_radius': 1, 294 | 'point_type': 0, 295 | 'acc_type': 0 296 | } 297 | 298 | def init_models(self): 299 | in_channels = 0 300 | if not self._args['ignore_receptor']: 301 | in_channels += REC_TYPER[self._args['rec_typer']].size() 302 | if not self._args['ignore_parent']: 303 | in_channels += LIG_TYPER[self._args['lig_typer']].size() 304 | 305 | voxel = VoxelFingerprintNet( 306 | in_channels=in_channels, 307 | output_size=self._args['output_size'], 308 | blocks=self._args['blocks'], 309 | fc=self._args['fc'], 310 | pad=self._args['pad'] 311 | ).to(self._device) 312 | return {'voxel': voxel} 313 | 314 | def load_data(self): 315 | print('[*] Loading data...', flush=True) 316 | dat = FragmentDataset( 317 | self._args['fragments'], 318 | rec_typer=REC_TYPER[self._args['rec_typer']], 319 | lig_typer=LIG_TYPER[self._args['lig_typer']], 320 | verbose=True 321 | ) 322 | 323 | train_dat = SharedFragmentDataset( 324 | dat, 325 | filter_rec=set(get_bios(moad_partitions.TRAIN)), 326 | filter_smi=set(moad_partitions.TRAIN_SMI), 327 | fdist_min=self._args['fdist_min'], 328 | fdist_max=self._args['fdist_max'], 329 | fmass_min=self._args['fmass_min'], 330 | fmass_max=self._args['fmass_max'], 331 | ) 332 | 333 | val_dat = SharedFragmentDataset( 334 | dat, 335 | filter_rec=set(get_bios(moad_partitions.VAL)), 336 | filter_smi=set(moad_partitions.VAL_SMI), 337 | fdist_min=self._args['fdist_min'], 338 | fdist_max=self._args['fdist_max'], 339 | fmass_min=self._args['fmass_min'], 340 | fmass_max=self._args['fmass_max'], 341 | ) 342 | 343 | return train_dat, val_dat 344 | 345 | def train(self, save_path=None, custom_steps=None, checkpoint_callback=None, data=None): 346 | 347 | if data is None: 348 | data = self.load_data() 349 | 350 | train_dat, val_dat = data 351 | 352 | fingerprints = FingerprintDataset(self._args['fingerprints']) 353 | 354 | train_smiles = train_dat.get_valid_smiles() 355 | val_smiles = val_dat.get_valid_smiles() 356 | all_smiles = list(set(train_smiles) | set(val_smiles)) 357 | 358 | train_fingerprints = fingerprints.for_smiles(train_smiles).cuda() 359 | val_fingerprints = fingerprints.for_smiles(val_smiles).cuda() 360 | all_fingerprints = fingerprints.for_smiles(all_smiles).cuda() 361 | 362 | # fingerprint metrics 363 | print('[*] Train smiles: %d' % len(train_smiles)) 364 | print('[*] Val smiles: %d' % len(val_smiles)) 365 | print('[*] All smiles: %d' % len(all_smiles)) 366 | 367 | print('[*] Train smiles: %d' % train_fingerprints.shape[0]) 368 | print('[*] Val smiles: %d' % val_fingerprints.shape[0]) 369 | print('[*] All smiles: %d' % all_fingerprints.shape[0]) 370 | 371 | # memory optimization, drop some unnecessary columns 372 | train_dat._dat.frag['frag_mass'] = None 373 | train_dat._dat.frag['frag_dist'] = None 374 | train_dat._dat.frag['frag_lig_smi'] = None 375 | train_dat._dat.frag['frag_lig_idx'] = None 376 | 377 | print('[*] Training...', flush=True) 378 | opt = torch.optim.Adam( 379 | self._models['voxel'].parameters(), lr=self._args['learning_rate']) 380 | steps_per_epoch = len(train_dat) // self._args['batch_size'] 381 | steps_per_epoch = custom_steps if custom_steps is not None else steps_per_epoch 382 | 383 | # configure metrics 384 | dist_fn = DIST_FN[self._args['dist_fn']] 385 | 386 | loss_fingerprints = train_fingerprints 387 | if self._args['use_all_labels']: 388 | loss_fingerprints = all_fingerprints 389 | 390 | loss_fn = LOSS_TYPE[self._args['loss']](loss_fingerprints, dist_fn) 391 | 392 | train_metrics = MetricTracker('train', { 393 | 'all': top_k_acc(all_fingerprints, dist_fn, [1,8,64], pre='all') 394 | }) 395 | val_metrics = MetricTracker('val', { 396 | 'all': top_k_acc(all_fingerprints, dist_fn, [1,8,64], pre='all'), 397 | # 'val': top_k_acc(val_fingerprints, dist_fn, [1,5,10,50,100], pre='val'), 398 | }) 399 | 400 | best_loss = None 401 | 402 | for epoch in range(self._args['num_epochs']): 403 | self._models['voxel'].train() 404 | train_pbar = tqdm.tqdm( 405 | range(steps_per_epoch), 406 | desc='Train (epoch %d)' % epoch 407 | ) 408 | for step in train_pbar: 409 | torch_grid, examples = get_batch( 410 | train_dat, 411 | batch_size=self._args['batch_size'], 412 | batch_set=None, 413 | width=self._args['grid_width'], 414 | res=self._args['grid_res'], 415 | ignore_receptor=self._args['ignore_receptor'], 416 | ignore_parent=self._args['ignore_parent'], 417 | point_radius=self._args['point_radius'], 418 | point_type=self._args['point_type'], 419 | acc_type=self._args['acc_type'] 420 | ) 421 | 422 | smiles = [example['smiles'] for example in examples] 423 | correct_fp = torch.Tensor( 424 | fingerprints.for_smiles(smiles)).cuda() 425 | 426 | predicted_fp = self._models['voxel'](torch_grid) 427 | 428 | loss = loss_fn(predicted_fp, correct_fp) 429 | 430 | opt.zero_grad() 431 | loss.backward() 432 | opt.step() 433 | 434 | train_metrics.update('loss', loss) 435 | train_metrics.evaluate(predicted_fp, correct_fp) 436 | self._log.log(train_metrics.get_all()) 437 | train_metrics.clear() 438 | 439 | self._models['voxel'].eval() 440 | 441 | val_pbar = tqdm.tqdm( 442 | range(self._args['test_steps']), 443 | desc='Val %d' % epoch 444 | ) 445 | with torch.no_grad(): 446 | for step in val_pbar: 447 | torch_grid, examples = get_batch( 448 | val_dat, 449 | batch_size=self._args['batch_size'], 450 | batch_set=None, 451 | width=self._args['grid_width'], 452 | res=self._args['grid_res'], 453 | ignore_receptor=self._args['ignore_receptor'], 454 | ignore_parent=self._args['ignore_parent'], 455 | point_radius=self._args['point_radius'], 456 | point_type=self._args['point_type'], 457 | acc_type=self._args['acc_type'] 458 | ) 459 | 460 | smiles = [example['smiles'] for example in examples] 461 | correct_fp = torch.Tensor( 462 | fingerprints.for_smiles(smiles)).cuda() 463 | 464 | predicted_fp = self._models['voxel'](torch_grid) 465 | 466 | loss = loss_fn(predicted_fp, correct_fp) 467 | 468 | val_metrics.update('loss', loss) 469 | val_metrics.evaluate(predicted_fp, correct_fp) 470 | 471 | if checkpoint_callback: 472 | checkpoint_callback(self, epoch) 473 | 474 | val_metrics.normalize(self._args['test_steps']) 475 | self._log.log(val_metrics.get_all()) 476 | 477 | val_loss = val_metrics.get('loss') 478 | if best_loss is None or val_loss < best_loss: 479 | # save new best model 480 | best_loss = val_loss 481 | print('[*] New best loss: %f' % best_loss, flush=True) 482 | if save_path: 483 | self.save(save_path) 484 | 485 | val_metrics.clear() 486 | 487 | def run_test(self, save_path, use_val=False): 488 | # load test dataset 489 | test_dat = FragmentDataset( 490 | self._args['fragments'], 491 | rec_typer=REC_TYPER[self._args['rec_typer']], 492 | lig_typer=LIG_TYPER[self._args['lig_typer']], 493 | # filter_rec=partitions.TEST, 494 | filter_rec=set(get_bios(moad_partitions.VAL if use_val else moad_partitions.TEST)), 495 | filter_smi=set(moad_partitions.VAL_SMI if use_val else moad_partitions.TEST_SMI), 496 | fdist_min=self._args['fdist_min'], 497 | fdist_max=self._args['fdist_max'], 498 | fmass_min=self._args['fmass_min'], 499 | fmass_max=self._args['fmass_max'], 500 | verbose=True 501 | ) 502 | 503 | fingerprints = FingerprintDataset(self._args['fingerprints']) 504 | 505 | self._models['voxel'].eval() 506 | 507 | predicted_fp = np.zeros(( 508 | len(test_dat), 509 | self._args['samples_per_example'], 510 | self._args['output_size'])) 511 | 512 | smiles = [test_dat[i]['smiles'] for i in range(len(test_dat))] 513 | correct_fp = fingerprints.for_smiles(smiles).numpy() 514 | 515 | # (example_idx, sample_idx) 516 | queries = [] 517 | for i in range(len(test_dat)): 518 | queries += [(i,x) for x in range(self._args['samples_per_example'])] 519 | 520 | # run inference 521 | pbar = tqdm.tqdm( 522 | range(0, len(queries), self._args['batch_size']), desc='Inference') 523 | for i in pbar: 524 | batch = queries[i:i+self._args['batch_size']] 525 | 526 | torch_grid, examples = get_batch( 527 | test_dat, 528 | batch_size=self._args['batch_size'], 529 | batch_set=[x[0] for x in batch], 530 | width=self._args['grid_width'], 531 | res=self._args['grid_res'], 532 | ignore_receptor=self._args['ignore_receptor'], 533 | ignore_parent=self._args['ignore_parent'], 534 | point_radius=self._args['point_radius'], 535 | point_type=self._args['point_type'], 536 | acc_type=self._args['acc_type'] 537 | ) 538 | 539 | predicted = self._models['voxel'](torch_grid) 540 | 541 | for j in range(len(batch)): 542 | example_idx, sample_idx = batch[j] 543 | predicted_fp[example_idx][sample_idx] = predicted[j].detach().cpu().numpy() 544 | 545 | if use_val: 546 | np.save(os.path.join(save_path, 'val_predicted_fp.npy'), predicted_fp) 547 | np.save(os.path.join(save_path, 'val_correct_fp.npy'), correct_fp) 548 | else: 549 | np.save(os.path.join(save_path, 'predicted_fp.npy'), predicted_fp) 550 | np.save(os.path.join(save_path, 'correct_fp.npy'), correct_fp) 551 | 552 | print('done.') 553 | 554 | def predict(self, batch): 555 | with torch.no_grad(): 556 | pred = self._models['voxel'](batch) 557 | return pred 558 | 559 | 560 | MODELS = { 561 | 'voxelnet': VoxelNet 562 | } 563 | -------------------------------------------------------------------------------- /leadopt/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | -------------------------------------------------------------------------------- /leadopt/models/backport.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | from torch.nn import Module 17 | from typing import Tuple, Union 18 | from torch import Tensor 19 | 20 | ''' 21 | On some older versions of Cuda, we may need to use an early version of PyTorch 22 | that doesn't have the Flatten layer builtin. 23 | ''' 24 | 25 | class Flatten(Module): 26 | __constants__ = ['start_dim', 'end_dim'] 27 | start_dim: int 28 | end_dim: int 29 | 30 | def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: 31 | super(Flatten, self).__init__() 32 | self.start_dim = start_dim 33 | self.end_dim = end_dim 34 | 35 | def forward(self, input: Tensor) -> Tensor: 36 | return input.flatten(self.start_dim, self.end_dim) 37 | 38 | def extra_repr(self) -> str: 39 | return 'start_dim={}, end_dim={}'.format( 40 | self.start_dim, self.end_dim 41 | ) 42 | -------------------------------------------------------------------------------- /leadopt/models/voxel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | Flatten = None 21 | try: 22 | Flatten = nn.Flatten 23 | except: 24 | from . import backport 25 | Flatten = backport.Flatten 26 | 27 | class VoxelFingerprintNet(nn.Module): 28 | def __init__(self, in_channels, output_size, blocks=[32,64], fc=[2048], pad=True): 29 | super(VoxelFingerprintNet, self).__init__() 30 | 31 | blocks = list(blocks) 32 | fc = list(fc) 33 | 34 | self.blocks = nn.ModuleList() 35 | prev = in_channels 36 | for i in range(len(blocks)): 37 | b = blocks[i] 38 | parts = [] 39 | parts += [ 40 | nn.BatchNorm3d(prev), 41 | nn.Conv3d(prev, b, (3,3,3), padding=(1 if pad else 0)), 42 | nn.ReLU(), 43 | nn.Conv3d(b, b, (3,3,3), padding=(1 if pad else 0)), 44 | nn.ReLU(), 45 | nn.Conv3d(b, b, (3,3,3), padding=(1 if pad else 0)), 46 | nn.ReLU() 47 | ] 48 | if i != len(blocks)-1: 49 | parts += [ 50 | nn.MaxPool3d((2,2,2)) 51 | ] 52 | 53 | self.blocks.append(nn.Sequential(*parts)) 54 | prev = b 55 | 56 | self.reduce = nn.Sequential( 57 | nn.AdaptiveAvgPool3d((1,1,1)), 58 | Flatten(), 59 | ) 60 | 61 | pred = [] 62 | prev = blocks[-1] 63 | for f in fc + [output_size]: 64 | pred += [ 65 | nn.ReLU(), 66 | nn.Dropout(), 67 | nn.Linear(prev, f) 68 | ] 69 | prev = f 70 | 71 | self.pred = nn.Sequential(*pred) 72 | self.norm = nn.Sigmoid() 73 | 74 | def forward(self, x): 75 | for b in self.blocks: 76 | x = b(x) 77 | x = self.reduce(x) 78 | x = self.pred(x) 79 | x = self.norm(x) 80 | 81 | return x 82 | -------------------------------------------------------------------------------- /leadopt/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | ''' 17 | rdkit/openbabel utility scripts 18 | ''' 19 | import numpy as np 20 | 21 | 22 | # try: 23 | # import pybel 24 | # except: 25 | # from openbabel import pybel 26 | 27 | from rdkit import Chem 28 | 29 | 30 | def get_coords(mol): 31 | """Returns an array of atom coordinates from an rdkit mol.""" 32 | conf = mol.GetConformer() 33 | coords = np.array([conf.GetAtomPosition(i) for i in range(conf.GetNumAtoms())]) 34 | return coords 35 | 36 | 37 | def get_types(mol): 38 | """Returns an array of atomic numbers from an rdkit mol.""" 39 | return [mol.GetAtomWithIdx(i).GetAtomicNum() for i in range(mol.GetNumAtoms())] 40 | 41 | 42 | def combine_all(frags): 43 | """Combines a list of rdkit mols.""" 44 | if len(frags) == 0: 45 | return None 46 | 47 | c = frags[0] 48 | for f in frags[1:]: 49 | c = Chem.CombineMols(c,f) 50 | 51 | return c 52 | 53 | 54 | def generate_fragments(mol, max_heavy_atoms=0, only_single_bonds=True): 55 | """Takes an rdkit molecule and returns a list of (parent, fragment) tuples. 56 | 57 | Args: 58 | mol: The molecule to fragment. 59 | max_heavy_atoms: The maximum number of heavy atoms to include 60 | in generated fragments. 61 | nly_single_bonds: If set to true, this method will only return 62 | fragments generated by breaking single bonds. 63 | 64 | Returns: 65 | A list of (parent, fragment) tuples where mol is larger than fragment. 66 | """ 67 | # list of (parent, fragment) tuples 68 | splits = [] 69 | 70 | # if we have multiple ligands already, split into pieces and then iterate 71 | ligands = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) 72 | 73 | for i in range(len(ligands)): 74 | lig = ligands[i] 75 | other = list(ligands[:i] + ligands[i+1:]) 76 | 77 | # iterate over bonds 78 | for i in range(lig.GetNumBonds()): 79 | # (optional) filter base on single bonds 80 | if only_single_bonds and lig.GetBondWithIdx(i).GetBondType() != Chem.rdchem.BondType.SINGLE: 81 | continue 82 | 83 | # split the molecule 84 | split_mol = Chem.rdmolops.FragmentOnBonds(lig, [i]) 85 | 86 | # obtain fragments 87 | fragments = Chem.GetMolFrags(split_mol, asMols=True, sanitizeFrags=False) 88 | 89 | # skip if this did not break the molecule into two pieces 90 | if len(fragments) != 2: 91 | continue 92 | 93 | # otherwise make sure the first fragment is larger 94 | if fragments[0].GetNumAtoms() < fragments[1].GetNumAtoms(): 95 | fragments = fragments[::-1] 96 | 97 | # make sure the fragment has at least one heavy atom 98 | if fragments[1].GetNumHeavyAtoms() == 0: 99 | continue 100 | 101 | # (optional) filter based on number of heavy atoms in the fragment 102 | if max_heavy_atoms > 0 and fragments[1].GetNumHeavyAtoms() > max_heavy_atoms: 103 | continue 104 | 105 | # if we have other ligands present, merge them with the parent 106 | parent = fragments[0] 107 | 108 | if len(other) > 0: 109 | parent = combine_all([parent] + other) 110 | 111 | # add this pair 112 | splits.append((parent, fragments[1])) 113 | 114 | return splits 115 | 116 | 117 | def load_ligand(sdf): 118 | """Loads a ligand from an sdf file and fragments it. 119 | 120 | Args: 121 | sdf: Path to sdf file containing a ligand. 122 | """ 123 | lig = next(Chem.SDMolSupplier(sdf, sanitize=False)) 124 | frags = generate_fragments(lig) 125 | 126 | return lig, frags 127 | 128 | 129 | def load_ligands_pdb(pdb): 130 | """Load multiple ligands from a pdb file. 131 | 132 | Args: 133 | pdb: Path to pdb file containing a ligand. 134 | """ 135 | lig_mult = Chem.MolFromPDBFile(pdb) 136 | ligands = Chem.GetMolFrags(lig_mult, asMols=True, sanitizeFrags=True) 137 | 138 | return ligands 139 | 140 | 141 | def remove_water(m): 142 | """Removes water molecules from an rdkit mol.""" 143 | parts = Chem.GetMolFrags(m, asMols=True, sanitizeFrags=False) 144 | valid = [k for k in parts if not Chem.MolToSmiles(k, allHsExplicit=True) == '[OH2]'] 145 | 146 | assert len(valid) > 0, 'error: molecule contains only water' 147 | 148 | merged = valid[0] 149 | for part in valid[1:]: 150 | merged = Chem.CombineMols(merged, part) 151 | 152 | return merged 153 | 154 | 155 | def load_receptor(rec_path): 156 | """Loads a receptor from a pdb file and retrieves atomic information. 157 | 158 | Args: 159 | rec_path: Path to a pdb file. 160 | """ 161 | rec = Chem.MolFromPDBFile(rec_path, sanitize=False) 162 | rec = remove_water(rec) 163 | 164 | return rec 165 | 166 | 167 | # def load_receptor_ob(rec_path): 168 | # rec = next(pybel.readfile('pdb', rec_path)) 169 | # valid = [r for r in rec.residues if r.name != 'HOH'] 170 | 171 | # # map partial charge into byte range 172 | # def conv_charge(x): 173 | # x = max(x,-0.5) 174 | # x = min(x,0.5) 175 | # x += 0.5 176 | # x *= 255 177 | # x = int(x) 178 | # return x 179 | 180 | # coords = [] 181 | # types = [] 182 | # for v in valid: 183 | # coords += [k.coords for k in v.atoms] 184 | # types += [( 185 | # k.atomicnum, 186 | # int(k.OBAtom.IsAromatic()), 187 | # int(k.OBAtom.IsHbondDonor()), 188 | # int(k.OBAtom.IsHbondAcceptor()), 189 | # conv_charge(k.OBAtom.GetPartialCharge()) 190 | # ) for k in v.atoms] 191 | 192 | # return np.array(coords), np.array(types) 193 | 194 | 195 | def load_receptor_ob(rec_path): 196 | rec = load_receptor(rec_path) 197 | 198 | coords = get_coords(rec) 199 | types = np.array(get_types(rec)) 200 | types = np.concatenate([ 201 | types.reshape(-1,1), 202 | np.zeros((len(types), 4)) 203 | ], 1) 204 | 205 | return coords, types 206 | 207 | 208 | def get_connection_point(frag): 209 | '''return the coordinates of the dummy atom as a numpy array [x,y,z]''' 210 | dummy_idx = get_types(frag).index(0) 211 | coords = get_coords(frag)[dummy_idx] 212 | 213 | return coords 214 | 215 | 216 | def frag_dist_to_receptor(rec, frag): 217 | '''compute the minimum distance between the fragment connection point any receptor atom''' 218 | rec_coords = rec.GetConformer().GetPositions() 219 | conn = get_connection_point(frag) 220 | 221 | dist = np.sum((rec_coords - conn) ** 2, axis=1) 222 | min_dist = np.sqrt(np.min(dist)) 223 | 224 | return min_dist 225 | 226 | 227 | def frag_dist_to_receptor_raw(coords, frag): 228 | '''compute the minimum distance between the fragment connection point any receptor atom''' 229 | rec_coords = np.array(coords) 230 | conn = get_connection_point(frag) 231 | 232 | dist = np.sum((rec_coords - conn) ** 2, axis=1) 233 | min_dist = np.sqrt(np.min(dist)) 234 | 235 | return min_dist 236 | 237 | 238 | def mol_array(mol): 239 | '''convert an rdkit mol to an array of coordinates and atom types''' 240 | coords = get_coords(mol) 241 | types = np.array(get_types(mol)).reshape(-1,1) 242 | 243 | arr = np.concatenate([coords, types], axis=1) 244 | 245 | return arr 246 | 247 | 248 | def desc_mol_array(mol, atom_fn): 249 | '''user-defined atomic mapping function''' 250 | coords = get_coords(mol) 251 | atoms = list(mol.GetAtoms()) 252 | types = np.array([atom_fn(x) for x in atoms]).reshape(-1,1) 253 | 254 | arr = np.concatenate([coords, types], axis=1) 255 | 256 | return arr 257 | 258 | 259 | def desc_mol_array_ob(atoms, atom_fn): 260 | coords = np.array([k[0] for k in atoms]) 261 | types = np.array([atom_fn(k[1]) for k in atoms]).reshape(-1,1) 262 | 263 | # arr = np.concatenate([coords, types], axis=1) 264 | 265 | return coords, types 266 | 267 | 268 | def mol_to_points(mol, atom_types=[6,7,8,9,15,16,17,35,53]): 269 | '''convert an rdkit mol to an array of coordinates and layers''' 270 | coords = get_coords(mol) 271 | 272 | types = get_types(mol) 273 | layers = np.array([(atom_types.index(k) if k in atom_types else -1) for k in types]) 274 | 275 | # filter by existing layer 276 | coords = coords[layers != -1] 277 | layers = layers[layers != -1].reshape(-1,1) 278 | 279 | return coords, layers 280 | 281 | 282 | def merge_smiles(sma, smb): 283 | '''merge two smile frament strings by combining at the dummy connection point''' 284 | a = Chem.MolFromSmiles(sma, sanitize=False) 285 | b = Chem.MolFromSmiles(smb, sanitize=False) 286 | 287 | # merge molecules 288 | c = Chem.CombineMols(a,b) 289 | 290 | # find dummy atoms 291 | da,db = np.where(np.array([k.GetAtomicNum() for k in c.GetAtoms()]) == 0)[0] 292 | 293 | # find neighbors to connect 294 | na = c.GetAtomWithIdx(int(da)).GetNeighbors()[0].GetIdx() 295 | nb = c.GetAtomWithIdx(int(db)).GetNeighbors()[0].GetIdx() 296 | 297 | e = Chem.EditableMol(c) 298 | for d in sorted([da,db])[::-1]: 299 | e.RemoveAtom(int(d)) 300 | 301 | # adjust atom indexes 302 | na -= int(da < na) + int(db < na) 303 | nb -= int(da < nb) + int(db < nb) 304 | 305 | e.AddBond(na,nb,Chem.rdchem.BondType.SINGLE) 306 | 307 | r = e.GetMol() 308 | 309 | sm = Chem.MolToSmiles(Chem.RemoveHs(r, sanitize=False), isomericSmiles=False) 310 | 311 | return sm 312 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.6.20 2 | chardet==3.0.4 3 | click==7.1.2 4 | configparser==5.0.0 5 | docker-pycreds==0.4.0 6 | future==0.18.2 7 | gitdb==4.0.5 8 | GitPython==3.1.7 9 | gql==0.2.0 10 | graphql-core==1.1 11 | h5py==3.10.0 12 | idna==2.10 13 | llvmlite==0.41.1 14 | numba==0.58.1 15 | numpy==1.24.4 16 | nvidia-ml-py3==7.352.0 17 | pathtools==0.1.2 18 | Pillow==7.2.0 19 | promise==2.3 20 | psutil==5.7.2 21 | python-dateutil==2.8.1 22 | PyYAML==5.3.1 23 | requests==2.24.0 24 | sentry-sdk==0.16.3 25 | shortuuid==1.0.1 26 | six==1.15.0 27 | smmap==3.0.4 28 | subprocess32==3.5.4 29 | tqdm==4.48.2 30 | urllib3==1.25.10 31 | wandb==0.9.4 32 | watchdog==0.10.3 33 | rdkit==2023.09.2 34 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Data processing scripts 3 | 4 | ## `make_fingerprints.py` 5 | 6 | Utility script to generate fingerprints for a set of smiles strings. By precomputing the fingerprints for all the fragments in our dataset, we can speed up training. 7 | 8 | To use, pass in the `moad.h5` file and specify the fingerprint type and output path. 9 | 10 | Supported fingerprints: 11 | - `rdk`: RDKFingerprint (2048 bits) 12 | - `rdk10`: RDKFingerprint (path size 10) (2048 bits) 13 | - `morgan`: Mogan fingerprint (r=2) (2048 bits) 14 | - `gobbi2d`: Gobbi 2d pharmophocore fingerprint (folded to 2048 bits) 15 | 16 | Usage: 17 | 18 | ``` 19 | usage: make_fingerprints.py [-h] -f FRAGMENTS -fp {rdk,rdk10,morgan,gobbi2d} 20 | [-o OUTPUT] 21 | 22 | optional arguments: 23 | -h, --help show this help message and exit 24 | -f FRAGMENTS, --fragments FRAGMENTS 25 | Path to fragemnts.h5 containing "frag_smiles" array 26 | -fp {rdk,rdk10,morgan,gobbi2d}, --fingerprint {rdk,rdk10,morgan,gobbi2d} 27 | Which fingerprint type to generate 28 | -o OUTPUT, --output OUTPUT 29 | Output file path (.h5) 30 | ``` 31 | 32 | ## MOAD Dataset 33 | 34 | For instructions on working with MOAD data, see [`README_MOAD.md`](./README_MOAD.md). 35 | -------------------------------------------------------------------------------- /scripts/README_MOAD.md: -------------------------------------------------------------------------------- 1 | 2 | # MOAD Data 3 | 4 | This readme describes how to process raw MOAD data for use in training the DeepFrag model. Note that we already provide fully processed datasets for 5 | training so this section mostly serves as a description of how to process future versions of the MOAD dataset or as an example for researchers looking 6 | to accomplish similar things. 7 | 8 | See [`data/README.md`](../data/README.md) for instructions on how to download the packed `moad.h5` data. 9 | 10 | See [`config/moad_partitions.py`](../config/moad_partitions.py) for a pre-computed MOAD TRAIN/VAL/TEST split (generated with seed 7). 11 | 12 | # 1. Download MOAD datasets 13 | 14 | Go to https://bindingmoad.org/Home/download and download `every_part_a.zip` and `every_part_b.zip` and "Binding data" (`every.csv`). 15 | 16 | This readme assumes these files are stored in `$MOAD_DIR`. 17 | 18 | # 2. Unpack MOAD datasets 19 | 20 | Run: 21 | 22 | ```sh 23 | $ unzip every_part_a.zip 24 | ... 25 | $ unzip every_part_b.zip 26 | ... 27 | ``` 28 | 29 | # 3. Process MOAD pdb files 30 | 31 | The MOAD dataset contains ligand/receptor structures combined in a single pdb file (named with a `.bio` extension). In this step, we will separate the receptor and each ligand into individual files. 32 | 33 | Run: 34 | 35 | ```sh 36 | $ cd $MOAD_DIR && mkdir split 37 | $ python3 scripts/split_moad.py \ 38 | -d $MOAD_DIR/BindingMOAD_2020 \ 39 | -c $MOAD_DIR/every.csv \ 40 | -o $MOAD_DIR/split \ 41 | -n 42 | ``` 43 | 44 | # 4. Generate packed data files. 45 | 46 | For training purposes, we pack all of the relevant information into an h5 file so we can load it entirely in memory during training. 47 | 48 | This step will produce several similar `.h5` files that can be combined later. 49 | 50 | Run: 51 | 52 | ```sh 53 | $ cd $MOAD_DIR && mkdir packed 54 | $ python3 scripts/process_moad.py \ 55 | -d $MOAD_DIR/split \ 56 | -c $MOAD_DIR/every.csv \ 57 | -o $MOAD_DIR/packed/moad.h5 \ 58 | -n \ 59 | -s 60 | ``` 61 | 62 | # 5. Merge packed data files. 63 | 64 | ```sh 65 | $ python3 scripts/merge_moad.py \ 66 | -i $MOAD_DIR/packed \ 67 | -o moad.h5 68 | ``` 69 | 70 | # 6. Generate MOAD Training splits 71 | 72 | ```sh 73 | $ python3 scripts/moad_training_splits.py \ 74 | -c $MOAD_DIR/every.csv \ 75 | -s 7 \ 76 | -o moad_partitions.py 77 | ``` 78 | -------------------------------------------------------------------------------- /scripts/make_fingerprints.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | ''' 17 | Utility script to generate fingerprints for smiles strings 18 | ''' 19 | import argparse 20 | 21 | import h5py 22 | import numpy as np 23 | import tqdm 24 | 25 | from rdkit.Chem import rdMolDescriptors 26 | import rdkit.Chem.AllChem as Chem 27 | from rdkit.Chem.Pharm2D import Generate, Gobbi_Pharm2D 28 | 29 | 30 | def fold_to(bv, size=2048): 31 | '''fold a SparseBitVec to a certain length''' 32 | fp = np.zeros(size) 33 | 34 | for b in list(bv.GetOnBits()): 35 | fp[b % size] = 1 36 | 37 | return fp 38 | 39 | 40 | def rdkfingerprint(m): 41 | '''rdkfingerprint as 2048-len bit array''' 42 | fp = Chem.rdmolops.RDKFingerprint(m) 43 | n_fp = list(map(int, list(fp.ToBitString()))) 44 | return n_fp 45 | 46 | 47 | def rdkfingerprint10(m): 48 | '''rdkfingerprint as 2048-len bit array (maxPath=10)''' 49 | fp = Chem.rdmolops.RDKFingerprint(m, maxPath=10) 50 | n_fp = list(map(int, list(fp.ToBitString()))) 51 | return n_fp 52 | 53 | 54 | def morganfingerprint(m): 55 | '''morgan fingerprint as 2048-len bit array''' 56 | m.UpdatePropertyCache(strict=False) 57 | Chem.rdmolops.FastFindRings(m) 58 | fp = rdMolDescriptors.GetMorganFingerprintAsBitVect(m, 2) 59 | n_fp = list(map(int, list(fp.ToBitString()))) 60 | return n_fp 61 | 62 | 63 | def gobbi2d(m): 64 | '''gobbi 2d pharmacophore as 2048-len bit array''' 65 | m.UpdatePropertyCache(strict=False) 66 | Chem.rdmolops.FastFindRings(m) 67 | bv = Generate.Gen2DFingerprint(m, Gobbi_Pharm2D.factory) 68 | n_fp = fold_to(bv, size=2048) 69 | return n_fp 70 | 71 | 72 | FINGERPRINTS = { 73 | 'rdk': (rdkfingerprint, 2048), 74 | 'rdk10': (rdkfingerprint10, 2048), 75 | 'morgan': (morganfingerprint, 2048), 76 | 'gobbi2d': (gobbi2d, 2048), 77 | } 78 | 79 | 80 | def process(fragments_path, fp_func, fp_size, out_path): 81 | # open fragments file 82 | f = h5py.File(fragments_path, 'r') 83 | smiles = f['frag_smiles'][()] 84 | f.close() 85 | 86 | # deduplicate smiles strings 87 | all_smiles = list(set(smiles)) 88 | n_smiles = np.array(all_smiles) 89 | 90 | n_fingerprints = np.zeros((len(all_smiles), fp_size)) 91 | 92 | for i in tqdm.tqdm(range(len(all_smiles))): 93 | # generate fingerprint 94 | m = Chem.MolFromSmiles(all_smiles[i].decode('ascii'), sanitize=False) 95 | n_fingerprints[i] = fp_func(m) 96 | 97 | # save 98 | with h5py.File(out_path, 'w') as f: 99 | f['fingerprints'] = n_fingerprints 100 | f['smiles'] = n_smiles 101 | 102 | print('Done!') 103 | 104 | 105 | def main(): 106 | parser = argparse.ArgumentParser() 107 | 108 | parser.add_argument('-f', '--fragments', required=True, help='Path to fragemnts.h5 containing "frag_smiles" array') 109 | parser.add_argument('-fp', '--fingerprint', required=True, choices=[k for k in FINGERPRINTS], help='Which fingerprint type to generate') 110 | parser.add_argument('-o', '--output', default='fingerprints.h5', help='Output file path (.h5)') 111 | 112 | args = parser.parse_args() 113 | 114 | fn, size = FINGERPRINTS[args.fingerprint] 115 | process(args.fragments, fn, size, out_path=args.output) 116 | 117 | 118 | if __name__=='__main__': 119 | main() 120 | -------------------------------------------------------------------------------- /scripts/merge_moad.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | import argparse 17 | import os 18 | 19 | import h5py 20 | import numpy as np 21 | 22 | 23 | def unpack(path): 24 | f = h5py.File(path, 'r') 25 | 26 | dat = {} 27 | for k in f.keys(): 28 | dat[k] = f[k][()] 29 | 30 | f.close() 31 | 32 | return dat 33 | 34 | def append(dat, other): 35 | frag_coord_off = dat['frag_data'].shape[0] 36 | frag_lig_smi_off = dat['frag_lig_smi'].shape[0] 37 | rec_coord_off = dat['rec_coords'].shape[0] 38 | 39 | # Update fragment coords. 40 | other['frag_lookup']['f1'] += frag_coord_off 41 | other['frag_lookup']['f2'] += frag_coord_off 42 | other['frag_lookup']['f3'] += frag_coord_off 43 | other['frag_lookup']['f4'] += frag_coord_off 44 | 45 | # Update receptor coords. 46 | other['rec_lookup']['f1'] += rec_coord_off 47 | other['rec_lookup']['f2'] += rec_coord_off 48 | 49 | # Update ligand index. 50 | other['frag_lig_idx'] += frag_lig_smi_off 51 | 52 | # Concatenate everything. 53 | for k in dat: 54 | dat[k] = np.concatenate((dat[k], other[k]), axis=0) 55 | 56 | def cat_all(paths): 57 | dat = unpack(paths[0]) 58 | 59 | for i in range(1, len(paths)): 60 | append(dat, unpack(paths[i])) 61 | 62 | return dat 63 | 64 | def main(): 65 | parser = argparse.ArgumentParser() 66 | 67 | parser.add_argument('-i', '--input', required=True, help='Path to folder containing intermediate .h5 fragments') 68 | parser.add_argument('-o', '--output', default='moad.h5', help='Output file path (.h5)') 69 | 70 | args = parser.parse_args() 71 | 72 | inp = args.input 73 | paths = [x for x in os.listdir(inp) if x.endswith('.h5')] 74 | paths = [os.path.join(inp, x) for x in paths] 75 | 76 | print('Merging:') 77 | for k in paths: 78 | print('- %s' % k) 79 | 80 | full = cat_all(paths) 81 | 82 | f = h5py.File(args.output, 'w') 83 | for k in dat: 84 | f[k] = dat[k] 85 | f.close() 86 | 87 | print('Done!') 88 | 89 | if __name__=='__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /scripts/moad_training_splits.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | import argparse 17 | import os 18 | 19 | import numpy as np 20 | 21 | from moad_util import parse_moad 22 | 23 | 24 | def all_smi(families): 25 | smi = [] 26 | for f in families: 27 | for t in f.targets: 28 | for lig in t.ligands: 29 | smi.append(lig[1]) 30 | return set(smi) 31 | 32 | def do_split(families, pa=0.6): 33 | fa = [] 34 | fb = [] 35 | 36 | for f in families: 37 | if np.random.rand() < pa: 38 | fa.append(f) 39 | else: 40 | fb.append(f) 41 | 42 | return (fa, fb) 43 | 44 | def split_smi(smi): 45 | l = list(smi) 46 | sz = len(l) 47 | 48 | np.random.shuffle(l) 49 | 50 | return (set(l[:sz//2]), set(l[sz//2:])) 51 | 52 | def split_smi3(smi): 53 | l = list(smi) 54 | sz = len(l) 55 | 56 | np.random.shuffle(l) 57 | 58 | v = sz//3 59 | return (set(l[:v]), set(l[v:v*2]), set(l[v*2:])) 60 | 61 | def get_ids(fam): 62 | ids = [] 63 | for f in fam: 64 | for t in f.targets: 65 | ids.append(t.pdb_id) 66 | return ids 67 | 68 | def gen_split(csv, output, seed=7): 69 | np.random.seed(seed) 70 | 71 | moad_families, moad_targets = parse_moad(csv) 72 | 73 | train, other = do_split(moad_families, pa=0.6) 74 | val, test = do_split(other, pa=0.5) 75 | 76 | train_sum = np.sum([len(x.targets) for x in train]) 77 | val_sum = np.sum([len(x.targets) for x in val]) 78 | test_sum = np.sum([len(x.targets) for x in test]) 79 | 80 | print('[Targets] (Train: %d) (Val: %d) (Test %d)' % (train_sum, val_sum, test_sum)) 81 | 82 | train_smi = all_smi(train) 83 | val_smi = all_smi(val) 84 | test_smi = all_smi(test) 85 | 86 | train_smi_uniq = train_smi - (val_smi | test_smi) 87 | val_smi_uniq = val_smi - (train_smi | test_smi) 88 | test_smi_uniq = test_smi - (val_smi | train_smi) 89 | 90 | print('[Unique ligands] (Train: %d) (Val: %d) (Test %d)' % ( 91 | len(train_smi_uniq), len(val_smi_uniq), len(test_smi_uniq))) 92 | 93 | print('[Total unique ligands] %d' % len(train_smi | val_smi | test_smi)) 94 | 95 | split_train_val = (train_smi & val_smi) - test_smi 96 | split_train_test = (train_smi & test_smi) - val_smi 97 | split_val_test = (val_smi & test_smi) - train_smi 98 | 99 | split_all = (train_smi & val_smi & test_smi) 100 | 101 | split_train_val_a, split_train_val_b = split_smi(split_train_val) 102 | split_train_test_a, split_train_test_b = split_smi(split_train_test) 103 | split_val_test_a, split_val_test_b = split_smi(split_val_test) 104 | 105 | split_all_train, split_all_val, split_all_test = split_smi3(split_all) 106 | 107 | train_full = (train_smi_uniq | split_train_val_a | split_train_test_a | split_all_train) 108 | val_full = (val_smi_uniq | split_train_val_b | split_val_test_a | split_all_val) 109 | test_full = (test_smi_uniq | split_train_test_b | split_val_test_b | split_all_test) 110 | 111 | print('[Full ligands] (Train: %d) (Val: %d) (Test %d)' % ( 112 | len(train_full), len(val_full), len(test_full))) 113 | 114 | mixed = (train_full & val_full) | (val_full & test_full) | (train_full & test_full) 115 | 116 | train_ids = sorted(get_ids(train)) 117 | val_ids = sorted(get_ids(val)) 118 | test_ids = sorted(get_ids(test)) 119 | 120 | train_s = sorted(train_full) 121 | val_s = sorted(val_full) 122 | test_s = sorted(test_full) 123 | 124 | # Format as a python file. 125 | out = '' 126 | out += 'TRAIN = ' + repr(train_ids).replace(' ','') + '\n' 127 | out += 'TRAIN_SMI = ' + repr(train_s).replace(' ','') + '\n' 128 | out += 'VAL = ' + repr(val_ids).replace(' ','') + '\n' 129 | out += 'VAL_SMI = ' + repr(val_s).replace(' ','') + '\n' 130 | out += 'TEST = ' + repr(test_ids).replace(' ','') + '\n' 131 | out += 'TEST_SMI = ' + repr(test_s).replace(' ','') + '\n' 132 | 133 | open(output, 'w').write(out) 134 | 135 | def main(): 136 | parser = argparse.ArgumentParser() 137 | 138 | parser.add_argument('-c', '--csv', required=True, help='Path to every.csv file') 139 | parser.add_argument('-s', '--seed', required=False, default=7, type=int, help='Integer seed') 140 | parser.add_argument('-o', '--output', default='moad_partitions.py', help='Output file path (.py)') 141 | 142 | args = parser.parse_args() 143 | 144 | gen_split(args.csv, args.output, args.seed) 145 | print('Done!') 146 | 147 | if __name__=='__main__': 148 | main() 149 | -------------------------------------------------------------------------------- /scripts/moad_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | # MOAD csv parsing utility 17 | 18 | class Family(object): 19 | def __init__(self): 20 | self.targets = [] 21 | 22 | def __repr__(self): 23 | return 'F(%d)' % len(self.targets) 24 | 25 | class Protein(object): 26 | def __init__(self, pdb_id): 27 | # (chain, smi) 28 | self.pdb_id = pdb_id.upper() 29 | self.ligands = [] 30 | 31 | def __repr__(self): 32 | return '%s(%d)' % (self.pdb_id, len(self.ligands)) 33 | 34 | def parse_moad(csv): 35 | csv_dat = open(csv, 'r').read().strip().split('\n') 36 | csv_dat = [x.split(',') for x in csv_dat] 37 | 38 | families = [] 39 | 40 | curr_f = None 41 | curr_t = None 42 | 43 | for line in csv_dat: 44 | if line[0] != '': 45 | # new class 46 | continue 47 | elif line[1] != '': 48 | # new family 49 | if curr_t != None: 50 | curr_f.targets.append(curr_t) 51 | if curr_f != None: 52 | families.append(curr_f) 53 | curr_f = Family() 54 | curr_t = Protein(line[2]) 55 | elif line[2] != '': 56 | # new target 57 | if curr_t != None: 58 | curr_f.targets.append(curr_t) 59 | curr_t = Protein(line[2]) 60 | elif line[3] != '': 61 | # new ligand 62 | if line[4] != 'valid': 63 | continue 64 | curr_t.ligands.append((line[3], line[9])) 65 | 66 | curr_f.targets.append(curr_t) 67 | families.append(curr_f) 68 | 69 | by_target = {} 70 | for f in families: 71 | for t in f.targets: 72 | by_target[t.pdb_id] = t 73 | 74 | return families, by_target 75 | -------------------------------------------------------------------------------- /scripts/process_moad.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | ''' 17 | Utility script to convert the MOAD dataset into a packed format 18 | ''' 19 | import sys 20 | import argparse 21 | import os 22 | import re 23 | import multiprocessing 24 | import threading 25 | 26 | from moad_util import parse_moad 27 | 28 | import rdkit.Chem.AllChem as Chem 29 | from rdkit.Chem.Descriptors import ExactMolWt 30 | import molvs 31 | import h5py 32 | import tqdm 33 | import numpy as np 34 | 35 | # add leadopt to path 36 | sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir)) 37 | 38 | from leadopt import util 39 | 40 | 41 | # Data Format 42 | # 43 | # Receptor data: 44 | # - rec_lookup: [id][start][end] 45 | # - rec_coords: [x][y][z] 46 | # - rec_types: [num][is_hacc][is_hdon][is_aro][pcharge] 47 | # 48 | # Fragment data: 49 | # - frag_lookup: [id][fstart][fend][pstart][pend] 50 | # - frag_lig_id: [lig_id] 51 | # - frag_coords: [x][y][z] 52 | # - frag_types: [num] 53 | # - frag_smiles: [smiles] 54 | # - frag_mass: [mass] 55 | # - frag_dist: [dist] 56 | 57 | LOAD_TIMEOUT = 60 58 | 59 | 60 | u = molvs.charge.Uncharger() 61 | t = molvs.tautomer.TautomerCanonicalizer() 62 | 63 | 64 | REPLACE = [ 65 | ('C(=O)[O-]', 'C(=O)O'), 66 | ('N=[N+]=N', 'N=[N+]=[N-]'), 67 | ('[N+](=O)O', '[N+](=O)[O-]'), 68 | ('S(O)(O)O', '[S+2](O)(O)O'), 69 | ] 70 | 71 | 72 | def basic_replace(sm): 73 | m = Chem.MolFromSmiles(sm, False) 74 | 75 | for a,b in REPLACE: 76 | m = Chem.ReplaceSubstructs( 77 | m, 78 | Chem.MolFromSmiles(a, sanitize=False), 79 | Chem.MolFromSmiles(b, sanitize=False), 80 | replaceAll=True 81 | )[0] 82 | 83 | return Chem.MolToSmiles(m) 84 | 85 | 86 | def neutralize_smiles(sm): 87 | m = Chem.MolFromSmiles(sm) 88 | m = u.uncharge(m) 89 | m = t.canonicalize(m) 90 | sm = Chem.MolToSmiles(m) 91 | sm = basic_replace(sm) 92 | 93 | try: 94 | return molvs.standardize_smiles(sm) 95 | except: 96 | print(sm) 97 | return sm 98 | 99 | 100 | def load_example(base, rec_id, target): 101 | 102 | rec_path = os.path.join(base, '%s_rec.pdb' % rec_id) 103 | 104 | # Load receptor data. 105 | rec_coords, rec_types = util.load_receptor_ob(rec_path) 106 | 107 | # (frag_data, parent_data, smiles, mass, dist, lig_off) 108 | fragments = [] 109 | 110 | # (smi) 111 | lig_smiles = [] 112 | 113 | 114 | lig_off = 0 115 | for lig in target.ligands: 116 | 117 | lig_path = os.path.join(base, '%s_%s.pdb' % (rec_id, lig[0].replace(' ','_'))) 118 | try: 119 | lig_mol = Chem.MolFromPDBFile(lig_path, True) 120 | except: 121 | continue 122 | 123 | if lig_mol is None: 124 | continue 125 | 126 | lig_smi = lig[1] 127 | lig_smiles.append(lig_smi) 128 | 129 | ref = Chem.MolFromSmiles(lig_smi) 130 | lig_fixed = Chem.AssignBondOrdersFromTemplate(ref, lig_mol) 131 | 132 | splits = util.generate_fragments(lig_fixed) 133 | 134 | for parent, frag in splits: 135 | frag_data = util.mol_array(frag) 136 | parent_data = util.mol_array(parent) 137 | 138 | frag_smi = Chem.MolToSmiles( 139 | frag, 140 | isomericSmiles=False, 141 | kekuleSmiles=False, 142 | canonical=True, 143 | allHsExplicit=False 144 | ) 145 | 146 | frag_smi = neutralize_smiles(frag_smi) 147 | 148 | frag.UpdatePropertyCache(strict=False) 149 | mass = ExactMolWt(frag) 150 | 151 | dist = util.frag_dist_to_receptor_raw(rec_coords, frag) 152 | 153 | fragments.append((frag_data, parent_data, frag_smi, mass, dist, lig_off)) 154 | 155 | lig_off += 1 156 | 157 | return (rec_coords, rec_types, fragments, lig_smiles) 158 | 159 | def do_thread(out, args): 160 | try: 161 | out[0] = load_example(*args) 162 | except: 163 | out[0] = None 164 | 165 | def multi_load(packed): 166 | out = [None] 167 | 168 | t = threading.Thread(target=do_thread, args=(out, packed)) 169 | t.start() 170 | t.join(timeout=LOAD_TIMEOUT) 171 | 172 | if t.is_alive(): 173 | print('timeout', packed[1]) 174 | 175 | return (packed, out[0]) 176 | 177 | 178 | def process(work, processed, moad_csv, out_path='moad.h5', num_cores=1): 179 | '''Process MOAD data and save to a packed format. 180 | 181 | Args: 182 | - out_path: where to save the .h5 packed data 183 | ''' 184 | rec_lookup = [] # (id, start, end) 185 | rec_coords = [] # (x,y,z) 186 | rec_types = [] # (num, aro, hdon, hacc, pcharge) 187 | 188 | frag_lookup = [] # (id, f_start, f_end, p_start, p_end) 189 | frag_lig_idx = [] # (lig_idx) 190 | frag_lig_smi = [] # (lig_smi) 191 | frag_data = [] # (x,y,z,type) 192 | frag_smiles = [] # (frag_smi) 193 | frag_mass = [] # (mass) 194 | frag_dist = [] # (dist) 195 | 196 | # Data pointers. 197 | rec_i = 0 198 | frag_i = 0 199 | 200 | # Multiprocess. 201 | with multiprocessing.Pool(num_cores) as p: 202 | with tqdm.tqdm(total=len(work)) as pbar: 203 | for w, res in p.imap_unordered(multi_load, work): 204 | pbar.update() 205 | 206 | if res == None: 207 | print('[!] Failed: %s' % w[1]) 208 | continue 209 | 210 | rcoords, rtypes, fragments, ex_lig_smiles = res 211 | 212 | if len(fragments) == 0: 213 | print('Empty', w[1]) 214 | continue 215 | 216 | rec_id = w[1] 217 | 218 | # Add receptor info. 219 | rec_start = rec_i 220 | rec_end = rec_i + rcoords.shape[0] 221 | rec_i += rcoords.shape[0] 222 | 223 | rec_coords.append(rcoords) 224 | rec_types.append(rtypes) 225 | rec_lookup.append((rec_id.encode('ascii'), rec_start, rec_end)) 226 | 227 | lig_idx = len(frag_lig_smi) 228 | 229 | # Add fragment info. 230 | for fdat, pdat, frag_smi, mass, dist, lig_off in fragments: 231 | frag_start = frag_i 232 | frag_end = frag_i + fdat.shape[0] 233 | frag_i += fdat.shape[0] 234 | 235 | parent_start = frag_i 236 | parent_end = frag_i + pdat.shape[0] 237 | frag_i += pdat.shape[0] 238 | 239 | frag_data.append(fdat) 240 | frag_data.append(pdat) 241 | 242 | frag_lookup.append((rec_id.encode('ascii'), frag_start, frag_end, parent_start, parent_end)) 243 | frag_lig_idx.append(lig_idx+lig_off) 244 | frag_smiles.append(frag_smi) 245 | frag_mass.append(mass) 246 | frag_dist.append(dist) 247 | 248 | # Add ligand smiles. 249 | frag_lig_smi += ex_lig_smiles 250 | 251 | # Convert to numpy format. 252 | print('Convert numpy...', flush=True) 253 | n_rec_lookup = np.array(rec_lookup, dtype=' 3 and lig_resnum == '1': 41 | # assume peptide, take the whole chain 42 | sel = m.select('chain %s' % lig_chain) 43 | else: 44 | # assume small molecule, single residue 45 | sel = m.select('chain %s and resnum = %s' % (lig_chain, lig_resnum)) 46 | 47 | ligands.append((lig[0], sel)) 48 | 49 | return rec, ligands 50 | 51 | 52 | def do_proc(packed): 53 | out_dir, rec_name, path, target = packed 54 | 55 | try: 56 | rec, ligands = load_example(path, target) 57 | 58 | # Save receptor. 59 | prody.writePDB(os.path.join(out_dir, rec_name + '_rec.pdb'), rec) 60 | 61 | # Save ligands. 62 | for lig_name, lig_sel in ligands: 63 | if lig_sel is None: 64 | continue 65 | lig_name = lig_name.replace(' ', '_') 66 | prody.writePDB(os.path.join(out_dir, rec_name + '_' + lig_name + '.pdb'), lig_sel) 67 | 68 | except Exception as e: 69 | print('failed', path) 70 | print(e) 71 | 72 | return None 73 | 74 | 75 | def load_all(moad_dir, moad_csv, out_dir, num_cores=1): 76 | computed = [] 77 | 78 | if os.path.exists(out_dir): 79 | names = os.listdir(out_dir) 80 | names = [x.split('_rec')[0] for x in names if '_rec' in x] 81 | computed = names 82 | else: 83 | os.mkdir(out_dir) 84 | 85 | # Load MOAD csv. 86 | moad_families, moad_targets = parse_moad(moad_csv) 87 | 88 | # Collect input files. 89 | files = [] 90 | for fname in os.listdir(moad_dir): 91 | if fname.startswith('.'): 92 | continue 93 | files.append(os.path.join(moad_dir, fname)) 94 | 95 | files = sorted(files) 96 | print('[*] Loading %d files...' % len(files)) 97 | 98 | failed = [] 99 | info = {} 100 | 101 | # (path, target) 102 | work = [] 103 | 104 | for path in tqdm.tqdm(files): 105 | rec_name = os.path.split(path)[-1].replace('.','_') 106 | if rec_name in computed: 107 | # Skip computed. 108 | continue 109 | 110 | pdb_id = rec_name.split('_')[0].upper() 111 | target = moad_targets[pdb_id] 112 | 113 | work.append((out_dir, rec_name, path, target)) 114 | 115 | print('[*] Starting...') 116 | with multiprocessing.Pool(num_cores) as p: 117 | with tqdm.tqdm(total=len(work)) as pbar: 118 | for r in p.imap_unordered(do_proc, work): 119 | pbar.update() 120 | 121 | print('[*] Done.') 122 | 123 | 124 | def run(): 125 | parser = argparse.ArgumentParser() 126 | 127 | parser.add_argument('-d', '--dataset', required=True, help='Path to MOAD folder') 128 | parser.add_argument('-c', '--csv', required=True, help='Path to MOAD "every.csv"') 129 | parser.add_argument('-o', '--output', default='./processed', help='Output directory') 130 | parser.add_argument('-n', '--num_cores', default=1, type=int, help='Number of cores') 131 | 132 | args = parser.parse_args() 133 | 134 | load_all(args.dataset, args.csv, args.output, args.num_cores) 135 | 136 | 137 | if __name__=='__main__': 138 | run() 139 | -------------------------------------------------------------------------------- /test_installation.sh: -------------------------------------------------------------------------------- 1 | rm -rf test_installation 2 | mkdir test_installation 3 | cd test_installation 4 | wget https://files.rcsb.org/view/1XDN.pdb 5 | cat 1XDN.pdb | grep -v ATP > receptor.pdb 6 | cat 1XDN.pdb | grep ATP > ligand.pdb 7 | 8 | # Remove terminal phosphate 9 | cat ligand.pdb | grep -v "O1G" | grep -v "PG" | grep -v "O2G" | grep -v "O3G" > ligand2.pdb 10 | 11 | cd ../ 12 | python deepfrag.py --receptor test_installation/receptor.pdb --ligand test_installation/ligand2.pdb --cx 44.807 --cy 16.562 --cz 14.092 13 | 14 | 15 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Jacob Durrant 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy 5 | # of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | """ 17 | Utility script to launch training jobs. 18 | """ 19 | 20 | import argparse 21 | import json 22 | 23 | from leadopt.model_conf import MODELS 24 | 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--save_path', help='location to save the model') 29 | parser.add_argument('--wandb_project', help='set this to log run to wandb') 30 | parser.add_argument('--configuration', help='path to a configuration args.json file') 31 | 32 | subparsers = parser.add_subparsers(dest='version') 33 | 34 | for m in MODELS: 35 | sub = subparsers.add_parser(m) 36 | MODELS[m].setup_parser(sub) 37 | 38 | args = parser.parse_args() 39 | args_dict = args.__dict__ 40 | 41 | if args.configuration is None and args.version is None: 42 | parser.print_help() 43 | exit(0) 44 | elif args.configuration is not None: 45 | _args = {} 46 | try: 47 | _args = json.loads(open(args.configuration, 'r').read()) 48 | except Exception as e: 49 | print('Error reading configuration file: %s' % args.configuration) 50 | print(e) 51 | exit(-1) 52 | args_dict.update(_args) 53 | elif args.version is not None: 54 | pass 55 | else: 56 | print('You can specify a model or configuration file but not both.') 57 | exit(-1) 58 | 59 | # Initialize model. 60 | model_type = args_dict['version'] 61 | model = MODELS[model_type](args_dict) 62 | 63 | model.train(args.save_path) 64 | 65 | 66 | if __name__=='__main__': 67 | main() 68 | --------------------------------------------------------------------------------