├── .gitignore ├── LICENSE ├── README.md ├── assets └── cover-large.png ├── bin └── .gitignore ├── configs ├── test │ ├── abopt_singlecdr.yml │ ├── codesign_multicdrs.yml │ ├── codesign_single.yml │ ├── fixbb.yml │ └── strpred.yml └── train │ ├── codesign_fv.yml │ ├── codesign_multicdrs.yml │ ├── codesign_single.yml │ ├── fixbb.yml │ └── strpred.yml ├── data ├── .gitignore ├── examples │ ├── 3QHF_Fv.pdb │ ├── 7DK2_AB_C.pdb │ └── Omicron_RBD.pdb └── sabdab_summary_all.tsv ├── design_dock.py ├── design_pdb.py ├── design_testset.py ├── diffab ├── datasets │ ├── __init__.py │ ├── _base.py │ ├── custom.py │ └── sabdab.py ├── models │ ├── __init__.py │ ├── _base.py │ └── diffab.py ├── modules │ ├── common │ │ ├── geometry.py │ │ ├── layers.py │ │ ├── so3.py │ │ ├── structure.py │ │ └── topology.py │ ├── diffusion │ │ ├── dpm_full.py │ │ └── transition.py │ └── encoders │ │ ├── ga.py │ │ ├── pair.py │ │ └── residue.py ├── tools │ ├── dock │ │ ├── base.py │ │ └── hdock.py │ ├── eval │ │ ├── __main__.py │ │ ├── base.py │ │ ├── energy.py │ │ ├── run.py │ │ └── similarity.py │ ├── relax │ │ ├── __main__.py │ │ ├── base.py │ │ ├── openmm_relaxer.py │ │ ├── pyrosetta_relaxer.py │ │ └── run.py │ ├── renumber │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── run.py │ └── runner │ │ ├── design_for_pdb.py │ │ └── design_for_testset.py └── utils │ ├── data.py │ ├── inference.py │ ├── misc.py │ ├── protein │ ├── constants.py │ ├── parsers.py │ └── writers.py │ ├── train.py │ └── transforms │ ├── __init__.py │ ├── _base.py │ ├── mask.py │ ├── merge.py │ ├── patch.py │ └── select_atom.py ├── env.yaml ├── streamlit_demo.py ├── train.py └── trained_models └── .gitignore /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | /playgrounds* 132 | /logs* 133 | /results* 134 | /*.csv 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffAb 2 | 3 | ![cover-large](./assets/cover-large.png) 4 | 5 | Antigen-Specific Antibody Design and Optimization with Diffusion-Based Generative Models for Protein Structures (NeurIPS 2022) 6 | 7 | [[Paper]](https://www.biorxiv.org/content/10.1101/2022.07.10.499510.abstract)[[Demo]](https://huggingface.co/spaces/luost26/DiffAb) 8 | 9 | ## Install 10 | 11 | ### Environment 12 | 13 | ```bash 14 | conda env create -f env.yaml -n diffab 15 | conda activate diffab 16 | ``` 17 | 18 | The default `cudatoolkit` version is 11.3. You may change it in [`env.yaml`](./env.yaml). 19 | 20 | ### Datasets and Trained Weights 21 | 22 | Protein structures in the `SAbDab` dataset can be downloaded [**here**](https://opig.stats.ox.ac.uk/webapps/newsabdab/sabdab/archive/all/). Extract `all_structures.zip` into the `data` folder. 23 | 24 | The `data` folder contains a snapshot of the dataset index (`sabdab_summary_all.tsv`). You may replace the index with the latest version [**here**](https://opig.stats.ox.ac.uk/webapps/newsabdab/sabdab/summary/all/). 25 | 26 | Trained model weights are available [**here** (Hugging Face)](https://huggingface.co/luost26/DiffAb/tree/main) or [**here** (Google Drive)](https://drive.google.com/drive/folders/15ANqouWRTG2UmQS_p0ErSsrKsU4HmNQc?usp=sharing). 27 | 28 | ### [Optional] HDOCK 29 | 30 | HDOCK is required to design CDRs for antigens without bound antibody frameworks. Please download HDOCK [**here**](http://huanglab.phys.hust.edu.cn/software/hdocklite/) and put the `hdock` and `createpl` programs into the [`bin`](./bin) folder. 31 | 32 | ### [Optional] PyRosetta 33 | 34 | PyRosetta is required to relax the generated structures and compute binding energy. Please follow the instruction [**here**](https://www.pyrosetta.org/downloads) to install. 35 | 36 | ### [Optional] Ray 37 | 38 | Ray is required to relax and evaluate the generated antibodies. Please install Ray using the following command: 39 | 40 | ```bash 41 | pip install -U ray 42 | ``` 43 | 44 | ## Design Antibodies 45 | 46 | 5 design modes are available. Each mode corresponds to a config file in the `configs/test` folder: 47 | 48 | | Config File | Description | 49 | | ------------------------ | ------------------------------------------------------------ | 50 | | `codesign_single.yml` | Sample both the **sequence** and **structure** of **one** CDR. | 51 | | `codesign_multicdrs.yml` | Sample both the **sequence** and **structure** of **all** the CDRs simultaneously. | 52 | | `abopt_singlecdr.yml` | Optimize the **sequence** and **structure** of **one** CDR. | 53 | | `fixbb.yml` | Sample only the **sequence** of **one** CDR (fix-backbone sequence design). | 54 | | `strpred.yml` | Sample only the **structure** of **one** CDR (structure prediction). | 55 | 56 | ### Antibody-Antigen Complex 57 | 58 | Below is the usage of `design_pdb.py`. It samples CDRs for antibody-antigen complexes. The full list of options can be found in [`diffab/tools/runner/design_for_pdb.py`](diffab/tools/runner/design_for_pdb.py). 59 | 60 | ```bash 61 | python design_pdb.py \ 62 | \ 63 | --heavy \ 64 | --light \ 65 | --config 66 | ``` 67 | 68 | The `--heavy` and `--light` options can be omitted as the script can automatically identify them with AbNumber and ANARCI. 69 | 70 | The below example designs the six CDRs separately for the `7DK2_AB_C` antibody-antigen complex. 71 | 72 | ```bash 73 | python design_pdb.py ./data/examples/7DK2_AB_C.pdb \ 74 | --config ./config/test/codesign_single.yml 75 | ``` 76 | 77 | ### Antigen Only 78 | 79 | HDOCK is required to design antibodies for antigens without bound antibody structures (see above for instructions on installing HDOCK). Below is the usage of `design_dock.py`. 80 | 81 | ```bash 82 | python design_dock.py \ 83 | --antigen \ 84 | --antibody \ 85 | --config 86 | ``` 87 | 88 | The `--antibody` option is optional and the default antibody template is [`3QHF_Fv.pdb`](data/examples/3QHF_Fv.pdb). The full list of options can be found in the script. 89 | 90 | Below is an example that designs antibodies for SARS-CoV-2 Omicron RBD. 91 | 92 | ```python 93 | python design_dock.py \ 94 | --antigen ./data/examples/Omicron_RBD.pdb \ 95 | --config ./config/test/codesign_multicdrs.yml 96 | ``` 97 | 98 | ## Train 99 | 100 | ```bash 101 | python train.py ./configs/train/ 102 | ``` 103 | 104 | ## Reference 105 | 106 | ```bibtex 107 | @inproceedings{luo2022antigenspecific, 108 | title={Antigen-Specific Antibody Design and Optimization with Diffusion-Based Generative Models for Protein Structures}, 109 | author={Shitong Luo and Yufeng Su and Xingang Peng and Sheng Wang and Jian Peng and Jianzhu Ma}, 110 | booktitle={Advances in Neural Information Processing Systems}, 111 | editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho}, 112 | year={2022}, 113 | url={https://openreview.net/forum?id=jSorGn2Tjg} 114 | } 115 | ``` 116 | -------------------------------------------------------------------------------- /assets/cover-large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luost26/diffab/c3e2966601bf8025025ab87717b31b08fdd4834e/assets/cover-large.png -------------------------------------------------------------------------------- /bin/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /configs/test/abopt_singlecdr.yml: -------------------------------------------------------------------------------- 1 | mode: abopt 2 | model: 3 | checkpoint: ./trained_models/codesign_single.pt 4 | 5 | sampling: 6 | seed: 2022 7 | sample_structure: true 8 | sample_sequence: true 9 | cdrs: 10 | - H_CDR3 11 | num_samples: 100 12 | optimize_steps: 13 | - 1 14 | - 2 15 | - 4 16 | - 8 17 | - 16 18 | - 32 19 | - 64 20 | 21 | dataset: 22 | test: 23 | type: sabdab 24 | summary_path: ./data/sabdab_summary_all.tsv 25 | chothia_dir: ./data/all_structures/chothia 26 | processed_dir: ./data/processed 27 | split: test 28 | -------------------------------------------------------------------------------- /configs/test/codesign_multicdrs.yml: -------------------------------------------------------------------------------- 1 | mode: multiple_cdrs 2 | model: 3 | checkpoint: ./trained_models/codesign_multicdrs.pt 4 | 5 | sampling: 6 | seed: 2022 7 | sample_structure: true 8 | sample_sequence: true 9 | cdrs: 10 | - H_CDR1 11 | - H_CDR2 12 | - H_CDR3 13 | - L_CDR1 14 | - L_CDR2 15 | - L_CDR3 16 | num_samples: 100 17 | 18 | dataset: 19 | test: 20 | type: sabdab 21 | summary_path: ./data/sabdab_summary_all.tsv 22 | chothia_dir: ./data/all_structures/chothia 23 | processed_dir: ./data/processed 24 | split: test 25 | -------------------------------------------------------------------------------- /configs/test/codesign_single.yml: -------------------------------------------------------------------------------- 1 | mode: single_cdr 2 | model: 3 | checkpoint: ./trained_models/codesign_single.pt 4 | 5 | sampling: 6 | seed: 2022 7 | sample_structure: true 8 | sample_sequence: true 9 | cdrs: 10 | - H_CDR1 11 | - H_CDR2 12 | - H_CDR3 13 | - L_CDR1 14 | - L_CDR2 15 | - L_CDR3 16 | num_samples: 100 17 | 18 | dataset: 19 | test: 20 | type: sabdab 21 | summary_path: ./data/sabdab_summary_all.tsv 22 | chothia_dir: ./data/all_structures/chothia 23 | processed_dir: ./data/processed 24 | split: test 25 | -------------------------------------------------------------------------------- /configs/test/fixbb.yml: -------------------------------------------------------------------------------- 1 | mode: single_cdr 2 | model: 3 | checkpoint: ./trained_models/fixbb.pt 4 | 5 | sampling: 6 | seed: 2022 7 | sample_structure: false 8 | sample_sequence: true 9 | cdrs: 10 | - H_CDR1 11 | - H_CDR2 12 | - H_CDR3 13 | - L_CDR1 14 | - L_CDR2 15 | - L_CDR3 16 | num_samples: 100 17 | 18 | dataset: 19 | test: 20 | type: sabdab 21 | summary_path: ./data/sabdab_summary_all.tsv 22 | chothia_dir: ./data/all_structures/chothia 23 | processed_dir: ./data/processed 24 | split: test 25 | -------------------------------------------------------------------------------- /configs/test/strpred.yml: -------------------------------------------------------------------------------- 1 | mode: single_cdr 2 | model: 3 | checkpoint: ./trained_models/structure_pred.pt 4 | 5 | sampling: 6 | seed: 2022 7 | sample_structure: true 8 | sample_sequence: false 9 | cdrs: 10 | - H_CDR1 11 | - H_CDR2 12 | - H_CDR3 13 | - L_CDR1 14 | - L_CDR2 15 | - L_CDR3 16 | num_samples: 100 17 | 18 | dataset: 19 | test: 20 | type: sabdab 21 | summary_path: ./data/sabdab_summary_all.tsv 22 | chothia_dir: ./data/all_structures/chothia 23 | processed_dir: ./data/processed 24 | split: test 25 | -------------------------------------------------------------------------------- /configs/train/codesign_fv.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: diffab 3 | res_feat_dim: 128 4 | pair_feat_dim: 64 5 | diffusion: 6 | num_steps: 100 7 | eps_net_opt: 8 | num_layers: 6 9 | train_structure: true 10 | train_sequence: true 11 | 12 | train: 13 | loss_weights: 14 | rot: 1.0 15 | pos: 1.0 16 | seq: 1.0 17 | max_iters: 200_000 18 | val_freq: 1000 19 | batch_size: 16 20 | seed: 2022 21 | max_grad_norm: 100.0 22 | optimizer: 23 | type: adam 24 | lr: 1.e-4 25 | weight_decay: 0.0 26 | beta1: 0.9 27 | beta2: 0.999 28 | scheduler: 29 | type: plateau 30 | factor: 0.8 31 | patience: 10 32 | min_lr: 5.e-6 33 | 34 | dataset: 35 | train: 36 | type: sabdab 37 | summary_path: ./data/sabdab_summary_all.tsv 38 | chothia_dir: ./data/all_structures/chothia 39 | processed_dir: ./data/processed 40 | split: train 41 | transform: 42 | - type: mask_antibody 43 | - type: merge_chains 44 | - type: patch_around_anchor 45 | val: 46 | type: sabdab 47 | summary_path: ./data/sabdab_summary_all.tsv 48 | chothia_dir: ./data/all_structures/chothia 49 | processed_dir: ./data/processed 50 | split: val 51 | transform: 52 | - type: mask_antibody 53 | - type: merge_chains 54 | - type: patch_around_anchor 55 | -------------------------------------------------------------------------------- /configs/train/codesign_multicdrs.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: diffab 3 | res_feat_dim: 128 4 | pair_feat_dim: 64 5 | diffusion: 6 | num_steps: 100 7 | eps_net_opt: 8 | num_layers: 6 9 | train_structure: true 10 | train_sequence: true 11 | 12 | train: 13 | loss_weights: 14 | rot: 1.0 15 | pos: 1.0 16 | seq: 1.0 17 | max_iters: 200_000 18 | val_freq: 1000 19 | batch_size: 16 20 | seed: 2022 21 | max_grad_norm: 100.0 22 | optimizer: 23 | type: adam 24 | lr: 1.e-4 25 | weight_decay: 0.0 26 | beta1: 0.9 27 | beta2: 0.999 28 | scheduler: 29 | type: plateau 30 | factor: 0.8 31 | patience: 10 32 | min_lr: 5.e-6 33 | 34 | dataset: 35 | train: 36 | type: sabdab 37 | summary_path: ./data/sabdab_summary_all.tsv 38 | chothia_dir: ./data/all_structures/chothia 39 | processed_dir: ./data/processed 40 | split: train 41 | transform: 42 | - type: mask_multiple_cdrs 43 | - type: merge_chains 44 | - type: patch_around_anchor 45 | val: 46 | type: sabdab 47 | summary_path: ./data/sabdab_summary_all.tsv 48 | chothia_dir: ./data/all_structures/chothia 49 | processed_dir: ./data/processed 50 | split: val 51 | transform: 52 | - type: mask_single_cdr # Mask only CDR3 at validation 53 | selection: CDR3 54 | - type: merge_chains 55 | - type: patch_around_anchor 56 | -------------------------------------------------------------------------------- /configs/train/codesign_single.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: diffab 3 | res_feat_dim: 128 4 | pair_feat_dim: 64 5 | diffusion: 6 | num_steps: 100 7 | eps_net_opt: 8 | num_layers: 6 9 | train_structure: true 10 | train_sequence: true 11 | 12 | train: 13 | loss_weights: 14 | rot: 1.0 15 | pos: 1.0 16 | seq: 1.0 17 | max_iters: 200_000 18 | val_freq: 1000 19 | batch_size: 16 20 | seed: 2022 21 | max_grad_norm: 100.0 22 | optimizer: 23 | type: adam 24 | lr: 1.e-4 25 | weight_decay: 0.0 26 | beta1: 0.9 27 | beta2: 0.999 28 | scheduler: 29 | type: plateau 30 | factor: 0.8 31 | patience: 10 32 | min_lr: 5.e-6 33 | 34 | dataset: 35 | train: 36 | type: sabdab 37 | summary_path: ./data/sabdab_summary_all.tsv 38 | chothia_dir: ./data/all_structures/chothia 39 | processed_dir: ./data/processed 40 | split: train 41 | transform: 42 | - type: mask_single_cdr 43 | - type: merge_chains 44 | - type: patch_around_anchor 45 | val: 46 | type: sabdab 47 | summary_path: ./data/sabdab_summary_all.tsv 48 | chothia_dir: ./data/all_structures/chothia 49 | processed_dir: ./data/processed 50 | split: val 51 | transform: 52 | - type: mask_single_cdr 53 | selection: CDR3 54 | - type: merge_chains 55 | - type: patch_around_anchor 56 | -------------------------------------------------------------------------------- /configs/train/fixbb.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: diffab 3 | resolution: backbone+CB 4 | res_feat_dim: 128 5 | pair_feat_dim: 64 6 | diffusion: 7 | num_steps: 100 8 | eps_net_opt: 9 | num_layers: 6 10 | train_structure: false 11 | train_sequence: true 12 | 13 | train: 14 | loss_weights: 15 | rot: 1.0 16 | pos: 1.0 17 | seq: 1.0 18 | max_iters: 200_000 19 | val_freq: 1000 20 | batch_size: 16 21 | seed: 2022 22 | max_grad_norm: 100.0 23 | optimizer: 24 | type: adam 25 | lr: 1.e-4 26 | weight_decay: 0.0 27 | beta1: 0.9 28 | beta2: 0.999 29 | scheduler: 30 | type: plateau 31 | factor: 0.8 32 | patience: 10 33 | min_lr: 5.e-6 34 | 35 | dataset: 36 | train: 37 | type: sabdab 38 | summary_path: ./data/sabdab_summary_all.tsv 39 | chothia_dir: ./data/all_structures/chothia 40 | processed_dir: ./data/processed 41 | split: train 42 | transform: 43 | - type: mask_single_cdr 44 | - type: merge_chains 45 | - type: patch_around_anchor 46 | val: 47 | type: sabdab 48 | summary_path: ./data/sabdab_summary_all.tsv 49 | chothia_dir: ./data/all_structures/chothia 50 | processed_dir: ./data/processed 51 | split: val 52 | transform: 53 | - type: mask_single_cdr 54 | selection: CDR3 55 | - type: merge_chains 56 | - type: patch_around_anchor 57 | -------------------------------------------------------------------------------- /configs/train/strpred.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: diffab 3 | res_feat_dim: 128 4 | pair_feat_dim: 64 5 | diffusion: 6 | num_steps: 100 7 | eps_net_opt: 8 | num_layers: 6 9 | train_structure: true 10 | train_sequence: false 11 | 12 | train: 13 | loss_weights: 14 | rot: 1.0 15 | pos: 1.0 16 | seq: 1.0 17 | max_iters: 200_000 18 | val_freq: 1000 19 | batch_size: 16 20 | seed: 2022 21 | max_grad_norm: 100.0 22 | optimizer: 23 | type: adam 24 | lr: 1.e-4 25 | weight_decay: 0.0 26 | beta1: 0.9 27 | beta2: 0.999 28 | scheduler: 29 | type: plateau 30 | factor: 0.8 31 | patience: 10 32 | min_lr: 5.e-6 33 | 34 | dataset: 35 | train: 36 | type: sabdab 37 | summary_path: ./data/sabdab_summary_all.tsv 38 | chothia_dir: ./data/all_structures/chothia 39 | processed_dir: ./data/processed 40 | split: train 41 | transform: 42 | - type: mask_single_cdr 43 | - type: merge_chains 44 | - type: patch_around_anchor 45 | val: 46 | type: sabdab 47 | summary_path: ./data/sabdab_summary_all.tsv 48 | chothia_dir: ./data/all_structures/chothia 49 | processed_dir: ./data/processed 50 | split: val 51 | transform: 52 | - type: mask_single_cdr 53 | selection: CDR3 54 | - type: merge_chains 55 | - type: patch_around_anchor 56 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !sabdab_summary_all.tsv 4 | !examples 5 | !examples/* 6 | -------------------------------------------------------------------------------- /design_dock.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | from diffab.tools.dock.hdock import HDockAntibody 5 | from diffab.tools.runner.design_for_pdb import args_factory, design_for_pdb 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--antigen', type=str, required=True) 11 | parser.add_argument('--antibody', type=str, default='./data/examples/3QHF_Fv.pdb') 12 | parser.add_argument('--heavy', type=str, default='H', help='Chain id of the heavy chain.') 13 | parser.add_argument('--light', type=str, default='L', help='Chain id of the light chain.') 14 | parser.add_argument('--hdock_bin', type=str, default='./bin/hdock') 15 | parser.add_argument('--createpl_bin', type=str, default='./bin/createpl') 16 | parser.add_argument('-n', '--num_docks', type=int, default=10) 17 | parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml') 18 | parser.add_argument('-o', '--out_root', type=str, default='./results') 19 | parser.add_argument('-t', '--tag', type=str, default='') 20 | parser.add_argument('-s', '--seed', type=int, default=None) 21 | parser.add_argument('-d', '--device', type=str, default='cuda') 22 | parser.add_argument('-b', '--batch_size', type=int, default=16) 23 | args = parser.parse_args() 24 | 25 | hdock_missing = [] 26 | if not os.path.exists(args.hdock_bin): 27 | hdock_missing.append(args.hdock_bin) 28 | if not os.path.exists(args.createpl_bin): 29 | hdock_missing.append(args.createpl_bin) 30 | if len(hdock_missing) > 0: 31 | print("[WARNING] The following HDOCK applications are missing:") 32 | for f in hdock_missing: 33 | print(f" > {f}") 34 | print("Please download HDOCK from http://huanglab.phys.hust.edu.cn/software/hdocklite/ " 35 | "and put `hdock` and `createpl` to the above path.") 36 | exit() 37 | 38 | antigen_name = os.path.basename(os.path.splitext(args.antigen)[0]) 39 | docked_pdb_dir = os.path.join(os.path.splitext(args.antigen)[0] + '_dock') 40 | os.makedirs(docked_pdb_dir, exist_ok=True) 41 | docked_pdb_paths = [] 42 | for fname in os.listdir(docked_pdb_dir): 43 | if fname.endswith('.pdb'): 44 | docked_pdb_paths.append(os.path.join(docked_pdb_dir, fname)) 45 | if len(docked_pdb_paths) < args.num_docks: 46 | with HDockAntibody() as dock_session: 47 | dock_session.set_antigen(args.antigen) 48 | dock_session.set_antibody(args.antibody) 49 | docked_tmp_paths = dock_session.dock() 50 | for i, tmp_path in enumerate(docked_tmp_paths[:args.num_docks]): 51 | dest_path = os.path.join(docked_pdb_dir, f"{antigen_name}_Ab_{i:04d}.pdb") 52 | shutil.copyfile(tmp_path, dest_path) 53 | print(f'[INFO] Copy {tmp_path} -> {dest_path}') 54 | docked_pdb_paths.append(dest_path) 55 | 56 | for pdb_path in docked_pdb_paths: 57 | current_args = vars(args) 58 | current_args['tag'] += antigen_name 59 | design_args = args_factory( 60 | pdb_path = pdb_path, 61 | **current_args, 62 | ) 63 | design_for_pdb(design_args) 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /design_pdb.py: -------------------------------------------------------------------------------- 1 | from diffab.tools.runner.design_for_pdb import args_from_cmdline, design_for_pdb 2 | 3 | if __name__ == '__main__': 4 | design_for_pdb(args_from_cmdline()) 5 | -------------------------------------------------------------------------------- /design_testset.py: -------------------------------------------------------------------------------- 1 | from diffab.tools.runner.design_for_testset import main 2 | 3 | if __name__ == '__main__': 4 | main() 5 | -------------------------------------------------------------------------------- /diffab/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .sabdab import SAbDabDataset 2 | from .custom import CustomDataset 3 | 4 | from ._base import get_dataset 5 | -------------------------------------------------------------------------------- /diffab/datasets/_base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, ConcatDataset 2 | from diffab.utils.transforms import get_transform 3 | 4 | 5 | _DATASET_DICT = {} 6 | 7 | 8 | def register_dataset(name): 9 | def decorator(cls): 10 | _DATASET_DICT[name] = cls 11 | return cls 12 | return decorator 13 | 14 | 15 | def get_dataset(cfg): 16 | transform = get_transform(cfg.transform) if 'transform' in cfg else None 17 | return _DATASET_DICT[cfg.type](cfg, transform=transform) 18 | 19 | 20 | @register_dataset('concat') 21 | def get_concat_dataset(cfg): 22 | datasets = [get_dataset(d) for d in cfg.datasets] 23 | return ConcatDataset(datasets) 24 | 25 | 26 | @register_dataset('balanced_concat') 27 | class BalancedConcatDataset(Dataset): 28 | 29 | def __init__(self, cfg, transform=None): 30 | super().__init__() 31 | assert transform is None, 'transform is not supported.' 32 | self.datasets = [get_dataset(d) for d in cfg.datasets] 33 | self.max_size = max([len(d) for d in self.datasets]) 34 | 35 | def __len__(self): 36 | return self.max_size * len(self.datasets) 37 | 38 | def __getitem__(self, idx): 39 | dataset_idx = idx // self.max_size 40 | return self.datasets[dataset_idx][idx % len(self.datasets[dataset_idx])] 41 | -------------------------------------------------------------------------------- /diffab/datasets/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import joblib 4 | import pickle 5 | import lmdb 6 | from Bio import PDB 7 | from Bio.PDB import PDBExceptions 8 | from torch.utils.data import Dataset 9 | from tqdm.auto import tqdm 10 | 11 | from ..utils.protein import parsers 12 | from .sabdab import _label_heavy_chain_cdr, _label_light_chain_cdr 13 | from ._base import register_dataset 14 | 15 | 16 | def preprocess_antibody_structure(task): 17 | pdb_path = task['pdb_path'] 18 | H_id = task.get('heavy_id', 'H') 19 | L_id = task.get('light_id', 'L') 20 | 21 | parser = PDB.PDBParser(QUIET=True) 22 | model = parser.get_structure(id, pdb_path)[0] 23 | 24 | all_chain_ids = [c.id for c in model] 25 | 26 | parsed = { 27 | 'id': task['id'], 28 | 'heavy': None, 29 | 'heavy_seqmap': None, 30 | 'light': None, 31 | 'light_seqmap': None, 32 | 'antigen': None, 33 | 'antigen_seqmap': None, 34 | } 35 | try: 36 | if H_id in all_chain_ids: 37 | ( 38 | parsed['heavy'], 39 | parsed['heavy_seqmap'] 40 | ) = _label_heavy_chain_cdr(*parsers.parse_biopython_structure( 41 | model[H_id], 42 | max_resseq = 113 # Chothia, end of Heavy chain Fv 43 | )) 44 | 45 | if L_id in all_chain_ids: 46 | ( 47 | parsed['light'], 48 | parsed['light_seqmap'] 49 | ) = _label_light_chain_cdr(*parsers.parse_biopython_structure( 50 | model[L_id], 51 | max_resseq = 106 # Chothia, end of Light chain Fv 52 | )) 53 | 54 | if parsed['heavy'] is None and parsed['light'] is None: 55 | raise ValueError( 56 | f'Neither valid antibody H-chain or L-chain is found. ' 57 | f'Please ensure that the chain id of heavy chain is "{H_id}" ' 58 | f'and the id of the light chain is "{L_id}".' 59 | ) 60 | 61 | 62 | ag_chain_ids = [cid for cid in all_chain_ids if cid not in (H_id, L_id)] 63 | if len(ag_chain_ids) > 0: 64 | chains = [model[c] for c in ag_chain_ids] 65 | ( 66 | parsed['antigen'], 67 | parsed['antigen_seqmap'] 68 | ) = parsers.parse_biopython_structure(chains) 69 | 70 | except ( 71 | PDBExceptions.PDBConstructionException, 72 | parsers.ParsingException, 73 | KeyError, 74 | ValueError, 75 | ) as e: 76 | logging.warning('[{}] {}: {}'.format( 77 | task['id'], 78 | e.__class__.__name__, 79 | str(e) 80 | )) 81 | return None 82 | 83 | return parsed 84 | 85 | 86 | @register_dataset('custom') 87 | class CustomDataset(Dataset): 88 | 89 | MAP_SIZE = 32*(1024*1024*1024) # 32GB 90 | 91 | def __init__(self, structure_dir, transform=None, reset=False): 92 | super().__init__() 93 | self.structure_dir = structure_dir 94 | self.transform = transform 95 | 96 | self.db_conn = None 97 | self.db_ids = None 98 | self._load_structures(reset) 99 | 100 | @property 101 | def _cache_db_path(self): 102 | return os.path.join(self.structure_dir, 'structure_cache.lmdb') 103 | 104 | def _connect_db(self): 105 | self._close_db() 106 | self.db_conn = lmdb.open( 107 | self._cache_db_path, 108 | map_size=self.MAP_SIZE, 109 | create=False, 110 | subdir=False, 111 | readonly=True, 112 | lock=False, 113 | readahead=False, 114 | meminit=False, 115 | ) 116 | with self.db_conn.begin() as txn: 117 | keys = [k.decode() for k in txn.cursor().iternext(values=False)] 118 | self.db_ids = keys 119 | 120 | def _close_db(self): 121 | if self.db_conn is not None: 122 | self.db_conn.close() 123 | self.db_conn = None 124 | self.db_ids = None 125 | 126 | def _load_structures(self, reset): 127 | all_pdbs = [] 128 | for fname in os.listdir(self.structure_dir): 129 | if not fname.endswith('.pdb'): continue 130 | all_pdbs.append(fname) 131 | 132 | if reset or not os.path.exists(self._cache_db_path): 133 | todo_pdbs = all_pdbs 134 | else: 135 | self._connect_db() 136 | processed_pdbs = self.db_ids 137 | self._close_db() 138 | todo_pdbs = list(set(all_pdbs) - set(processed_pdbs)) 139 | 140 | if len(todo_pdbs) > 0: 141 | self._preprocess_structures(todo_pdbs) 142 | 143 | def _preprocess_structures(self, pdb_list): 144 | tasks = [] 145 | for pdb_fname in pdb_list: 146 | pdb_path = os.path.join(self.structure_dir, pdb_fname) 147 | tasks.append({ 148 | 'id': pdb_fname, 149 | 'pdb_path': pdb_path, 150 | }) 151 | 152 | data_list = joblib.Parallel( 153 | n_jobs = max(joblib.cpu_count() // 2, 1), 154 | )( 155 | joblib.delayed(preprocess_antibody_structure)(task) 156 | for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess') 157 | ) 158 | 159 | db_conn = lmdb.open( 160 | self._cache_db_path, 161 | map_size = self.MAP_SIZE, 162 | create=True, 163 | subdir=False, 164 | readonly=False, 165 | ) 166 | ids = [] 167 | with db_conn.begin(write=True, buffers=True) as txn: 168 | for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'): 169 | if data is None: 170 | continue 171 | ids.append(data['id']) 172 | txn.put(data['id'].encode('utf-8'), pickle.dumps(data)) 173 | 174 | def __len__(self): 175 | return len(self.db_ids) 176 | 177 | def __getitem__(self, index): 178 | self._connect_db() 179 | id = self.db_ids[index] 180 | with self.db_conn.begin() as txn: 181 | data = pickle.loads(txn.get(id.encode())) 182 | if self.transform is not None: 183 | data = self.transform(data) 184 | return data 185 | 186 | 187 | if __name__ == '__main__': 188 | import argparse 189 | parser = argparse.ArgumentParser() 190 | parser.add_argument('--dir', type=str, default='./data/custom') 191 | parser.add_argument('--reset', action='store_true', default=False) 192 | args = parser.parse_args() 193 | 194 | dataset = CustomDataset( 195 | structure_dir = args.dir, 196 | reset = args.reset, 197 | ) 198 | print(dataset[0]) 199 | print(len(dataset)) 200 | -------------------------------------------------------------------------------- /diffab/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffab import DiffusionAntibodyDesign 2 | 3 | from ._base import get_model 4 | -------------------------------------------------------------------------------- /diffab/models/_base.py: -------------------------------------------------------------------------------- 1 | 2 | _MODEL_DICT = {} 3 | 4 | 5 | def register_model(name): 6 | def decorator(cls): 7 | _MODEL_DICT[name] = cls 8 | return cls 9 | return decorator 10 | 11 | 12 | def get_model(cfg): 13 | return _MODEL_DICT[cfg.type](cfg) 14 | -------------------------------------------------------------------------------- /diffab/models/diffab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from diffab.modules.common.geometry import construct_3d_basis 5 | from diffab.modules.common.so3 import rotation_to_so3vec 6 | from diffab.modules.encoders.residue import ResidueEmbedding 7 | from diffab.modules.encoders.pair import PairEmbedding 8 | from diffab.modules.diffusion.dpm_full import FullDPM 9 | from diffab.utils.protein.constants import max_num_heavyatoms, BBHeavyAtom 10 | from ._base import register_model 11 | 12 | 13 | resolution_to_num_atoms = { 14 | 'backbone+CB': 5, 15 | 'full': max_num_heavyatoms 16 | } 17 | 18 | 19 | @register_model('diffab') 20 | class DiffusionAntibodyDesign(nn.Module): 21 | 22 | def __init__(self, cfg): 23 | super().__init__() 24 | self.cfg = cfg 25 | 26 | num_atoms = resolution_to_num_atoms[cfg.get('resolution', 'full')] 27 | self.residue_embed = ResidueEmbedding(cfg.res_feat_dim, num_atoms) 28 | self.pair_embed = PairEmbedding(cfg.pair_feat_dim, num_atoms) 29 | 30 | self.diffusion = FullDPM( 31 | cfg.res_feat_dim, 32 | cfg.pair_feat_dim, 33 | **cfg.diffusion, 34 | ) 35 | 36 | def encode(self, batch, remove_structure, remove_sequence): 37 | """ 38 | Returns: 39 | res_feat: (N, L, res_feat_dim) 40 | pair_feat: (N, L, L, pair_feat_dim) 41 | """ 42 | # This is used throughout embedding and encoding layers 43 | # to avoid data leakage. 44 | context_mask = torch.logical_and( 45 | batch['mask_heavyatom'][:, :, BBHeavyAtom.CA], 46 | ~batch['generate_flag'] # Context means ``not generated'' 47 | ) 48 | 49 | structure_mask = context_mask if remove_structure else None 50 | sequence_mask = context_mask if remove_sequence else None 51 | 52 | res_feat = self.residue_embed( 53 | aa = batch['aa'], 54 | res_nb = batch['res_nb'], 55 | chain_nb = batch['chain_nb'], 56 | pos_atoms = batch['pos_heavyatom'], 57 | mask_atoms = batch['mask_heavyatom'], 58 | fragment_type = batch['fragment_type'], 59 | structure_mask = structure_mask, 60 | sequence_mask = sequence_mask, 61 | ) 62 | 63 | pair_feat = self.pair_embed( 64 | aa = batch['aa'], 65 | res_nb = batch['res_nb'], 66 | chain_nb = batch['chain_nb'], 67 | pos_atoms = batch['pos_heavyatom'], 68 | mask_atoms = batch['mask_heavyatom'], 69 | structure_mask = structure_mask, 70 | sequence_mask = sequence_mask, 71 | ) 72 | 73 | R = construct_3d_basis( 74 | batch['pos_heavyatom'][:, :, BBHeavyAtom.CA], 75 | batch['pos_heavyatom'][:, :, BBHeavyAtom.C], 76 | batch['pos_heavyatom'][:, :, BBHeavyAtom.N], 77 | ) 78 | p = batch['pos_heavyatom'][:, :, BBHeavyAtom.CA] 79 | 80 | return res_feat, pair_feat, R, p 81 | 82 | def forward(self, batch): 83 | mask_generate = batch['generate_flag'] 84 | mask_res = batch['mask'] 85 | res_feat, pair_feat, R_0, p_0 = self.encode( 86 | batch, 87 | remove_structure = self.cfg.get('train_structure', True), 88 | remove_sequence = self.cfg.get('train_sequence', True) 89 | ) 90 | v_0 = rotation_to_so3vec(R_0) 91 | s_0 = batch['aa'] 92 | 93 | loss_dict = self.diffusion( 94 | v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, 95 | denoise_structure = self.cfg.get('train_structure', True), 96 | denoise_sequence = self.cfg.get('train_sequence', True), 97 | ) 98 | return loss_dict 99 | 100 | @torch.no_grad() 101 | def sample( 102 | self, 103 | batch, 104 | sample_opt={ 105 | 'sample_structure': True, 106 | 'sample_sequence': True, 107 | } 108 | ): 109 | mask_generate = batch['generate_flag'] 110 | mask_res = batch['mask'] 111 | res_feat, pair_feat, R_0, p_0 = self.encode( 112 | batch, 113 | remove_structure = sample_opt.get('sample_structure', True), 114 | remove_sequence = sample_opt.get('sample_sequence', True) 115 | ) 116 | v_0 = rotation_to_so3vec(R_0) 117 | s_0 = batch['aa'] 118 | traj = self.diffusion.sample(v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, **sample_opt) 119 | return traj 120 | 121 | @torch.no_grad() 122 | def optimize( 123 | self, 124 | batch, 125 | opt_step, 126 | optimize_opt={ 127 | 'sample_structure': True, 128 | 'sample_sequence': True, 129 | } 130 | ): 131 | mask_generate = batch['generate_flag'] 132 | mask_res = batch['mask'] 133 | res_feat, pair_feat, R_0, p_0 = self.encode( 134 | batch, 135 | remove_structure = optimize_opt.get('sample_structure', True), 136 | remove_sequence = optimize_opt.get('sample_sequence', True) 137 | ) 138 | v_0 = rotation_to_so3vec(R_0) 139 | s_0 = batch['aa'] 140 | 141 | traj = self.diffusion.optimize(v_0, p_0, s_0, opt_step, res_feat, pair_feat, mask_generate, mask_res, **optimize_opt) 142 | return traj 143 | -------------------------------------------------------------------------------- /diffab/modules/common/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def mask_zero(mask, value): 7 | return torch.where(mask, value, torch.zeros_like(value)) 8 | 9 | 10 | def clampped_one_hot(x, num_classes): 11 | mask = (x >= 0) & (x < num_classes) # (N, L) 12 | x = x.clamp(min=0, max=num_classes-1) 13 | y = F.one_hot(x, num_classes) * mask[...,None] # (N, L, C) 14 | return y 15 | 16 | 17 | class DistanceToBins(nn.Module): 18 | 19 | def __init__(self, dist_min=0.0, dist_max=20.0, num_bins=64, use_onehot=False): 20 | super().__init__() 21 | self.dist_min = dist_min 22 | self.dist_max = dist_max 23 | self.num_bins = num_bins 24 | self.use_onehot = use_onehot 25 | 26 | if use_onehot: 27 | offset = torch.linspace(dist_min, dist_max, self.num_bins) 28 | else: 29 | offset = torch.linspace(dist_min, dist_max, self.num_bins-1) # 1 overflow flag 30 | self.coeff = -0.5 / ((offset[1] - offset[0]) * 0.2).item() ** 2 # `*0.2`: makes it not too blurred 31 | self.register_buffer('offset', offset) 32 | 33 | @property 34 | def out_channels(self): 35 | return self.num_bins 36 | 37 | def forward(self, dist, dim, normalize=True): 38 | """ 39 | Args: 40 | dist: (N, *, 1, *) 41 | Returns: 42 | (N, *, num_bins, *) 43 | """ 44 | assert dist.size()[dim] == 1 45 | offset_shape = [1] * len(dist.size()) 46 | offset_shape[dim] = -1 47 | 48 | if self.use_onehot: 49 | diff = torch.abs(dist - self.offset.view(*offset_shape)) # (N, *, num_bins, *) 50 | bin_idx = torch.argmin(diff, dim=dim, keepdim=True) # (N, *, 1, *) 51 | y = torch.zeros_like(diff).scatter_(dim=dim, index=bin_idx, value=1.0) 52 | else: 53 | overflow_symb = (dist >= self.dist_max).float() # (N, *, 1, *) 54 | y = dist - self.offset.view(*offset_shape) # (N, *, num_bins-1, *) 55 | y = torch.exp(self.coeff * torch.pow(y, 2)) # (N, *, num_bins-1, *) 56 | y = torch.cat([y, overflow_symb], dim=dim) # (N, *, num_bins, *) 57 | if normalize: 58 | y = y / y.sum(dim=dim, keepdim=True) 59 | 60 | return y 61 | 62 | 63 | class PositionalEncoding(nn.Module): 64 | 65 | def __init__(self, num_funcs=6): 66 | super().__init__() 67 | self.num_funcs = num_funcs 68 | self.register_buffer('freq_bands', 2.0 ** torch.linspace(0.0, num_funcs-1, num_funcs)) 69 | 70 | def get_out_dim(self, in_dim): 71 | return in_dim * (2 * self.num_funcs + 1) 72 | 73 | def forward(self, x): 74 | """ 75 | Args: 76 | x: (..., d). 77 | """ 78 | shape = list(x.shape[:-1]) + [-1] 79 | x = x.unsqueeze(-1) # (..., d, 1) 80 | code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1) 81 | code = code.reshape(shape) 82 | return code 83 | 84 | 85 | class AngularEncoding(nn.Module): 86 | 87 | def __init__(self, num_funcs=3): 88 | super().__init__() 89 | self.num_funcs = num_funcs 90 | self.register_buffer('freq_bands', torch.FloatTensor( 91 | [i+1 for i in range(num_funcs)] + [1./(i+1) for i in range(num_funcs)] 92 | )) 93 | 94 | def get_out_dim(self, in_dim): 95 | return in_dim * (1 + 2*2*self.num_funcs) 96 | 97 | def forward(self, x): 98 | """ 99 | Args: 100 | x: (..., d). 101 | """ 102 | shape = list(x.shape[:-1]) + [-1] 103 | x = x.unsqueeze(-1) # (..., d, 1) 104 | code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1) 105 | code = code.reshape(shape) 106 | return code 107 | 108 | 109 | class LayerNorm(nn.Module): 110 | 111 | def __init__(self, 112 | normal_shape, 113 | gamma=True, 114 | beta=True, 115 | epsilon=1e-10): 116 | """Layer normalization layer 117 | See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf) 118 | :param normal_shape: The shape of the input tensor or the last dimension of the input tensor. 119 | :param gamma: Add a scale parameter if it is True. 120 | :param beta: Add an offset parameter if it is True. 121 | :param epsilon: Epsilon for calculating variance. 122 | """ 123 | super().__init__() 124 | if isinstance(normal_shape, int): 125 | normal_shape = (normal_shape,) 126 | else: 127 | normal_shape = (normal_shape[-1],) 128 | self.normal_shape = torch.Size(normal_shape) 129 | self.epsilon = epsilon 130 | if gamma: 131 | self.gamma = nn.Parameter(torch.Tensor(*normal_shape)) 132 | else: 133 | self.register_parameter('gamma', None) 134 | if beta: 135 | self.beta = nn.Parameter(torch.Tensor(*normal_shape)) 136 | else: 137 | self.register_parameter('beta', None) 138 | self.reset_parameters() 139 | 140 | def reset_parameters(self): 141 | if self.gamma is not None: 142 | self.gamma.data.fill_(1) 143 | if self.beta is not None: 144 | self.beta.data.zero_() 145 | 146 | def forward(self, x): 147 | mean = x.mean(dim=-1, keepdim=True) 148 | var = ((x - mean) ** 2).mean(dim=-1, keepdim=True) 149 | std = (var + self.epsilon).sqrt() 150 | y = (x - mean) / std 151 | if self.gamma is not None: 152 | y *= self.gamma 153 | if self.beta is not None: 154 | y += self.beta 155 | return y 156 | 157 | def extra_repr(self): 158 | return 'normal_shape={}, gamma={}, beta={}, epsilon={}'.format( 159 | self.normal_shape, self.gamma is not None, self.beta is not None, self.epsilon, 160 | ) 161 | -------------------------------------------------------------------------------- /diffab/modules/common/so3.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .geometry import quaternion_to_rotation_matrix 8 | 9 | 10 | def log_rotation(R): 11 | trace = R[..., range(3), range(3)].sum(-1) 12 | if torch.is_grad_enabled(): 13 | # The derivative of acos at -1.0 is -inf, so to stablize the gradient, we use -0.9999 14 | min_cos = -0.999 15 | else: 16 | min_cos = -1.0 17 | cos_theta = ( (trace-1) / 2 ).clamp_min(min=min_cos) 18 | sin_theta = torch.sqrt(1 - cos_theta**2) 19 | theta = torch.acos(cos_theta) 20 | coef = ((theta+1e-8)/(2*sin_theta+2e-8))[..., None, None] 21 | logR = coef * (R - R.transpose(-1, -2)) 22 | return logR 23 | 24 | 25 | def skewsym_to_so3vec(S): 26 | x = S[..., 1, 2] 27 | y = S[..., 2, 0] 28 | z = S[..., 0, 1] 29 | w = torch.stack([x,y,z], dim=-1) 30 | return w 31 | 32 | 33 | def so3vec_to_skewsym(w): 34 | x, y, z = torch.unbind(w, dim=-1) 35 | o = torch.zeros_like(x) 36 | S = torch.stack([ 37 | o, z, -y, 38 | -z, o, x, 39 | y, -x, o, 40 | ], dim=-1).reshape(w.shape[:-1] + (3, 3)) 41 | return S 42 | 43 | 44 | def exp_skewsym(S): 45 | x = torch.linalg.norm(skewsym_to_so3vec(S), dim=-1) 46 | I = torch.eye(3).to(S).view([1 for _ in range(S.dim()-2)] + [3, 3]) 47 | 48 | sinx, cosx = torch.sin(x), torch.cos(x) 49 | b = (sinx + 1e-8) / (x + 1e-8) 50 | c = (1-cosx + 1e-8) / (x**2 + 2e-8) # lim_{x->0} (1-cosx)/(x^2) = 0.5 51 | 52 | S2 = S @ S 53 | return I + b[..., None, None]*S + c[..., None, None]*S2 54 | 55 | 56 | def so3vec_to_rotation(w): 57 | return exp_skewsym(so3vec_to_skewsym(w)) 58 | 59 | 60 | def rotation_to_so3vec(R): 61 | logR = log_rotation(R) 62 | w = skewsym_to_so3vec(logR) 63 | return w 64 | 65 | 66 | def random_uniform_so3(size, device='cpu'): 67 | q = F.normalize(torch.randn(list(size)+[4,], device=device), dim=-1) # (..., 4) 68 | return rotation_to_so3vec(quaternion_to_rotation_matrix(q)) 69 | 70 | 71 | class ApproxAngularDistribution(nn.Module): 72 | 73 | def __init__(self, stddevs, std_threshold=0.1, num_bins=8192, num_iters=1024): 74 | super().__init__() 75 | self.std_threshold = std_threshold 76 | self.num_bins = num_bins 77 | self.num_iters = num_iters 78 | self.register_buffer('stddevs', torch.FloatTensor(stddevs)) 79 | self.register_buffer('approx_flag', self.stddevs <= std_threshold) 80 | self._precompute_histograms() 81 | 82 | @staticmethod 83 | def _pdf(x, e, L): 84 | """ 85 | Args: 86 | x: (N, ) 87 | e: Float 88 | L: Integer 89 | """ 90 | x = x[:, None] # (N, *) 91 | c = ((1 - torch.cos(x)) / math.pi) # (N, *) 92 | l = torch.arange(0, L)[None, :] # (*, L) 93 | a = (2*l+1) * torch.exp(-l*(l+1)*(e**2)) # (*, L) 94 | b = (torch.sin( (l+0.5)* x ) + 1e-6) / (torch.sin( x / 2 ) + 1e-6) # (N, L) 95 | 96 | f = (c * a * b).sum(dim=1) 97 | return f 98 | 99 | def _precompute_histograms(self): 100 | X, Y = [], [] 101 | for std in self.stddevs: 102 | std = std.item() 103 | x = torch.linspace(0, math.pi, self.num_bins) # (n_bins,) 104 | y = self._pdf(x, std, self.num_iters) # (n_bins,) 105 | y = torch.nan_to_num(y).clamp_min(0) 106 | X.append(x) 107 | Y.append(y) 108 | self.register_buffer('X', torch.stack(X, dim=0)) # (n_stddevs, n_bins) 109 | self.register_buffer('Y', torch.stack(Y, dim=0)) # (n_stddevs, n_bins) 110 | 111 | def sample(self, std_idx): 112 | """ 113 | Args: 114 | std_idx: Indices of standard deviation. 115 | Returns: 116 | samples: Angular samples [0, PI), same size as std. 117 | """ 118 | size = std_idx.size() 119 | std_idx = std_idx.flatten() # (N,) 120 | 121 | # Samples from histogram 122 | prob = self.Y[std_idx] # (N, n_bins) 123 | bin_idx = torch.multinomial(prob[:, :-1], num_samples=1).squeeze(-1) # (N,) 124 | bin_start = self.X[std_idx, bin_idx] # (N,) 125 | bin_width = self.X[std_idx, bin_idx+1] - self.X[std_idx, bin_idx] 126 | samples_hist = bin_start + torch.rand_like(bin_start) * bin_width # (N,) 127 | 128 | # Samples from Gaussian approximation 129 | mean_gaussian = self.stddevs[std_idx]*2 130 | std_gaussian = self.stddevs[std_idx] 131 | samples_gaussian = mean_gaussian + torch.randn_like(mean_gaussian) * std_gaussian 132 | samples_gaussian = samples_gaussian.abs() % math.pi 133 | 134 | # Choose from histogram or Gaussian 135 | gaussian_flag = self.approx_flag[std_idx] 136 | samples = torch.where(gaussian_flag, samples_gaussian, samples_hist) 137 | 138 | return samples.reshape(size) 139 | 140 | 141 | def random_normal_so3(std_idx, angular_distrib, device='cpu'): 142 | size = std_idx.size() 143 | u = F.normalize(torch.randn(list(size)+[3,], device=device), dim=-1) 144 | theta = angular_distrib.sample(std_idx) 145 | w = u * theta[..., None] 146 | return w 147 | -------------------------------------------------------------------------------- /diffab/modules/common/structure.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, Linear, LayerNorm, Sequential, ReLU 3 | 4 | from ..common.geometry import compose_rotation_and_translation, quaternion_to_rotation_matrix, repr_6d_to_rotation_matrix 5 | 6 | 7 | class FrameRotationTranslationPrediction(Module): 8 | 9 | def __init__(self, feat_dim, rot_repr, nn_type='mlp'): 10 | super().__init__() 11 | assert rot_repr in ('quaternion', '6d') 12 | self.rot_repr = rot_repr 13 | if rot_repr == 'quaternion': 14 | out_dim = 3 + 3 15 | elif rot_repr == '6d': 16 | out_dim = 6 + 3 17 | 18 | if nn_type == 'linear': 19 | self.nn = Linear(feat_dim, out_dim) 20 | elif nn_type == 'mlp': 21 | self.nn = Sequential( 22 | Linear(feat_dim, feat_dim), ReLU(), 23 | Linear(feat_dim, feat_dim), ReLU(), 24 | Linear(feat_dim, out_dim) 25 | ) 26 | else: 27 | raise ValueError('Unknown nn_type: %s' % nn_type) 28 | 29 | def forward(self, x): 30 | y = self.nn(x) # (..., d+3) 31 | if self.rot_repr == 'quaternion': 32 | quaternion = torch.cat([torch.ones_like(y[..., :1]), y[..., 0:3]], dim=-1) 33 | R_delta = quaternion_to_rotation_matrix(quaternion) 34 | t_delta = y[..., 3:6] 35 | return R_delta, t_delta 36 | elif self.rot_repr == '6d': 37 | R_delta = repr_6d_to_rotation_matrix(y[..., 0:6]) 38 | t_delta = y[..., 6:9] 39 | return R_delta, t_delta 40 | 41 | 42 | class FrameUpdate(Module): 43 | 44 | def __init__(self, node_feat_dim, rot_repr='quaternion', rot_tran_nn_type='mlp'): 45 | super().__init__() 46 | self.transition_mlp = Sequential( 47 | Linear(node_feat_dim, node_feat_dim), ReLU(), 48 | Linear(node_feat_dim, node_feat_dim), ReLU(), 49 | Linear(node_feat_dim, node_feat_dim), 50 | ) 51 | self.transition_layer_norm = LayerNorm(node_feat_dim) 52 | 53 | self.rot_tran = FrameRotationTranslationPrediction(node_feat_dim, rot_repr, nn_type=rot_tran_nn_type) 54 | 55 | def forward(self, R, t, x, mask_generate): 56 | """ 57 | Args: 58 | R: Frame basis matrices, (N, L, 3, 3_index). 59 | t: Frame external (absolute) coordinates, (N, L, 3). Unit: Angstrom. 60 | x: Node-wise features, (N, L, F). 61 | mask_generate: Masks, (N, L). 62 | Returns: 63 | R': Updated basis matrices, (N, L, 3, 3_index). 64 | t': Updated coordinates, (N, L, 3). 65 | """ 66 | x = self.transition_layer_norm(x + self.transition_mlp(x)) 67 | 68 | R_delta, t_delta = self.rot_tran(x) # (N, L, 3, 3), (N, L, 3) 69 | R_new, t_new = compose_rotation_and_translation(R, t, R_delta, t_delta) 70 | 71 | mask_R = mask_generate[:, :, None, None].expand_as(R) 72 | mask_t = mask_generate[:, :, None].expand_as(t) 73 | 74 | R_new = torch.where(mask_R, R_new, R) 75 | t_new = torch.where(mask_t, t_new, t) 76 | 77 | return R_new, t_new 78 | -------------------------------------------------------------------------------- /diffab/modules/common/topology.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def get_consecutive_flag(chain_nb, res_nb, mask): 6 | """ 7 | Args: 8 | chain_nb, res_nb 9 | Returns: 10 | consec: A flag tensor indicating whether residue-i is connected to residue-(i+1), 11 | BoolTensor, (B, L-1)[b, i]. 12 | """ 13 | d_res_nb = (res_nb[:, 1:] - res_nb[:, :-1]).abs() # (B, L-1) 14 | same_chain = (chain_nb[:, 1:] == chain_nb[:, :-1]) 15 | consec = torch.logical_and(d_res_nb == 1, same_chain) 16 | consec = torch.logical_and(consec, mask[:, :-1]) 17 | return consec 18 | 19 | 20 | def get_terminus_flag(chain_nb, res_nb, mask): 21 | consec = get_consecutive_flag(chain_nb, res_nb, mask) 22 | N_term_flag = F.pad(torch.logical_not(consec), pad=(1, 0), value=1) 23 | C_term_flag = F.pad(torch.logical_not(consec), pad=(0, 1), value=1) 24 | return N_term_flag, C_term_flag 25 | -------------------------------------------------------------------------------- /diffab/modules/diffusion/transition.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from diffab.modules.common.layers import clampped_one_hot 7 | from diffab.modules.common.so3 import ApproxAngularDistribution, random_normal_so3, so3vec_to_rotation, rotation_to_so3vec 8 | 9 | 10 | class VarianceSchedule(nn.Module): 11 | 12 | def __init__(self, num_steps=100, s=0.01): 13 | super().__init__() 14 | T = num_steps 15 | t = torch.arange(0, num_steps+1, dtype=torch.float) 16 | f_t = torch.cos( (np.pi / 2) * ((t/T) + s) / (1 + s) ) ** 2 17 | alpha_bars = f_t / f_t[0] 18 | 19 | betas = 1 - (alpha_bars[1:] / alpha_bars[:-1]) 20 | betas = torch.cat([torch.zeros([1]), betas], dim=0) 21 | betas = betas.clamp_max(0.999) 22 | 23 | sigmas = torch.zeros_like(betas) 24 | for i in range(1, betas.size(0)): 25 | sigmas[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i] 26 | sigmas = torch.sqrt(sigmas) 27 | 28 | self.register_buffer('betas', betas) 29 | self.register_buffer('alpha_bars', alpha_bars) 30 | self.register_buffer('alphas', 1 - betas) 31 | self.register_buffer('sigmas', sigmas) 32 | 33 | 34 | class PositionTransition(nn.Module): 35 | 36 | def __init__(self, num_steps, var_sched_opt={}): 37 | super().__init__() 38 | self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) 39 | 40 | def add_noise(self, p_0, mask_generate, t): 41 | """ 42 | Args: 43 | p_0: (N, L, 3). 44 | mask_generate: (N, L). 45 | t: (N,). 46 | """ 47 | alpha_bar = self.var_sched.alpha_bars[t] 48 | 49 | c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) 50 | c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) 51 | 52 | e_rand = torch.randn_like(p_0) 53 | p_noisy = c0*p_0 + c1*e_rand 54 | p_noisy = torch.where(mask_generate[..., None].expand_as(p_0), p_noisy, p_0) 55 | 56 | return p_noisy, e_rand 57 | 58 | def denoise(self, p_t, eps_p, mask_generate, t): 59 | # IMPORTANT: 60 | # clampping alpha is to fix the instability issue at the first step (t=T) 61 | # it seems like a problem with the ``improved ddpm''. 62 | alpha = self.var_sched.alphas[t].clamp_min( 63 | self.var_sched.alphas[-2] 64 | ) 65 | alpha_bar = self.var_sched.alpha_bars[t] 66 | sigma = self.var_sched.sigmas[t].view(-1, 1, 1) 67 | 68 | c0 = ( 1.0 / torch.sqrt(alpha + 1e-8) ).view(-1, 1, 1) 69 | c1 = ( (1 - alpha) / torch.sqrt(1 - alpha_bar + 1e-8) ).view(-1, 1, 1) 70 | 71 | z = torch.where( 72 | (t > 1)[:, None, None].expand_as(p_t), 73 | torch.randn_like(p_t), 74 | torch.zeros_like(p_t), 75 | ) 76 | 77 | p_next = c0 * (p_t - c1 * eps_p) + sigma * z 78 | p_next = torch.where(mask_generate[..., None].expand_as(p_t), p_next, p_t) 79 | return p_next 80 | 81 | 82 | class RotationTransition(nn.Module): 83 | 84 | def __init__(self, num_steps, var_sched_opt={}, angular_distrib_fwd_opt={}, angular_distrib_inv_opt={}): 85 | super().__init__() 86 | self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) 87 | 88 | # Forward (perturb) 89 | c1 = torch.sqrt(1 - self.var_sched.alpha_bars) # (T,). 90 | self.angular_distrib_fwd = ApproxAngularDistribution(c1.tolist(), **angular_distrib_fwd_opt) 91 | 92 | # Inverse (generate) 93 | sigma = self.var_sched.sigmas 94 | self.angular_distrib_inv = ApproxAngularDistribution(sigma.tolist(), **angular_distrib_inv_opt) 95 | 96 | self.register_buffer('_dummy', torch.empty([0, ])) 97 | 98 | def add_noise(self, v_0, mask_generate, t): 99 | """ 100 | Args: 101 | v_0: (N, L, 3). 102 | mask_generate: (N, L). 103 | t: (N,). 104 | """ 105 | N, L = mask_generate.size() 106 | alpha_bar = self.var_sched.alpha_bars[t] 107 | c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) 108 | c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) 109 | 110 | # Noise rotation 111 | e_scaled = random_normal_so3(t[:, None].expand(N, L), self.angular_distrib_fwd, device=self._dummy.device) # (N, L, 3) 112 | e_normal = e_scaled / (c1 + 1e-8) 113 | E_scaled = so3vec_to_rotation(e_scaled) # (N, L, 3, 3) 114 | 115 | # Scaled true rotation 116 | R0_scaled = so3vec_to_rotation(c0 * v_0) # (N, L, 3, 3) 117 | 118 | R_noisy = E_scaled @ R0_scaled 119 | v_noisy = rotation_to_so3vec(R_noisy) 120 | v_noisy = torch.where(mask_generate[..., None].expand_as(v_0), v_noisy, v_0) 121 | 122 | return v_noisy, e_scaled 123 | 124 | def denoise(self, v_t, v_next, mask_generate, t): 125 | N, L = mask_generate.size() 126 | e = random_normal_so3(t[:, None].expand(N, L), self.angular_distrib_inv, device=self._dummy.device) # (N, L, 3) 127 | e = torch.where( 128 | (t > 1)[:, None, None].expand(N, L, 3), 129 | e, 130 | torch.zeros_like(e) # Simply denoise and don't add noise at the last step 131 | ) 132 | E = so3vec_to_rotation(e) 133 | 134 | R_next = E @ so3vec_to_rotation(v_next) 135 | v_next = rotation_to_so3vec(R_next) 136 | v_next = torch.where(mask_generate[..., None].expand_as(v_next), v_next, v_t) 137 | 138 | return v_next 139 | 140 | 141 | class AminoacidCategoricalTransition(nn.Module): 142 | 143 | def __init__(self, num_steps, num_classes=20, var_sched_opt={}): 144 | super().__init__() 145 | self.num_classes = num_classes 146 | self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) 147 | 148 | @staticmethod 149 | def _sample(c): 150 | """ 151 | Args: 152 | c: (N, L, K). 153 | Returns: 154 | x: (N, L). 155 | """ 156 | N, L, K = c.size() 157 | c = c.view(N*L, K) + 1e-8 158 | x = torch.multinomial(c, 1).view(N, L) 159 | return x 160 | 161 | def add_noise(self, x_0, mask_generate, t): 162 | """ 163 | Args: 164 | x_0: (N, L) 165 | mask_generate: (N, L). 166 | t: (N,). 167 | Returns: 168 | c_t: Probability, (N, L, K). 169 | x_t: Sample, LongTensor, (N, L). 170 | """ 171 | N, L = x_0.size() 172 | K = self.num_classes 173 | c_0 = clampped_one_hot(x_0, num_classes=K).float() # (N, L, K). 174 | alpha_bar = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1) 175 | c_noisy = (alpha_bar*c_0) + ( (1-alpha_bar)/K ) 176 | c_t = torch.where(mask_generate[..., None].expand(N,L,K), c_noisy, c_0) 177 | x_t = self._sample(c_t) 178 | return c_t, x_t 179 | 180 | def posterior(self, x_t, x_0, t): 181 | """ 182 | Args: 183 | x_t: Category LongTensor (N, L) or Probability FloatTensor (N, L, K). 184 | x_0: Category LongTensor (N, L) or Probability FloatTensor (N, L, K). 185 | t: (N,). 186 | Returns: 187 | theta: Posterior probability at (t-1)-th step, (N, L, K). 188 | """ 189 | K = self.num_classes 190 | 191 | if x_t.dim() == 3: 192 | c_t = x_t # When x_t is probability distribution. 193 | else: 194 | c_t = clampped_one_hot(x_t, num_classes=K).float() # (N, L, K) 195 | 196 | if x_0.dim() == 3: 197 | c_0 = x_0 # When x_0 is probability distribution. 198 | else: 199 | c_0 = clampped_one_hot(x_0, num_classes=K).float() # (N, L, K) 200 | 201 | alpha = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1) 202 | alpha_bar = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1) 203 | 204 | theta = ((alpha*c_t) + (1-alpha)/K) * ((alpha_bar*c_0) + (1-alpha_bar)/K) # (N, L, K) 205 | theta = theta / (theta.sum(dim=-1, keepdim=True) + 1e-8) 206 | return theta 207 | 208 | def denoise(self, x_t, c_0_pred, mask_generate, t): 209 | """ 210 | Args: 211 | x_t: (N, L). 212 | c_0_pred: Normalized probability predicted by networks, (N, L, K). 213 | mask_generate: (N, L). 214 | t: (N,). 215 | Returns: 216 | post: Posterior probability at (t-1)-th step, (N, L, K). 217 | x_next: Sample at (t-1)-th step, LongTensor, (N, L). 218 | """ 219 | c_t = clampped_one_hot(x_t, num_classes=self.num_classes).float() # (N, L, K) 220 | post = self.posterior(c_t, c_0_pred, t=t) # (N, L, K) 221 | post = torch.where(mask_generate[..., None].expand(post.size()), post, c_t) 222 | x_next = self._sample(post) 223 | return post, x_next 224 | -------------------------------------------------------------------------------- /diffab/modules/encoders/ga.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from diffab.modules.common.geometry import global_to_local, local_to_global, normalize_vector, construct_3d_basis, angstrom_to_nm 7 | from diffab.modules.common.layers import mask_zero, LayerNorm 8 | from diffab.utils.protein.constants import BBHeavyAtom 9 | 10 | 11 | def _alpha_from_logits(logits, mask, inf=1e5): 12 | """ 13 | Args: 14 | logits: Logit matrices, (N, L_i, L_j, num_heads). 15 | mask: Masks, (N, L). 16 | Returns: 17 | alpha: Attention weights. 18 | """ 19 | N, L, _, _ = logits.size() 20 | mask_row = mask.view(N, L, 1, 1).expand_as(logits) # (N, L, *, *) 21 | mask_pair = mask_row * mask_row.permute(0, 2, 1, 3) # (N, L, L, *) 22 | 23 | logits = torch.where(mask_pair, logits, logits - inf) 24 | alpha = torch.softmax(logits, dim=2) # (N, L, L, num_heads) 25 | alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha)) 26 | return alpha 27 | 28 | 29 | def _heads(x, n_heads, n_ch): 30 | """ 31 | Args: 32 | x: (..., num_heads * num_channels) 33 | Returns: 34 | (..., num_heads, num_channels) 35 | """ 36 | s = list(x.size())[:-1] + [n_heads, n_ch] 37 | return x.view(*s) 38 | 39 | 40 | class GABlock(nn.Module): 41 | 42 | def __init__(self, node_feat_dim, pair_feat_dim, value_dim=32, query_key_dim=32, num_query_points=8, 43 | num_value_points=8, num_heads=12, bias=False): 44 | super().__init__() 45 | self.node_feat_dim = node_feat_dim 46 | self.pair_feat_dim = pair_feat_dim 47 | self.value_dim = value_dim 48 | self.query_key_dim = query_key_dim 49 | self.num_query_points = num_query_points 50 | self.num_value_points = num_value_points 51 | self.num_heads = num_heads 52 | 53 | # Node 54 | self.proj_query = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias) 55 | self.proj_key = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias) 56 | self.proj_value = nn.Linear(node_feat_dim, value_dim * num_heads, bias=bias) 57 | 58 | # Pair 59 | self.proj_pair_bias = nn.Linear(pair_feat_dim, num_heads, bias=bias) 60 | 61 | # Spatial 62 | self.spatial_coef = nn.Parameter(torch.full([1, 1, 1, self.num_heads], fill_value=np.log(np.exp(1.) - 1.)), 63 | requires_grad=True) 64 | self.proj_query_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias) 65 | self.proj_key_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias) 66 | self.proj_value_point = nn.Linear(node_feat_dim, num_value_points * num_heads * 3, bias=bias) 67 | 68 | # Output 69 | self.out_transform = nn.Linear( 70 | in_features=(num_heads * pair_feat_dim) + (num_heads * value_dim) + ( 71 | num_heads * num_value_points * (3 + 3 + 1)), 72 | out_features=node_feat_dim, 73 | ) 74 | 75 | self.layer_norm_1 = LayerNorm(node_feat_dim) 76 | self.mlp_transition = nn.Sequential(nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(), 77 | nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(), 78 | nn.Linear(node_feat_dim, node_feat_dim)) 79 | self.layer_norm_2 = LayerNorm(node_feat_dim) 80 | 81 | def _node_logits(self, x): 82 | query_l = _heads(self.proj_query(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch) 83 | key_l = _heads(self.proj_key(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch) 84 | logits_node = (query_l.unsqueeze(2) * key_l.unsqueeze(1) * 85 | (1 / np.sqrt(self.query_key_dim))).sum(-1) # (N, L, L, num_heads) 86 | return logits_node 87 | 88 | def _pair_logits(self, z): 89 | logits_pair = self.proj_pair_bias(z) 90 | return logits_pair 91 | 92 | def _spatial_logits(self, R, t, x): 93 | N, L, _ = t.size() 94 | 95 | # Query 96 | query_points = _heads(self.proj_query_point(x), self.num_heads * self.num_query_points, 97 | 3) # (N, L, n_heads * n_pnts, 3) 98 | query_points = local_to_global(R, t, query_points) # Global query coordinates, (N, L, n_heads * n_pnts, 3) 99 | query_s = query_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3) 100 | 101 | # Key 102 | key_points = _heads(self.proj_key_point(x), self.num_heads * self.num_query_points, 103 | 3) # (N, L, 3, n_heads * n_pnts) 104 | key_points = local_to_global(R, t, key_points) # Global key coordinates, (N, L, n_heads * n_pnts, 3) 105 | key_s = key_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3) 106 | 107 | # Q-K Product 108 | sum_sq_dist = ((query_s.unsqueeze(2) - key_s.unsqueeze(1)) ** 2).sum(-1) # (N, L, L, n_heads) 109 | gamma = F.softplus(self.spatial_coef) 110 | logits_spatial = sum_sq_dist * ((-1 * gamma * np.sqrt(2 / (9 * self.num_query_points))) 111 | / 2) # (N, L, L, n_heads) 112 | return logits_spatial 113 | 114 | def _pair_aggregation(self, alpha, z): 115 | N, L = z.shape[:2] 116 | feat_p2n = alpha.unsqueeze(-1) * z.unsqueeze(-2) # (N, L, L, n_heads, C) 117 | feat_p2n = feat_p2n.sum(dim=2) # (N, L, n_heads, C) 118 | return feat_p2n.reshape(N, L, -1) 119 | 120 | def _node_aggregation(self, alpha, x): 121 | N, L = x.shape[:2] 122 | value_l = _heads(self.proj_value(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, v_ch) 123 | feat_node = alpha.unsqueeze(-1) * value_l.unsqueeze(1) # (N, L, L, n_heads, *) @ (N, *, L, n_heads, v_ch) 124 | feat_node = feat_node.sum(dim=2) # (N, L, n_heads, v_ch) 125 | return feat_node.reshape(N, L, -1) 126 | 127 | def _spatial_aggregation(self, alpha, R, t, x): 128 | N, L, _ = t.size() 129 | value_points = _heads(self.proj_value_point(x), self.num_heads * self.num_value_points, 130 | 3) # (N, L, n_heads * n_v_pnts, 3) 131 | value_points = local_to_global(R, t, value_points.reshape(N, L, self.num_heads, self.num_value_points, 132 | 3)) # (N, L, n_heads, n_v_pnts, 3) 133 | aggr_points = alpha.reshape(N, L, L, self.num_heads, 1, 1) * \ 134 | value_points.unsqueeze(1) # (N, *, L, n_heads, n_pnts, 3) 135 | aggr_points = aggr_points.sum(dim=2) # (N, L, n_heads, n_pnts, 3) 136 | 137 | feat_points = global_to_local(R, t, aggr_points) # (N, L, n_heads, n_pnts, 3) 138 | feat_distance = feat_points.norm(dim=-1) # (N, L, n_heads, n_pnts) 139 | feat_direction = normalize_vector(feat_points, dim=-1, eps=1e-4) # (N, L, n_heads, n_pnts, 3) 140 | 141 | feat_spatial = torch.cat([ 142 | feat_points.reshape(N, L, -1), 143 | feat_distance.reshape(N, L, -1), 144 | feat_direction.reshape(N, L, -1), 145 | ], dim=-1) 146 | 147 | return feat_spatial 148 | 149 | def forward(self, R, t, x, z, mask): 150 | """ 151 | Args: 152 | R: Frame basis matrices, (N, L, 3, 3_index). 153 | t: Frame external (absolute) coordinates, (N, L, 3). 154 | x: Node-wise features, (N, L, F). 155 | z: Pair-wise features, (N, L, L, C). 156 | mask: Masks, (N, L). 157 | Returns: 158 | x': Updated node-wise features, (N, L, F). 159 | """ 160 | # Attention logits 161 | logits_node = self._node_logits(x) 162 | logits_pair = self._pair_logits(z) 163 | logits_spatial = self._spatial_logits(R, t, x) 164 | # Summing logits up and apply `softmax`. 165 | logits_sum = logits_node + logits_pair + logits_spatial 166 | alpha = _alpha_from_logits(logits_sum * np.sqrt(1 / 3), mask) # (N, L, L, n_heads) 167 | 168 | # Aggregate features 169 | feat_p2n = self._pair_aggregation(alpha, z) 170 | feat_node = self._node_aggregation(alpha, x) 171 | feat_spatial = self._spatial_aggregation(alpha, R, t, x) 172 | 173 | # Finally 174 | feat_all = self.out_transform(torch.cat([feat_p2n, feat_node, feat_spatial], dim=-1)) # (N, L, F) 175 | feat_all = mask_zero(mask.unsqueeze(-1), feat_all) 176 | x_updated = self.layer_norm_1(x + feat_all) 177 | x_updated = self.layer_norm_2(x_updated + self.mlp_transition(x_updated)) 178 | return x_updated 179 | 180 | 181 | class GAEncoder(nn.Module): 182 | 183 | def __init__(self, node_feat_dim, pair_feat_dim, num_layers, ga_block_opt={}): 184 | super(GAEncoder, self).__init__() 185 | self.blocks = nn.ModuleList([ 186 | GABlock(node_feat_dim, pair_feat_dim, **ga_block_opt) 187 | for _ in range(num_layers) 188 | ]) 189 | 190 | def forward(self, R, t, res_feat, pair_feat, mask): 191 | for i, block in enumerate(self.blocks): 192 | res_feat = block(R, t, res_feat, pair_feat, mask) 193 | return res_feat 194 | -------------------------------------------------------------------------------- /diffab/modules/encoders/pair.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from diffab.modules.common.geometry import angstrom_to_nm, pairwise_dihedrals 6 | from diffab.modules.common.layers import AngularEncoding 7 | from diffab.utils.protein.constants import BBHeavyAtom, AA 8 | 9 | 10 | class PairEmbedding(nn.Module): 11 | 12 | def __init__(self, feat_dim, max_num_atoms, max_aa_types=22, max_relpos=32): 13 | super().__init__() 14 | self.max_num_atoms = max_num_atoms 15 | self.max_aa_types = max_aa_types 16 | self.max_relpos = max_relpos 17 | self.aa_pair_embed = nn.Embedding(self.max_aa_types*self.max_aa_types, feat_dim) 18 | self.relpos_embed = nn.Embedding(2*max_relpos+1, feat_dim) 19 | 20 | self.aapair_to_distcoef = nn.Embedding(self.max_aa_types*self.max_aa_types, max_num_atoms*max_num_atoms) 21 | nn.init.zeros_(self.aapair_to_distcoef.weight) 22 | self.distance_embed = nn.Sequential( 23 | nn.Linear(max_num_atoms*max_num_atoms, feat_dim), nn.ReLU(), 24 | nn.Linear(feat_dim, feat_dim), nn.ReLU(), 25 | ) 26 | 27 | self.dihedral_embed = AngularEncoding() 28 | feat_dihed_dim = self.dihedral_embed.get_out_dim(2) # Phi and Psi 29 | 30 | infeat_dim = feat_dim+feat_dim+feat_dim+feat_dihed_dim 31 | self.out_mlp = nn.Sequential( 32 | nn.Linear(infeat_dim, feat_dim), nn.ReLU(), 33 | nn.Linear(feat_dim, feat_dim), nn.ReLU(), 34 | nn.Linear(feat_dim, feat_dim), 35 | ) 36 | 37 | def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, structure_mask=None, sequence_mask=None): 38 | """ 39 | Args: 40 | aa: (N, L). 41 | res_nb: (N, L). 42 | chain_nb: (N, L). 43 | pos_atoms: (N, L, A, 3) 44 | mask_atoms: (N, L, A) 45 | structure_mask: (N, L) 46 | sequence_mask: (N, L), mask out unknown amino acids to generate. 47 | 48 | Returns: 49 | (N, L, L, feat_dim) 50 | """ 51 | N, L = aa.size() 52 | 53 | # Remove other atoms 54 | pos_atoms = pos_atoms[:, :, :self.max_num_atoms] 55 | mask_atoms = mask_atoms[:, :, :self.max_num_atoms] 56 | 57 | mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L) 58 | mask_pair = mask_residue[:, :, None] * mask_residue[:, None, :] 59 | pair_structure_mask = structure_mask[:, :, None] * structure_mask[:, None, :] if structure_mask is not None else None 60 | 61 | # Pair identities 62 | if sequence_mask is not None: 63 | # Avoid data leakage at training time 64 | aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK)) 65 | aa_pair = aa[:,:,None]*self.max_aa_types + aa[:,None,:] # (N, L, L) 66 | feat_aapair = self.aa_pair_embed(aa_pair) 67 | 68 | # Relative sequential positions 69 | same_chain = (chain_nb[:, :, None] == chain_nb[:, None, :]) 70 | relpos = torch.clamp( 71 | res_nb[:,:,None] - res_nb[:,None,:], 72 | min=-self.max_relpos, max=self.max_relpos, 73 | ) # (N, L, L) 74 | feat_relpos = self.relpos_embed(relpos + self.max_relpos) * same_chain[:,:,:,None] 75 | 76 | # Distances 77 | d = angstrom_to_nm(torch.linalg.norm( 78 | pos_atoms[:,:,None,:,None] - pos_atoms[:,None,:,None,:], 79 | dim = -1, ord = 2, 80 | )).reshape(N, L, L, -1) # (N, L, L, A*A) 81 | c = F.softplus(self.aapair_to_distcoef(aa_pair)) # (N, L, L, A*A) 82 | d_gauss = torch.exp(-1 * c * d**2) 83 | mask_atom_pair = (mask_atoms[:,:,None,:,None] * mask_atoms[:,None,:,None,:]).reshape(N, L, L, -1) 84 | feat_dist = self.distance_embed(d_gauss * mask_atom_pair) 85 | if pair_structure_mask is not None: 86 | # Avoid data leakage at training time 87 | feat_dist = feat_dist * pair_structure_mask[:, :, :, None] 88 | 89 | # Orientations 90 | dihed = pairwise_dihedrals(pos_atoms) # (N, L, L, 2) 91 | feat_dihed = self.dihedral_embed(dihed) 92 | if pair_structure_mask is not None: 93 | # Avoid data leakage at training time 94 | feat_dihed = feat_dihed * pair_structure_mask[:, :, :, None] 95 | 96 | # All 97 | feat_all = torch.cat([feat_aapair, feat_relpos, feat_dist, feat_dihed], dim=-1) 98 | feat_all = self.out_mlp(feat_all) # (N, L, L, F) 99 | feat_all = feat_all * mask_pair[:, :, :, None] 100 | 101 | return feat_all 102 | 103 | -------------------------------------------------------------------------------- /diffab/modules/encoders/residue.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from diffab.modules.common.geometry import construct_3d_basis, global_to_local, get_backbone_dihedral_angles 5 | from diffab.modules.common.layers import AngularEncoding 6 | from diffab.utils.protein.constants import BBHeavyAtom, AA 7 | 8 | 9 | class ResidueEmbedding(nn.Module): 10 | 11 | def __init__(self, feat_dim, max_num_atoms, max_aa_types=22): 12 | super().__init__() 13 | self.max_num_atoms = max_num_atoms 14 | self.max_aa_types = max_aa_types 15 | self.aatype_embed = nn.Embedding(self.max_aa_types, feat_dim) 16 | self.dihed_embed = AngularEncoding() 17 | self.type_embed = nn.Embedding(10, feat_dim, padding_idx=0) # 1: Heavy, 2: Light, 3: Ag 18 | infeat_dim = feat_dim + (self.max_aa_types*max_num_atoms*3) + self.dihed_embed.get_out_dim(3) + feat_dim 19 | self.mlp = nn.Sequential( 20 | nn.Linear(infeat_dim, feat_dim * 2), nn.ReLU(), 21 | nn.Linear(feat_dim * 2, feat_dim), nn.ReLU(), 22 | nn.Linear(feat_dim, feat_dim), nn.ReLU(), 23 | nn.Linear(feat_dim, feat_dim) 24 | ) 25 | 26 | def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, fragment_type, structure_mask=None, sequence_mask=None): 27 | """ 28 | Args: 29 | aa: (N, L). 30 | res_nb: (N, L). 31 | chain_nb: (N, L). 32 | pos_atoms: (N, L, A, 3). 33 | mask_atoms: (N, L, A). 34 | fragment_type: (N, L). 35 | structure_mask: (N, L), mask out unknown structures to generate. 36 | sequence_mask: (N, L), mask out unknown amino acids to generate. 37 | """ 38 | N, L = aa.size() 39 | mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L) 40 | 41 | # Remove other atoms 42 | pos_atoms = pos_atoms[:, :, :self.max_num_atoms] 43 | mask_atoms = mask_atoms[:, :, :self.max_num_atoms] 44 | 45 | # Amino acid identity features 46 | if sequence_mask is not None: 47 | # Avoid data leakage at training time 48 | aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK)) 49 | aa_feat = self.aatype_embed(aa) # (N, L, feat) 50 | 51 | # Coordinate features 52 | R = construct_3d_basis( 53 | pos_atoms[:, :, BBHeavyAtom.CA], 54 | pos_atoms[:, :, BBHeavyAtom.C], 55 | pos_atoms[:, :, BBHeavyAtom.N] 56 | ) 57 | t = pos_atoms[:, :, BBHeavyAtom.CA] 58 | crd = global_to_local(R, t, pos_atoms) # (N, L, A, 3) 59 | crd_mask = mask_atoms[:, :, :, None].expand_as(crd) 60 | crd = torch.where(crd_mask, crd, torch.zeros_like(crd)) 61 | 62 | aa_expand = aa[:, :, None, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3) 63 | rng_expand = torch.arange(0, self.max_aa_types)[None, None, :, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3).to(aa_expand) 64 | place_mask = (aa_expand == rng_expand) 65 | crd_expand = crd[:, :, None, :, :].expand(N, L, self.max_aa_types, self.max_num_atoms, 3) 66 | crd_expand = torch.where(place_mask, crd_expand, torch.zeros_like(crd_expand)) 67 | crd_feat = crd_expand.reshape(N, L, self.max_aa_types*self.max_num_atoms*3) 68 | if structure_mask is not None: 69 | # Avoid data leakage at training time 70 | crd_feat = crd_feat * structure_mask[:, :, None] 71 | 72 | # Backbone dihedral features 73 | bb_dihedral, mask_bb_dihed = get_backbone_dihedral_angles(pos_atoms, chain_nb=chain_nb, res_nb=res_nb, mask=mask_residue) 74 | dihed_feat = self.dihed_embed(bb_dihedral[:, :, :, None]) * mask_bb_dihed[:, :, :, None] # (N, L, 3, dihed/3) 75 | dihed_feat = dihed_feat.reshape(N, L, -1) 76 | if structure_mask is not None: 77 | # Avoid data leakage at training time 78 | dihed_mask = torch.logical_and( 79 | structure_mask, 80 | torch.logical_and( 81 | torch.roll(structure_mask, shifts=+1, dims=1), 82 | torch.roll(structure_mask, shifts=-1, dims=1) 83 | ), 84 | ) # Avoid slight data leakage via dihedral angles of anchor residues 85 | dihed_feat = dihed_feat * dihed_mask[:, :, None] 86 | 87 | # Type feature 88 | type_feat = self.type_embed(fragment_type) # (N, L, feat) 89 | 90 | out_feat = self.mlp(torch.cat([aa_feat, crd_feat, dihed_feat, type_feat], dim=-1)) # (N, L, F) 91 | out_feat = out_feat * mask_residue[:, :, None] 92 | return out_feat 93 | -------------------------------------------------------------------------------- /diffab/tools/dock/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import List 3 | 4 | 5 | FilePath = str 6 | 7 | 8 | class DockingEngine(abc.ABC): 9 | 10 | @abc.abstractmethod 11 | def __enter__(self): 12 | pass 13 | 14 | @abc.abstractmethod 15 | def __exit__(self, typ, value, traceback): 16 | pass 17 | 18 | @abc.abstractmethod 19 | def set_receptor(self, pdb_path: FilePath): 20 | pass 21 | 22 | @abc.abstractmethod 23 | def set_ligand(self, pdb_path: FilePath): 24 | pass 25 | 26 | @abc.abstractmethod 27 | def dock(self) -> List[FilePath]: 28 | pass 29 | -------------------------------------------------------------------------------- /diffab/tools/dock/hdock.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | import subprocess 5 | import dataclasses as dc 6 | from typing import List, Optional 7 | from Bio import PDB 8 | from Bio.PDB import Model as PDBModel 9 | 10 | from diffab.tools.renumber import renumber as renumber_chothia 11 | from .base import DockingEngine 12 | 13 | 14 | def fix_docked_pdb(pdb_path): 15 | fixed = [] 16 | with open(pdb_path, 'r') as f: 17 | for ln in f.readlines(): 18 | if (ln.startswith('ATOM') or ln.startswith('HETATM')) and len(ln) == 56: 19 | fixed.append( ln[:-1] + ' 1.00 0.00 \n' ) 20 | else: 21 | fixed.append(ln) 22 | with open(pdb_path, 'w') as f: 23 | f.write(''.join(fixed)) 24 | 25 | 26 | class HDock(DockingEngine): 27 | 28 | def __init__( 29 | self, 30 | hdock_bin='./bin/hdock', 31 | createpl_bin='./bin/createpl', 32 | ): 33 | super().__init__() 34 | self.hdock_bin = os.path.realpath(hdock_bin) 35 | self.createpl_bin = os.path.realpath(createpl_bin) 36 | self.tmpdir = tempfile.TemporaryDirectory() 37 | 38 | self._has_receptor = False 39 | self._has_ligand = False 40 | 41 | self._receptor_chains = [] 42 | self._ligand_chains = [] 43 | 44 | def __enter__(self): 45 | return self 46 | 47 | def __exit__(self, typ, value, traceback): 48 | self.tmpdir.cleanup() 49 | 50 | def set_receptor(self, pdb_path): 51 | shutil.copyfile(pdb_path, os.path.join(self.tmpdir.name, 'receptor.pdb')) 52 | self._has_receptor = True 53 | 54 | def set_ligand(self, pdb_path): 55 | shutil.copyfile(pdb_path, os.path.join(self.tmpdir.name, 'ligand.pdb')) 56 | self._has_ligand = True 57 | 58 | def _dump_complex_pdb(self): 59 | parser = PDB.PDBParser(QUIET=True) 60 | model_receptor = parser.get_structure(None, os.path.join(self.tmpdir.name, 'receptor.pdb'))[0] 61 | docked_pdb_path = os.path.join(self.tmpdir.name, 'ligand_docked.pdb') 62 | fix_docked_pdb(docked_pdb_path) 63 | structure_ligdocked = parser.get_structure(None, docked_pdb_path) 64 | 65 | pdb_io = PDB.PDBIO() 66 | paths = [] 67 | for i, model_ligdocked in enumerate(structure_ligdocked): 68 | model_complex = PDBModel.Model(0) 69 | for chain in model_receptor: 70 | model_complex.add(chain.copy()) 71 | for chain in model_ligdocked: 72 | model_complex.add(chain.copy()) 73 | pdb_io.set_structure(model_complex) 74 | save_path = os.path.join(self.tmpdir.name, f"complex_{i}.pdb") 75 | pdb_io.save(save_path) 76 | paths.append(save_path) 77 | return paths 78 | 79 | def dock(self): 80 | if not (self._has_receptor and self._has_ligand): 81 | raise ValueError('Missing receptor or ligand.') 82 | subprocess.run( 83 | [self.hdock_bin, "receptor.pdb", "ligand.pdb"], 84 | cwd=self.tmpdir.name, check=True 85 | ) 86 | subprocess.run( 87 | [self.createpl_bin, "Hdock.out", "ligand_docked.pdb"], 88 | cwd=self.tmpdir.name, check=True 89 | ) 90 | return self._dump_complex_pdb() 91 | 92 | 93 | @dc.dataclass 94 | class DockSite: 95 | chain: str 96 | resseq: int 97 | 98 | 99 | class HDockAntibody(HDock): 100 | 101 | def __init__(self, *args, **kwargs): 102 | super().__init__(*args, **kwargs) 103 | self._heavy_chain_id = None 104 | self._epitope_sites: Optional[List[DockSite]] = None 105 | 106 | def set_ligand(self, pdb_path): 107 | raise NotImplementedError('Please use set_antibody') 108 | 109 | def set_receptor(self, pdb_path): 110 | raise NotImplementedError('Please use set_antigen') 111 | 112 | def set_antigen(self, pdb_path, epitope_sites: Optional[List[DockSite]]=None): 113 | super().set_receptor(pdb_path) 114 | self._epitope_sites = epitope_sites 115 | 116 | def set_antibody(self, pdb_path): 117 | heavy_chains, _ = renumber_chothia(pdb_path, os.path.join(self.tmpdir.name, 'ligand.pdb')) 118 | self._has_ligand = True 119 | self._heavy_chain_id = heavy_chains[0] 120 | 121 | def _prepare_lsite(self): 122 | lsite_content = f"95-102:{self._heavy_chain_id}\n" # Chothia CDR H3 123 | with open(os.path.join(self.tmpdir.name, 'lsite.txt'), 'w') as f: 124 | f.write(lsite_content) 125 | print(f"[INFO] lsite content: {lsite_content}") 126 | 127 | def _prepare_rsite(self): 128 | rsite_content = "" 129 | for site in self._epitope_sites: 130 | rsite_content += f"{site.resseq}:{site.chain}\n" 131 | with open(os.path.join(self.tmpdir.name, 'rsite.txt'), 'w') as f: 132 | f.write(rsite_content) 133 | print(f"[INFO] rsite content: {rsite_content}") 134 | 135 | def dock(self): 136 | if not (self._has_receptor and self._has_ligand): 137 | raise ValueError('Missing receptor or ligand.') 138 | self._prepare_lsite() 139 | 140 | cmd_hdock = [self.hdock_bin, "receptor.pdb", "ligand.pdb", "-lsite", "lsite.txt"] 141 | if self._epitope_sites is not None: 142 | self._prepare_rsite() 143 | cmd_hdock += ["-rsite", "rsite.txt"] 144 | subprocess.run( 145 | cmd_hdock, 146 | cwd=self.tmpdir.name, check=True 147 | ) 148 | 149 | cmd_pl = [self.createpl_bin, "Hdock.out", "ligand_docked.pdb", "-lsite", "lsite.txt"] 150 | if self._epitope_sites is not None: 151 | self._prepare_rsite() 152 | cmd_pl += ["-rsite", "rsite.txt"] 153 | subprocess.run( 154 | cmd_pl, 155 | cwd=self.tmpdir.name, check=True 156 | ) 157 | return self._dump_complex_pdb() 158 | 159 | 160 | if __name__ == '__main__': 161 | with HDockAntibody('hdock', 'createpl') as dock: 162 | dock.set_antigen('./data/dock/receptor.pdb', [DockSite('A', 991)]) 163 | dock.set_antibody('./data/example_dock/3qhf_fv.pdb') 164 | print(dock.dock()) 165 | -------------------------------------------------------------------------------- /diffab/tools/eval/__main__.py: -------------------------------------------------------------------------------- 1 | from .run import main 2 | 3 | if __name__ == '__main__': 4 | main() 5 | -------------------------------------------------------------------------------- /diffab/tools/eval/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import shelve 5 | from Bio import PDB 6 | from typing import Optional, Tuple, List 7 | from dataclasses import dataclass, field 8 | 9 | 10 | @dataclass 11 | class EvalTask: 12 | in_path: str 13 | ref_path: str 14 | info: dict 15 | structure: str 16 | name: str 17 | method: str 18 | cdr: str 19 | ab_chains: List 20 | 21 | residue_first: Optional[Tuple] = None 22 | residue_last: Optional[Tuple] = None 23 | 24 | scores: dict = field(default_factory=dict) 25 | 26 | def get_gen_biopython_model(self): 27 | parser = PDB.PDBParser(QUIET=True) 28 | return parser.get_structure(self.in_path, self.in_path)[0] 29 | 30 | def get_ref_biopython_model(self): 31 | parser = PDB.PDBParser(QUIET=True) 32 | return parser.get_structure(self.ref_path, self.ref_path)[0] 33 | 34 | def save_to_db(self, db: shelve.Shelf): 35 | db[self.in_path] = self 36 | 37 | def to_report_dict(self): 38 | return { 39 | 'method': self.method, 40 | 'structure': self.structure, 41 | 'cdr': self.cdr, 42 | 'filename': os.path.basename(self.in_path), 43 | **self.scores 44 | } 45 | 46 | 47 | class TaskScanner: 48 | 49 | def __init__(self, root, postfix=None, db: Optional[shelve.Shelf]=None): 50 | super().__init__() 51 | self.root = root 52 | self.postfix = postfix 53 | self.visited = set() 54 | self.db = db 55 | if db is not None: 56 | for k in db.keys(): 57 | self.visited.add(k) 58 | 59 | def _get_metadata(self, fpath): 60 | json_path = os.path.join( 61 | os.path.dirname(os.path.dirname(fpath)), 62 | 'metadata.json' 63 | ) 64 | tag_name = os.path.basename(os.path.dirname(fpath)) 65 | method_name = os.path.basename( 66 | os.path.dirname(os.path.dirname(os.path.dirname(fpath))) 67 | ) 68 | try: 69 | antibody_chains = set() 70 | info = None 71 | with open(json_path, 'r') as f: 72 | metadata = json.load(f) 73 | for item in metadata['items']: 74 | if item['tag'] == tag_name: 75 | info = item 76 | antibody_chains.add(item['residue_first'][0]) 77 | if info is not None: 78 | info['antibody_chains'] = list(antibody_chains) 79 | info['structure'] = metadata['identifier'] 80 | info['method'] = method_name 81 | return info 82 | except (json.JSONDecodeError, FileNotFoundError) as e: 83 | return None 84 | 85 | def scan(self) -> List[EvalTask]: 86 | tasks = [] 87 | if self.postfix is None or not self.postfix: 88 | input_fname_pattern = '^\d+\.pdb$' 89 | ref_fname = 'REF1.pdb' 90 | else: 91 | input_fname_pattern = f'^\d+\_{self.postfix}\.pdb$' 92 | ref_fname = f'REF1_{self.postfix}.pdb' 93 | for parent, _, files in os.walk(self.root): 94 | for fname in files: 95 | fpath = os.path.join(parent, fname) 96 | if not re.match(input_fname_pattern, fname): 97 | continue 98 | if os.path.getsize(fpath) == 0: 99 | continue 100 | if fpath in self.visited: 101 | continue 102 | 103 | # Path to the reference structure 104 | ref_path = os.path.join(parent, ref_fname) 105 | if not os.path.exists(ref_path): 106 | continue 107 | 108 | # CDR information 109 | info = self._get_metadata(fpath) 110 | if info is None: 111 | continue 112 | tasks.append(EvalTask( 113 | in_path = fpath, 114 | ref_path = ref_path, 115 | info = info, 116 | structure = info['structure'], 117 | name = info['name'], 118 | method = info['method'], 119 | cdr = info['tag'], 120 | ab_chains = info['antibody_chains'], 121 | residue_first = info.get('residue_first', None), 122 | residue_last = info.get('residue_last', None), 123 | )) 124 | self.visited.add(fpath) 125 | return tasks 126 | -------------------------------------------------------------------------------- /diffab/tools/eval/energy.py: -------------------------------------------------------------------------------- 1 | # pyright: reportMissingImports=false 2 | import pyrosetta 3 | from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover 4 | pyrosetta.init(' '.join([ 5 | '-mute', 'all', 6 | '-use_input_sc', 7 | '-ignore_unrecognized_res', 8 | '-ignore_zero_occupancy', 'false', 9 | '-load_PDB_components', 'false', 10 | '-relax:default_repeats', '2', 11 | '-no_fconfig', 12 | ])) 13 | 14 | from tools.eval.base import EvalTask 15 | 16 | 17 | def pyrosetta_interface_energy(pdb_path, interface): 18 | pose = pyrosetta.pose_from_pdb(pdb_path) 19 | mover = InterfaceAnalyzerMover(interface) 20 | mover.set_pack_separated(True) 21 | mover.apply(pose) 22 | return pose.scores['dG_separated'] 23 | 24 | 25 | def eval_interface_energy(task: EvalTask): 26 | model_gen = task.get_gen_biopython_model() 27 | antigen_chains = set() 28 | for chain in model_gen: 29 | if chain.id not in task.ab_chains: 30 | antigen_chains.add(chain.id) 31 | antigen_chains = ''.join(list(antigen_chains)) 32 | antibody_chains = ''.join(task.ab_chains) 33 | interface = f"{antibody_chains}_{antigen_chains}" 34 | 35 | dG_gen = pyrosetta_interface_energy(task.in_path, interface) 36 | dG_ref = pyrosetta_interface_energy(task.ref_path, interface) 37 | 38 | task.scores.update({ 39 | 'dG_gen': dG_gen, 40 | 'dG_ref': dG_ref, 41 | 'ddG': dG_gen - dG_ref 42 | }) 43 | return task 44 | -------------------------------------------------------------------------------- /diffab/tools/eval/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import ray 4 | import shelve 5 | import time 6 | import pandas as pd 7 | from typing import Mapping 8 | 9 | from tools.eval.base import EvalTask, TaskScanner 10 | from tools.eval.similarity import eval_similarity 11 | from tools.eval.energy import eval_interface_energy 12 | 13 | 14 | @ray.remote(num_cpus=1) 15 | def evaluate(task, args): 16 | funcs = [] 17 | funcs.append(eval_similarity) 18 | if not args.no_energy: 19 | funcs.append(eval_interface_energy) 20 | for f in funcs: 21 | task = f(task) 22 | return task 23 | 24 | 25 | def dump_db(db: Mapping[str, EvalTask], path): 26 | table = [] 27 | for task in db.values(): 28 | if 'abopt' in path and task.scores['seqid'] >= 100.0: 29 | # In abopt (Antibody Optimization) mode, ignore sequences identical to the wild-type 30 | continue 31 | table.append(task.to_report_dict()) 32 | table = pd.DataFrame(table) 33 | table.to_csv(path, index=False, float_format='%.6f') 34 | return table 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--root', type=str, default='./results') 40 | parser.add_argument('--pfx', type=str, default='rosetta') 41 | parser.add_argument('--no_energy', action='store_true', default=False) 42 | args = parser.parse_args() 43 | ray.init() 44 | 45 | db_path = os.path.join(args.root, 'evaluation_db') 46 | with shelve.open(db_path) as db: 47 | scanner = TaskScanner(root=args.root, postfix=args.pfx, db=db) 48 | 49 | while True: 50 | tasks = scanner.scan() 51 | futures = [evaluate.remote(t, args) for t in tasks] 52 | if len(futures) > 0: 53 | print(f'Submitted {len(futures)} tasks.') 54 | while len(futures) > 0: 55 | done_ids, futures = ray.wait(futures, num_returns=1) 56 | for done_id in done_ids: 57 | done_task = ray.get(done_id) 58 | done_task.save_to_db(db) 59 | print(f'Remaining {len(futures)}. Finished {done_task.in_path}') 60 | db.sync() 61 | 62 | dump_db(db, os.path.join(args.root, 'summary.csv')) 63 | time.sleep(1.0) 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /diffab/tools/eval/similarity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from Bio.PDB import PDBParser, Selection 3 | from Bio.PDB.Polypeptide import three_to_one 4 | from Bio import pairwise2 5 | from Bio.Align import substitution_matrices 6 | 7 | from diffab.tools.eval.base import EvalTask 8 | 9 | 10 | def reslist_rmsd(res_list1, res_list2): 11 | res_short, res_long = (res_list1, res_list2) if len(res_list1) < len(res_list2) else (res_list2, res_list1) 12 | M, N = len(res_short), len(res_long) 13 | 14 | def d(i, j): 15 | coord_i = np.array(res_short[i]['CA'].get_coord()) 16 | coord_j = np.array(res_long[j]['CA'].get_coord()) 17 | return ((coord_i - coord_j) ** 2).sum() 18 | 19 | SD = np.full([M, N], np.inf) 20 | for i in range(M): 21 | j = N - (M - i) 22 | SD[i, j] = sum([ d(i+k, j+k) for k in range(N-j) ]) 23 | 24 | for j in range(N): 25 | SD[M-1, j] = d(M-1, j) 26 | 27 | for i in range(M-2, -1, -1): 28 | for j in range((N-(M-i))-1, -1, -1): 29 | SD[i, j] = min( 30 | d(i, j) + SD[i+1, j+1], 31 | SD[i, j+1] 32 | ) 33 | 34 | min_SD = SD[0, :N-M+1].min() 35 | best_RMSD = np.sqrt(min_SD / M) 36 | return best_RMSD 37 | 38 | 39 | def entity_to_seq(entity): 40 | seq = '' 41 | mapping = [] 42 | for res in Selection.unfold_entities(entity, 'R'): 43 | try: 44 | seq += three_to_one(res.get_resname()) 45 | mapping.append(res.get_id()) 46 | except KeyError: 47 | pass 48 | assert len(seq) == len(mapping) 49 | return seq, mapping 50 | 51 | 52 | def reslist_seqid(res_list1, res_list2): 53 | seq1, _ = entity_to_seq(res_list1) 54 | seq2, _ = entity_to_seq(res_list2) 55 | _, seq_id = align_sequences(seq1, seq2) 56 | return seq_id 57 | 58 | 59 | def align_sequences(sequence_A, sequence_B, **kwargs): 60 | """ 61 | Performs a global pairwise alignment between two sequences 62 | using the BLOSUM62 matrix and the Needleman-Wunsch algorithm 63 | as implemented in Biopython. Returns the alignment, the sequence 64 | identity and the residue mapping between both original sequences. 65 | """ 66 | 67 | def _calculate_identity(sequenceA, sequenceB): 68 | """ 69 | Returns the percentage of identical characters between two sequences. 70 | Assumes the sequences are aligned. 71 | """ 72 | 73 | sa, sb, sl = sequenceA, sequenceB, len(sequenceA) 74 | matches = [sa[i] == sb[i] for i in range(sl)] 75 | seq_id = (100 * sum(matches)) / sl 76 | return seq_id 77 | 78 | # gapless_sl = sum([1 for i in range(sl) if (sa[i] != '-' and sb[i] != '-')]) 79 | # gap_id = (100 * sum(matches)) / gapless_sl 80 | # return (seq_id, gap_id) 81 | 82 | # 83 | matrix = kwargs.get('matrix', substitution_matrices.load("BLOSUM62")) 84 | gap_open = kwargs.get('gap_open', -10.0) 85 | gap_extend = kwargs.get('gap_extend', -0.5) 86 | 87 | alns = pairwise2.align.globalds(sequence_A, sequence_B, 88 | matrix, gap_open, gap_extend, 89 | penalize_end_gaps=(False, False) ) 90 | 91 | best_aln = alns[0] 92 | aligned_A, aligned_B, score, begin, end = best_aln 93 | 94 | # Calculate sequence identity 95 | seq_id = _calculate_identity(aligned_A, aligned_B) 96 | return (aligned_A, aligned_B), seq_id 97 | 98 | 99 | def extract_reslist(model, residue_first, residue_last): 100 | assert residue_first[0] == residue_last[0] 101 | residue_first, residue_last = tuple(residue_first), tuple(residue_last) 102 | 103 | chain_id = residue_first[0] 104 | pos_first, pos_last = residue_first[1:], residue_last[1:] 105 | chain = model[chain_id] 106 | reslist = [] 107 | for res in Selection.unfold_entities(chain, 'R'): 108 | pos_current = (res.id[1], res.id[2]) 109 | if pos_first <= pos_current <= pos_last: 110 | reslist.append(res) 111 | return reslist 112 | 113 | 114 | def eval_similarity(task: EvalTask): 115 | model_gen = task.get_gen_biopython_model() 116 | model_ref = task.get_ref_biopython_model() 117 | 118 | reslist_gen = extract_reslist(model_gen, task.residue_first, task.residue_last) 119 | reslist_ref = extract_reslist(model_ref, task.residue_first, task.residue_last) 120 | 121 | task.scores.update({ 122 | 'rmsd': reslist_rmsd(reslist_gen, reslist_ref), 123 | 'seqid': reslist_seqid(reslist_gen, reslist_ref), 124 | }) 125 | return task 126 | -------------------------------------------------------------------------------- /diffab/tools/relax/__main__.py: -------------------------------------------------------------------------------- 1 | from .run import main 2 | 3 | if __name__ == '__main__': 4 | main() 5 | -------------------------------------------------------------------------------- /diffab/tools/relax/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | from typing import Optional, Tuple, List 5 | from dataclasses import dataclass 6 | 7 | 8 | @dataclass 9 | class RelaxTask: 10 | in_path: str 11 | current_path: str 12 | info: dict 13 | status: str 14 | 15 | flexible_residue_first: Optional[Tuple] = None 16 | flexible_residue_last: Optional[Tuple] = None 17 | 18 | def get_in_path_with_tag(self, tag): 19 | name, ext = os.path.splitext(self.in_path) 20 | new_path = f'{name}_{tag}{ext}' 21 | return new_path 22 | 23 | def set_current_path_tag(self, tag): 24 | new_path = self.get_in_path_with_tag(tag) 25 | self.current_path = new_path 26 | return new_path 27 | 28 | def check_current_path_exists(self): 29 | ok = os.path.exists(self.current_path) 30 | if not ok: 31 | self.mark_failure() 32 | if os.path.getsize(self.current_path) == 0: 33 | ok = False 34 | self.mark_failure() 35 | os.unlink(self.current_path) 36 | return ok 37 | 38 | def update_if_finished(self, tag): 39 | out_path = self.get_in_path_with_tag(tag) 40 | if os.path.exists(out_path) and os.path.getsize(out_path) > 0: 41 | # print('Already finished', out_path) 42 | self.set_current_path_tag(tag) 43 | self.mark_success() 44 | return True 45 | return False 46 | 47 | def can_proceed(self): 48 | self.check_current_path_exists() 49 | return self.status != 'failed' 50 | 51 | def mark_success(self): 52 | self.status = 'success' 53 | 54 | def mark_failure(self): 55 | self.status = 'failed' 56 | 57 | 58 | 59 | class TaskScanner: 60 | 61 | def __init__(self, root, final_postfix=None): 62 | super().__init__() 63 | self.root = root 64 | self.visited = set() 65 | self.final_postfix = final_postfix 66 | 67 | def _get_metadata(self, fpath): 68 | json_path = os.path.join( 69 | os.path.dirname(os.path.dirname(fpath)), 70 | 'metadata.json' 71 | ) 72 | tag_name = os.path.basename(os.path.dirname(fpath)) 73 | try: 74 | with open(json_path, 'r') as f: 75 | metadata = json.load(f) 76 | for item in metadata['items']: 77 | if item['tag'] == tag_name: 78 | return item 79 | except (json.JSONDecodeError, FileNotFoundError) as e: 80 | return None 81 | return None 82 | 83 | def scan(self) -> List[RelaxTask]: 84 | tasks = [] 85 | input_fname_pattern = '(^\d+\.pdb$|^REF\d\.pdb$)' 86 | for parent, _, files in os.walk(self.root): 87 | for fname in files: 88 | fpath = os.path.join(parent, fname) 89 | if not re.match(input_fname_pattern, fname): 90 | continue 91 | if os.path.getsize(fpath) == 0: 92 | continue 93 | if fpath in self.visited: 94 | continue 95 | 96 | # If finished 97 | if self.final_postfix is not None: 98 | fpath_name, fpath_ext = os.path.splitext(fpath) 99 | fpath_final = f"{fpath_name}_{self.final_postfix}{fpath_ext}" 100 | if os.path.exists(fpath_final): 101 | continue 102 | 103 | # Get metadata 104 | info = self._get_metadata(fpath) 105 | if info is None: 106 | continue 107 | 108 | tasks.append(RelaxTask( 109 | in_path = fpath, 110 | current_path = fpath, 111 | info = info, 112 | status = 'created', 113 | flexible_residue_first = info.get('residue_first', None), 114 | flexible_residue_last = info.get('residue_last', None), 115 | )) 116 | self.visited.add(fpath) 117 | return tasks 118 | -------------------------------------------------------------------------------- /diffab/tools/relax/openmm_relaxer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import io 4 | import logging 5 | import pdbfixer 6 | import openmm 7 | from openmm import app as openmm_app 8 | from openmm import unit 9 | ENERGY = unit.kilocalories_per_mole 10 | LENGTH = unit.angstroms 11 | 12 | from diffab.tools.relax.base import RelaxTask 13 | 14 | 15 | def current_milli_time(): 16 | return round(time.time() * 1000) 17 | 18 | 19 | def _is_in_the_range(ch_rs_ic, flexible_residue_first, flexible_residue_last): 20 | if ch_rs_ic[0] != flexible_residue_first[0]: return False 21 | r_first, r_last = tuple(flexible_residue_first[1:]), tuple(flexible_residue_last[1:]) 22 | rs_ic = ch_rs_ic[1:] 23 | return r_first <= rs_ic <= r_last 24 | 25 | 26 | class ForceFieldMinimizer(object): 27 | 28 | def __init__(self, stiffness=10.0, max_iterations=0, tolerance=2.39*unit.kilocalories_per_mole, platform='CUDA'): 29 | super().__init__() 30 | self.stiffness = stiffness 31 | self.max_iterations = max_iterations 32 | self.tolerance = tolerance 33 | assert platform in ('CUDA', 'CPU') 34 | self.platform = platform 35 | 36 | def _fix(self, pdb_str): 37 | fixer = pdbfixer.PDBFixer(pdbfile=io.StringIO(pdb_str)) 38 | fixer.findNonstandardResidues() 39 | fixer.replaceNonstandardResidues() 40 | 41 | fixer.findMissingResidues() 42 | fixer.findMissingAtoms() 43 | fixer.addMissingAtoms(seed=0) 44 | fixer.addMissingHydrogens() 45 | 46 | out_handle = io.StringIO() 47 | openmm_app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle, keepIds=True) 48 | return out_handle.getvalue() 49 | 50 | def _get_pdb_string(self, topology, positions): 51 | with io.StringIO() as f: 52 | openmm_app.PDBFile.writeFile(topology, positions, f, keepIds=True) 53 | return f.getvalue() 54 | 55 | def _minimize(self, pdb_str, flexible_residue_first=None, flexible_residue_last=None): 56 | pdb = openmm_app.PDBFile(io.StringIO(pdb_str)) 57 | 58 | force_field = openmm_app.ForceField("amber99sb.xml") 59 | constraints = openmm_app.HBonds 60 | system = force_field.createSystem(pdb.topology, constraints=constraints) 61 | 62 | # Add constraints to non-generated regions 63 | force = openmm.CustomExternalForce("0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)") 64 | force.addGlobalParameter("k", self.stiffness) 65 | for p in ["x0", "y0", "z0"]: 66 | force.addPerParticleParameter(p) 67 | 68 | if flexible_residue_first is not None and flexible_residue_last is not None: 69 | for i, a in enumerate(pdb.topology.atoms()): 70 | ch_rs_ic = (a.residue.chain.id, int(a.residue.id), a.residue.insertionCode) 71 | if not _is_in_the_range(ch_rs_ic, flexible_residue_first, flexible_residue_last) and a.element.name != "hydrogen": 72 | force.addParticle(i, pdb.positions[i]) 73 | 74 | system.addForce(force) 75 | 76 | # Set up the integrator and simulation 77 | integrator = openmm.LangevinIntegrator(0, 0.01, 0.0) 78 | platform = openmm.Platform.getPlatformByName("CUDA") 79 | simulation = openmm_app.Simulation(pdb.topology, system, integrator, platform) 80 | simulation.context.setPositions(pdb.positions) 81 | 82 | # Perform minimization 83 | ret = {} 84 | state = simulation.context.getState(getEnergy=True, getPositions=True) 85 | ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY) 86 | ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH) 87 | 88 | simulation.minimizeEnergy(maxIterations=self.max_iterations, tolerance=self.tolerance) 89 | 90 | state = simulation.context.getState(getEnergy=True, getPositions=True) 91 | ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY) 92 | ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH) 93 | ret["min_pdb"] = self._get_pdb_string(simulation.topology, state.getPositions()) 94 | 95 | return ret['min_pdb'], ret 96 | 97 | def _add_energy_remarks(self, pdb_str, ret): 98 | pdb_lines = pdb_str.splitlines() 99 | pdb_lines.insert(1, "REMARK 1 FINAL ENERGY: {:.3f} KCAL/MOL".format(ret['efinal'])) 100 | pdb_lines.insert(1, "REMARK 1 INITIAL ENERGY: {:.3f} KCAL/MOL".format(ret['einit'])) 101 | return "\n".join(pdb_lines) 102 | 103 | def __call__(self, pdb_str, flexible_residue_first=None, flexible_residue_last=None, return_info=True): 104 | if '\n' not in pdb_str and pdb_str.lower().endswith(".pdb"): 105 | with open(pdb_str) as f: 106 | pdb_str = f.read() 107 | 108 | pdb_fixed = self._fix(pdb_str) 109 | pdb_min, ret = self._minimize(pdb_fixed, flexible_residue_first, flexible_residue_last) 110 | pdb_min = self._add_energy_remarks(pdb_min, ret) 111 | if return_info: 112 | return pdb_min, ret 113 | else: 114 | return pdb_min 115 | 116 | 117 | def run_openmm(task: RelaxTask): 118 | if not task.can_proceed() : 119 | return task 120 | if task.update_if_finished('openmm'): 121 | return task 122 | 123 | try: 124 | minimizer = ForceFieldMinimizer() 125 | with open(task.current_path, 'r') as f: 126 | pdb_str = f.read() 127 | 128 | pdb_min = minimizer( 129 | pdb_str = pdb_str, 130 | flexible_residue_first = task.flexible_residue_first, 131 | flexible_residue_last = task.flexible_residue_last, 132 | return_info = False, 133 | ) 134 | out_path = task.set_current_path_tag('openmm') 135 | with open(out_path, 'w') as f: 136 | f.write(pdb_min) 137 | task.mark_success() 138 | except ValueError as e: 139 | logging.warning( 140 | f'{e.__class__.__name__}: {str(e)} ({task.current_path})' 141 | ) 142 | task.mark_failure() 143 | return task 144 | 145 | -------------------------------------------------------------------------------- /diffab/tools/relax/pyrosetta_relaxer.py: -------------------------------------------------------------------------------- 1 | # pyright: reportMissingImports=false 2 | import os 3 | import time 4 | import pyrosetta 5 | from pyrosetta.rosetta.protocols.relax import FastRelax 6 | from pyrosetta.rosetta.core.pack.task import TaskFactory 7 | from pyrosetta.rosetta.core.pack.task import operation 8 | from pyrosetta.rosetta.core.select import residue_selector as selections 9 | from pyrosetta.rosetta.core.select.movemap import MoveMapFactory, move_map_action 10 | pyrosetta.init(' '.join([ 11 | '-mute', 'all', 12 | '-use_input_sc', 13 | '-ignore_unrecognized_res', 14 | '-ignore_zero_occupancy', 'false', 15 | '-load_PDB_components', 'false', 16 | '-relax:default_repeats', '2', 17 | '-no_fconfig', 18 | ])) 19 | 20 | from diffab.tools.relax.base import RelaxTask 21 | 22 | 23 | def current_milli_time(): 24 | return round(time.time() * 1000) 25 | 26 | 27 | def parse_residue_position(p): 28 | icode = None 29 | if not p[-1].isnumeric(): # Has ICODE 30 | icode = p[-1] 31 | 32 | for i, c in enumerate(p): 33 | if c.isnumeric(): 34 | break 35 | chain = p[:i] 36 | resseq = int(p[i:]) 37 | 38 | if icode is not None: 39 | return chain, resseq, icode 40 | else: 41 | return chain, resseq 42 | 43 | 44 | def get_scorefxn(scorefxn_name:str): 45 | """ 46 | Gets the scorefxn with appropriate corrections. 47 | Taken from: https://gist.github.com/matteoferla/b33585f3aeab58b8424581279e032550 48 | """ 49 | import pyrosetta 50 | 51 | corrections = { 52 | 'beta_july15': False, 53 | 'beta_nov16': False, 54 | 'gen_potential': False, 55 | 'restore_talaris_behavior': False, 56 | } 57 | if 'beta_july15' in scorefxn_name or 'beta_nov15' in scorefxn_name: 58 | # beta_july15 is ref2015 59 | corrections['beta_july15'] = True 60 | elif 'beta_nov16' in scorefxn_name: 61 | corrections['beta_nov16'] = True 62 | elif 'genpot' in scorefxn_name: 63 | corrections['gen_potential'] = True 64 | pyrosetta.rosetta.basic.options.set_boolean_option('corrections:beta_july15', True) 65 | elif 'talaris' in scorefxn_name: #2013 and 2014 66 | corrections['restore_talaris_behavior'] = True 67 | else: 68 | pass 69 | for corr, value in corrections.items(): 70 | pyrosetta.rosetta.basic.options.set_boolean_option(f'corrections:{corr}', value) 71 | return pyrosetta.create_score_function(scorefxn_name) 72 | 73 | 74 | class RelaxRegion(object): 75 | 76 | def __init__(self, scorefxn='ref2015', max_iter=1000, subset='nbrs', move_bb=True): 77 | super().__init__() 78 | self.scorefxn = get_scorefxn(scorefxn) 79 | self.fast_relax = FastRelax() 80 | self.fast_relax.set_scorefxn(self.scorefxn) 81 | self.fast_relax.max_iter(max_iter) 82 | assert subset in ('all', 'target', 'nbrs') 83 | self.subset = subset 84 | self.move_bb = move_bb 85 | 86 | def __call__(self, pdb_path, flexible_residue_first, flexible_residue_last): 87 | pose = pyrosetta.pose_from_pdb(pdb_path) 88 | start_t = current_milli_time() 89 | original_pose = pose.clone() 90 | 91 | tf = TaskFactory() 92 | tf.push_back(operation.InitializeFromCommandline()) 93 | tf.push_back(operation.RestrictToRepacking()) # Only allow residues to repack. No design at any position. 94 | 95 | # Create selector for the region to be relaxed 96 | # Turn off design and repacking on irrelevant positions 97 | if flexible_residue_first[-1] == ' ': 98 | flexible_residue_first = flexible_residue_first[:-1] 99 | if flexible_residue_last[-1] == ' ': 100 | flexible_residue_last = flexible_residue_last[:-1] 101 | if self.subset != 'all': 102 | gen_selector = selections.ResidueIndexSelector() 103 | gen_selector.set_index_range( 104 | pose.pdb_info().pdb2pose(*flexible_residue_first), 105 | pose.pdb_info().pdb2pose(*flexible_residue_last), 106 | ) 107 | nbr_selector = selections.NeighborhoodResidueSelector() 108 | nbr_selector.set_focus_selector(gen_selector) 109 | nbr_selector.set_include_focus_in_subset(True) 110 | 111 | if self.subset == 'nbrs': 112 | subset_selector = nbr_selector 113 | elif self.subset == 'target': 114 | subset_selector = gen_selector 115 | 116 | prevent_repacking_rlt = operation.PreventRepackingRLT() 117 | prevent_subset_repacking = operation.OperateOnResidueSubset( 118 | prevent_repacking_rlt, 119 | subset_selector, 120 | flip_subset=True, 121 | ) 122 | tf.push_back(prevent_subset_repacking) 123 | 124 | scorefxn = self.scorefxn 125 | fr = self.fast_relax 126 | 127 | pose = original_pose.clone() 128 | pos_list = pyrosetta.rosetta.utility.vector1_unsigned_long() 129 | for pos in range(pose.pdb_info().pdb2pose(*flexible_residue_first), pose.pdb_info().pdb2pose(*flexible_residue_last)+1): 130 | pos_list.append(pos) 131 | # basic_idealize(pose, pos_list, scorefxn, fast=True) 132 | 133 | mmf = MoveMapFactory() 134 | if self.move_bb: 135 | mmf.add_bb_action(move_map_action.mm_enable, gen_selector) 136 | mmf.add_chi_action(move_map_action.mm_enable, subset_selector) 137 | mm = mmf.create_movemap_from_pose(pose) 138 | 139 | fr.set_movemap(mm) 140 | fr.set_task_factory(tf) 141 | fr.apply(pose) 142 | 143 | e_before = scorefxn(original_pose) 144 | e_relax = scorefxn(pose) 145 | # print('\n\n[Finished in %.2f secs]' % ((current_milli_time() - start_t) / 1000)) 146 | # print(' > Energy (before): %.4f' % scorefxn(original_pose)) 147 | # print(' > Energy (optimized): %.4f' % scorefxn(pose)) 148 | return pose, e_before, e_relax 149 | 150 | 151 | def run_pyrosetta(task: RelaxTask): 152 | if not task.can_proceed() : 153 | return task 154 | if task.update_if_finished('rosetta'): 155 | return task 156 | 157 | minimizer = RelaxRegion() 158 | pose_min, _, _ = minimizer( 159 | pdb_path = task.current_path, 160 | flexible_residue_first = task.flexible_residue_first, 161 | flexible_residue_last = task.flexible_residue_last, 162 | ) 163 | 164 | out_path = task.set_current_path_tag('rosetta') 165 | pose_min.dump_pdb(out_path) 166 | task.mark_success() 167 | return task 168 | 169 | 170 | def run_pyrosetta_fixbb(task: RelaxTask): 171 | if not task.can_proceed() : 172 | return task 173 | if task.update_if_finished('fixbb'): 174 | return task 175 | 176 | minimizer = RelaxRegion(move_bb=False) 177 | pose_min, _, _ = minimizer( 178 | pdb_path = task.current_path, 179 | flexible_residue_first = task.flexible_residue_first, 180 | flexible_residue_last = task.flexible_residue_last, 181 | ) 182 | 183 | out_path = task.set_current_path_tag('fixbb') 184 | pose_min.dump_pdb(out_path) 185 | task.mark_success() 186 | return task 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /diffab/tools/relax/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ray 3 | import time 4 | 5 | from diffab.tools.relax.openmm_relaxer import run_openmm 6 | from diffab.tools.relax.pyrosetta_relaxer import run_pyrosetta, run_pyrosetta_fixbb 7 | from diffab.tools.relax.base import TaskScanner 8 | 9 | 10 | @ray.remote(num_gpus=1/8, num_cpus=1) 11 | def run_openmm_remote(task): 12 | return run_openmm(task) 13 | 14 | 15 | @ray.remote(num_cpus=1) 16 | def run_pyrosetta_remote(task): 17 | return run_pyrosetta(task) 18 | 19 | 20 | @ray.remote(num_cpus=1) 21 | def run_pyrosetta_fixbb_remote(task): 22 | return run_pyrosetta_fixbb(task) 23 | 24 | 25 | @ray.remote 26 | def pipeline_openmm_pyrosetta(task): 27 | funcs = [ 28 | run_openmm_remote, 29 | run_pyrosetta_remote, 30 | ] 31 | for fn in funcs: 32 | task = fn.remote(task) 33 | return ray.get(task) 34 | 35 | 36 | @ray.remote 37 | def pipeline_pyrosetta(task): 38 | funcs = [ 39 | run_pyrosetta_remote, 40 | ] 41 | for fn in funcs: 42 | task = fn.remote(task) 43 | return ray.get(task) 44 | 45 | 46 | @ray.remote 47 | def pipeline_pyrosetta_fixbb(task): 48 | funcs = [ 49 | run_pyrosetta_fixbb_remote, 50 | ] 51 | for fn in funcs: 52 | task = fn.remote(task) 53 | return ray.get(task) 54 | 55 | 56 | pipeline_dict = { 57 | 'openmm_pyrosetta': pipeline_openmm_pyrosetta, 58 | 'pyrosetta': pipeline_pyrosetta, 59 | 'pyrosetta_fixbb': pipeline_pyrosetta_fixbb, 60 | } 61 | 62 | 63 | def main(): 64 | ray.init() 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--root', type=str, default='./results') 67 | parser.add_argument('--pipeline', type=lambda s: pipeline_dict[s], default=pipeline_openmm_pyrosetta) 68 | args = parser.parse_args() 69 | 70 | final_pfx = 'fixbb' if args.pipeline == pipeline_pyrosetta_fixbb else 'rosetta' 71 | scanner = TaskScanner(args.root, final_postfix=final_pfx) 72 | while True: 73 | tasks = scanner.scan() 74 | futures = [args.pipeline.remote(t) for t in tasks] 75 | if len(futures) > 0: 76 | print(f'Submitted {len(futures)} tasks.') 77 | while len(futures) > 0: 78 | done_ids, futures = ray.wait(futures, num_returns=1) 79 | for done_id in done_ids: 80 | done_task = ray.get(done_id) 81 | print(f'Remaining {len(futures)}. Finished {done_task.current_path}') 82 | time.sleep(1.0) 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /diffab/tools/renumber/__init__.py: -------------------------------------------------------------------------------- 1 | from .run import renumber 2 | -------------------------------------------------------------------------------- /diffab/tools/renumber/__main__.py: -------------------------------------------------------------------------------- 1 | from .run import main 2 | 3 | if __name__ == '__main__': 4 | main() 5 | -------------------------------------------------------------------------------- /diffab/tools/renumber/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import abnumber 3 | from Bio import PDB 4 | from Bio.PDB import Model, Chain, Residue, Selection 5 | from Bio.Data import SCOPData 6 | from typing import List, Tuple 7 | 8 | 9 | def biopython_chain_to_sequence(chain: Chain.Chain): 10 | residue_list = Selection.unfold_entities(chain, 'R') 11 | seq = ''.join([SCOPData.protein_letters_3to1.get(r.resname, 'X') for r in residue_list]) 12 | return seq, residue_list 13 | 14 | 15 | def assign_number_to_sequence(seq): 16 | abchain = abnumber.Chain(seq, scheme='chothia') 17 | offset = seq.index(abchain.seq) 18 | if not (offset >= 0): 19 | raise ValueError( 20 | 'The identified Fv sequence is not a subsequence of the original sequence.' 21 | ) 22 | 23 | numbers = [None for _ in range(len(seq))] 24 | for i, (pos, aa) in enumerate(abchain): 25 | resseq = pos.number 26 | icode = pos.letter if pos.letter else ' ' 27 | numbers[i+offset] = (resseq, icode) 28 | return numbers, abchain 29 | 30 | 31 | def renumber_biopython_chain(chain_id, residue_list: List[Residue.Residue], numbers: List[Tuple[int, str]]): 32 | chain = Chain.Chain(chain_id) 33 | for residue, number in zip(residue_list, numbers): 34 | if number is None: 35 | continue 36 | residue = residue.copy() 37 | new_id = (residue.id[0], number[0], number[1]) 38 | residue.id = new_id 39 | chain.add(residue) 40 | return chain 41 | 42 | 43 | def renumber(in_pdb, out_pdb, return_other_chains=False): 44 | parser = PDB.PDBParser(QUIET=True) 45 | structure = parser.get_structure(None, in_pdb) 46 | model = structure[0] 47 | model_new = Model.Model(0) 48 | 49 | heavy_chains, light_chains, other_chains = [], [], [] 50 | 51 | for chain in model: 52 | try: 53 | seq, reslist = biopython_chain_to_sequence(chain) 54 | numbers, abchain = assign_number_to_sequence(seq) 55 | chain_new = renumber_biopython_chain(chain.id, reslist, numbers) 56 | print(f'[INFO] Renumbered chain {chain_new.id} ({abchain.chain_type})') 57 | if abchain.chain_type == 'H': 58 | heavy_chains.append(chain_new.id) 59 | elif abchain.chain_type in ('K', 'L'): 60 | light_chains.append(chain_new.id) 61 | except abnumber.ChainParseError as e: 62 | print(f'[INFO] Chain {chain.id} does not contain valid Fv: {str(e)}') 63 | chain_new = chain.copy() 64 | other_chains.append(chain_new.id) 65 | model_new.add(chain_new) 66 | 67 | pdb_io = PDB.PDBIO() 68 | pdb_io.set_structure(model_new) 69 | pdb_io.save(out_pdb) 70 | if return_other_chains: 71 | return heavy_chains, light_chains, other_chains 72 | else: 73 | return heavy_chains, light_chains 74 | 75 | 76 | def main(): 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('in_pdb', type=str) 79 | parser.add_argument('out_pdb', type=str) 80 | args = parser.parse_args() 81 | 82 | renumber(args.in_pdb, args.out_pdb) 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /diffab/tools/runner/design_for_pdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import copy 4 | import json 5 | from tqdm.auto import tqdm 6 | from torch.utils.data import DataLoader 7 | 8 | from diffab.datasets.custom import preprocess_antibody_structure 9 | from diffab.models import get_model 10 | from diffab.modules.common.geometry import reconstruct_backbone_partially 11 | from diffab.modules.common.so3 import so3vec_to_rotation 12 | from diffab.utils.inference import RemoveNative 13 | from diffab.utils.protein.writers import save_pdb 14 | from diffab.utils.train import recursive_to 15 | from diffab.utils.misc import * 16 | from diffab.utils.data import * 17 | from diffab.utils.transforms import * 18 | from diffab.utils.inference import * 19 | from diffab.tools.renumber import renumber as renumber_antibody 20 | 21 | 22 | def create_data_variants(config, structure_factory): 23 | structure = structure_factory() 24 | structure_id = structure['id'] 25 | 26 | data_variants = [] 27 | if config.mode == 'single_cdr': 28 | cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) 29 | for cdr_name in cdrs: 30 | transform = Compose([ 31 | MaskSingleCDR(cdr_name, augmentation=False), 32 | MergeChains(), 33 | ]) 34 | data_var = transform(structure_factory()) 35 | residue_first, residue_last = get_residue_first_last(data_var) 36 | data_variants.append({ 37 | 'data': data_var, 38 | 'name': f'{structure_id}-{cdr_name}', 39 | 'tag': f'{cdr_name}', 40 | 'cdr': cdr_name, 41 | 'residue_first': residue_first, 42 | 'residue_last': residue_last, 43 | }) 44 | elif config.mode == 'multiple_cdrs': 45 | cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) 46 | transform = Compose([ 47 | MaskMultipleCDRs(selection=cdrs, augmentation=False), 48 | MergeChains(), 49 | ]) 50 | data_var = transform(structure_factory()) 51 | data_variants.append({ 52 | 'data': data_var, 53 | 'name': f'{structure_id}-MultipleCDRs', 54 | 'tag': 'MultipleCDRs', 55 | 'cdrs': cdrs, 56 | 'residue_first': None, 57 | 'residue_last': None, 58 | }) 59 | elif config.mode == 'full': 60 | transform = Compose([ 61 | MaskAntibody(), 62 | MergeChains(), 63 | ]) 64 | data_var = transform(structure_factory()) 65 | data_variants.append({ 66 | 'data': data_var, 67 | 'name': f'{structure_id}-Full', 68 | 'tag': 'Full', 69 | 'residue_first': None, 70 | 'residue_last': None, 71 | }) 72 | elif config.mode == 'abopt': 73 | cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) 74 | for cdr_name in cdrs: 75 | transform = Compose([ 76 | MaskSingleCDR(cdr_name, augmentation=False), 77 | MergeChains(), 78 | ]) 79 | data_var = transform(structure_factory()) 80 | residue_first, residue_last = get_residue_first_last(data_var) 81 | for opt_step in config.sampling.optimize_steps: 82 | data_variants.append({ 83 | 'data': data_var, 84 | 'name': f'{structure_id}-{cdr_name}-O{opt_step}', 85 | 'tag': f'{cdr_name}-O{opt_step}', 86 | 'cdr': cdr_name, 87 | 'opt_step': opt_step, 88 | 'residue_first': residue_first, 89 | 'residue_last': residue_last, 90 | }) 91 | else: 92 | raise ValueError(f'Unknown mode: {config.mode}.') 93 | return data_variants 94 | 95 | 96 | def design_for_pdb(args): 97 | # Load configs 98 | config, config_name = load_config(args.config) 99 | seed_all(args.seed if args.seed is not None else config.sampling.seed) 100 | 101 | # Structure loading 102 | data_id = os.path.basename(args.pdb_path) 103 | if args.no_renumber: 104 | pdb_path = args.pdb_path 105 | else: 106 | in_pdb_path = args.pdb_path 107 | out_pdb_path = os.path.splitext(in_pdb_path)[0] + '_chothia.pdb' 108 | heavy_chains, light_chains = renumber_antibody(in_pdb_path, out_pdb_path) 109 | pdb_path = out_pdb_path 110 | 111 | if args.heavy is None and len(heavy_chains) > 0: 112 | args.heavy = heavy_chains[0] 113 | if args.light is None and len(light_chains) > 0: 114 | args.light = light_chains[0] 115 | if args.heavy is None and args.light is None: 116 | raise ValueError("Neither heavy chain id (--heavy) or light chain id (--light) is specified.") 117 | get_structure = lambda: preprocess_antibody_structure({ 118 | 'id': data_id, 119 | 'pdb_path': pdb_path, 120 | 'heavy_id': args.heavy, 121 | # If the input is a nanobody, the light chain will be ignores 122 | 'light_id': args.light, 123 | }) 124 | 125 | # Logging 126 | structure_ = get_structure() 127 | structure_id = structure_['id'] 128 | tag_postfix = '_%s' % args.tag if args.tag else '' 129 | log_dir = get_new_log_dir( 130 | os.path.join(args.out_root, config_name + tag_postfix), 131 | prefix=data_id 132 | ) 133 | logger = get_logger('sample', log_dir) 134 | logger.info(f'Data ID: {structure_["id"]}') 135 | logger.info(f'Results will be saved to {log_dir}') 136 | data_native = MergeChains()(structure_) 137 | save_pdb(data_native, os.path.join(log_dir, 'reference.pdb')) 138 | 139 | # Load checkpoint and model 140 | logger.info('Loading model config and checkpoints: %s' % (config.model.checkpoint)) 141 | ckpt = torch.load(config.model.checkpoint, map_location='cpu') 142 | cfg_ckpt = ckpt['config'] 143 | model = get_model(cfg_ckpt.model).to(args.device) 144 | lsd = model.load_state_dict(ckpt['model']) 145 | logger.info(str(lsd)) 146 | 147 | # Make data variants 148 | data_variants = create_data_variants( 149 | config = config, 150 | structure_factory = get_structure, 151 | ) 152 | 153 | # Save metadata 154 | metadata = { 155 | 'identifier': structure_id, 156 | 'index': data_id, 157 | 'config': args.config, 158 | 'items': [{kk: vv for kk, vv in var.items() if kk != 'data'} for var in data_variants], 159 | } 160 | with open(os.path.join(log_dir, 'metadata.json'), 'w') as f: 161 | json.dump(metadata, f, indent=2) 162 | 163 | # Start sampling 164 | collate_fn = PaddingCollate(eight=False) 165 | inference_tfm = [ PatchAroundAnchor(), ] 166 | if 'abopt' not in config.mode: # Don't remove native CDR in optimization mode 167 | inference_tfm.append(RemoveNative( 168 | remove_structure = config.sampling.sample_structure, 169 | remove_sequence = config.sampling.sample_sequence, 170 | )) 171 | inference_tfm = Compose(inference_tfm) 172 | 173 | for variant in data_variants: 174 | os.makedirs(os.path.join(log_dir, variant['tag']), exist_ok=True) 175 | logger.info(f"Start sampling for: {variant['tag']}") 176 | 177 | save_pdb(data_native, os.path.join(log_dir, variant['tag'], 'REF1.pdb')) # w/ OpenMM minimization 178 | 179 | data_cropped = inference_tfm( 180 | copy.deepcopy(variant['data']) 181 | ) 182 | data_list_repeat = [ data_cropped ] * config.sampling.num_samples 183 | loader = DataLoader(data_list_repeat, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) 184 | 185 | count = 0 186 | for batch in tqdm(loader, desc=variant['name'], dynamic_ncols=True): 187 | torch.set_grad_enabled(False) 188 | model.eval() 189 | batch = recursive_to(batch, args.device) 190 | if 'abopt' in config.mode: 191 | # Antibody optimization starting from native 192 | traj_batch = model.optimize(batch, opt_step=variant['opt_step'], optimize_opt={ 193 | 'pbar': True, 194 | 'sample_structure': config.sampling.sample_structure, 195 | 'sample_sequence': config.sampling.sample_sequence, 196 | }) 197 | else: 198 | # De novo design 199 | traj_batch = model.sample(batch, sample_opt={ 200 | 'pbar': True, 201 | 'sample_structure': config.sampling.sample_structure, 202 | 'sample_sequence': config.sampling.sample_sequence, 203 | }) 204 | 205 | aa_new = traj_batch[0][2] # 0: Last sampling step. 2: Amino acid. 206 | pos_atom_new, mask_atom_new = reconstruct_backbone_partially( 207 | pos_ctx = batch['pos_heavyatom'], 208 | R_new = so3vec_to_rotation(traj_batch[0][0]), 209 | t_new = traj_batch[0][1], 210 | aa = aa_new, 211 | chain_nb = batch['chain_nb'], 212 | res_nb = batch['res_nb'], 213 | mask_atoms = batch['mask_heavyatom'], 214 | mask_recons = batch['generate_flag'], 215 | ) 216 | aa_new = aa_new.cpu() 217 | pos_atom_new = pos_atom_new.cpu() 218 | mask_atom_new = mask_atom_new.cpu() 219 | 220 | for i in range(aa_new.size(0)): 221 | data_tmpl = variant['data'] 222 | aa = apply_patch_to_tensor(data_tmpl['aa'], aa_new[i], data_cropped['patch_idx']) 223 | mask_ha = apply_patch_to_tensor(data_tmpl['mask_heavyatom'], mask_atom_new[i], data_cropped['patch_idx']) 224 | pos_ha = ( 225 | apply_patch_to_tensor( 226 | data_tmpl['pos_heavyatom'], 227 | pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(), 228 | data_cropped['patch_idx'] 229 | ) 230 | ) 231 | 232 | save_path = os.path.join(log_dir, variant['tag'], '%04d.pdb' % (count, )) 233 | save_pdb({ 234 | 'chain_nb': data_tmpl['chain_nb'], 235 | 'chain_id': data_tmpl['chain_id'], 236 | 'resseq': data_tmpl['resseq'], 237 | 'icode': data_tmpl['icode'], 238 | # Generated 239 | 'aa': aa, 240 | 'mask_heavyatom': mask_ha, 241 | 'pos_heavyatom': pos_ha, 242 | }, path=save_path) 243 | # save_pdb({ 244 | # 'chain_nb': data_cropped['chain_nb'], 245 | # 'chain_id': data_cropped['chain_id'], 246 | # 'resseq': data_cropped['resseq'], 247 | # 'icode': data_cropped['icode'], 248 | # # Generated 249 | # 'aa': aa_new[i], 250 | # 'mask_heavyatom': mask_atom_new[i], 251 | # 'pos_heavyatom': pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(), 252 | # }, path=os.path.join(log_dir, variant['tag'], '%04d_patch.pdb' % (count, ))) 253 | count += 1 254 | 255 | logger.info('Finished.\n') 256 | 257 | 258 | def args_from_cmdline(): 259 | parser = argparse.ArgumentParser() 260 | parser.add_argument('pdb_path', type=str) 261 | parser.add_argument('--heavy', type=str, default=None, help='Chain id of the heavy chain.') 262 | parser.add_argument('--light', type=str, default=None, help='Chain id of the light chain.') 263 | parser.add_argument('--no_renumber', action='store_true', default=False) 264 | parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml') 265 | parser.add_argument('-o', '--out_root', type=str, default='./results') 266 | parser.add_argument('-t', '--tag', type=str, default='') 267 | parser.add_argument('-s', '--seed', type=int, default=None) 268 | parser.add_argument('-d', '--device', type=str, default='cuda') 269 | parser.add_argument('-b', '--batch_size', type=int, default=16) 270 | args = parser.parse_args() 271 | return args 272 | 273 | 274 | def args_factory(**kwargs): 275 | default_args = EasyDict( 276 | heavy = 'H', 277 | light = 'L', 278 | no_renumber = False, 279 | config = './configs/test/codesign_single.yml', 280 | out_root = './results', 281 | tag = '', 282 | seed = None, 283 | device = 'cuda', 284 | batch_size = 16 285 | ) 286 | default_args.update(kwargs) 287 | return default_args 288 | 289 | 290 | if __name__ == '__main__': 291 | design_for_pdb(args_from_cmdline()) 292 | -------------------------------------------------------------------------------- /diffab/tools/runner/design_for_testset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import copy 4 | import json 5 | from tqdm.auto import tqdm 6 | from torch.utils.data import DataLoader 7 | 8 | from diffab.datasets import get_dataset 9 | from diffab.models import get_model 10 | from diffab.modules.common.geometry import reconstruct_backbone_partially 11 | from diffab.modules.common.so3 import so3vec_to_rotation 12 | from diffab.utils.inference import RemoveNative 13 | from diffab.utils.protein.writers import save_pdb 14 | from diffab.utils.train import recursive_to 15 | from diffab.utils.misc import * 16 | from diffab.utils.data import * 17 | from diffab.utils.transforms import * 18 | from diffab.utils.inference import * 19 | 20 | 21 | def create_data_variants(config, structure_factory): 22 | structure = structure_factory() 23 | structure_id = structure['id'] 24 | 25 | data_variants = [] 26 | if config.mode == 'single_cdr': 27 | cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) 28 | for cdr_name in cdrs: 29 | transform = Compose([ 30 | MaskSingleCDR(cdr_name, augmentation=False), 31 | MergeChains(), 32 | ]) 33 | data_var = transform(structure_factory()) 34 | residue_first, residue_last = get_residue_first_last(data_var) 35 | data_variants.append({ 36 | 'data': data_var, 37 | 'name': f'{structure_id}-{cdr_name}', 38 | 'tag': f'{cdr_name}', 39 | 'cdr': cdr_name, 40 | 'residue_first': residue_first, 41 | 'residue_last': residue_last, 42 | }) 43 | elif config.mode == 'multiple_cdrs': 44 | cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) 45 | transform = Compose([ 46 | MaskMultipleCDRs(selection=cdrs, augmentation=False), 47 | MergeChains(), 48 | ]) 49 | data_var = transform(structure_factory()) 50 | data_variants.append({ 51 | 'data': data_var, 52 | 'name': f'{structure_id}-MultipleCDRs', 53 | 'tag': 'MultipleCDRs', 54 | 'cdrs': cdrs, 55 | 'residue_first': None, 56 | 'residue_last': None, 57 | }) 58 | elif config.mode == 'full': 59 | transform = Compose([ 60 | MaskAntibody(), 61 | MergeChains(), 62 | ]) 63 | data_var = transform(structure_factory()) 64 | data_variants.append({ 65 | 'data': data_var, 66 | 'name': f'{structure_id}-Full', 67 | 'tag': 'Full', 68 | 'residue_first': None, 69 | 'residue_last': None, 70 | }) 71 | elif config.mode == 'abopt': 72 | cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) 73 | for cdr_name in cdrs: 74 | transform = Compose([ 75 | MaskSingleCDR(cdr_name, augmentation=False), 76 | MergeChains(), 77 | ]) 78 | data_var = transform(structure_factory()) 79 | residue_first, residue_last = get_residue_first_last(data_var) 80 | for opt_step in config.sampling.optimize_steps: 81 | data_variants.append({ 82 | 'data': data_var, 83 | 'name': f'{structure_id}-{cdr_name}-O{opt_step}', 84 | 'tag': f'{cdr_name}-O{opt_step}', 85 | 'cdr': cdr_name, 86 | 'opt_step': opt_step, 87 | 'residue_first': residue_first, 88 | 'residue_last': residue_last, 89 | }) 90 | else: 91 | raise ValueError(f'Unknown mode: {config.mode}.') 92 | return data_variants 93 | 94 | def main(): 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument('index', type=int) 97 | parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml') 98 | parser.add_argument('-o', '--out_root', type=str, default='./results') 99 | parser.add_argument('-t', '--tag', type=str, default='') 100 | parser.add_argument('-s', '--seed', type=int, default=None) 101 | parser.add_argument('-d', '--device', type=str, default='cuda') 102 | parser.add_argument('-b', '--batch_size', type=int, default=16) 103 | args = parser.parse_args() 104 | 105 | # Load configs 106 | config, config_name = load_config(args.config) 107 | seed_all(args.seed if args.seed is not None else config.sampling.seed) 108 | 109 | # Testset 110 | dataset = get_dataset(config.dataset.test) 111 | get_structure = lambda: dataset[args.index] 112 | 113 | # Logging 114 | structure_ = get_structure() 115 | structure_id = structure_['id'] 116 | tag_postfix = '_%s' % args.tag if args.tag else '' 117 | log_dir = get_new_log_dir(os.path.join(args.out_root, config_name + tag_postfix), prefix='%04d_%s' % (args.index, structure_['id'])) 118 | logger = get_logger('sample', log_dir) 119 | logger.info('Data ID: %s' % structure_['id']) 120 | data_native = MergeChains()(structure_) 121 | save_pdb(data_native, os.path.join(log_dir, 'reference.pdb')) 122 | 123 | # Load checkpoint and model 124 | logger.info('Loading model config and checkpoints: %s' % (config.model.checkpoint)) 125 | ckpt = torch.load(config.model.checkpoint, map_location='cpu') 126 | cfg_ckpt = ckpt['config'] 127 | model = get_model(cfg_ckpt.model).to(args.device) 128 | lsd = model.load_state_dict(ckpt['model']) 129 | logger.info(str(lsd)) 130 | 131 | # Make data variants 132 | data_variants = create_data_variants( 133 | config = config, 134 | structure_factory = get_structure, 135 | ) 136 | 137 | # Save metadata 138 | metadata = { 139 | 'identifier': structure_id, 140 | 'index': args.index, 141 | 'config': args.config, 142 | 'items': [{kk: vv for kk, vv in var.items() if kk != 'data'} for var in data_variants], 143 | } 144 | with open(os.path.join(log_dir, 'metadata.json'), 'w') as f: 145 | json.dump(metadata, f, indent=2) 146 | 147 | # Start sampling 148 | collate_fn = PaddingCollate(eight=False) 149 | inference_tfm = [ PatchAroundAnchor(), ] 150 | if 'abopt' not in config.mode: # Don't remove native CDR in optimization mode 151 | inference_tfm.append(RemoveNative( 152 | remove_structure = config.sampling.sample_structure, 153 | remove_sequence = config.sampling.sample_sequence, 154 | )) 155 | inference_tfm = Compose(inference_tfm) 156 | 157 | for variant in data_variants: 158 | os.makedirs(os.path.join(log_dir, variant['tag']), exist_ok=True) 159 | logger.info(f"Start sampling for: {variant['tag']}") 160 | 161 | save_pdb(data_native, os.path.join(log_dir, variant['tag'], 'REF1.pdb')) # w/ OpenMM minimization 162 | 163 | data_cropped = inference_tfm( 164 | copy.deepcopy(variant['data']) 165 | ) 166 | data_list_repeat = [ data_cropped ] * config.sampling.num_samples 167 | loader = DataLoader(data_list_repeat, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) 168 | 169 | count = 0 170 | for batch in tqdm(loader, desc=variant['name'], dynamic_ncols=True): 171 | torch.set_grad_enabled(False) 172 | model.eval() 173 | batch = recursive_to(batch, args.device) 174 | if 'abopt' in config.mode: 175 | # Antibody optimization starting from native 176 | traj_batch = model.optimize(batch, opt_step=variant['opt_step'], optimize_opt={ 177 | 'pbar': True, 178 | 'sample_structure': config.sampling.sample_structure, 179 | 'sample_sequence': config.sampling.sample_sequence, 180 | }) 181 | else: 182 | # De novo design 183 | traj_batch = model.sample(batch, sample_opt={ 184 | 'pbar': True, 185 | 'sample_structure': config.sampling.sample_structure, 186 | 'sample_sequence': config.sampling.sample_sequence, 187 | }) 188 | 189 | aa_new = traj_batch[0][2] # 0: Last sampling step. 2: Amino acid. 190 | pos_atom_new, mask_atom_new = reconstruct_backbone_partially( 191 | pos_ctx = batch['pos_heavyatom'], 192 | R_new = so3vec_to_rotation(traj_batch[0][0]), 193 | t_new = traj_batch[0][1], 194 | aa = aa_new, 195 | chain_nb = batch['chain_nb'], 196 | res_nb = batch['res_nb'], 197 | mask_atoms = batch['mask_heavyatom'], 198 | mask_recons = batch['generate_flag'], 199 | ) 200 | aa_new = aa_new.cpu() 201 | pos_atom_new = pos_atom_new.cpu() 202 | mask_atom_new = mask_atom_new.cpu() 203 | 204 | for i in range(aa_new.size(0)): 205 | data_tmpl = variant['data'] 206 | aa = apply_patch_to_tensor(data_tmpl['aa'], aa_new[i], data_cropped['patch_idx']) 207 | mask_ha = apply_patch_to_tensor(data_tmpl['mask_heavyatom'], mask_atom_new[i], data_cropped['patch_idx']) 208 | pos_ha = ( 209 | apply_patch_to_tensor( 210 | data_tmpl['pos_heavyatom'], 211 | pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(), 212 | data_cropped['patch_idx'] 213 | ) 214 | ) 215 | 216 | save_path = os.path.join(log_dir, variant['tag'], '%04d.pdb' % (count, )) 217 | save_pdb({ 218 | 'chain_nb': data_tmpl['chain_nb'], 219 | 'chain_id': data_tmpl['chain_id'], 220 | 'resseq': data_tmpl['resseq'], 221 | 'icode': data_tmpl['icode'], 222 | # Generated 223 | 'aa': aa, 224 | 'mask_heavyatom': mask_ha, 225 | 'pos_heavyatom': pos_ha, 226 | }, path=save_path) 227 | # save_pdb({ 228 | # 'chain_nb': data_cropped['chain_nb'], 229 | # 'chain_id': data_cropped['chain_id'], 230 | # 'resseq': data_cropped['resseq'], 231 | # 'icode': data_cropped['icode'], 232 | # # Generated 233 | # 'aa': aa_new[i], 234 | # 'mask_heavyatom': mask_atom_new[i], 235 | # 'pos_heavyatom': pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(), 236 | # }, path=os.path.join(log_dir, variant['tag'], '%04d_patch.pdb' % (count, ))) 237 | count += 1 238 | 239 | logger.info('Finished.\n') 240 | 241 | 242 | if __name__ == '__main__': 243 | main() 244 | -------------------------------------------------------------------------------- /diffab/utils/data.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data._utils.collate import default_collate 4 | 5 | 6 | DEFAULT_PAD_VALUES = { 7 | 'aa': 21, 8 | 'chain_id': ' ', 9 | 'icode': ' ', 10 | } 11 | 12 | DEFAULT_NO_PADDING = { 13 | 'origin', 14 | } 15 | 16 | class PaddingCollate(object): 17 | 18 | def __init__(self, length_ref_key='aa', pad_values=DEFAULT_PAD_VALUES, no_padding=DEFAULT_NO_PADDING, eight=True): 19 | super().__init__() 20 | self.length_ref_key = length_ref_key 21 | self.pad_values = pad_values 22 | self.no_padding = no_padding 23 | self.eight = eight 24 | 25 | @staticmethod 26 | def _pad_last(x, n, value=0): 27 | if isinstance(x, torch.Tensor): 28 | assert x.size(0) <= n 29 | if x.size(0) == n: 30 | return x 31 | pad_size = [n - x.size(0)] + list(x.shape[1:]) 32 | pad = torch.full(pad_size, fill_value=value).to(x) 33 | return torch.cat([x, pad], dim=0) 34 | elif isinstance(x, list): 35 | pad = [value] * (n - len(x)) 36 | return x + pad 37 | else: 38 | return x 39 | 40 | @staticmethod 41 | def _get_pad_mask(l, n): 42 | return torch.cat([ 43 | torch.ones([l], dtype=torch.bool), 44 | torch.zeros([n-l], dtype=torch.bool) 45 | ], dim=0) 46 | 47 | @staticmethod 48 | def _get_common_keys(list_of_dict): 49 | keys = set(list_of_dict[0].keys()) 50 | for d in list_of_dict[1:]: 51 | keys = keys.intersection(d.keys()) 52 | return keys 53 | 54 | 55 | def _get_pad_value(self, key): 56 | if key not in self.pad_values: 57 | return 0 58 | return self.pad_values[key] 59 | 60 | def __call__(self, data_list): 61 | max_length = max([data[self.length_ref_key].size(0) for data in data_list]) 62 | keys = self._get_common_keys(data_list) 63 | 64 | if self.eight: 65 | max_length = math.ceil(max_length / 8) * 8 66 | data_list_padded = [] 67 | for data in data_list: 68 | data_padded = { 69 | k: self._pad_last(v, max_length, value=self._get_pad_value(k)) if k not in self.no_padding else v 70 | for k, v in data.items() 71 | if k in keys 72 | } 73 | data_padded['mask'] = self._get_pad_mask(data[self.length_ref_key].size(0), max_length) 74 | data_list_padded.append(data_padded) 75 | return default_collate(data_list_padded) 76 | 77 | 78 | def apply_patch_to_tensor(x_full, x_patch, patch_idx): 79 | """ 80 | Args: 81 | x_full: (N, ...) 82 | x_patch: (M, ...) 83 | patch_idx: (M, ) 84 | Returns: 85 | (N, ...) 86 | """ 87 | x_full = x_full.clone() 88 | x_full[patch_idx] = x_patch 89 | return x_full 90 | -------------------------------------------------------------------------------- /diffab/utils/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .protein import constants 3 | 4 | 5 | def find_cdrs(structure): 6 | cdrs = [] 7 | if structure['heavy'] is not None: 8 | flag = structure['heavy']['cdr_flag'] 9 | if int(constants.CDR.H1) in flag: 10 | cdrs.append('H_CDR1') 11 | if int(constants.CDR.H2) in flag: 12 | cdrs.append('H_CDR2') 13 | if int(constants.CDR.H3) in flag: 14 | cdrs.append('H_CDR3') 15 | 16 | if structure['light'] is not None: 17 | flag = structure['light']['cdr_flag'] 18 | if int(constants.CDR.L1) in flag: 19 | cdrs.append('L_CDR1') 20 | if int(constants.CDR.L2) in flag: 21 | cdrs.append('L_CDR2') 22 | if int(constants.CDR.L3) in flag: 23 | cdrs.append('L_CDR3') 24 | 25 | return cdrs 26 | 27 | 28 | def get_residue_first_last(data): 29 | loop_flag = data['generate_flag'] 30 | loop_idx = torch.arange(loop_flag.size(0))[loop_flag] 31 | idx_first, idx_last = loop_idx.min().item(), loop_idx.max().item() 32 | residue_first = (data['chain_id'][idx_first], data['resseq'][idx_first].item(), data['icode'][idx_first]) 33 | residue_last = (data['chain_id'][idx_last], data['resseq'][idx_last].item(), data['icode'][idx_last]) 34 | return residue_first, residue_last 35 | 36 | 37 | class RemoveNative(object): 38 | 39 | def __init__(self, remove_structure, remove_sequence): 40 | super().__init__() 41 | self.remove_structure = remove_structure 42 | self.remove_sequence = remove_sequence 43 | 44 | def __call__(self, data): 45 | generate_flag = data['generate_flag'].clone() 46 | if self.remove_sequence: 47 | data['aa'] = torch.where( 48 | generate_flag, 49 | torch.full_like(data['aa'], fill_value=int(constants.AA.UNK)), # Is loop 50 | data['aa'] 51 | ) 52 | 53 | if self.remove_structure: 54 | data['pos_heavyatom'] = torch.where( 55 | generate_flag[:, None, None].expand(data['pos_heavyatom'].shape), 56 | torch.randn_like(data['pos_heavyatom']) * 10, 57 | data['pos_heavyatom'] 58 | ) 59 | 60 | return data -------------------------------------------------------------------------------- /diffab/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import logging 5 | from typing import OrderedDict 6 | import torch 7 | import torch.linalg 8 | import numpy as np 9 | import yaml 10 | from easydict import EasyDict 11 | from glob import glob 12 | 13 | 14 | class BlackHole(object): 15 | def __setattr__(self, name, value): 16 | pass 17 | 18 | def __call__(self, *args, **kwargs): 19 | return self 20 | 21 | def __getattr__(self, name): 22 | return self 23 | 24 | 25 | class Counter(object): 26 | def __init__(self, start=0): 27 | super().__init__() 28 | self.now = start 29 | 30 | def step(self, delta=1): 31 | prev = self.now 32 | self.now += delta 33 | return prev 34 | 35 | 36 | def get_logger(name, log_dir=None): 37 | logger = logging.getLogger(name) 38 | logger.setLevel(logging.DEBUG) 39 | formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') 40 | 41 | stream_handler = logging.StreamHandler() 42 | stream_handler.setLevel(logging.DEBUG) 43 | stream_handler.setFormatter(formatter) 44 | logger.addHandler(stream_handler) 45 | 46 | if log_dir is not None: 47 | file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt')) 48 | file_handler.setLevel(logging.DEBUG) 49 | file_handler.setFormatter(formatter) 50 | logger.addHandler(file_handler) 51 | 52 | return logger 53 | 54 | 55 | def get_new_log_dir(root='./logs', prefix='', tag=''): 56 | fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) 57 | if prefix != '': 58 | fn = prefix + '_' + fn 59 | if tag != '': 60 | fn = fn + '_' + tag 61 | log_dir = os.path.join(root, fn) 62 | os.makedirs(log_dir) 63 | return log_dir 64 | 65 | 66 | def seed_all(seed): 67 | torch.backends.cudnn.deterministic = True 68 | torch.manual_seed(seed) 69 | torch.cuda.manual_seed_all(seed) 70 | np.random.seed(seed) 71 | random.seed(seed) 72 | 73 | 74 | def inf_iterator(iterable): 75 | iterator = iterable.__iter__() 76 | while True: 77 | try: 78 | yield iterator.__next__() 79 | except StopIteration: 80 | iterator = iterable.__iter__() 81 | 82 | 83 | def log_hyperparams(writer, args): 84 | from torch.utils.tensorboard.summary import hparams 85 | vars_args = {k: v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} 86 | exp, ssi, sei = hparams(vars_args, {}) 87 | writer.file_writer.add_summary(exp) 88 | writer.file_writer.add_summary(ssi) 89 | writer.file_writer.add_summary(sei) 90 | 91 | 92 | def int_tuple(argstr): 93 | return tuple(map(int, argstr.split(','))) 94 | 95 | 96 | def str_tuple(argstr): 97 | return tuple(argstr.split(',')) 98 | 99 | 100 | def get_checkpoint_path(folder, it=None): 101 | if it is not None: 102 | return os.path.join(folder, '%d.pt' % it), it 103 | all_iters = list(map(lambda x: int(os.path.basename(x[:-3])), glob(os.path.join(folder, '*.pt')))) 104 | all_iters.sort() 105 | return os.path.join(folder, '%d.pt' % all_iters[-1]), all_iters[-1] 106 | 107 | 108 | def load_config(config_path): 109 | with open(config_path, 'r') as f: 110 | config = EasyDict(yaml.safe_load(f)) 111 | config_name = os.path.basename(config_path)[:os.path.basename(config_path).rfind('.')] 112 | return config, config_name 113 | 114 | 115 | def extract_weights(weights: OrderedDict, prefix): 116 | extracted = OrderedDict() 117 | for k, v in weights.items(): 118 | if k.startswith(prefix): 119 | extracted.update({ 120 | k[len(prefix):]: v 121 | }) 122 | return extracted 123 | 124 | 125 | def current_milli_time(): 126 | return round(time.time() * 1000) 127 | -------------------------------------------------------------------------------- /diffab/utils/protein/constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import enum 3 | 4 | class CDR(enum.IntEnum): 5 | H1 = 1 6 | H2 = 2 7 | H3 = 3 8 | L1 = 4 9 | L2 = 5 10 | L3 = 6 11 | 12 | 13 | class ChothiaCDRRange: 14 | H1 = (26, 32) 15 | H2 = (52, 56) 16 | H3 = (95, 102) 17 | 18 | L1 = (24, 34) 19 | L2 = (50, 56) 20 | L3 = (89, 97) 21 | 22 | @classmethod 23 | def to_cdr(cls, chain_type, resseq): 24 | assert chain_type in ('H', 'L') 25 | if chain_type == 'H': 26 | if cls.H1[0] <= resseq <= cls.H1[1]: 27 | return CDR.H1 28 | elif cls.H2[0] <= resseq <= cls.H2[1]: 29 | return CDR.H2 30 | elif cls.H3[0] <= resseq <= cls.H3[1]: 31 | return CDR.H3 32 | elif chain_type == 'L': 33 | if cls.L1[0] <= resseq <= cls.L1[1]: # Chothia VH-CDR1 34 | return CDR.L1 35 | elif cls.L2[0] <= resseq <= cls.L2[1]: 36 | return CDR.L2 37 | elif cls.L3[0] <= resseq <= cls.L3[1]: 38 | return CDR.L3 39 | 40 | 41 | class Fragment(enum.IntEnum): 42 | Heavy = 1 43 | Light = 2 44 | Antigen = 3 45 | 46 | ## 47 | # Residue identities 48 | """ 49 | This is part of the OpenMM molecular simulation toolkit originating from 50 | Simbios, the NIH National Center for Physics-Based Simulation of 51 | Biological Structures at Stanford, funded under the NIH Roadmap for 52 | Medical Research, grant U54 GM072970. See https://simtk.org. 53 | 54 | Portions copyright (c) 2013 Stanford University and the Authors. 55 | Authors: Peter Eastman 56 | Contributors: 57 | 58 | Permission is hereby granted, free of charge, to any person obtaining a 59 | copy of this software and associated documentation files (the "Software"), 60 | to deal in the Software without restriction, including without limitation 61 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 62 | and/or sell copies of the Software, and to permit persons to whom the 63 | Software is furnished to do so, subject to the following conditions: 64 | 65 | The above copyright notice and this permission notice shall be included in 66 | all copies or substantial portions of the Software. 67 | 68 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 69 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 70 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 71 | THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 72 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 73 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE 74 | USE OR OTHER DEALINGS IN THE SOFTWARE. 75 | """ 76 | non_standard_residue_substitutions = { 77 | '2AS':'ASP', '3AH':'HIS', '5HP':'GLU', 'ACL':'ARG', 'AGM':'ARG', 'AIB':'ALA', 'ALM':'ALA', 'ALO':'THR', 'ALY':'LYS', 'ARM':'ARG', 78 | 'ASA':'ASP', 'ASB':'ASP', 'ASK':'ASP', 'ASL':'ASP', 'ASQ':'ASP', 'AYA':'ALA', 'BCS':'CYS', 'BHD':'ASP', 'BMT':'THR', 'BNN':'ALA', 79 | 'BUC':'CYS', 'BUG':'LEU', 'C5C':'CYS', 'C6C':'CYS', 'CAS':'CYS', 'CCS':'CYS', 'CEA':'CYS', 'CGU':'GLU', 'CHG':'ALA', 'CLE':'LEU', 'CME':'CYS', 80 | 'CSD':'ALA', 'CSO':'CYS', 'CSP':'CYS', 'CSS':'CYS', 'CSW':'CYS', 'CSX':'CYS', 'CXM':'MET', 'CY1':'CYS', 'CY3':'CYS', 'CYG':'CYS', 81 | 'CYM':'CYS', 'CYQ':'CYS', 'DAH':'PHE', 'DAL':'ALA', 'DAR':'ARG', 'DAS':'ASP', 'DCY':'CYS', 'DGL':'GLU', 'DGN':'GLN', 'DHA':'ALA', 82 | 'DHI':'HIS', 'DIL':'ILE', 'DIV':'VAL', 'DLE':'LEU', 'DLY':'LYS', 'DNP':'ALA', 'DPN':'PHE', 'DPR':'PRO', 'DSN':'SER', 'DSP':'ASP', 83 | 'DTH':'THR', 'DTR':'TRP', 'DTY':'TYR', 'DVA':'VAL', 'EFC':'CYS', 'FLA':'ALA', 'FME':'MET', 'GGL':'GLU', 'GL3':'GLY', 'GLZ':'GLY', 84 | 'GMA':'GLU', 'GSC':'GLY', 'HAC':'ALA', 'HAR':'ARG', 'HIC':'HIS', 'HIP':'HIS', 'HMR':'ARG', 'HPQ':'PHE', 'HTR':'TRP', 'HYP':'PRO', 85 | 'IAS':'ASP', 'IIL':'ILE', 'IYR':'TYR', 'KCX':'LYS', 'LLP':'LYS', 'LLY':'LYS', 'LTR':'TRP', 'LYM':'LYS', 'LYZ':'LYS', 'MAA':'ALA', 'MEN':'ASN', 86 | 'MHS':'HIS', 'MIS':'SER', 'MLE':'LEU', 'MPQ':'GLY', 'MSA':'GLY', 'MSE':'MET', 'MVA':'VAL', 'NEM':'HIS', 'NEP':'HIS', 'NLE':'LEU', 87 | 'NLN':'LEU', 'NLP':'LEU', 'NMC':'GLY', 'OAS':'SER', 'OCS':'CYS', 'OMT':'MET', 'PAQ':'TYR', 'PCA':'GLU', 'PEC':'CYS', 'PHI':'PHE', 88 | 'PHL':'PHE', 'PR3':'CYS', 'PRR':'ALA', 'PTR':'TYR', 'PYX':'CYS', 'SAC':'SER', 'SAR':'GLY', 'SCH':'CYS', 'SCS':'CYS', 'SCY':'CYS', 89 | 'SEL':'SER', 'SEP':'SER', 'SET':'SER', 'SHC':'CYS', 'SHR':'LYS', 'SMC':'CYS', 'SOC':'CYS', 'STY':'TYR', 'SVA':'SER', 'TIH':'ALA', 90 | 'TPL':'TRP', 'TPO':'THR', 'TPQ':'ALA', 'TRG':'LYS', 'TRO':'TRP', 'TYB':'TYR', 'TYI':'TYR', 'TYQ':'TYR', 'TYS':'TYR', 'TYY':'TYR' 91 | } 92 | 93 | 94 | ressymb_to_resindex = { 95 | 'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 96 | 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 97 | 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 98 | 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19, 99 | 'X': 20, 100 | } 101 | 102 | 103 | class AA(enum.IntEnum): 104 | ALA = 0; CYS = 1; ASP = 2; GLU = 3; PHE = 4 105 | GLY = 5; HIS = 6; ILE = 7; LYS = 8; LEU = 9 106 | MET = 10; ASN = 11; PRO = 12; GLN = 13; ARG = 14 107 | SER = 15; THR = 16; VAL = 17; TRP = 18; TYR = 19 108 | UNK = 20 109 | 110 | @classmethod 111 | def _missing_(cls, value): 112 | if isinstance(value, str) and len(value) == 3: # three representation 113 | if value in non_standard_residue_substitutions: 114 | value = non_standard_residue_substitutions[value] 115 | if value in cls._member_names_: 116 | return getattr(cls, value) 117 | elif isinstance(value, str) and len(value) == 1: # one representation 118 | if value in ressymb_to_resindex: 119 | return cls(ressymb_to_resindex[value]) 120 | 121 | return super()._missing_(value) 122 | 123 | def __str__(self): 124 | return self.name 125 | 126 | @classmethod 127 | def is_aa(cls, value): 128 | return (value in ressymb_to_resindex) or \ 129 | (value in non_standard_residue_substitutions) or \ 130 | (value in cls._member_names_) or \ 131 | (value in cls._member_map_.values()) 132 | 133 | 134 | num_aa_types = len(AA) 135 | 136 | ## 137 | # Atom identities 138 | 139 | class BBHeavyAtom(enum.IntEnum): 140 | N = 0; CA = 1; C = 2; O = 3; CB = 4; OXT=14; 141 | 142 | max_num_heavyatoms = 15 143 | 144 | # Copyright 2021 DeepMind Technologies Limited 145 | # 146 | # Licensed under the Apache License, Version 2.0 (the "License"); 147 | # you may not use this file except in compliance with the License. 148 | # You may obtain a copy of the License at 149 | # 150 | # http://www.apache.org/licenses/LICENSE-2.0 151 | # 152 | # Unless required by applicable law or agreed to in writing, software 153 | # distributed under the License is distributed on an "AS IS" BASIS, 154 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 155 | # See the License for the specific language governing permissions and 156 | # limitations under the License. 157 | restype_to_heavyatom_names = { 158 | AA.ALA: ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', '', 'OXT'], 159 | AA.ARG: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', '', 'OXT'], 160 | AA.ASN: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', '', 'OXT'], 161 | AA.ASP: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', '', 'OXT'], 162 | AA.CYS: ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', '', 'OXT'], 163 | AA.GLN: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', '', 'OXT'], 164 | AA.GLU: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', '', 'OXT'], 165 | AA.GLY: ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', '', 'OXT'], 166 | AA.HIS: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', '', 'OXT'], 167 | AA.ILE: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', '', 'OXT'], 168 | AA.LEU: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', '', 'OXT'], 169 | AA.LYS: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', '', 'OXT'], 170 | AA.MET: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', '', 'OXT'], 171 | AA.PHE: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', '', 'OXT'], 172 | AA.PRO: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', '', 'OXT'], 173 | AA.SER: ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', '', 'OXT'], 174 | AA.THR: ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', '', 'OXT'], 175 | AA.TRP: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2', 'OXT'], 176 | AA.TYR: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', '', 'OXT'], 177 | AA.VAL: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', '', 'OXT'], 178 | AA.UNK: ['', '', '', '', '', '', '', '', '', '', '', '', '', '', ''], 179 | } 180 | for names in restype_to_heavyatom_names.values(): assert len(names) == max_num_heavyatoms 181 | 182 | 183 | backbone_atom_coordinates = { 184 | AA.ALA: [ 185 | (-0.525, 1.363, 0.0), # N 186 | (0.0, 0.0, 0.0), # CA 187 | (1.526, -0.0, -0.0), # C 188 | ], 189 | AA.ARG: [ 190 | (-0.524, 1.362, -0.0), # N 191 | (0.0, 0.0, 0.0), # CA 192 | (1.525, -0.0, -0.0), # C 193 | ], 194 | AA.ASN: [ 195 | (-0.536, 1.357, 0.0), # N 196 | (0.0, 0.0, 0.0), # CA 197 | (1.526, -0.0, -0.0), # C 198 | ], 199 | AA.ASP: [ 200 | (-0.525, 1.362, -0.0), # N 201 | (0.0, 0.0, 0.0), # CA 202 | (1.527, 0.0, -0.0), # C 203 | ], 204 | AA.CYS: [ 205 | (-0.522, 1.362, -0.0), # N 206 | (0.0, 0.0, 0.0), # CA 207 | (1.524, 0.0, 0.0), # C 208 | ], 209 | AA.GLN: [ 210 | (-0.526, 1.361, -0.0), # N 211 | (0.0, 0.0, 0.0), # CA 212 | (1.526, 0.0, 0.0), # C 213 | ], 214 | AA.GLU: [ 215 | (-0.528, 1.361, 0.0), # N 216 | (0.0, 0.0, 0.0), # CA 217 | (1.526, -0.0, -0.0), # C 218 | ], 219 | AA.GLY: [ 220 | (-0.572, 1.337, 0.0), # N 221 | (0.0, 0.0, 0.0), # CA 222 | (1.517, -0.0, -0.0), # C 223 | ], 224 | AA.HIS: [ 225 | (-0.527, 1.36, 0.0), # N 226 | (0.0, 0.0, 0.0), # CA 227 | (1.525, 0.0, 0.0), # C 228 | ], 229 | AA.ILE: [ 230 | (-0.493, 1.373, -0.0), # N 231 | (0.0, 0.0, 0.0), # CA 232 | (1.527, -0.0, -0.0), # C 233 | ], 234 | AA.LEU: [ 235 | (-0.52, 1.363, 0.0), # N 236 | (0.0, 0.0, 0.0), # CA 237 | (1.525, -0.0, -0.0), # C 238 | ], 239 | AA.LYS: [ 240 | (-0.526, 1.362, -0.0), # N 241 | (0.0, 0.0, 0.0), # CA 242 | (1.526, 0.0, 0.0), # C 243 | ], 244 | AA.MET: [ 245 | (-0.521, 1.364, -0.0), # N 246 | (0.0, 0.0, 0.0), # CA 247 | (1.525, 0.0, 0.0), # C 248 | ], 249 | AA.PHE: [ 250 | (-0.518, 1.363, 0.0), # N 251 | (0.0, 0.0, 0.0), # CA 252 | (1.524, 0.0, -0.0), # C 253 | ], 254 | AA.PRO: [ 255 | (-0.566, 1.351, -0.0), # N 256 | (0.0, 0.0, 0.0), # CA 257 | (1.527, -0.0, 0.0), # C 258 | ], 259 | AA.SER: [ 260 | (-0.529, 1.36, -0.0), # N 261 | (0.0, 0.0, 0.0), # CA 262 | (1.525, -0.0, -0.0), # C 263 | ], 264 | AA.THR: [ 265 | (-0.517, 1.364, 0.0), # N 266 | (0.0, 0.0, 0.0), # CA 267 | (1.526, 0.0, -0.0), # C 268 | ], 269 | AA.TRP: [ 270 | (-0.521, 1.363, 0.0), # N 271 | (0.0, 0.0, 0.0), # CA 272 | (1.525, -0.0, 0.0), # C 273 | ], 274 | AA.TYR: [ 275 | (-0.522, 1.362, 0.0), # N 276 | (0.0, 0.0, 0.0), # CA 277 | (1.524, -0.0, -0.0), # C 278 | ], 279 | AA.VAL: [ 280 | (-0.494, 1.373, -0.0), # N 281 | (0.0, 0.0, 0.0), # CA 282 | (1.527, -0.0, -0.0), # C 283 | ], 284 | } 285 | 286 | bb_oxygen_coordinate = { 287 | AA.ALA: (2.153, -1.062, 0.0), 288 | AA.ARG: (2.151, -1.062, 0.0), 289 | AA.ASN: (2.151, -1.062, 0.0), 290 | AA.ASP: (2.153, -1.062, 0.0), 291 | AA.CYS: (2.149, -1.062, 0.0), 292 | AA.GLN: (2.152, -1.062, 0.0), 293 | AA.GLU: (2.152, -1.062, 0.0), 294 | AA.GLY: (2.143, -1.062, 0.0), 295 | AA.HIS: (2.15, -1.063, 0.0), 296 | AA.ILE: (2.154, -1.062, 0.0), 297 | AA.LEU: (2.15, -1.063, 0.0), 298 | AA.LYS: (2.152, -1.062, 0.0), 299 | AA.MET: (2.15, -1.062, 0.0), 300 | AA.PHE: (2.15, -1.062, 0.0), 301 | AA.PRO: (2.148, -1.066, 0.0), 302 | AA.SER: (2.151, -1.062, 0.0), 303 | AA.THR: (2.152, -1.062, 0.0), 304 | AA.TRP: (2.152, -1.062, 0.0), 305 | AA.TYR: (2.151, -1.062, 0.0), 306 | AA.VAL: (2.154, -1.062, 0.0), 307 | } 308 | 309 | backbone_atom_coordinates_tensor = torch.zeros([21, 3, 3]) 310 | bb_oxygen_coordinate_tensor = torch.zeros([21, 3]) 311 | 312 | def make_coordinate_tensors(): 313 | for restype, atom_coords in backbone_atom_coordinates.items(): 314 | for atom_id, atom_coord in enumerate(atom_coords): 315 | backbone_atom_coordinates_tensor[restype][atom_id] = torch.FloatTensor(atom_coord) 316 | 317 | for restype, bb_oxy_coord in bb_oxygen_coordinate.items(): 318 | bb_oxygen_coordinate_tensor[restype] = torch.FloatTensor(bb_oxy_coord) 319 | make_coordinate_tensors() 320 | -------------------------------------------------------------------------------- /diffab/utils/protein/parsers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from Bio.PDB import Selection 3 | from Bio.PDB.Residue import Residue 4 | from easydict import EasyDict 5 | 6 | from .constants import ( 7 | AA, max_num_heavyatoms, 8 | restype_to_heavyatom_names, 9 | BBHeavyAtom 10 | ) 11 | 12 | 13 | class ParsingException(Exception): 14 | pass 15 | 16 | 17 | def _get_residue_heavyatom_info(res: Residue): 18 | pos_heavyatom = torch.zeros([max_num_heavyatoms, 3], dtype=torch.float) 19 | mask_heavyatom = torch.zeros([max_num_heavyatoms, ], dtype=torch.bool) 20 | restype = AA(res.get_resname()) 21 | for idx, atom_name in enumerate(restype_to_heavyatom_names[restype]): 22 | if atom_name == '': continue 23 | if atom_name in res: 24 | pos_heavyatom[idx] = torch.tensor(res[atom_name].get_coord().tolist(), dtype=pos_heavyatom.dtype) 25 | mask_heavyatom[idx] = True 26 | return pos_heavyatom, mask_heavyatom 27 | 28 | 29 | def parse_biopython_structure(entity, unknown_threshold=1.0, max_resseq=None): 30 | chains = Selection.unfold_entities(entity, 'C') 31 | chains.sort(key=lambda c: c.get_id()) 32 | data = EasyDict({ 33 | 'chain_id': [], 34 | 'resseq': [], 'icode': [], 'res_nb': [], 35 | 'aa': [], 36 | 'pos_heavyatom': [], 'mask_heavyatom': [], 37 | }) 38 | tensor_types = { 39 | 'resseq': torch.LongTensor, 40 | 'res_nb': torch.LongTensor, 41 | 'aa': torch.LongTensor, 42 | 'pos_heavyatom': torch.stack, 43 | 'mask_heavyatom': torch.stack, 44 | } 45 | 46 | count_aa, count_unk = 0, 0 47 | 48 | for i, chain in enumerate(chains): 49 | seq_this = 0 # Renumbering residues 50 | residues = Selection.unfold_entities(chain, 'R') 51 | residues.sort(key=lambda res: (res.get_id()[1], res.get_id()[2])) # Sort residues by resseq-icode 52 | for _, res in enumerate(residues): 53 | resseq_this = int(res.get_id()[1]) 54 | if max_resseq is not None and resseq_this > max_resseq: 55 | continue 56 | 57 | resname = res.get_resname() 58 | if not AA.is_aa(resname): continue 59 | if not (res.has_id('CA') and res.has_id('C') and res.has_id('N')): continue 60 | restype = AA(resname) 61 | count_aa += 1 62 | if restype == AA.UNK: 63 | count_unk += 1 64 | continue 65 | 66 | # Chain info 67 | data.chain_id.append(chain.get_id()) 68 | 69 | # Residue types 70 | data.aa.append(restype) # Will be automatically cast to torch.long 71 | 72 | # Heavy atoms 73 | pos_heavyatom, mask_heavyatom = _get_residue_heavyatom_info(res) 74 | data.pos_heavyatom.append(pos_heavyatom) 75 | data.mask_heavyatom.append(mask_heavyatom) 76 | 77 | # Sequential number 78 | resseq_this = int(res.get_id()[1]) 79 | icode_this = res.get_id()[2] 80 | if seq_this == 0: 81 | seq_this = 1 82 | else: 83 | d_CA_CA = torch.linalg.norm(data.pos_heavyatom[-2][BBHeavyAtom.CA] - data.pos_heavyatom[-1][BBHeavyAtom.CA], ord=2).item() 84 | if d_CA_CA <= 4.0: 85 | seq_this += 1 86 | else: 87 | d_resseq = resseq_this - data.resseq[-1] 88 | seq_this += max(2, d_resseq) 89 | 90 | data.resseq.append(resseq_this) 91 | data.icode.append(icode_this) 92 | data.res_nb.append(seq_this) 93 | 94 | if len(data.aa) == 0: 95 | raise ParsingException('No parsed residues.') 96 | 97 | if (count_unk / count_aa) >= unknown_threshold: 98 | raise ParsingException( 99 | f'Too many unknown residues, threshold {unknown_threshold:.2f}.' 100 | ) 101 | 102 | seq_map = {} 103 | for i, (chain_id, resseq, icode) in enumerate(zip(data.chain_id, data.resseq, data.icode)): 104 | seq_map[(chain_id, resseq, icode)] = i 105 | 106 | for key, convert_fn in tensor_types.items(): 107 | data[key] = convert_fn(data[key]) 108 | 109 | return data, seq_map 110 | -------------------------------------------------------------------------------- /diffab/utils/protein/writers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | from Bio import BiopythonWarning 4 | from Bio.PDB import PDBIO 5 | from Bio.PDB.StructureBuilder import StructureBuilder 6 | 7 | from .constants import AA, restype_to_heavyatom_names 8 | 9 | 10 | def save_pdb(data, path=None): 11 | """ 12 | Args: 13 | data: A dict that contains: `chain_nb`, `chain_id`, `aa`, `resseq`, `icode`, 14 | `pos_heavyatom`, `mask_heavyatom`. 15 | """ 16 | 17 | def _mask_select(v, mask): 18 | if isinstance(v, str): 19 | return ''.join([s for i, s in enumerate(v) if mask[i]]) 20 | elif isinstance(v, list): 21 | return [s for i, s in enumerate(v) if mask[i]] 22 | elif isinstance(v, torch.Tensor): 23 | return v[mask] 24 | else: 25 | return v 26 | 27 | def _build_chain(builder, aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, chain_id_ch, resseq_ch, icode_ch): 28 | builder.init_chain(chain_id_ch[0]) 29 | builder.init_seg(' ') 30 | 31 | for aa_res, pos_allatom_res, mask_allatom_res, resseq_res, icode_res in \ 32 | zip(aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, resseq_ch, icode_ch): 33 | if not AA.is_aa(aa_res.item()): 34 | print('[Warning] Unknown amino acid type at %d%s: %r' % (resseq_res.item(), icode_res, aa_res.item())) 35 | continue 36 | restype = AA(aa_res.item()) 37 | builder.init_residue( 38 | resname = str(restype), 39 | field = ' ', 40 | resseq = resseq_res.item(), 41 | icode = icode_res, 42 | ) 43 | 44 | for i, atom_name in enumerate(restype_to_heavyatom_names[restype]): 45 | if atom_name == '': continue # No expected atom 46 | if (~mask_allatom_res[i]).any(): continue # Atom is missing 47 | if len(atom_name) == 1: fullname = ' %s ' % atom_name 48 | elif len(atom_name) == 2: fullname = ' %s ' % atom_name 49 | elif len(atom_name) == 3: fullname = ' %s' % atom_name 50 | else: fullname = atom_name # len == 4 51 | builder.init_atom(atom_name, pos_allatom_res[i].tolist(), 0.0, 1.0, ' ', fullname,) 52 | 53 | warnings.simplefilter('ignore', BiopythonWarning) 54 | builder = StructureBuilder() 55 | builder.init_structure(0) 56 | builder.init_model(0) 57 | 58 | unique_chain_nb = data['chain_nb'].unique().tolist() 59 | for ch_nb in unique_chain_nb: 60 | mask = (data['chain_nb'] == ch_nb) 61 | aa = _mask_select(data['aa'], mask) 62 | pos_heavyatom = _mask_select(data['pos_heavyatom'], mask) 63 | mask_heavyatom = _mask_select(data['mask_heavyatom'], mask) 64 | chain_id = _mask_select(data['chain_id'], mask) 65 | resseq = _mask_select(data['resseq'], mask) 66 | icode = _mask_select(data['icode'], mask) 67 | 68 | _build_chain(builder, aa, pos_heavyatom, mask_heavyatom, chain_id, resseq, icode) 69 | 70 | structure = builder.get_structure() 71 | if path is not None: 72 | io = PDBIO() 73 | io.set_structure(structure) 74 | io.save(path) 75 | return structure 76 | -------------------------------------------------------------------------------- /diffab/utils/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from easydict import EasyDict 4 | 5 | from .misc import BlackHole 6 | 7 | 8 | def get_optimizer(cfg, model): 9 | if cfg.type == 'adam': 10 | return torch.optim.Adam( 11 | model.parameters(), 12 | lr=cfg.lr, 13 | weight_decay=cfg.weight_decay, 14 | betas=(cfg.beta1, cfg.beta2, ) 15 | ) 16 | else: 17 | raise NotImplementedError('Optimizer not supported: %s' % cfg.type) 18 | 19 | 20 | def get_scheduler(cfg, optimizer): 21 | if cfg.type is None: 22 | return BlackHole() 23 | elif cfg.type == 'plateau': 24 | return torch.optim.lr_scheduler.ReduceLROnPlateau( 25 | optimizer, 26 | factor=cfg.factor, 27 | patience=cfg.patience, 28 | min_lr=cfg.min_lr, 29 | ) 30 | elif cfg.type == 'multistep': 31 | return torch.optim.lr_scheduler.MultiStepLR( 32 | optimizer, 33 | milestones=cfg.milestones, 34 | gamma=cfg.gamma, 35 | ) 36 | elif cfg.type == 'exp': 37 | return torch.optim.lr_scheduler.ExponentialLR( 38 | optimizer, 39 | gamma=cfg.gamma, 40 | ) 41 | elif cfg.type is None: 42 | return BlackHole() 43 | else: 44 | raise NotImplementedError('Scheduler not supported: %s' % cfg.type) 45 | 46 | 47 | def get_warmup_sched(cfg, optimizer): 48 | if cfg is None: return BlackHole() 49 | lambdas = [lambda it : (it / cfg.max_iters) if it <= cfg.max_iters else 1 for _ in optimizer.param_groups] 50 | warmup_sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lambdas) 51 | return warmup_sched 52 | 53 | 54 | def log_losses(out, it, tag, logger=BlackHole(), writer=BlackHole(), others={}): 55 | logstr = '[%s] Iter %05d' % (tag, it) 56 | logstr += ' | loss %.4f' % out['overall'].item() 57 | for k, v in out.items(): 58 | if k == 'overall': continue 59 | logstr += ' | loss(%s) %.4f' % (k, v.item()) 60 | for k, v in others.items(): 61 | logstr += ' | %s %2.4f' % (k, v) 62 | logger.info(logstr) 63 | 64 | for k, v in out.items(): 65 | if k == 'overall': 66 | writer.add_scalar('%s/loss' % tag, v, it) 67 | else: 68 | writer.add_scalar('%s/loss_%s' % (tag, k), v, it) 69 | for k, v in others.items(): 70 | writer.add_scalar('%s/%s' % (tag, k), v, it) 71 | writer.flush() 72 | 73 | 74 | class ValidationLossTape(object): 75 | 76 | def __init__(self): 77 | super().__init__() 78 | self.accumulate = {} 79 | self.others = {} 80 | self.total = 0 81 | 82 | def update(self, out, n, others={}): 83 | self.total += n 84 | for k, v in out.items(): 85 | if k not in self.accumulate: 86 | self.accumulate[k] = v.clone().detach() 87 | else: 88 | self.accumulate[k] += v.clone().detach() 89 | 90 | for k, v in others.items(): 91 | if k not in self.others: 92 | self.others[k] = v.clone().detach() 93 | else: 94 | self.others[k] += v.clone().detach() 95 | 96 | 97 | def log(self, it, logger=BlackHole(), writer=BlackHole(), tag='val'): 98 | avg = EasyDict({k:v / self.total for k, v in self.accumulate.items()}) 99 | avg_others = EasyDict({k:v / self.total for k, v in self.others.items()}) 100 | log_losses(avg, it, tag, logger, writer, others=avg_others) 101 | return avg['overall'] 102 | 103 | 104 | def recursive_to(obj, device): 105 | if isinstance(obj, torch.Tensor): 106 | if device == 'cpu': 107 | return obj.cpu() 108 | try: 109 | return obj.cuda(device=device, non_blocking=True) 110 | except RuntimeError: 111 | return obj.to(device) 112 | elif isinstance(obj, list): 113 | return [recursive_to(o, device=device) for o in obj] 114 | elif isinstance(obj, tuple): 115 | return tuple(recursive_to(o, device=device) for o in obj) 116 | elif isinstance(obj, dict): 117 | return {k: recursive_to(v, device=device) for k, v in obj.items()} 118 | 119 | else: 120 | return obj 121 | 122 | 123 | def reweight_loss_by_sequence_length(length, max_length, mode='sqrt'): 124 | if mode == 'sqrt': 125 | w = np.sqrt(length / max_length) 126 | elif mode == 'linear': 127 | w = length / max_length 128 | elif mode is None: 129 | w = 1.0 130 | else: 131 | raise ValueError('Unknown reweighting mode: %s' % mode) 132 | return w 133 | 134 | 135 | def sum_weighted_losses(losses, weights): 136 | """ 137 | Args: 138 | losses: Dict of scalar tensors. 139 | weights: Dict of weights. 140 | """ 141 | loss = 0 142 | for k in losses.keys(): 143 | if weights is None: 144 | loss = loss + losses[k] 145 | else: 146 | loss = loss + weights[k] * losses[k] 147 | return loss 148 | 149 | 150 | def count_parameters(model): 151 | return sum(p.numel() for p in model.parameters()) 152 | -------------------------------------------------------------------------------- /diffab/utils/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Transforms 2 | from .mask import MaskSingleCDR, MaskMultipleCDRs, MaskAntibody 3 | from .merge import MergeChains 4 | from .patch import PatchAroundAnchor 5 | 6 | # Factory 7 | from ._base import get_transform, Compose 8 | -------------------------------------------------------------------------------- /diffab/utils/transforms/_base.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torchvision.transforms import Compose 4 | 5 | 6 | _TRANSFORM_DICT = {} 7 | 8 | 9 | def register_transform(name): 10 | def decorator(cls): 11 | _TRANSFORM_DICT[name] = cls 12 | return cls 13 | return decorator 14 | 15 | 16 | def get_transform(cfg): 17 | if cfg is None or len(cfg) == 0: 18 | return None 19 | tfms = [] 20 | for t_dict in cfg: 21 | t_dict = copy.deepcopy(t_dict) 22 | cls = _TRANSFORM_DICT[t_dict.pop('type')] 23 | tfms.append(cls(**t_dict)) 24 | return Compose(tfms) 25 | 26 | 27 | def _index_select(v, index, n): 28 | if isinstance(v, torch.Tensor) and v.size(0) == n: 29 | return v[index] 30 | elif isinstance(v, list) and len(v) == n: 31 | return [v[i] for i in index] 32 | else: 33 | return v 34 | 35 | 36 | def _index_select_data(data, index): 37 | return { 38 | k: _index_select(v, index, data['aa'].size(0)) 39 | for k, v in data.items() 40 | } 41 | 42 | 43 | def _mask_select(v, mask): 44 | if isinstance(v, torch.Tensor) and v.size(0) == mask.size(0): 45 | return v[mask] 46 | elif isinstance(v, list) and len(v) == mask.size(0): 47 | return [v[i] for i, b in enumerate(mask) if b] 48 | else: 49 | return v 50 | 51 | 52 | def _mask_select_data(data, mask): 53 | return { 54 | k: _mask_select(v, mask) 55 | for k, v in data.items() 56 | } 57 | -------------------------------------------------------------------------------- /diffab/utils/transforms/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from typing import List, Optional 4 | 5 | from ..protein import constants 6 | from ._base import register_transform 7 | 8 | 9 | def random_shrink_extend(flag, min_length=5, shrink_limit=1, extend_limit=2): 10 | first, last = continuous_flag_to_range(flag) 11 | length = flag.sum().item() 12 | if (length - 2*shrink_limit) < min_length: 13 | shrink_limit = 0 14 | first_ext = max(0, first-random.randint(-shrink_limit, extend_limit)) 15 | last_ext = min(last+random.randint(-shrink_limit, extend_limit), flag.size(0)-1) 16 | flag_ext = flag.clone() 17 | flag_ext[first_ext : last_ext+1] = True 18 | return flag_ext 19 | 20 | 21 | def continuous_flag_to_range(flag): 22 | first = (torch.arange(0, flag.size(0))[flag]).min().item() 23 | last = (torch.arange(0, flag.size(0))[flag]).max().item() 24 | return first, last 25 | 26 | 27 | @register_transform('mask_single_cdr') 28 | class MaskSingleCDR(object): 29 | 30 | def __init__(self, selection=None, augmentation=True): 31 | super().__init__() 32 | cdr_str_to_enum = { 33 | 'H1': constants.CDR.H1, 34 | 'H2': constants.CDR.H2, 35 | 'H3': constants.CDR.H3, 36 | 'L1': constants.CDR.L1, 37 | 'L2': constants.CDR.L2, 38 | 'L3': constants.CDR.L3, 39 | 'H_CDR1': constants.CDR.H1, 40 | 'H_CDR2': constants.CDR.H2, 41 | 'H_CDR3': constants.CDR.H3, 42 | 'L_CDR1': constants.CDR.L1, 43 | 'L_CDR2': constants.CDR.L2, 44 | 'L_CDR3': constants.CDR.L3, 45 | 'CDR3': 'CDR3', # H3 first, then fallback to L3 46 | } 47 | assert selection is None or selection in cdr_str_to_enum 48 | self.selection = cdr_str_to_enum.get(selection, None) 49 | self.augmentation = augmentation 50 | 51 | def perform_masking_(self, data, selection=None): 52 | cdr_flag = data['cdr_flag'] 53 | 54 | if selection is None: 55 | cdr_all = cdr_flag[cdr_flag > 0].unique().tolist() 56 | cdr_to_mask = random.choice(cdr_all) 57 | else: 58 | cdr_to_mask = selection 59 | 60 | cdr_to_mask_flag = (cdr_flag == cdr_to_mask) 61 | if self.augmentation: 62 | cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag) 63 | 64 | cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag) 65 | left_idx = max(0, cdr_first-1) 66 | right_idx = min(data['aa'].size(0)-1, cdr_last+1) 67 | anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool) 68 | anchor_flag[left_idx] = True 69 | anchor_flag[right_idx] = True 70 | 71 | data['generate_flag'] = cdr_to_mask_flag 72 | data['anchor_flag'] = anchor_flag 73 | 74 | def __call__(self, structure): 75 | if self.selection is None: 76 | ab_data = [] 77 | if structure['heavy'] is not None: 78 | ab_data.append(structure['heavy']) 79 | if structure['light'] is not None: 80 | ab_data.append(structure['light']) 81 | data_to_mask = random.choice(ab_data) 82 | sel = None 83 | elif self.selection in (constants.CDR.H1, constants.CDR.H2, constants.CDR.H3, ): 84 | data_to_mask = structure['heavy'] 85 | sel = int(self.selection) 86 | elif self.selection in (constants.CDR.L1, constants.CDR.L2, constants.CDR.L3, ): 87 | data_to_mask = structure['light'] 88 | sel = int(self.selection) 89 | elif self.selection == 'CDR3': 90 | if structure['heavy'] is not None: 91 | data_to_mask = structure['heavy'] 92 | sel = constants.CDR.H3 93 | else: 94 | data_to_mask = structure['light'] 95 | sel = constants.CDR.L3 96 | 97 | self.perform_masking_(data_to_mask, selection=sel) 98 | return structure 99 | 100 | 101 | @register_transform('mask_multiple_cdrs') 102 | class MaskMultipleCDRs(object): 103 | 104 | def __init__(self, selection: Optional[List[str]]=None, augmentation=True): 105 | super().__init__() 106 | cdr_str_to_enum = { 107 | 'H1': constants.CDR.H1, 108 | 'H2': constants.CDR.H2, 109 | 'H3': constants.CDR.H3, 110 | 'L1': constants.CDR.L1, 111 | 'L2': constants.CDR.L2, 112 | 'L3': constants.CDR.L3, 113 | 'H_CDR1': constants.CDR.H1, 114 | 'H_CDR2': constants.CDR.H2, 115 | 'H_CDR3': constants.CDR.H3, 116 | 'L_CDR1': constants.CDR.L1, 117 | 'L_CDR2': constants.CDR.L2, 118 | 'L_CDR3': constants.CDR.L3, 119 | } 120 | if selection is not None: 121 | self.selection = [cdr_str_to_enum[s] for s in selection] 122 | else: 123 | self.selection = None 124 | self.augmentation = augmentation 125 | 126 | def mask_one_cdr_(self, data, cdr_to_mask): 127 | cdr_flag = data['cdr_flag'] 128 | 129 | cdr_to_mask_flag = (cdr_flag == cdr_to_mask) 130 | if self.augmentation: 131 | cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag) 132 | 133 | cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag) 134 | left_idx = max(0, cdr_first-1) 135 | right_idx = min(data['aa'].size(0)-1, cdr_last+1) 136 | anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool) 137 | anchor_flag[left_idx] = True 138 | anchor_flag[right_idx] = True 139 | 140 | if 'generate_flag' not in data: 141 | data['generate_flag'] = cdr_to_mask_flag 142 | data['anchor_flag'] = anchor_flag 143 | else: 144 | data['generate_flag'] |= cdr_to_mask_flag 145 | data['anchor_flag'] |= anchor_flag 146 | 147 | def mask_for_one_chain_(self, data): 148 | cdr_flag = data['cdr_flag'] 149 | cdr_all = cdr_flag[cdr_flag > 0].unique().tolist() 150 | 151 | num_cdrs_to_mask = random.randint(1, len(cdr_all)) 152 | 153 | if self.selection is not None: 154 | cdrs_to_mask = list(set(cdr_all).intersection(self.selection)) 155 | else: 156 | random.shuffle(cdr_all) 157 | cdrs_to_mask = cdr_all[:num_cdrs_to_mask] 158 | 159 | for cdr_to_mask in cdrs_to_mask: 160 | self.mask_one_cdr_(data, cdr_to_mask) 161 | 162 | def __call__(self, structure): 163 | if structure['heavy'] is not None: 164 | self.mask_for_one_chain_(structure['heavy']) 165 | if structure['light'] is not None: 166 | self.mask_for_one_chain_(structure['light']) 167 | return structure 168 | 169 | 170 | @register_transform('mask_antibody') 171 | class MaskAntibody(object): 172 | 173 | def mask_ab_chain_(self, data): 174 | data['generate_flag'] = torch.ones(data['aa'].shape, dtype=torch.bool) 175 | 176 | def __call__(self, structure): 177 | pos_ab_alpha = [] 178 | if structure['heavy'] is not None: 179 | self.mask_ab_chain_(structure['heavy']) 180 | pos_ab_alpha.append( 181 | structure['heavy']['pos_heavyatom'][:, constants.BBHeavyAtom.CA] 182 | ) 183 | if structure['light'] is not None: 184 | self.mask_ab_chain_(structure['light']) 185 | pos_ab_alpha.append( 186 | structure['light']['pos_heavyatom'][:, constants.BBHeavyAtom.CA] 187 | ) 188 | pos_ab_alpha = torch.cat(pos_ab_alpha, dim=0) # (L_Ab, 3) 189 | 190 | if structure['antigen'] is not None: 191 | pos_ag_alpha = structure['antigen']['pos_heavyatom'][:, constants.BBHeavyAtom.CA] 192 | ag_ab_dist = torch.cdist(pos_ag_alpha, pos_ab_alpha) # (L_Ag, L_Ab) 193 | nn_ab_dist = ag_ab_dist.min(dim=1)[0] # (L_Ag) 194 | contact_flag = (nn_ab_dist <= 6.0) # (L_Ag) 195 | if contact_flag.sum().item() == 0: 196 | contact_flag[nn_ab_dist.argmin()] = True 197 | 198 | anchor_idx = torch.multinomial(contact_flag.float(), num_samples=1).item() 199 | anchor_flag = torch.zeros(structure['antigen']['aa'].shape, dtype=torch.bool) 200 | anchor_flag[anchor_idx] = True 201 | structure['antigen']['anchor_flag'] = anchor_flag 202 | structure['antigen']['contact_flag'] = contact_flag 203 | 204 | return structure 205 | 206 | 207 | @register_transform('remove_antigen') 208 | class RemoveAntigen: 209 | 210 | def __call__(self, structure): 211 | structure['antigen'] = None 212 | structure['antigen_seqmap'] = None 213 | return structure 214 | -------------------------------------------------------------------------------- /diffab/utils/transforms/merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..protein import constants 4 | from ._base import register_transform 5 | 6 | 7 | @register_transform('merge_chains') 8 | class MergeChains(object): 9 | 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def assign_chain_number_(self, data_list): 14 | chains = set() 15 | for data in data_list: 16 | chains.update(data['chain_id']) 17 | chains = {c: i for i, c in enumerate(chains)} 18 | 19 | for data in data_list: 20 | data['chain_nb'] = torch.LongTensor([ 21 | chains[c] for c in data['chain_id'] 22 | ]) 23 | 24 | def _data_attr(self, data, name): 25 | if name in ('generate_flag', 'anchor_flag') and name not in data: 26 | return torch.zeros(data['aa'].shape, dtype=torch.bool) 27 | else: 28 | return data[name] 29 | 30 | def __call__(self, structure): 31 | data_list = [] 32 | if structure['heavy'] is not None: 33 | structure['heavy']['fragment_type'] = torch.full_like( 34 | structure['heavy']['aa'], 35 | fill_value = constants.Fragment.Heavy, 36 | ) 37 | data_list.append(structure['heavy']) 38 | 39 | if structure['light'] is not None: 40 | structure['light']['fragment_type'] = torch.full_like( 41 | structure['light']['aa'], 42 | fill_value = constants.Fragment.Light, 43 | ) 44 | data_list.append(structure['light']) 45 | 46 | if structure['antigen'] is not None: 47 | structure['antigen']['fragment_type'] = torch.full_like( 48 | structure['antigen']['aa'], 49 | fill_value = constants.Fragment.Antigen, 50 | ) 51 | structure['antigen']['cdr_flag'] = torch.zeros_like( 52 | structure['antigen']['aa'], 53 | ) 54 | data_list.append(structure['antigen']) 55 | 56 | self.assign_chain_number_(data_list) 57 | 58 | list_props = { 59 | 'chain_id': [], 60 | 'icode': [], 61 | } 62 | tensor_props = { 63 | 'chain_nb': [], 64 | 'resseq': [], 65 | 'res_nb': [], 66 | 'aa': [], 67 | 'pos_heavyatom': [], 68 | 'mask_heavyatom': [], 69 | 'generate_flag': [], 70 | 'cdr_flag': [], 71 | 'anchor_flag': [], 72 | 'fragment_type': [], 73 | } 74 | 75 | for data in data_list: 76 | for k in list_props.keys(): 77 | list_props[k].append(self._data_attr(data, k)) 78 | for k in tensor_props.keys(): 79 | tensor_props[k].append(self._data_attr(data, k)) 80 | 81 | list_props = {k: sum(v, start=[]) for k, v in list_props.items()} 82 | tensor_props = {k: torch.cat(v, dim=0) for k, v in tensor_props.items()} 83 | data_out = { 84 | **list_props, 85 | **tensor_props, 86 | } 87 | return data_out 88 | 89 | -------------------------------------------------------------------------------- /diffab/utils/transforms/patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ._base import _mask_select_data, register_transform 4 | from ..protein import constants 5 | 6 | 7 | @register_transform('patch_around_anchor') 8 | class PatchAroundAnchor(object): 9 | 10 | def __init__(self, initial_patch_size=128, antigen_size=128): 11 | super().__init__() 12 | self.initial_patch_size = initial_patch_size 13 | self.antigen_size = antigen_size 14 | 15 | def _center(self, data, origin): 16 | origin = origin.reshape(1, 1, 3) 17 | data['pos_heavyatom'] -= origin # (L, A, 3) 18 | data['pos_heavyatom'] = data['pos_heavyatom'] * data['mask_heavyatom'][:, :, None] 19 | data['origin'] = origin.reshape(3) 20 | return data 21 | 22 | def __call__(self, data): 23 | anchor_flag = data['anchor_flag'] # (L,) 24 | anchor_points = data['pos_heavyatom'][anchor_flag, constants.BBHeavyAtom.CA] # (n_anchors, 3) 25 | antigen_mask = (data['fragment_type'] == constants.Fragment.Antigen) 26 | antibody_mask = torch.logical_not(antigen_mask) 27 | 28 | if anchor_flag.sum().item() == 0: 29 | # Generating full antibody-Fv, no antigen given 30 | data_patch = _mask_select_data( 31 | data = data, 32 | mask = antibody_mask, 33 | ) 34 | data_patch = self._center( 35 | data_patch, 36 | origin = data_patch['pos_heavyatom'][:, constants.BBHeavyAtom.CA].mean(dim=0) 37 | ) 38 | return data_patch 39 | 40 | pos_alpha = data['pos_heavyatom'][:, constants.BBHeavyAtom.CA] # (L, 3) 41 | dist_anchor = torch.cdist(pos_alpha, anchor_points).min(dim=1)[0] # (L, ) 42 | initial_patch_idx = torch.topk( 43 | dist_anchor, 44 | k = min(self.initial_patch_size, dist_anchor.size(0)), 45 | largest=False, 46 | )[1] # (initial_patch_size, ) 47 | 48 | dist_anchor_antigen = dist_anchor.masked_fill( 49 | mask = antibody_mask, # Fill antibody with +inf 50 | value = float('+inf') 51 | ) # (L, ) 52 | antigen_patch_idx = torch.topk( 53 | dist_anchor_antigen, 54 | k = min(self.antigen_size, antigen_mask.sum().item()), 55 | largest=False, sorted=True 56 | )[1] # (ag_size, ) 57 | 58 | patch_mask = torch.logical_or( 59 | data['generate_flag'], 60 | data['anchor_flag'], 61 | ) 62 | patch_mask[initial_patch_idx] = True 63 | patch_mask[antigen_patch_idx] = True 64 | 65 | patch_idx = torch.arange(0, patch_mask.shape[0])[patch_mask] 66 | 67 | data_patch = _mask_select_data(data, patch_mask) 68 | data_patch = self._center( 69 | data_patch, 70 | origin = anchor_points.mean(dim=0) 71 | ) 72 | data_patch['patch_idx'] = patch_idx 73 | return data_patch 74 | -------------------------------------------------------------------------------- /diffab/utils/transforms/select_atom.py: -------------------------------------------------------------------------------- 1 | 2 | from ._base import register_transform 3 | 4 | 5 | @register_transform('select_atom') 6 | class SelectAtom(object): 7 | 8 | def __init__(self, resolution): 9 | super().__init__() 10 | assert resolution in ('full', 'backbone') 11 | self.resolution = resolution 12 | 13 | def __call__(self, data): 14 | if self.resolution == 'full': 15 | data['pos_atoms'] = data['pos_heavyatom'][:, :] 16 | data['mask_atoms'] = data['mask_heavyatom'][:, :] 17 | elif self.resolution == 'backbone': 18 | data['pos_atoms'] = data['pos_heavyatom'][:, :5] 19 | data['mask_atoms'] = data['mask_heavyatom'][:, :5] 20 | return data 21 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: diffab 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - bioconda 6 | - defaults 7 | dependencies: 8 | - python=3.8 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - cudatoolkit=11.3.1 12 | - joblib 13 | - python-lmdb 14 | - tqdm 15 | - easydict 16 | - pyyaml 17 | - tensorboard 18 | - biopython=1.78 19 | - abnumber=0.3.0 20 | - mmseqs2 21 | - pdbfixer 22 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch 5 | import torch.utils.tensorboard 6 | from torch.nn.utils import clip_grad_norm_ 7 | from torch.utils.data import DataLoader 8 | from tqdm.auto import tqdm 9 | torch.backends.cuda.matmul.allow_tf32 = True 10 | torch.backends.cudnn.allow_tf32 = True 11 | 12 | from diffab.datasets import get_dataset 13 | from diffab.models import get_model 14 | from diffab.utils.misc import * 15 | from diffab.utils.data import * 16 | from diffab.utils.train import * 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('config', type=str) 22 | parser.add_argument('--logdir', type=str, default='./logs') 23 | parser.add_argument('--debug', action='store_true', default=False) 24 | parser.add_argument('--device', type=str, default='cuda') 25 | parser.add_argument('--num_workers', type=int, default=8) 26 | parser.add_argument('--tag', type=str, default='') 27 | parser.add_argument('--resume', type=str, default=None) 28 | parser.add_argument('--finetune', type=str, default=None) 29 | args = parser.parse_args() 30 | 31 | # Load configs 32 | config, config_name = load_config(args.config) 33 | seed_all(config.train.seed) 34 | 35 | # Logging 36 | if args.debug: 37 | logger = get_logger('train', None) 38 | writer = BlackHole() 39 | else: 40 | if args.resume: 41 | log_dir = os.path.dirname(os.path.dirname(args.resume)) 42 | else: 43 | log_dir = get_new_log_dir(args.logdir, prefix=config_name, tag=args.tag) 44 | ckpt_dir = os.path.join(log_dir, 'checkpoints') 45 | if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) 46 | logger = get_logger('train', log_dir) 47 | writer = torch.utils.tensorboard.SummaryWriter(log_dir) 48 | tensorboard_trace_handler = torch.profiler.tensorboard_trace_handler(log_dir) 49 | if not os.path.exists(os.path.join(log_dir, os.path.basename(args.config))): 50 | shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config))) 51 | logger.info(args) 52 | logger.info(config) 53 | 54 | # Data 55 | logger.info('Loading dataset...') 56 | train_dataset = get_dataset(config.dataset.train) 57 | val_dataset = get_dataset(config.dataset.val) 58 | train_iterator = inf_iterator(DataLoader( 59 | train_dataset, 60 | batch_size=config.train.batch_size, 61 | collate_fn=PaddingCollate(), 62 | shuffle=True, 63 | num_workers=args.num_workers 64 | )) 65 | val_loader = DataLoader(val_dataset, batch_size=config.train.batch_size, collate_fn=PaddingCollate(), shuffle=False, num_workers=args.num_workers) 66 | logger.info('Train %d | Val %d' % (len(train_dataset), len(val_dataset))) 67 | 68 | # Model 69 | logger.info('Building model...') 70 | model = get_model(config.model).to(args.device) 71 | logger.info('Number of parameters: %d' % count_parameters(model)) 72 | 73 | # Optimizer & scheduler 74 | optimizer = get_optimizer(config.train.optimizer, model) 75 | scheduler = get_scheduler(config.train.scheduler, optimizer) 76 | optimizer.zero_grad() 77 | it_first = 1 78 | 79 | # Resume 80 | if args.resume is not None or args.finetune is not None: 81 | ckpt_path = args.resume if args.resume is not None else args.finetune 82 | logger.info('Resuming from checkpoint: %s' % ckpt_path) 83 | ckpt = torch.load(ckpt_path, map_location=args.device) 84 | it_first = ckpt['iteration'] # + 1 85 | model.load_state_dict(ckpt['model']) 86 | logger.info('Resuming optimizer states...') 87 | optimizer.load_state_dict(ckpt['optimizer']) 88 | logger.info('Resuming scheduler states...') 89 | scheduler.load_state_dict(ckpt['scheduler']) 90 | 91 | # Train 92 | def train(it): 93 | time_start = current_milli_time() 94 | model.train() 95 | 96 | # Prepare data 97 | batch = recursive_to(next(train_iterator), args.device) 98 | 99 | # Forward 100 | # if args.debug: torch.set_anomaly_enabled(True) 101 | loss_dict = model(batch) 102 | loss = sum_weighted_losses(loss_dict, config.train.loss_weights) 103 | loss_dict['overall'] = loss 104 | time_forward_end = current_milli_time() 105 | 106 | # Backward 107 | loss.backward() 108 | orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm) 109 | optimizer.step() 110 | optimizer.zero_grad() 111 | time_backward_end = current_milli_time() 112 | 113 | # Logging 114 | log_losses(loss_dict, it, 'train', logger, writer, others={ 115 | 'grad': orig_grad_norm, 116 | 'lr': optimizer.param_groups[0]['lr'], 117 | 'time_forward': (time_forward_end - time_start) / 1000, 118 | 'time_backward': (time_backward_end - time_forward_end) / 1000, 119 | }) 120 | 121 | if not torch.isfinite(loss): 122 | logger.error('NaN or Inf detected.') 123 | torch.save({ 124 | 'config': config, 125 | 'model': model.state_dict(), 126 | 'optimizer': optimizer.state_dict(), 127 | 'scheduler': scheduler.state_dict(), 128 | 'iteration': it, 129 | 'batch': recursive_to(batch, 'cpu'), 130 | }, os.path.join(log_dir, 'checkpoint_nan_%d.pt' % it)) 131 | raise KeyboardInterrupt() 132 | 133 | # Validate 134 | def validate(it): 135 | loss_tape = ValidationLossTape() 136 | with torch.no_grad(): 137 | model.eval() 138 | for i, batch in enumerate(tqdm(val_loader, desc='Validate', dynamic_ncols=True)): 139 | # Prepare data 140 | batch = recursive_to(batch, args.device) 141 | # Forward 142 | loss_dict = model(batch) 143 | loss = sum_weighted_losses(loss_dict, config.train.loss_weights) 144 | loss_dict['overall'] = loss 145 | 146 | loss_tape.update(loss_dict, 1) 147 | 148 | avg_loss = loss_tape.log(it, logger, writer, 'val') 149 | # Trigger scheduler 150 | if config.train.scheduler.type == 'plateau': 151 | scheduler.step(avg_loss) 152 | else: 153 | scheduler.step() 154 | return avg_loss 155 | 156 | try: 157 | for it in range(it_first, config.train.max_iters + 1): 158 | train(it) 159 | if it % config.train.val_freq == 0: 160 | avg_val_loss = validate(it) 161 | if not args.debug: 162 | ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it) 163 | torch.save({ 164 | 'config': config, 165 | 'model': model.state_dict(), 166 | 'optimizer': optimizer.state_dict(), 167 | 'scheduler': scheduler.state_dict(), 168 | 'iteration': it, 169 | 'avg_val_loss': avg_val_loss, 170 | }, ckpt_path) 171 | except KeyboardInterrupt: 172 | logger.info('Terminating...') 173 | -------------------------------------------------------------------------------- /trained_models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | --------------------------------------------------------------------------------