├── .gitignore ├── LICENSE ├── README.md ├── docs ├── boltz1_pred_figure.png ├── evaluation.md ├── plot_casp.png ├── plot_test.png ├── prediction.md └── training.md ├── examples ├── ligand.fasta ├── ligand.yaml ├── msa │ ├── seq1.a3m │ └── seq2.a3m ├── multimer.yaml ├── pocket.yaml ├── prot.fasta ├── prot.yaml ├── prot_custom_msa.yaml └── prot_no_msa.yaml ├── pyproject.toml ├── scripts ├── eval │ ├── aggregate_evals.py │ └── run_evals.py ├── process │ ├── README.md │ ├── ccd.py │ ├── cluster.py │ ├── mmcif.py │ ├── msa.py │ ├── rcsb.py │ └── requirements.txt └── train │ ├── README.md │ ├── assets │ ├── casp15_ids.txt │ ├── test_ids.txt │ └── validation_ids.txt │ ├── configs │ ├── confidence.yaml │ ├── full.yaml │ └── structure.yaml │ └── train.py ├── src └── boltz │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── const.py │ ├── crop │ │ ├── __init__.py │ │ ├── boltz.py │ │ └── cropper.py │ ├── feature │ │ ├── __init__.py │ │ ├── featurizer.py │ │ ├── pad.py │ │ └── symmetry.py │ ├── filter │ │ ├── __init__.py │ │ ├── dynamic │ │ │ ├── __init__.py │ │ │ ├── date.py │ │ │ ├── filter.py │ │ │ ├── max_residues.py │ │ │ ├── resolution.py │ │ │ ├── size.py │ │ │ └── subset.py │ │ └── static │ │ │ ├── __init__.py │ │ │ ├── filter.py │ │ │ ├── ligand.py │ │ │ └── polymer.py │ ├── module │ │ ├── __init__.py │ │ ├── inference.py │ │ └── training.py │ ├── msa │ │ ├── __init__.py │ │ └── mmseqs2.py │ ├── parse │ │ ├── __init__.py │ │ ├── a3m.py │ │ ├── csv.py │ │ ├── fasta.py │ │ ├── schema.py │ │ └── yaml.py │ ├── sample │ │ ├── __init__.py │ │ ├── cluster.py │ │ ├── distillation.py │ │ ├── random.py │ │ └── sampler.py │ ├── tokenize │ │ ├── __init__.py │ │ ├── boltz.py │ │ └── tokenizer.py │ ├── types.py │ └── write │ │ ├── __init__.py │ │ ├── mmcif.py │ │ ├── pdb.py │ │ ├── utils.py │ │ └── writer.py │ ├── main.py │ └── model │ ├── __init__.py │ ├── layers │ ├── __init__.py │ ├── attention.py │ ├── dropout.py │ ├── initialize.py │ ├── outer_product_mean.py │ ├── pair_averaging.py │ ├── transition.py │ ├── triangular_attention │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── primitives.py │ │ └── utils.py │ └── triangular_mult.py │ ├── loss │ ├── __init__.py │ ├── confidence.py │ ├── diffusion.py │ ├── distogram.py │ └── validation.py │ ├── model.py │ ├── modules │ ├── __init__.py │ ├── confidence.py │ ├── confidence_utils.py │ ├── diffusion.py │ ├── encoders.py │ ├── transformers.py │ ├── trunk.py │ └── utils.py │ └── optim │ ├── __init__.py │ ├── ema.py │ └── scheduler.py └── tests ├── model └── layers │ ├── test_outer_product_mean.py │ └── test_triangle_attention.py ├── test_regression.py └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Boltz-1: 2 | 3 | Democratizing Biomolecular Interaction Modeling 4 |

5 | 6 | ![](docs/boltz1_pred_figure.png) 7 | 8 | Boltz-1 is the state-of-the-art open-source model that predicts the 3D structure of proteins, RNA, DNA, and small molecules; it handles modified residues, covalent ligands and glycans, as well as condition the generation on pocket residues. 9 | 10 | For more information about the model, see our [technical report](https://doi.org/10.1101/2024.11.19.624167). 11 | 12 | ## Installation 13 | Install boltz with PyPI (recommended): 14 | 15 | ``` 16 | pip install boltz -U 17 | ``` 18 | 19 | or directly from GitHub for daily updates: 20 | 21 | ``` 22 | git clone https://github.com/jwohlwend/boltz.git 23 | cd boltz; pip install -e . 24 | ``` 25 | > Note: we recommend installing boltz in a fresh python environment 26 | 27 | ## Inference 28 | 29 | You can run inference using Boltz-1 with: 30 | 31 | ``` 32 | boltz predict input_path --use_msa_server 33 | ``` 34 | 35 | Boltz currently accepts three input formats: 36 | 37 | 1. Fasta file, for most use cases 38 | 39 | 2. A comprehensive YAML schema, for more complex use cases 40 | 41 | 3. A directory containing files of the above formats, for batched processing 42 | 43 | To see all available options: `boltz predict --help` and for more information on these input formats, see our [prediction instructions](docs/prediction.md). 44 | 45 | ## Evaluation 46 | 47 | To encourage reproducibility and facilitate comparison with other models, we provide the evaluation scripts and predictions for Boltz-1, Chai-1 and AlphaFold3 on our test benchmark dataset as well as CASP15. These datasets are created to contain biomolecules different from the training data and to benchmark the performance of these models we run them with the same input MSAs and same number of recycling and diffusion steps. More details on these evaluations can be found in our [evaluation instructions](docs/evaluation.md). 48 | 49 | ![Test set evaluations](docs/plot_test.png) 50 | ![CASP15 set evaluations](docs/plot_casp.png) 51 | 52 | 53 | ## Training 54 | 55 | If you're interested in retraining the model, see our [training instructions](docs/training.md). 56 | 57 | ## Contributing 58 | 59 | We welcome external contributions and are eager to engage with the community. Connect with us on our [Slack channel](https://join.slack.com/t/boltz-community/shared_invite/zt-2w0bw6dtt-kZU4png9HUgprx9NK2xXZw) to discuss advancements, share insights, and foster collaboration around Boltz-1. 60 | 61 | ## Coming very soon 62 | 63 | - [x] Auto-generated MSAs using MMseqs2 64 | - [x] More examples 65 | - [x] Support for custom paired MSA 66 | - [x] Confidence model checkpoint 67 | - [x] Chunking for lower memory usage 68 | - [x] Pocket conditioning support 69 | - [x] Full data processing pipeline 70 | - [ ] Colab notebook for inference 71 | - [ ] Kernel integration 72 | 73 | ## License 74 | 75 | Our model and code are released under MIT License, and can be freely used for both academic and commercial purposes. 76 | 77 | 78 | ## Cite 79 | 80 | If you use this code or the models in your research, please cite the following paper: 81 | 82 | ```bibtex 83 | @article{wohlwend2024boltz1, 84 | author = {Wohlwend, Jeremy and Corso, Gabriele and Passaro, Saro and Reveiz, Mateo and Leidal, Ken and Swiderski, Wojtek and Portnoi, Tally and Chinn, Itamar and Silterra, Jacob and Jaakkola, Tommi and Barzilay, Regina}, 85 | title = {Boltz-1: Democratizing Biomolecular Interaction Modeling}, 86 | year = {2024}, 87 | doi = {10.1101/2024.11.19.624167}, 88 | journal = {bioRxiv} 89 | } 90 | ``` 91 | 92 | In addition if you use the automatic MSA generation, please cite: 93 | 94 | ```bibtex 95 | @article{mirdita2022colabfold, 96 | title={ColabFold: making protein folding accessible to all}, 97 | author={Mirdita, Milot and Sch{\"u}tze, Konstantin and Moriwaki, Yoshitaka and Heo, Lim and Ovchinnikov, Sergey and Steinegger, Martin}, 98 | journal={Nature methods}, 99 | year={2022}, 100 | } 101 | ``` 102 | -------------------------------------------------------------------------------- /docs/boltz1_pred_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/docs/boltz1_pred_figure.png -------------------------------------------------------------------------------- /docs/evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | To encourage reproducibility and facilitate comparison with other models, we provide the evaluation scripts and predictions for Boltz-1, Chai-1, and AlphaFold3 on our test benchmark dataset as well as CASP15. These datasets are created to contain biomolecules different from the training data and to benchmark the performance of these models we run them with the same input MSAs and the same number of recycling and diffusion steps. 4 | 5 | ![Test set evaluations](../docs/plot_test.png) 6 | ![CASP15 set evaluations](../docs/plot_casp.png) 7 | 8 | 9 | ## Evaluation files 10 | 11 | You can download all the MSAs, input files, output files and evaluation outputs for Boltz-1, Chai-1, and AlphaFold3 from this [Google Drive folder](https://drive.google.com/file/d/1JvHlYUMINOaqPTunI9wBYrfYniKgVmxf/view?usp=sharing). 12 | 13 | The files are organized as follows: 14 | 15 | ``` 16 | boltz_results_final/ 17 | ├── inputs/ # Input files for every model 18 | ├── casp15/ 19 | ├── af3 20 | ├── boltz 21 | ├── chai 22 | └── msa 23 | └── test/ 24 | ├── targets/ # Target files from PDB 25 | ├── casp15 26 | └── test 27 | ├── outputs/ # Output files for every model 28 | ├── casp15 29 | └── test 30 | ├── evals/ # Output of evluation script for every 31 | ├── casp15 32 | └── test 33 | ├── results_casp.csv # Summary of evaluation results for CASP15 34 | └── results_test.csv # Summary of evaluation results for test set 35 | ``` 36 | 37 | ## Evaluation setup 38 | 39 | We evaluate the model on two datasets: 40 | - PDB test set: 541 targets after our validation cut-off date and at most 40% sequence similarity for proteins, 80% Tanimoto for ligands. 41 | - CASP15: 66 difficult targets from the CASP 2022 competition. 42 | 43 | We benchmark Boltz-1 against Chai-1 and AF3, other state-of-the-art structure prediction models, but much more closed source in terms of model code, training and data pipeline. Note that we remove overlap with our validation set, but we cannot ensure that there is no overlap with AF3 or Chai-1 validation set as those were not published. 44 | 45 | For fair comparison we compare the models with the following setup: 46 | - Same MSA’s. 47 | - Same parameters: 10 recycling steps, 200 sampling steps, 5 samples. 48 | - We compare our oracle and top-1 numbers among the 5 samples. 49 | 50 | 51 | ## Evaluation script 52 | 53 | We also provide the scripts we used to evaluate the models and aggregate results. The evaluations were run through [OpenStructure](https://openstructure.org/docs/2.9.0/) version 2.8.0 (it is important to use the specific version for reproducing the results). You can find these scripts at `scripts/eval/run_evals.py` and `scripts/eval/aggregate_evals.py`. -------------------------------------------------------------------------------- /docs/plot_casp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/docs/plot_casp.png -------------------------------------------------------------------------------- /docs/plot_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/docs/plot_test.png -------------------------------------------------------------------------------- /examples/ligand.fasta: -------------------------------------------------------------------------------- 1 | >A|protein|./examples/msa/seq1.a3m 2 | MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ 3 | >B|protein|./examples/msa/seq1.a3m 4 | MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ 5 | >C|ccd 6 | SAH 7 | >D|ccd 8 | SAH 9 | >E|smiles 10 | N[C@@H](Cc1ccc(O)cc1)C(=O)O 11 | >F|smiles 12 | N[C@@H](Cc1ccc(O)cc1)C(=O)O -------------------------------------------------------------------------------- /examples/ligand.yaml: -------------------------------------------------------------------------------- 1 | version: 1 # Optional, defaults to 1 2 | sequences: 3 | - protein: 4 | id: [A, B] 5 | sequence: MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ 6 | msa: ./examples/msa/seq1.a3m 7 | - ligand: 8 | id: [C, D] 9 | ccd: SAH 10 | - ligand: 11 | id: [E, F] 12 | smiles: N[C@@H](Cc1ccc(O)cc1)C(=O)O 13 | -------------------------------------------------------------------------------- /examples/multimer.yaml: -------------------------------------------------------------------------------- 1 | version: 1 # Optional, defaults to 1 2 | sequences: 3 | - protein: 4 | id: A 5 | sequence: MAHHHHHHVAVDAVSFTLLQDQLQSVLDTLSEREAGVVRLRFGLTDGQPRTLDEIGQVYGVTRERIRQIESKTMSKLRHPSRSQVLRDYLDGSSGSGTPEERLLRAIFGEKA 6 | - protein: 7 | id: B 8 | sequence: MRYAFAAEATTCNAFWRNVDMTVTALYEVPLGVCTQDPDRWTTTPDDEAKTLCRACPRRWLCARDAVESAGAEGLWAGVVIPESGRARAFALGQLRSLAERNGYPVRDHRVSAQSA 9 | -------------------------------------------------------------------------------- /examples/pocket.yaml: -------------------------------------------------------------------------------- 1 | sequences: 2 | - protein: 3 | id: [A1] 4 | sequence: MYNMRRLSLSPTFSMGFHLLVTVSLLFSHVDHVIAETEMEGEGNETGECTGSYYCKKGVILPIWEPQDPSFGDKIARATVYFVAMVYMFLGVSIIADRFMSSIEVITSQEKEITIKKPNGETTKTTVRIWNETVSNLTLMALGSSAPEILLSVIEVCGHNFTAGDLGPSTIVGSAAFNMFIIIALCVYVVPDGETRKIKHLRVFFVTAAWSIFAYTWLYIILSVISPGVVEVWEGLLTFFFFPICVVFAWVADRRLLFYKYVYKRYRAGKQRGMIIEHEGDRPSSKTEIEMDGKVVNSHVENFLDGALVLEVDERDQDDEEARREMARILKELKQKHPDKEIEQLIELANYQVLSQQQKSRAFYRIQATRLMTGAGNILKRHAADQARKAVSMHEVNTEVTENDPVSKIFFEQGTYQCLENCGTVALTIIRRGGDLTNTVFVDFRTEDGTANAGSDYEFTEGTVVFKPGDTQKEIRVGIIDDDIFEEDENFLVHLSNVKVSSEASEDGILEANHVSTLACLGSPSTATVTIFDDDHAGIFTFEEPVTHVSESIGIMEVKVLRTSGARGNVIVPYKTIEGTARGGGEDFEDTCGELEFQNDEIVKIITIRIFDREEYEKECSFSLVLEEPKWIRRGMKGGFTITDEYDDKQPLTSKEEEERRIAEMGRPILGEHTKLEVIIEESYEFKSTVDKLIKKTNLALVVGTNSWREQFIEAITVSAGEDDDDDECGEEKLPSCFDYVMHFLTVFWKVLFAFVPPTEYWNGWACFIVSILMIGLLTAFIGDLASHFGCTIGLKDSVTAVVFVALGTSVPDTFASKVAATQDQYADASIGNVTGSNAVNVFLGIGVAWSIAAIYHAANGEQFKVSPGTLAFSVTLFTIFAFINVGVLLYRRRPEIGGELGGPRTAKLLTSCLFVLLWLLYIFFSSLEAYCHIKGF 5 | - ligand: 6 | ccd: EKY 7 | id: [B1] 8 | constraints: 9 | - pocket: 10 | binder: B1 11 | contacts: [ [ A1, 829 ], [ A1, 138 ] ] 12 | 13 | -------------------------------------------------------------------------------- /examples/prot.fasta: -------------------------------------------------------------------------------- 1 | >A|protein|./examples/msa/seq2.a3m 2 | QLEDSEVEAVAKGLEEMYANGVTEDNFKNYVKNNFAQQEISSVEEELNVNISDSCVANKIKDEFFAMISISAIVKAAQKKAWKELAVTVLRFAKANGLKTNAIIVAGQLALWAVQCG -------------------------------------------------------------------------------- /examples/prot.yaml: -------------------------------------------------------------------------------- 1 | version: 1 # Optional, defaults to 1 2 | sequences: 3 | - protein: 4 | id: A 5 | sequence: QLEDSEVEAVAKGLEEMYANGVTEDNFKNYVKNNFAQQEISSVEEELNVNISDSCVANKIKDEFFAMISISAIVKAAQKKAWKELAVTVLRFAKANGLKTNAIIVAGQLALWAVQCG 6 | 7 | -------------------------------------------------------------------------------- /examples/prot_custom_msa.yaml: -------------------------------------------------------------------------------- 1 | version: 1 # Optional, defaults to 1 2 | sequences: 3 | - protein: 4 | id: A 5 | sequence: QLEDSEVEAVAKGLEEMYANGVTEDNFKNYVKNNFAQQEISSVEEELNVNISDSCVANKIKDEFFAMISISAIVKAAQKKAWKELAVTVLRFAKANGLKTNAIIVAGQLALWAVQCG 6 | msa: ./examples/msa/seq2.a3m 7 | 8 | -------------------------------------------------------------------------------- /examples/prot_no_msa.yaml: -------------------------------------------------------------------------------- 1 | version: 1 # Optional, defaults to 1 2 | sequences: 3 | - protein: 4 | id: A 5 | sequence: QLEDSEVEAVAKGLEEMYANGVTEDNFKNYVKNNFAQQEISSVEEELNVNISDSCVANKIKDEFFAMISISAIVKAAQKKAWKELAVTVLRFAKANGLKTNAIIVAGQLALWAVQCG 6 | msa: empty 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "boltz" 7 | version = "0.4.0" 8 | requires-python = ">=3.9" 9 | description = "Boltz-1" 10 | readme = "README.md" 11 | dependencies = [ 12 | "torch>=2.2", 13 | "numpy==1.26.3", 14 | "hydra-core==1.3.2", 15 | "pytorch-lightning==2.4.0", 16 | "rdkit>=2024.3.2", 17 | "dm-tree==0.1.8", 18 | "requests==2.32.3", 19 | "pandas>=2.2.2", 20 | "types-requests", 21 | "einops==0.8.0", 22 | "einx==0.3.0", 23 | "fairscale==0.4.13", 24 | "mashumaro==3.14", 25 | "modelcif==1.2", 26 | "wandb==0.18.7", 27 | "click==8.1.7", 28 | "pyyaml==6.0.2", 29 | "biopython==1.84", 30 | "scipy==1.13.1", 31 | ] 32 | 33 | [project.scripts] 34 | boltz = "boltz.main:cli" 35 | 36 | [project.optional-dependencies] 37 | lint = ["ruff"] 38 | test = ["pytest", "requests"] 39 | 40 | [tool.ruff] 41 | src = ["src"] 42 | extend-exclude = ["conf.py"] 43 | target-version = "py39" 44 | lint.select = ["ALL"] 45 | lint.ignore = [ 46 | "COM812", # Conflicts with the formatter 47 | "ISC001", # Conflicts with the formatter 48 | "ANN101", # "missing-type-self" 49 | "RET504", # Unnecessary assignment to `x` before `return` statementRuff 50 | "S101", # Use of `assert` detected 51 | "D100", # Missing docstring in public module 52 | "D104", # Missing docstring in public package 53 | "PT001", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715 54 | "PT004", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715 55 | "PT005", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715 56 | "PT023", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715 57 | "FBT001", 58 | "FBT002", 59 | "PLR0913", # Too many arguments to init (> 5) 60 | ] 61 | 62 | [tool.ruff.lint.per-file-ignores] 63 | "**/__init__.py" = [ 64 | "F401", # Imported but unused 65 | "F403", # Wildcard imports 66 | ] 67 | "docs/**" = [ 68 | "INP001", # Requires __init__.py but folder is not a package. 69 | ] 70 | "scripts/**" = [ 71 | "INP001", # Requires __init__.py but folder is not a package. 72 | ] 73 | 74 | [tool.ruff.lint.pyupgrade] 75 | # Preserve types, even if a file imports `from __future__ import annotations`(https://github.com/astral-sh/ruff/issues/5434) 76 | keep-runtime-typing = true 77 | 78 | [tool.ruff.lint.pydocstyle] 79 | convention = "numpy" 80 | 81 | [tool.pytest.ini_options] 82 | markers = [ 83 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 84 | "regression", 85 | ] 86 | -------------------------------------------------------------------------------- /scripts/eval/run_evals.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import concurrent.futures 3 | import subprocess 4 | from pathlib import Path 5 | 6 | from tqdm import tqdm 7 | 8 | OST_COMPARE_STRUCTURE = r""" 9 | #!/bin/bash 10 | # https://openstructure.org/docs/2.7/actions/#ost-compare-structures 11 | 12 | IMAGE_NAME=openstructure-0.2.8 13 | 14 | command="compare-structures \ 15 | -m {model_file} \ 16 | -r {reference_file} \ 17 | --fault-tolerant \ 18 | --min-pep-length 4 \ 19 | --min-nuc-length 4 \ 20 | -o {output_path} \ 21 | --lddt --bb-lddt --qs-score --dockq \ 22 | --ics --ips --rigid-scores --patch-scores --tm-score" 23 | 24 | sudo docker run -u $(id -u):$(id -g) --rm --volume {mount}:{mount} $IMAGE_NAME $command 25 | """ 26 | 27 | 28 | OST_COMPARE_LIGAND = r""" 29 | #!/bin/bash 30 | # https://openstructure.org/docs/2.7/actions/#ost-compare-structures 31 | 32 | IMAGE_NAME=openstructure-0.2.8 33 | 34 | command="compare-ligand-structures \ 35 | -m {model_file} \ 36 | -r {reference_file} \ 37 | --fault-tolerant \ 38 | --lddt-pli --rmsd \ 39 | --substructure-match \ 40 | -o {output_path}" 41 | 42 | sudo docker run -u $(id -u):$(id -g) --rm --volume {mount}:{mount} $IMAGE_NAME $command 43 | """ 44 | 45 | 46 | def evaluate_structure( 47 | name: str, 48 | pred: Path, 49 | reference: Path, 50 | outdir: str, 51 | mount: str, 52 | executable: str = "/bin/bash", 53 | ) -> None: 54 | """Evaluate the structure.""" 55 | # Evaluate polymer metrics 56 | out_path = Path(outdir) / f"{name}.json" 57 | 58 | if out_path.exists(): 59 | print( # noqa: T201 60 | f"Skipping recomputation of {name} as protein json file already exists" 61 | ) 62 | else: 63 | subprocess.run( 64 | OST_COMPARE_STRUCTURE.format( 65 | model_file=str(pred), 66 | reference_file=str(reference), 67 | output_path=str(out_path), 68 | mount=mount, 69 | ), 70 | shell=True, # noqa: S602 71 | check=False, 72 | executable=executable, 73 | capture_output=True, 74 | ) 75 | 76 | # Evaluate ligand metrics 77 | out_path = Path(outdir) / f"{name}_ligand.json" 78 | if out_path.exists(): 79 | print(f"Skipping recomputation of {name} as ligand json file already exists") # noqa: T201 80 | else: 81 | subprocess.run( 82 | OST_COMPARE_LIGAND.format( 83 | model_file=str(pred), 84 | reference_file=str(reference), 85 | output_path=str(out_path), 86 | mount=mount, 87 | ), 88 | shell=True, # noqa: S602 89 | check=False, 90 | executable=executable, 91 | capture_output=True, 92 | ) 93 | 94 | 95 | def main(args): 96 | # Aggregate the predictions and references 97 | files = list(args.data.iterdir()) 98 | names = {f.stem.lower(): f for f in files} 99 | 100 | # Create the output directory 101 | args.outdir.mkdir(parents=True, exist_ok=True) 102 | 103 | first_item = True 104 | with concurrent.futures.ThreadPoolExecutor(args.max_workers) as executor: 105 | futures = [] 106 | for name, folder in names.items(): 107 | for model_id in range(5): 108 | # Split the input data 109 | if args.format == "af3": 110 | pred_path = folder / f"seed-1_sample-{model_id}" / "model.cif" 111 | elif args.format == "chai": 112 | pred_path = folder / f"pred.model_idx_{model_id}.cif" 113 | elif args.format == "boltz": 114 | name_file = ( 115 | f"{name[0].upper()}{name[1:]}" 116 | if args.testset == "casp" 117 | else name.lower() 118 | ) 119 | pred_path = folder / f"{name_file}_model_{model_id}.cif" 120 | 121 | if args.testset == "casp": 122 | ref_path = args.pdb / f"{name[0].upper()}{name[1:]}.cif" 123 | elif args.testset == "test": 124 | ref_path = args.pdb / f"{name.lower()}.cif.gz" 125 | 126 | if first_item: 127 | # Evaluate the first item in the first prediction 128 | # Ensures that the docker image is downloaded 129 | evaluate_structure( 130 | name=f"{name}_model_{model_id}", 131 | pred=str(pred_path), 132 | reference=str(ref_path), 133 | outdir=str(args.outdir), 134 | mount=args.mount, 135 | executable=args.executable, 136 | ) 137 | first_item = False 138 | else: 139 | future = executor.submit( 140 | evaluate_structure, 141 | name=f"{name}_model_{model_id}", 142 | pred=str(pred_path), 143 | reference=str(ref_path), 144 | outdir=str(args.outdir), 145 | mount=args.mount, 146 | executable=args.executable, 147 | ) 148 | futures.append(future) 149 | 150 | # Wait for all tasks to complete 151 | with tqdm(total=len(futures)) as pbar: 152 | for _ in concurrent.futures.as_completed(futures): 153 | pbar.update(1) 154 | 155 | 156 | if __name__ == "__main__": 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("data", type=Path) 159 | parser.add_argument("pdb", type=Path) 160 | parser.add_argument("outdir", type=Path) 161 | parser.add_argument("--format", type=str, default="af3") 162 | parser.add_argument("--testset", type=str, default="casp") 163 | parser.add_argument("--mount", type=str) 164 | parser.add_argument("--executable", type=str, default="/bin/bash") 165 | parser.add_argument("--max-workers", type=int, default=32) 166 | args = parser.parse_args() 167 | main(args) -------------------------------------------------------------------------------- /scripts/process/README.md: -------------------------------------------------------------------------------- 1 | Please see our [data processing instructions](../../docs/training.md). -------------------------------------------------------------------------------- /scripts/process/ccd.py: -------------------------------------------------------------------------------- 1 | """Compute conformers and symmetries for all the CCD molecules.""" 2 | 3 | import argparse 4 | import multiprocessing 5 | import pickle 6 | import sys 7 | from functools import partial 8 | from pathlib import Path 9 | 10 | import pandas as pd 11 | import rdkit 12 | from p_tqdm import p_uimap 13 | from pdbeccdutils.core import ccd_reader 14 | from pdbeccdutils.core.component import ConformerType 15 | from rdkit import rdBase 16 | from rdkit.Chem import AllChem 17 | from rdkit.Chem.rdchem import Conformer, Mol 18 | 19 | 20 | def load_molecules(components: str) -> list[Mol]: 21 | """Load the CCD components file. 22 | 23 | Parameters 24 | ---------- 25 | components : str 26 | Path to the CCD components file. 27 | 28 | Returns 29 | ------- 30 | list[Mol] 31 | 32 | """ 33 | components: dict[str, ccd_reader.CCDReaderResult] 34 | components = ccd_reader.read_pdb_components_file(components) 35 | 36 | mols = [] 37 | for name, component in components.items(): 38 | mol = component.component.mol 39 | mol.SetProp("PDB_NAME", name) 40 | mols.append(mol) 41 | 42 | return mols 43 | 44 | 45 | def compute_3d(mol: Mol, version: str = "v3") -> bool: 46 | """Generate 3D coordinates using EKTDG method. 47 | 48 | Taken from `pdbeccdutils.core.component.Component`. 49 | 50 | Parameters 51 | ---------- 52 | mol: Mol 53 | The RDKit molecule to process 54 | version: str, optional 55 | The ETKDG version, defaults ot v3 56 | 57 | Returns 58 | ------- 59 | bool 60 | Whether computation was successful. 61 | 62 | """ 63 | if version == "v3": 64 | options = rdkit.Chem.AllChem.ETKDGv3() 65 | elif version == "v2": 66 | options = rdkit.Chem.AllChem.ETKDGv2() 67 | else: 68 | options = rdkit.Chem.AllChem.ETKDGv2() 69 | 70 | options.clearConfs = False 71 | conf_id = -1 72 | 73 | try: 74 | conf_id = rdkit.Chem.AllChem.EmbedMolecule(mol, options) 75 | rdkit.Chem.AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000) 76 | 77 | except RuntimeError: 78 | pass # Force field issue here 79 | except ValueError: 80 | pass # sanitization issue here 81 | 82 | if conf_id != -1: 83 | conformer = mol.GetConformer(conf_id) 84 | conformer.SetProp("name", ConformerType.Computed.name) 85 | conformer.SetProp("coord_generation", f"ETKDG{version}") 86 | 87 | return True 88 | 89 | return False 90 | 91 | 92 | def get_conformer(mol: Mol, c_type: ConformerType) -> Conformer: 93 | """Retrieve an rdkit object for a deemed conformer. 94 | 95 | Taken from `pdbeccdutils.core.component.Component`. 96 | 97 | Parameters 98 | ---------- 99 | mol: Mol 100 | The molecule to process. 101 | c_type: ConformerType 102 | The conformer type to extract. 103 | 104 | Returns 105 | ------- 106 | Conformer 107 | The desired conformer, if any. 108 | 109 | Raises 110 | ------ 111 | ValueError 112 | If there are no conformers of the given tyoe. 113 | 114 | """ 115 | for c in mol.GetConformers(): 116 | try: 117 | if c.GetProp("name") == c_type.name: 118 | return c 119 | except KeyError: # noqa: PERF203 120 | pass 121 | 122 | msg = f"Conformer {c_type.name} does not exist." 123 | raise ValueError(msg) 124 | 125 | 126 | def compute_symmetries(mol: Mol) -> list[list[int]]: 127 | """Compute the symmetries of a molecule. 128 | 129 | Parameters 130 | ---------- 131 | mol : Mol 132 | The molecule to process 133 | 134 | Returns 135 | ------- 136 | list[list[int]] 137 | The symmetries as a list of index permutations 138 | 139 | """ 140 | mol = AllChem.RemoveHs(mol) 141 | idx_map = {} 142 | atom_idx = 0 143 | for i, atom in enumerate(mol.GetAtoms()): 144 | # Skip if leaving atoms 145 | if int(atom.GetProp("leaving_atom")): 146 | continue 147 | idx_map[i] = atom_idx 148 | atom_idx += 1 149 | 150 | # Calculate self permutations 151 | permutations = [] 152 | raw_permutations = mol.GetSubstructMatches(mol, uniquify=False) 153 | for raw_permutation in raw_permutations: 154 | # Filter out permutations with leaving atoms 155 | try: 156 | if {raw_permutation[idx] for idx in idx_map} == set(idx_map.keys()): 157 | permutation = [ 158 | idx_map[idx] for idx in raw_permutation if idx in idx_map 159 | ] 160 | permutations.append(permutation) 161 | except Exception: # noqa: S110, PERF203, BLE001 162 | pass 163 | serialized_permutations = pickle.dumps(permutations) 164 | mol.SetProp("symmetries", serialized_permutations.hex()) 165 | return permutations 166 | 167 | 168 | def process(mol: Mol, output: str) -> tuple[str, str]: 169 | """Process a CCD component. 170 | 171 | Parameters 172 | ---------- 173 | mol : Mol 174 | The molecule to process 175 | output : str 176 | The directory to save the molecules 177 | 178 | Returns 179 | ------- 180 | str 181 | The name of the component 182 | str 183 | The result of the conformer generation 184 | 185 | """ 186 | # Get name 187 | name = mol.GetProp("PDB_NAME") 188 | 189 | # Check if single atom 190 | if mol.GetNumAtoms() == 1: 191 | result = "single" 192 | else: 193 | # Get the 3D conformer 194 | try: 195 | # Try to generate a 3D conformer with RDKit 196 | success = compute_3d(mol, version="v3") 197 | if success: 198 | _ = get_conformer(mol, ConformerType.Computed) 199 | result = "computed" 200 | 201 | # Otherwise, default to the ideal coordinates 202 | else: 203 | _ = get_conformer(mol, ConformerType.Ideal) 204 | result = "ideal" 205 | except ValueError: 206 | result = "failed" 207 | 208 | # Dump the molecule 209 | path = Path(output) / f"{name}.pkl" 210 | with path.open("wb") as f: 211 | pickle.dump(mol, f) 212 | 213 | # Output the results 214 | return name, result 215 | 216 | 217 | def main(args: argparse.Namespace) -> None: 218 | """Process conformers.""" 219 | # Set property saving 220 | rdkit.Chem.SetDefaultPickleProperties(rdkit.Chem.PropertyPickleOptions.AllProps) 221 | 222 | # Load components 223 | print("Loading components") # noqa: T201 224 | molecules = load_molecules(args.components) 225 | 226 | # Reset stdout and stderr, as pdbccdutils messes with them 227 | sys.stdout = sys.__stdout__ 228 | sys.stderr = sys.__stderr__ 229 | 230 | # Disable rdkit warnings 231 | blocker = rdBase.BlockLogs() # noqa: F841 232 | 233 | # Setup processing function 234 | outdir = Path(args.outdir) 235 | outdir.mkdir(parents=True, exist_ok=True) 236 | mol_output = outdir / "mols" 237 | mol_output.mkdir(parents=True, exist_ok=True) 238 | process_fn = partial(process, output=str(mol_output)) 239 | 240 | # Process the files in parallel 241 | print("Processing components") # noqa: T201 242 | metadata = [] 243 | num_processes = min(max(1, args.num_processes), multiprocessing.cpu_count()) 244 | for name, result in p_uimap( 245 | process_fn, 246 | molecules, 247 | num_cpus=num_processes, 248 | ): 249 | metadata.append({"name": name, "result": result}) 250 | 251 | # Load and group outputs 252 | molecules = {} 253 | for item in metadata: 254 | if item["result"] == "failed": 255 | continue 256 | 257 | # Load the mol file 258 | path = mol_output / f"{item['name']}.pkl" 259 | with path.open("rb") as f: 260 | mol = pickle.load(f) # noqa: S301 261 | molecules[item["name"]] = mol 262 | 263 | # Dump metadata 264 | path = outdir / "results.csv" 265 | metadata = pd.DataFrame(metadata) 266 | metadata.to_csv(path) 267 | 268 | # Dump the components 269 | path = outdir / "ccd.pkl" 270 | with path.open("wb") as f: 271 | pickle.dump(molecules, f) 272 | 273 | 274 | if __name__ == "__main__": 275 | parser = argparse.ArgumentParser() 276 | parser.add_argument("--components", type=str) 277 | parser.add_argument("--outdir", type=str) 278 | parser.add_argument( 279 | "--num_processes", 280 | type=int, 281 | default=multiprocessing.cpu_count(), 282 | ) 283 | args = parser.parse_args() 284 | main(args) 285 | -------------------------------------------------------------------------------- /scripts/process/cluster.py: -------------------------------------------------------------------------------- 1 | """Create a mapping from structure and chain ID to MSA indices.""" 2 | 3 | import argparse 4 | import hashlib 5 | import json 6 | import pickle 7 | import subprocess 8 | from pathlib import Path 9 | 10 | import pandas as pd 11 | from Bio import SeqIO 12 | 13 | 14 | def hash_sequence(seq: str) -> str: 15 | """Hash a sequence.""" 16 | return hashlib.sha256(seq.encode()).hexdigest() 17 | 18 | 19 | def main(args: argparse.Namespace) -> None: 20 | """Create clustering.""" 21 | # Set output directory 22 | outdir = Path(args.outdir) 23 | outdir.mkdir(parents=True, exist_ok=True) 24 | 25 | # Split the sequences into proteins and nucleotides 26 | with Path(args.sequences).open("r") as f: 27 | data = list(SeqIO.parse(f, "fasta")) 28 | 29 | proteins = set() 30 | nucleotides = set() 31 | 32 | for seq in data: 33 | if set(seq.seq).issubset({"A", "C", "G", "T", "U"}): 34 | nucleotides.add(str(seq.seq)) 35 | else: 36 | proteins.add(str(seq.seq)) 37 | 38 | # Run mmseqs on the protein data 39 | proteins = [f">{hash_sequence(seq)}\n{seq}" for seq in proteins] 40 | with (outdir / "proteins.fasta").open("w") as f: 41 | f.write("\n".join(proteins)) 42 | 43 | subprocess.run( 44 | f"{args.mmseqs} easy-cluster {outdir / 'proteins.fasta'} {outdir / 'clust_prot'} {outdir / 'tmp'} --min-seq-id 0.4", # noqa: E501 45 | shell=True, # noqa: S602 46 | check=True, 47 | ) 48 | 49 | # Load protein clusters 50 | clustering_path = outdir / "clust_prot_cluster.tsv" 51 | protein_data = pd.read_csv(clustering_path, sep="\t", header=None) 52 | clusters = protein_data[0] 53 | items = protein_data[1] 54 | clustering = dict(zip(list(items), list(clusters))) 55 | 56 | # Each unique rna sequence is given an id 57 | visited = {} 58 | for nucl in nucleotides: 59 | nucl_id = hash_sequence(nucl) 60 | if nucl not in visited: 61 | clustering[nucl_id] = nucl_id 62 | visited[nucl] = nucl_id 63 | else: 64 | clustering[nucl_id] = visited[nucl] 65 | 66 | # Load ligand data 67 | with Path(args.ccd).open("rb") as handle: 68 | ligand_data = pickle.load(handle) # noqa: S301 69 | 70 | # Each unique ligand CCD is given an id 71 | visited = {} 72 | for ccd_code in ligand_data: 73 | clustering[ccd_code] = ccd_code 74 | 75 | # Save clustering 76 | with (outdir / "clustering.json").open("w") as handle: 77 | json.dump(clustering, handle) 78 | 79 | 80 | if __name__ == "__main__": 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument( 83 | "--sequences", 84 | type=str, 85 | help="Input to protein fasta.", 86 | required=True, 87 | ) 88 | parser.add_argument( 89 | "--ccd", 90 | type=str, 91 | help="Input to rna fasta.", 92 | required=True, 93 | ) 94 | parser.add_argument( 95 | "--outdir", 96 | type=str, 97 | help="Output directory.", 98 | required=True, 99 | ) 100 | parser.add_argument( 101 | "--mmseqs", 102 | type=str, 103 | help="Path to mmseqs program.", 104 | default="mmseqs", 105 | ) 106 | args = parser.parse_args() 107 | main(args) 108 | -------------------------------------------------------------------------------- /scripts/process/msa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | from dataclasses import asdict 4 | from functools import partial 5 | from pathlib import Path 6 | from typing import Any 7 | 8 | import numpy as np 9 | from redis import Redis 10 | from tqdm import tqdm 11 | 12 | from boltz.data.parse.a3m import parse_a3m 13 | 14 | 15 | class Resource: 16 | """A shared resource for processing.""" 17 | 18 | def __init__(self, host: str, port: int) -> None: 19 | """Initialize the redis database.""" 20 | self._redis = Redis(host=host, port=port) 21 | 22 | def get(self, key: str) -> Any: # noqa: ANN401 23 | """Get an item from the Redis database.""" 24 | return self._redis.get(key) 25 | 26 | def __getitem__(self, key: str) -> Any: # noqa: ANN401 27 | """Get an item from the resource.""" 28 | out = self.get(key) 29 | if out is None: 30 | raise KeyError(key) 31 | return out 32 | 33 | 34 | def process_msa( 35 | data: list, 36 | host: str, 37 | port: int, 38 | outdir: str, 39 | max_seqs: int, 40 | ) -> None: 41 | """Run processing in a worker thread.""" 42 | outdir = Path(outdir) 43 | resource = Resource(host=host, port=port) 44 | for path in tqdm(data, total=len(data)): 45 | out_path = outdir / f"{path.stem}.npz" 46 | if not out_path.exists(): 47 | msa = parse_a3m(path, resource, max_seqs) 48 | np.savez_compressed(out_path, **asdict(msa)) 49 | 50 | 51 | def process(args) -> None: 52 | """Run the data processing task.""" 53 | # Create output directory 54 | args.outdir.mkdir(parents=True, exist_ok=True) 55 | 56 | # Check if we can run in parallel 57 | num_processes = min(args.num_processes, multiprocessing.cpu_count()) 58 | parallel = num_processes > 1 59 | 60 | # Load shared data from redis 61 | print("Loading shared data from Redis...") 62 | shared_data = Resource(host=args.redis_host, port=args.redis_port) 63 | 64 | # Get data points 65 | print("Fetching data...") 66 | data = args.msadir.rglob("*.a3m*") 67 | print(f"Found {len(data)} MSA's.") 68 | 69 | # Randomly permute the data 70 | random = np.random.RandomState() 71 | permute = random.permutation(len(data)) 72 | data = [data[i] for i in permute] 73 | 74 | # Run processing 75 | if parallel: 76 | # Create processing function 77 | fn = partial( 78 | process_msa, 79 | host=args.redis_host, 80 | port=args.redis_port, 81 | outdir=args.outdir, 82 | ) 83 | 84 | # Split the data into random chunks 85 | size = len(data) // num_processes 86 | chunks = [data[i : i + size] for i in range(0, len(data), size)] 87 | 88 | # Run processing in parallel 89 | with multiprocessing.Pool(num_processes) as pool: # noqa: SIM117 90 | with tqdm(total=len(chunks)) as pbar: 91 | for _ in pool.imap_unordered(fn, chunks): 92 | pbar.update() 93 | else: 94 | for item in tqdm(data, total=len(data)): 95 | process_msa(item, shared_data, args.outdir) 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser(description="Process MSA data.") 100 | parser.add_argument( 101 | "--msadir", 102 | type=Path, 103 | required=True, 104 | help="The MSA data directory.", 105 | ) 106 | parser.add_argument( 107 | "--outdir", 108 | type=Path, 109 | default="data", 110 | help="The output directory.", 111 | ) 112 | parser.add_argument( 113 | "--num-processes", 114 | type=int, 115 | default=multiprocessing.cpu_count(), 116 | help="The number of processes.", 117 | ) 118 | parser.add_argument( 119 | "--redis-host", 120 | type=str, 121 | default="localhost", 122 | help="The Redis host.", 123 | ) 124 | parser.add_argument( 125 | "--redis-port", 126 | type=int, 127 | default=7777, 128 | help="The Redis port.", 129 | ) 130 | parser.add_argument( 131 | "--max-seqs", 132 | type=int, 133 | default=16384, 134 | help="The maximum number of sequences.", 135 | ) 136 | args = parser.parse_args() 137 | process(args) 138 | -------------------------------------------------------------------------------- /scripts/process/requirements.txt: -------------------------------------------------------------------------------- 1 | gemmi 2 | pdbeccdutils 3 | redis 4 | scikit-learn 5 | p_tqdm -------------------------------------------------------------------------------- /scripts/train/README.md: -------------------------------------------------------------------------------- 1 | Please see our [training instructions](../../docs/training.md). -------------------------------------------------------------------------------- /scripts/train/assets/casp15_ids.txt: -------------------------------------------------------------------------------- 1 | T1112 2 | T1118v1 3 | T1154 4 | T1137s1 5 | T1188 6 | T1157s1 7 | T1137s6 8 | R1117 9 | H1106 10 | T1106s2 11 | R1149 12 | T1158 13 | T1137s2 14 | T1145 15 | T1121 16 | T1123 17 | T1113 18 | R1156 19 | T1114s1 20 | T1183 21 | R1107 22 | T1137s7 23 | T1124 24 | T1178 25 | T1147 26 | R1128 27 | T1161 28 | R1108 29 | T1194 30 | T1185s2 31 | T1176 32 | T1158v3 33 | T1137s4 34 | T1160 35 | T1120 36 | H1185 37 | T1134s1 38 | T1119 39 | H1151 40 | T1137s8 41 | T1133 42 | T1187 43 | H1157 44 | T1122 45 | T1104 46 | T1158v2 47 | T1137s5 48 | T1129s2 49 | T1174 50 | T1157s2 51 | T1155 52 | T1158v4 53 | T1152 54 | T1137s9 55 | T1134s2 56 | T1125 57 | R1116 58 | H1134 59 | R1136 60 | T1159 61 | T1137s3 62 | T1185s1 63 | T1179 64 | T1106s1 65 | T1132 66 | T1185s4 67 | T1114s3 68 | T1114s2 69 | T1151s2 70 | T1158v1 71 | R1117v2 72 | T1173 73 | -------------------------------------------------------------------------------- /scripts/train/assets/test_ids.txt: -------------------------------------------------------------------------------- 1 | 8BZ4 2 | 8URN 3 | 7U71 4 | 7Z64 5 | 7Y3Z 6 | 8SOT 7 | 8GH8 8 | 8IIB 9 | 7U08 10 | 8EB5 11 | 8G49 12 | 8K7Y 13 | 7QQD 14 | 8EIL 15 | 8JQE 16 | 8V1K 17 | 7ZRZ 18 | 7YN2 19 | 8D40 20 | 8RXO 21 | 8SXS 22 | 7UDL 23 | 8ADD 24 | 7Z3I 25 | 7YUK 26 | 7XWY 27 | 8F9Y 28 | 8WO7 29 | 8C27 30 | 8I3J 31 | 8HVC 32 | 8SXU 33 | 8K1I 34 | 8FTV 35 | 8ERC 36 | 8DVQ 37 | 8DTQ 38 | 8J12 39 | 8D0P 40 | 8POG 41 | 8HN0 42 | 7QPK 43 | 8AGR 44 | 8GXR 45 | 8K7X 46 | 8BL6 47 | 8HAW 48 | 8SRO 49 | 8HHM 50 | 8C26 51 | 7SPQ 52 | 8SME 53 | 7XGV 54 | 8GTY 55 | 8Q42 56 | 8BRY 57 | 8HDV 58 | 8B3Z 59 | 7XNJ 60 | 8EEL 61 | 8IOI 62 | 8Q70 63 | 8Y4U 64 | 8ANT 65 | 8IUB 66 | 8D49 67 | 8CPQ 68 | 8BAT 69 | 8E2B 70 | 8IWP 71 | 8IJT 72 | 7Y01 73 | 8CJG 74 | 8HML 75 | 8WU2 76 | 8VRM 77 | 8J1J 78 | 8DAJ 79 | 8SUT 80 | 8PTJ 81 | 8IVZ 82 | 8SDZ 83 | 7YDQ 84 | 8JU7 85 | 8K34 86 | 8B6Q 87 | 8F7N 88 | 8IBZ 89 | 7WOI 90 | 8R7D 91 | 8T65 92 | 8IQC 93 | 8SIU 94 | 8QK8 95 | 8HIG 96 | 7Y43 97 | 8IN8 98 | 8IBW 99 | 8GOY 100 | 7ZAO 101 | 8J9G 102 | 7ZCA 103 | 8HIO 104 | 8EFZ 105 | 8IQ8 106 | 8OQ0 107 | 8HHL 108 | 7XMW 109 | 8GI1 110 | 8AYR 111 | 7ZCB 112 | 8BRD 113 | 8IN6 114 | 8I3F 115 | 8HIU 116 | 8ER5 117 | 8WIL 118 | 7YPR 119 | 8UA2 120 | 8BW6 121 | 8IL8 122 | 8J3R 123 | 8K1F 124 | 8OHI 125 | 8WCT 126 | 8AN0 127 | 8BDQ 128 | 7FCT 129 | 8J69 130 | 8HTX 131 | 8PE3 132 | 8K5U 133 | 8AXT 134 | 8PSO 135 | 8JHR 136 | 8GY0 137 | 8QCW 138 | 8K3D 139 | 8P6J 140 | 8J0Q 141 | 7XS3 142 | 8DHJ 143 | 8EIN 144 | 7WKP 145 | 8GAQ 146 | 7WRN 147 | 8AHD 148 | 7SC4 149 | 8B3E 150 | 8AAS 151 | 8UZ8 152 | 8Q1K 153 | 8K5K 154 | 8B45 155 | 8PT7 156 | 7ZPN 157 | 8UQ9 158 | 8TJG 159 | 8TN8 160 | 8B2E 161 | 7XFZ 162 | 8FW7 163 | 8B3W 164 | 7T4W 165 | 8SVA 166 | 7YL4 167 | 8GLD 168 | 8OEI 169 | 8GMX 170 | 8OWF 171 | 8FNR 172 | 8IRQ 173 | 8JDG 174 | 7UXA 175 | 8TKA 176 | 7YH1 177 | 8HUZ 178 | 8TA2 179 | 8E5D 180 | 7YUN 181 | 7UOI 182 | 7WMY 183 | 8AA9 184 | 8ISZ 185 | 8EXA 186 | 8E7F 187 | 8B2S 188 | 8TP8 189 | 8GSY 190 | 7XRX 191 | 8SY3 192 | 8CIL 193 | 8WBR 194 | 7XF1 195 | 7YPO 196 | 8AXF 197 | 7QNL 198 | 8OYY 199 | 7R1N 200 | 8H5S 201 | 8B6U 202 | 8IBX 203 | 8Q43 204 | 8OW8 205 | 7XSG 206 | 8U0M 207 | 8IOO 208 | 8HR5 209 | 8BVK 210 | 8P0C 211 | 7TL6 212 | 8J48 213 | 8S0U 214 | 8K8A 215 | 8G53 216 | 7XYO 217 | 8POF 218 | 8U1K 219 | 8HF2 220 | 8K4L 221 | 8JAH 222 | 8KGZ 223 | 8BNB 224 | 7UG2 225 | 8A0A 226 | 8Q3Z 227 | 8XBI 228 | 8JNM 229 | 8GPS 230 | 8K1R 231 | 8Q66 232 | 7YLQ 233 | 7YNX 234 | 8IMD 235 | 7Y8H 236 | 8OXU 237 | 8BVE 238 | 8B4E 239 | 8V14 240 | 7R5I 241 | 8IR2 242 | 8UK7 243 | 8EBB 244 | 7XCC 245 | 8AEP 246 | 7YDW 247 | 8XX9 248 | 7VS6 249 | 8K3F 250 | 8CQM 251 | 7XH4 252 | 8BH9 253 | 7VXT 254 | 8SM9 255 | 8HGU 256 | 8PSQ 257 | 8SSU 258 | 8VXA 259 | 8GSX 260 | 8GHZ 261 | 8BJ3 262 | 8C9V 263 | 8T66 264 | 7XPC 265 | 8RH3 266 | 8CMQ 267 | 8AGG 268 | 8ERM 269 | 8P6M 270 | 8BUX 271 | 7S2J 272 | 8G32 273 | 8AXJ 274 | 8CID 275 | 8CPK 276 | 8P5Q 277 | 8HP8 278 | 7YUJ 279 | 8PT2 280 | 7YK3 281 | 7YYG 282 | 8ABV 283 | 7XL7 284 | 7YLZ 285 | 8JWS 286 | 8IW5 287 | 8SM6 288 | 8BBZ 289 | 8EOV 290 | 8PXC 291 | 7UWV 292 | 8A9N 293 | 7YH5 294 | 8DEO 295 | 7X2X 296 | 8W7P 297 | 8B5W 298 | 8CIH 299 | 8RB4 300 | 8HLG 301 | 8J8H 302 | 8UA5 303 | 7YKM 304 | 8S9W 305 | 7YPD 306 | 8GA6 307 | 7YPQ 308 | 8X7X 309 | 8HI8 310 | 8H7A 311 | 8C4D 312 | 8XAT 313 | 8W8S 314 | 8HM4 315 | 8H3Z 316 | 7W91 317 | 8GPP 318 | 8TNM 319 | 7YSI 320 | 8OML 321 | 8BBR 322 | 7YOJ 323 | 8JZX 324 | 8I3X 325 | 8AU6 326 | 8ITO 327 | 7SFY 328 | 8B6P 329 | 7Y8S 330 | 8ESL 331 | 8DSP 332 | 8CLZ 333 | 8F72 334 | 8QLD 335 | 8K86 336 | 8G8E 337 | 8QDO 338 | 8ANU 339 | 8PT6 340 | 8F5D 341 | 8DQ6 342 | 8IFK 343 | 8OJN 344 | 8SSC 345 | 7QRR 346 | 8E55 347 | 7TPU 348 | 7UQU 349 | 8HFP 350 | 7XGT 351 | 8A39 352 | 8CB2 353 | 8ACR 354 | 8G5S 355 | 7TZL 356 | 8T4R 357 | 8H18 358 | 7UI4 359 | 8Q41 360 | 8K76 361 | 7WUY 362 | 8VXC 363 | 8GYG 364 | 8IMS 365 | 8IKS 366 | 8X51 367 | 7Y7O 368 | 8PX4 369 | 8BF8 370 | 7XMJ 371 | 8GDW 372 | 7YTU 373 | 8CH4 374 | 7XHZ 375 | 7YH4 376 | 8PSN 377 | 8A16 378 | 8FBJ 379 | 7Y9G 380 | 8JI2 381 | 7YR9 382 | 8SW0 383 | 8A90 384 | 8X6V 385 | 8H8P 386 | 7WJU 387 | 8PSS 388 | 8HL8 389 | 8FJD 390 | 8PM4 391 | 7UK8 392 | 8DX0 393 | 8PHB 394 | 8FBN 395 | 8FXF 396 | 8GKH 397 | 8ENR 398 | 8PTH 399 | 8CBV 400 | 8GKV 401 | 8CQO 402 | 8OK3 403 | 8GSR 404 | 8TPK 405 | 8H1J 406 | 8QFL 407 | 8CHW 408 | 7V34 409 | 8HE2 410 | 7ZIE 411 | 8A50 412 | 7Z8E 413 | 8ILL 414 | 7WWC 415 | 7XVI 416 | 8Q2A 417 | 8HNO 418 | 8PR6 419 | 7XCA 420 | 7XGS 421 | 8H55 422 | 8FJE 423 | 7UNH 424 | 8AY2 425 | 8ARD 426 | 8HBR 427 | 8EWG 428 | 8D4A 429 | 8FIT 430 | 8E5E 431 | 8PMU 432 | 8F5G 433 | 8AMU 434 | 8CPN 435 | 7QPL 436 | 8EHN 437 | 8SQU 438 | 8F70 439 | 8FX9 440 | 7UR2 441 | 8T1M 442 | 7ZDS 443 | 7YH2 444 | 8B6A 445 | 8CHX 446 | 8G0N 447 | 8GY4 448 | 7YKG 449 | 8BH8 450 | 8BVI 451 | 7XF2 452 | 8BFY 453 | 8IA3 454 | 8JW3 455 | 8OQJ 456 | 8TFS 457 | 7Y1S 458 | 8HBB 459 | 8AF9 460 | 8IP1 461 | 7XZ3 462 | 8T0P 463 | 7Y16 464 | 8BRP 465 | 8JNX 466 | 8JP0 467 | 8EC3 468 | 8PZH 469 | 7URP 470 | 8B4D 471 | 8JFR 472 | 8GYR 473 | 7XFS 474 | 8SMQ 475 | 7WNH 476 | 8H0L 477 | 8OWI 478 | 8HFC 479 | 7X6G 480 | 8FKL 481 | 8PAG 482 | 8UPI 483 | 8D4B 484 | 8BCK 485 | 8JFU 486 | 8FUQ 487 | 8IF8 488 | 8PAQ 489 | 8HDU 490 | 8W9O 491 | 8ACA 492 | 7YIA 493 | 7ZFR 494 | 7Y9A 495 | 8TTO 496 | 7YFX 497 | 8B2H 498 | 8PSU 499 | 8ACC 500 | 8JMR 501 | 8IHA 502 | 7UYX 503 | 8DWJ 504 | 8BY5 505 | 8EZW 506 | 8A82 507 | 8TVL 508 | 8R79 509 | 8R8A 510 | 8AHZ 511 | 8AYV 512 | 8JHU 513 | 8Q44 514 | 8ARE 515 | 8OLJ 516 | 7Y95 517 | 7XP0 518 | 8EX9 519 | 8BID 520 | 8Q40 521 | 7QSJ 522 | 7UBA 523 | 7XFU 524 | 8OU1 525 | 8G2V 526 | 8YA7 527 | 8GMZ 528 | 8T8L 529 | 8CK0 530 | 7Y4H 531 | 8IOM 532 | 7ZLQ 533 | 8BZ2 534 | 8B4C 535 | 8DZJ 536 | 8CEG 537 | 8IBY 538 | 8T3J 539 | 8IVI 540 | 8ITN 541 | 8CR7 542 | 8TGH 543 | 8OKH 544 | 7UI8 545 | 8EHT 546 | 8ADC 547 | 8T4C 548 | 7XBJ 549 | 8CLU 550 | 7QA1 551 | -------------------------------------------------------------------------------- /scripts/train/assets/validation_ids.txt: -------------------------------------------------------------------------------- 1 | 7UTN 2 | 7F9H 3 | 7TZV 4 | 7ZHH 5 | 7SOV 6 | 7EOF 7 | 7R8H 8 | 8AW3 9 | 7F2F 10 | 8BAO 11 | 7BCB 12 | 7D8T 13 | 7D3T 14 | 7BHY 15 | 7YZ7 16 | 8DC2 17 | 7SOW 18 | 8CTL 19 | 7SOS 20 | 7V6W 21 | 7Z55 22 | 7NQF 23 | 7VTN 24 | 7KSP 25 | 7BJQ 26 | 7YZC 27 | 7Y3L 28 | 7TDX 29 | 7R8I 30 | 7OYK 31 | 7TZ1 32 | 7KIJ 33 | 7T8K 34 | 7KII 35 | 7YZA 36 | 7VP4 37 | 7KIK 38 | 7M5W 39 | 7Q94 40 | 7BCA 41 | 7YZB 42 | 7OG0 43 | 7VTI 44 | 7SOP 45 | 7S03 46 | 7YZG 47 | 7TXC 48 | 7VP5 49 | 7Y3I 50 | 7TDW 51 | 8B0R 52 | 7R8G 53 | 7FEF 54 | 7VP1 55 | 7VP3 56 | 7RGU 57 | 7DV2 58 | 7YZD 59 | 7OFZ 60 | 7Y3K 61 | 7TEC 62 | 7WQ5 63 | 7VP2 64 | 7EDB 65 | 7VP7 66 | 7PDV 67 | 7XHT 68 | 7R6R 69 | 8CSH 70 | 8CSZ 71 | 7V9O 72 | 7Q1C 73 | 8EDC 74 | 7PWI 75 | 7FI1 76 | 7ESI 77 | 7F0Y 78 | 7EYR 79 | 7ZVA 80 | 7WEG 81 | 7E4N 82 | 7U5Q 83 | 7FAV 84 | 7LJ2 85 | 7S6F 86 | 7B3N 87 | 7V4P 88 | 7AJO 89 | 7WH1 90 | 8DQP 91 | 7STT 92 | 7VQ7 93 | 7E4J 94 | 7RIS 95 | 7FH8 96 | 7BMW 97 | 7RD0 98 | 7V54 99 | 7LKC 100 | 7OU1 101 | 7QOD 102 | 7PX1 103 | 7EBY 104 | 7U1V 105 | 7PLP 106 | 7T8N 107 | 7SJK 108 | 7RGB 109 | 7TEM 110 | 7UG9 111 | 7B7A 112 | 7TM2 113 | 7Z74 114 | 7PCM 115 | 7V8G 116 | 7EUU 117 | 7VTL 118 | 7ZEI 119 | 7ZC0 120 | 7DZ9 121 | 8B2M 122 | 7NE9 123 | 7ALV 124 | 7M96 125 | 7O6T 126 | 7SKO 127 | 7Z2V 128 | 7OWX 129 | 7SHW 130 | 7TNI 131 | 7ZQY 132 | 7MDF 133 | 7EXR 134 | 7W6B 135 | 7EQF 136 | 7WWO 137 | 7FBW 138 | 8EHE 139 | 7CLE 140 | 7T80 141 | 7WMV 142 | 7SMG 143 | 7WSJ 144 | 7DBU 145 | 7VHY 146 | 7W5F 147 | 7SHG 148 | 7VU3 149 | 7ATH 150 | 7FGZ 151 | 7ADS 152 | 7REO 153 | 7T7H 154 | 7X0N 155 | 7TCU 156 | 7SKH 157 | 7EF6 158 | 7TBV 159 | 7B29 160 | 7VO5 161 | 7TM1 162 | 7QLD 163 | 7BB9 164 | 7SZ8 165 | 7RLM 166 | 7WWP 167 | 7NBV 168 | 7PLD 169 | 7DNM 170 | 7SFZ 171 | 7EAW 172 | 7QNQ 173 | 7SZX 174 | 7U2S 175 | 7WZX 176 | 7TYG 177 | 7QCE 178 | 7DCN 179 | 7WJL 180 | 7VV6 181 | 7TJ4 182 | 7VI8 183 | 8AKP 184 | 7WAO 185 | 7N7V 186 | 7EYO 187 | 7VTD 188 | 7VEG 189 | 7QY5 190 | 7ELV 191 | 7P0J 192 | 7YX8 193 | 7U4H 194 | 7TBD 195 | 7WME 196 | 7RI3 197 | 7TOH 198 | 7ZVM 199 | 7PUL 200 | 7VBO 201 | 7DM0 202 | 7XN9 203 | 7ALY 204 | 7LTB 205 | 8A28 206 | 7UBZ 207 | 8DTE 208 | 7TA2 209 | 7QST 210 | 7AN1 211 | 7FIB 212 | 8BAL 213 | 7TMJ 214 | 7REV 215 | 7PZJ 216 | 7T9X 217 | 7SUU 218 | 7KJQ 219 | 7V6P 220 | 7QA3 221 | 7ULC 222 | 7Y3X 223 | 7TMU 224 | 7OA7 225 | 7PO9 226 | 7Q20 227 | 8H2C 228 | 7VW1 229 | 7VLJ 230 | 8EP4 231 | 7P57 232 | 7QUL 233 | 7ZQE 234 | 7UJU 235 | 7WG1 236 | 7DMK 237 | 7Y8X 238 | 7EHG 239 | 7W13 240 | 7NL4 241 | 7R4J 242 | 7AOV 243 | 7RFT 244 | 7VUF 245 | 7F72 246 | 8DSR 247 | 7MK3 248 | 7MQQ 249 | 7R55 250 | 7T85 251 | 7NCY 252 | 7ZHL 253 | 7E1N 254 | 7W8F 255 | 7PGK 256 | 8GUN 257 | 7P8D 258 | 7PUK 259 | 7N9D 260 | 7XWN 261 | 7ZHA 262 | 7TVP 263 | 7VI6 264 | 7PW6 265 | 7YM0 266 | 7RWK 267 | 8DKR 268 | 7WGU 269 | 7LJI 270 | 7THW 271 | 7OB6 272 | 7N3Z 273 | 7T3S 274 | 7PAB 275 | 7F9F 276 | 7PPP 277 | 7AD5 278 | 7VGM 279 | 7WBO 280 | 7RWM 281 | 7QFI 282 | 7T91 283 | 7ANU 284 | 7UX0 285 | 7USR 286 | 7RDN 287 | 7VW5 288 | 7Q4T 289 | 7W3R 290 | 8DKQ 291 | 7RCX 292 | 7UOF 293 | 7OKR 294 | 7NX1 295 | 6ZBS 296 | 7VEV 297 | 8E8U 298 | 7WJ6 299 | 7MP4 300 | 7RPY 301 | 7R5Z 302 | 7VLM 303 | 7SNE 304 | 7WDW 305 | 8E19 306 | 7PP2 307 | 7Z5H 308 | 7P7I 309 | 7LJJ 310 | 7QPC 311 | 7VJS 312 | 7QOE 313 | 7KZH 314 | 7F6N 315 | 7TMI 316 | 7POH 317 | 8DKS 318 | 7YMO 319 | 6S5I 320 | 7N6O 321 | 7LYU 322 | 7POK 323 | 7BLK 324 | 7TCY 325 | 7W19 326 | 8B55 327 | 7SMU 328 | 7QFK 329 | 7T5T 330 | 7EPQ 331 | 7DCK 332 | 7S69 333 | 6ZSV 334 | 7ZGT 335 | 7TJ1 336 | 7V09 337 | 7ZHD 338 | 7ALL 339 | 7P1Y 340 | 7T71 341 | 7MNK 342 | 7W5Q 343 | 7PZ2 344 | 7QSQ 345 | 7QI3 346 | 7NZZ 347 | 7Q47 348 | 8D08 349 | 7QH5 350 | 7RXQ 351 | 7F45 352 | 8D07 353 | 8EHC 354 | 7PZT 355 | 7K3C 356 | 7ZGI 357 | 7MC4 358 | 7NPQ 359 | 7VD7 360 | 7XAN 361 | 7FDP 362 | 8A0K 363 | 7TXO 364 | 7ZB1 365 | 7V5V 366 | 7WWS 367 | 7PBK 368 | 8EBG 369 | 7N0J 370 | 7UMA 371 | 7T1S 372 | 8EHB 373 | 7DWC 374 | 7K6W 375 | 7WEJ 376 | 7LRH 377 | 7ZCV 378 | 7RKC 379 | 7X8C 380 | 7PV1 381 | 7UGK 382 | 7ULN 383 | 7A66 384 | 7R7M 385 | 7M0Q 386 | 7BGS 387 | 7UPP 388 | 7O62 389 | 7VKK 390 | 7L6Y 391 | 7VG4 392 | 7V2V 393 | 7ETN 394 | 7ZTB 395 | 7AOO 396 | 7OH2 397 | 7E0M 398 | 7PEG 399 | 8CUK 400 | 7ZP0 401 | 7T6A 402 | 7BTM 403 | 7DOV 404 | 7VVV 405 | 7P22 406 | 7RUO 407 | 7E40 408 | 7O5Y 409 | 7XPK 410 | 7R0K 411 | 8D04 412 | 7TYD 413 | 7LSV 414 | 7XSI 415 | 7RTZ 416 | 7UXR 417 | 7QH3 418 | 8END 419 | 8CYK 420 | 7MRJ 421 | 7DJL 422 | 7S5B 423 | 7XUX 424 | 7EV8 425 | 7R6S 426 | 7UH4 427 | 7R9X 428 | 7F7P 429 | 7ACW 430 | 7SPN 431 | 7W70 432 | 7Q5G 433 | 7DXN 434 | 7DK9 435 | 8DT0 436 | 7FDN 437 | 7DGX 438 | 7UJB 439 | 7X4O 440 | 7F4O 441 | 7T9W 442 | 8AID 443 | 7ERQ 444 | 7EQB 445 | 7YDG 446 | 7ETR 447 | 8D27 448 | 7OUU 449 | 7R5Y 450 | 7T8I 451 | 7UZT 452 | 7X8V 453 | 7QLH 454 | 7SAF 455 | 7EN6 456 | 8D4Y 457 | 7ESJ 458 | 7VWO 459 | 7SBE 460 | 7VYU 461 | 7RVJ 462 | 7FCL 463 | 7WUO 464 | 7WWF 465 | 7VMT 466 | 7SHJ 467 | 7SKP 468 | 7KOU 469 | 6ZSU 470 | 7VGW 471 | 7X45 472 | 8GYZ 473 | 8BFE 474 | 8DGL 475 | 7Z3H 476 | 8BD1 477 | 8A0J 478 | 7JRK 479 | 7QII 480 | 7X39 481 | 7Y6B 482 | 7OIY 483 | 7SBI 484 | 8A3I 485 | 7NLI 486 | 7F4U 487 | 7TVY 488 | 7X0O 489 | 7VMH 490 | 7EPN 491 | 7WBK 492 | 8BFJ 493 | 7XFP 494 | 7LXQ 495 | 7TIL 496 | 7O61 497 | 8B8B 498 | 7W2Q 499 | 8APR 500 | 7WZE 501 | 7NYQ 502 | 7RMX 503 | 7PGE 504 | 8F43 505 | 7N2K 506 | 7UXG 507 | 7SXN 508 | 7T5U 509 | 7R22 510 | 7E3T 511 | 7PTB 512 | 7OA8 513 | 7X5T 514 | 7PL7 515 | 7SQ5 516 | 7VBS 517 | 8D03 518 | 7TAE 519 | 7T69 520 | 7WF6 521 | 7LBU 522 | 8A06 523 | 8DA2 524 | 7QFL 525 | 7KUW 526 | 7X9R 527 | 7XT3 528 | 7RB4 529 | 7PT5 530 | 7RPS 531 | 7RXU 532 | 7TDY 533 | 7W89 534 | 7N9I 535 | 7T1M 536 | 7OBM 537 | 7K3X 538 | 7ZJC 539 | 8BDP 540 | 7V8W 541 | 7DJK 542 | 7W1K 543 | 7QFG 544 | 7DGY 545 | 7ZTQ 546 | 7F8A 547 | 7NEK 548 | 7CG9 549 | 7KOB 550 | 7TN7 551 | 8DYS 552 | 7WVR 553 | -------------------------------------------------------------------------------- /scripts/train/configs/confidence.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | accelerator: gpu 3 | devices: 1 4 | precision: 32 5 | gradient_clip_val: 10.0 6 | max_epochs: -1 7 | 8 | # Optional set wandb here 9 | # wandb: 10 | # name: boltz 11 | # project: boltz 12 | # entity: boltz 13 | 14 | 15 | output: SET_PATH_HERE 16 | pretrained: PATH_TO_STRUCTURE_CHECKPOINT_FILE 17 | resume: null 18 | disable_checkpoint: false 19 | matmul_precision: null 20 | save_top_k: -1 21 | load_confidence_from_trunk: true 22 | 23 | data: 24 | datasets: 25 | - _target_: boltz.data.module.training.DatasetConfig 26 | target_dir: PATH_TO_TARGETS_DIR 27 | msa_dir: PATH_TO_MSA_DIR 28 | prob: 1.0 29 | sampler: 30 | _target_: boltz.data.sample.cluster.ClusterSampler 31 | cropper: 32 | _target_: boltz.data.crop.boltz.BoltzCropper 33 | min_neighborhood: 0 34 | max_neighborhood: 40 35 | split: ./scripts/train/assets/validation_ids.txt 36 | 37 | filters: 38 | - _target_: boltz.data.filter.dynamic.size.SizeFilter 39 | min_chains: 1 40 | max_chains: 300 41 | - _target_: boltz.data.filter.dynamic.date.DateFilter 42 | date: "2021-09-30" 43 | ref: released 44 | - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter 45 | resolution: 4.0 46 | 47 | tokenizer: 48 | _target_: boltz.data.tokenize.boltz.BoltzTokenizer 49 | featurizer: 50 | _target_: boltz.data.feature.featurizer.BoltzFeaturizer 51 | 52 | symmetries: PATH_TO_SYMMETRY_FILE 53 | max_tokens: 512 54 | max_atoms: 4608 55 | max_seqs: 2048 56 | pad_to_max_tokens: true 57 | pad_to_max_atoms: true 58 | pad_to_max_seqs: true 59 | samples_per_epoch: 100000 60 | batch_size: 1 61 | num_workers: 4 62 | random_seed: 42 63 | pin_memory: true 64 | overfit: null 65 | crop_validation: true 66 | return_train_symmetries: true 67 | return_val_symmetries: true 68 | train_binder_pocket_conditioned_prop: 0.3 69 | val_binder_pocket_conditioned_prop: 0.3 70 | binder_pocket_cutoff: 6.0 71 | binder_pocket_sampling_geometric_p: 0.3 72 | min_dist: 2.0 73 | max_dist: 22.0 74 | num_bins: 64 75 | atoms_per_window_queries: 32 76 | 77 | model: 78 | _target_: boltz.model.model.Boltz1 79 | atom_s: 128 80 | atom_z: 16 81 | token_s: 384 82 | token_z: 128 83 | num_bins: 64 84 | atom_feature_dim: 389 85 | atoms_per_window_queries: 32 86 | atoms_per_window_keys: 128 87 | compile_pairformer: false 88 | nucleotide_rmsd_weight: 5.0 89 | ligand_rmsd_weight: 10.0 90 | ema: true 91 | ema_decay: 0.999 92 | 93 | embedder_args: 94 | atom_encoder_depth: 3 95 | atom_encoder_heads: 4 96 | 97 | msa_args: 98 | msa_s: 64 99 | msa_blocks: 4 100 | msa_dropout: 0.15 101 | z_dropout: 0.25 102 | pairwise_head_width: 32 103 | pairwise_num_heads: 4 104 | activation_checkpointing: true 105 | offload_to_cpu: false 106 | 107 | pairformer_args: 108 | num_blocks: 48 109 | num_heads: 16 110 | dropout: 0.25 111 | activation_checkpointing: true 112 | offload_to_cpu: false 113 | 114 | score_model_args: 115 | sigma_data: 16 116 | dim_fourier: 256 117 | atom_encoder_depth: 3 118 | atom_encoder_heads: 4 119 | token_transformer_depth: 24 120 | token_transformer_heads: 16 121 | atom_decoder_depth: 3 122 | atom_decoder_heads: 4 123 | conditioning_transition_layers: 2 124 | activation_checkpointing: true 125 | offload_to_cpu: false 126 | 127 | structure_prediction_training: false 128 | confidence_prediction: true 129 | alpha_pae: 1 130 | confidence_imitate_trunk: true 131 | confidence_model_args: 132 | num_dist_bins: 64 133 | max_dist: 22 134 | add_s_to_z_prod: true 135 | add_s_input_to_s: true 136 | use_s_diffusion: true 137 | add_z_input_to_z: true 138 | 139 | confidence_args: 140 | num_plddt_bins: 50 141 | num_pde_bins: 64 142 | num_pae_bins: 64 143 | 144 | training_args: 145 | recycling_steps: 3 146 | sampling_steps: 200 147 | diffusion_multiplicity: 16 148 | diffusion_samples: 1 149 | confidence_loss_weight: 3e-3 150 | diffusion_loss_weight: 4.0 151 | distogram_loss_weight: 3e-2 152 | adam_beta_1: 0.9 153 | adam_beta_2: 0.95 154 | adam_eps: 0.00000001 155 | lr_scheduler: af3 156 | base_lr: 0.0 157 | max_lr: 0.0018 158 | lr_warmup_no_steps: 1000 159 | lr_start_decay_after_n_steps: 50000 160 | lr_decay_every_n_steps: 50000 161 | lr_decay_factor: 0.95 162 | symmetry_correction: true 163 | 164 | validation_args: 165 | recycling_steps: 3 166 | sampling_steps: 200 167 | diffusion_samples: 5 168 | symmetry_correction: true 169 | run_confidence_sequentially: false 170 | 171 | diffusion_process_args: 172 | sigma_min: 0.0004 173 | sigma_max: 160.0 174 | sigma_data: 16.0 175 | rho: 7 176 | P_mean: -1.2 177 | P_std: 1.5 178 | gamma_0: 0.8 179 | gamma_min: 1.0 180 | noise_scale: 1.0 181 | step_scale: 1.0 182 | coordinate_augmentation: true 183 | alignment_reverse_diff: true 184 | synchronize_sigmas: true 185 | use_inference_model_cache: true 186 | 187 | diffusion_loss_args: 188 | add_smooth_lddt_loss: true 189 | nucleotide_loss_weight: 5.0 190 | ligand_loss_weight: 10.0 191 | -------------------------------------------------------------------------------- /scripts/train/configs/full.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | accelerator: gpu 3 | devices: 1 4 | precision: 32 5 | gradient_clip_val: 10.0 6 | max_epochs: -1 7 | 8 | # Optional set wandb here 9 | # wandb: 10 | # name: boltz 11 | # project: boltz 12 | # entity: boltz 13 | 14 | 15 | output: SET_PATH_HERE 16 | pretrained: PATH_TO_STRUCTURE_CHECKPOINT_FILE 17 | resume: null 18 | disable_checkpoint: false 19 | matmul_precision: null 20 | save_top_k: -1 21 | 22 | data: 23 | datasets: 24 | - _target_: boltz.data.module.training.DatasetConfig 25 | target_dir: PATH_TO_TARGETS_DIR 26 | msa_dir: PATH_TO_MSA_DIR 27 | prob: 1.0 28 | sampler: 29 | _target_: boltz.data.sample.cluster.ClusterSampler 30 | cropper: 31 | _target_: boltz.data.crop.boltz.BoltzCropper 32 | min_neighborhood: 0 33 | max_neighborhood: 40 34 | split: ./scripts/train/assets/validation_ids.txt 35 | 36 | filters: 37 | - _target_: boltz.data.filter.dynamic.size.SizeFilter 38 | min_chains: 1 39 | max_chains: 300 40 | - _target_: boltz.data.filter.dynamic.date.DateFilter 41 | date: "2021-09-30" 42 | ref: released 43 | - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter 44 | resolution: 4.0 45 | 46 | tokenizer: 47 | _target_: boltz.data.tokenize.boltz.BoltzTokenizer 48 | featurizer: 49 | _target_: boltz.data.feature.featurizer.BoltzFeaturizer 50 | 51 | symmetries: PATH_TO_SYMMETRY_FILE 52 | max_tokens: 512 53 | max_atoms: 4608 54 | max_seqs: 2048 55 | pad_to_max_tokens: true 56 | pad_to_max_atoms: true 57 | pad_to_max_seqs: true 58 | samples_per_epoch: 100000 59 | batch_size: 1 60 | num_workers: 4 61 | random_seed: 42 62 | pin_memory: true 63 | overfit: null 64 | crop_validation: true 65 | return_train_symmetries: true 66 | return_val_symmetries: true 67 | train_binder_pocket_conditioned_prop: 0.3 68 | val_binder_pocket_conditioned_prop: 0.3 69 | binder_pocket_cutoff: 6.0 70 | binder_pocket_sampling_geometric_p: 0.3 71 | min_dist: 2.0 72 | max_dist: 22.0 73 | num_bins: 64 74 | atoms_per_window_queries: 32 75 | 76 | model: 77 | _target_: boltz.model.model.Boltz1 78 | atom_s: 128 79 | atom_z: 16 80 | token_s: 384 81 | token_z: 128 82 | num_bins: 64 83 | atom_feature_dim: 389 84 | atoms_per_window_queries: 32 85 | atoms_per_window_keys: 128 86 | compile_pairformer: false 87 | nucleotide_rmsd_weight: 5.0 88 | ligand_rmsd_weight: 10.0 89 | ema: true 90 | ema_decay: 0.999 91 | 92 | embedder_args: 93 | atom_encoder_depth: 3 94 | atom_encoder_heads: 4 95 | 96 | msa_args: 97 | msa_s: 64 98 | msa_blocks: 4 99 | msa_dropout: 0.15 100 | z_dropout: 0.25 101 | pairwise_head_width: 32 102 | pairwise_num_heads: 4 103 | activation_checkpointing: true 104 | offload_to_cpu: false 105 | 106 | pairformer_args: 107 | num_blocks: 48 108 | num_heads: 16 109 | dropout: 0.25 110 | activation_checkpointing: true 111 | offload_to_cpu: false 112 | 113 | score_model_args: 114 | sigma_data: 16 115 | dim_fourier: 256 116 | atom_encoder_depth: 3 117 | atom_encoder_heads: 4 118 | token_transformer_depth: 24 119 | token_transformer_heads: 16 120 | atom_decoder_depth: 3 121 | atom_decoder_heads: 4 122 | conditioning_transition_layers: 2 123 | activation_checkpointing: true 124 | offload_to_cpu: false 125 | 126 | structure_prediction_training: true 127 | confidence_prediction: true 128 | alpha_pae: 1 129 | confidence_imitate_trunk: true 130 | confidence_model_args: 131 | num_dist_bins: 64 132 | max_dist: 22 133 | add_s_to_z_prod: true 134 | add_s_input_to_s: true 135 | use_s_diffusion: true 136 | add_z_input_to_z: true 137 | 138 | confidence_args: 139 | num_plddt_bins: 50 140 | num_pde_bins: 64 141 | num_pae_bins: 64 142 | 143 | training_args: 144 | recycling_steps: 3 145 | sampling_steps: 200 146 | diffusion_multiplicity: 16 147 | diffusion_samples: 1 148 | confidence_loss_weight: 3e-3 149 | diffusion_loss_weight: 4.0 150 | distogram_loss_weight: 3e-2 151 | adam_beta_1: 0.9 152 | adam_beta_2: 0.95 153 | adam_eps: 0.00000001 154 | lr_scheduler: af3 155 | base_lr: 0.0 156 | max_lr: 0.0018 157 | lr_warmup_no_steps: 1000 158 | lr_start_decay_after_n_steps: 50000 159 | lr_decay_every_n_steps: 50000 160 | lr_decay_factor: 0.95 161 | symmetry_correction: true 162 | run_confidence_sequentially: false 163 | 164 | validation_args: 165 | recycling_steps: 3 166 | sampling_steps: 200 167 | diffusion_samples: 5 168 | symmetry_correction: true 169 | 170 | diffusion_process_args: 171 | sigma_min: 0.0004 172 | sigma_max: 160.0 173 | sigma_data: 16.0 174 | rho: 7 175 | P_mean: -1.2 176 | P_std: 1.5 177 | gamma_0: 0.8 178 | gamma_min: 1.0 179 | noise_scale: 1.0 180 | step_scale: 1.0 181 | coordinate_augmentation: true 182 | alignment_reverse_diff: true 183 | synchronize_sigmas: true 184 | use_inference_model_cache: true 185 | 186 | diffusion_loss_args: 187 | add_smooth_lddt_loss: true 188 | nucleotide_loss_weight: 5.0 189 | ligand_loss_weight: 10.0 190 | -------------------------------------------------------------------------------- /scripts/train/configs/structure.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | accelerator: gpu 3 | devices: 1 4 | precision: 32 5 | gradient_clip_val: 10.0 6 | max_epochs: -1 7 | 8 | # Optional set wandb here 9 | # wandb: 10 | # name: boltz 11 | # project: boltz 12 | # entity: boltz 13 | 14 | output: SET_PATH_HERE 15 | resume: PATH_TO_CHECKPOINT_FILE 16 | disable_checkpoint: false 17 | matmul_precision: null 18 | save_top_k: -1 19 | 20 | data: 21 | datasets: 22 | - _target_: boltz.data.module.training.DatasetConfig 23 | target_dir: PATH_TO_TARGETS_DIR 24 | msa_dir: PATH_TO_MSA_DIR 25 | prob: 1.0 26 | sampler: 27 | _target_: boltz.data.sample.cluster.ClusterSampler 28 | cropper: 29 | _target_: boltz.data.crop.boltz.BoltzCropper 30 | min_neighborhood: 0 31 | max_neighborhood: 40 32 | split: ./scripts/train/assets/validation_ids.txt 33 | 34 | filters: 35 | - _target_: boltz.data.filter.dynamic.size.SizeFilter 36 | min_chains: 1 37 | max_chains: 300 38 | - _target_: boltz.data.filter.dynamic.date.DateFilter 39 | date: "2021-09-30" 40 | ref: released 41 | - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter 42 | resolution: 9.0 43 | 44 | tokenizer: 45 | _target_: boltz.data.tokenize.boltz.BoltzTokenizer 46 | featurizer: 47 | _target_: boltz.data.feature.featurizer.BoltzFeaturizer 48 | 49 | symmetries: PATH_TO_SYMMETRY_FILE 50 | max_tokens: 512 51 | max_atoms: 4608 52 | max_seqs: 2048 53 | pad_to_max_tokens: true 54 | pad_to_max_atoms: true 55 | pad_to_max_seqs: true 56 | samples_per_epoch: 100000 57 | batch_size: 1 58 | num_workers: 4 59 | random_seed: 42 60 | pin_memory: true 61 | overfit: null 62 | crop_validation: false 63 | return_train_symmetries: false 64 | return_val_symmetries: true 65 | train_binder_pocket_conditioned_prop: 0.3 66 | val_binder_pocket_conditioned_prop: 0.3 67 | binder_pocket_cutoff: 6.0 68 | binder_pocket_sampling_geometric_p: 0.3 69 | min_dist: 2.0 70 | max_dist: 22.0 71 | num_bins: 64 72 | atoms_per_window_queries: 32 73 | 74 | model: 75 | _target_: boltz.model.model.Boltz1 76 | atom_s: 128 77 | atom_z: 16 78 | token_s: 384 79 | token_z: 128 80 | num_bins: 64 81 | atom_feature_dim: 389 82 | atoms_per_window_queries: 32 83 | atoms_per_window_keys: 128 84 | compile_pairformer: false 85 | nucleotide_rmsd_weight: 5.0 86 | ligand_rmsd_weight: 10.0 87 | ema: true 88 | ema_decay: 0.999 89 | 90 | embedder_args: 91 | atom_encoder_depth: 3 92 | atom_encoder_heads: 4 93 | 94 | msa_args: 95 | msa_s: 64 96 | msa_blocks: 4 97 | msa_dropout: 0.15 98 | z_dropout: 0.25 99 | pairwise_head_width: 32 100 | pairwise_num_heads: 4 101 | activation_checkpointing: true 102 | offload_to_cpu: false 103 | 104 | pairformer_args: 105 | num_blocks: 48 106 | num_heads: 16 107 | dropout: 0.25 108 | activation_checkpointing: true 109 | offload_to_cpu: false 110 | 111 | score_model_args: 112 | sigma_data: 16 113 | dim_fourier: 256 114 | atom_encoder_depth: 3 115 | atom_encoder_heads: 4 116 | token_transformer_depth: 24 117 | token_transformer_heads: 16 118 | atom_decoder_depth: 3 119 | atom_decoder_heads: 4 120 | conditioning_transition_layers: 2 121 | activation_checkpointing: true 122 | offload_to_cpu: false 123 | 124 | confidence_prediction: false 125 | confidence_model_args: 126 | num_dist_bins: 64 127 | max_dist: 22 128 | add_s_to_z_prod: true 129 | add_s_input_to_s: true 130 | use_s_diffusion: true 131 | add_z_input_to_z: true 132 | 133 | confidence_args: 134 | num_plddt_bins: 50 135 | num_pde_bins: 64 136 | num_pae_bins: 64 137 | 138 | training_args: 139 | recycling_steps: 3 140 | sampling_steps: 20 141 | diffusion_multiplicity: 16 142 | diffusion_samples: 2 143 | confidence_loss_weight: 1e-4 144 | diffusion_loss_weight: 4.0 145 | distogram_loss_weight: 3e-2 146 | adam_beta_1: 0.9 147 | adam_beta_2: 0.95 148 | adam_eps: 0.00000001 149 | lr_scheduler: af3 150 | base_lr: 0.0 151 | max_lr: 0.0018 152 | lr_warmup_no_steps: 1000 153 | lr_start_decay_after_n_steps: 50000 154 | lr_decay_every_n_steps: 50000 155 | lr_decay_factor: 0.95 156 | 157 | validation_args: 158 | recycling_steps: 3 159 | sampling_steps: 200 160 | diffusion_samples: 5 161 | symmetry_correction: true 162 | run_confidence_sequentially: false 163 | 164 | diffusion_process_args: 165 | sigma_min: 0.0004 166 | sigma_max: 160.0 167 | sigma_data: 16.0 168 | rho: 7 169 | P_mean: -1.2 170 | P_std: 1.5 171 | gamma_0: 0.8 172 | gamma_min: 1.0 173 | noise_scale: 1.0 174 | step_scale: 1.0 175 | coordinate_augmentation: true 176 | alignment_reverse_diff: true 177 | synchronize_sigmas: true 178 | use_inference_model_cache: true 179 | 180 | diffusion_loss_args: 181 | add_smooth_lddt_loss: true 182 | nucleotide_loss_weight: 5.0 183 | ligand_loss_weight: 10.0 184 | -------------------------------------------------------------------------------- /scripts/train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Optional 6 | 7 | import hydra 8 | import omegaconf 9 | import pytorch_lightning as pl 10 | import torch 11 | import torch.multiprocessing 12 | from omegaconf import OmegaConf, listconfig 13 | from pytorch_lightning import LightningModule 14 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 15 | from pytorch_lightning.loggers import WandbLogger 16 | from pytorch_lightning.strategies import DDPStrategy 17 | from pytorch_lightning.utilities import rank_zero_only 18 | 19 | from boltz.data.module.training import BoltzTrainingDataModule, DataConfig 20 | 21 | 22 | @dataclass 23 | class TrainConfig: 24 | """Train configuration. 25 | 26 | Attributes 27 | ---------- 28 | data : DataConfig 29 | The data configuration. 30 | model : ModelConfig 31 | The model configuration. 32 | output : str 33 | The output directory. 34 | trainer : Optional[dict] 35 | The trainer configuration. 36 | resume : Optional[str] 37 | The resume checkpoint. 38 | pretrained : Optional[str] 39 | The pretrained model. 40 | wandb : Optional[dict] 41 | The wandb configuration. 42 | disable_checkpoint : bool 43 | Disable checkpoint. 44 | matmul_precision : Optional[str] 45 | The matmul precision. 46 | find_unused_parameters : Optional[bool] 47 | Find unused parameters. 48 | save_top_k : Optional[int] 49 | Save top k checkpoints. 50 | validation_only : bool 51 | Run validation only. 52 | debug : bool 53 | Debug mode. 54 | strict_loading : bool 55 | Fail on mismatched checkpoint weights. 56 | load_confidence_from_trunk: Optional[bool] 57 | Load pre-trained confidence weights from trunk. 58 | 59 | """ 60 | 61 | data: DataConfig 62 | model: LightningModule 63 | output: str 64 | trainer: Optional[dict] = None 65 | resume: Optional[str] = None 66 | pretrained: Optional[str] = None 67 | wandb: Optional[dict] = None 68 | disable_checkpoint: bool = False 69 | matmul_precision: Optional[str] = None 70 | find_unused_parameters: Optional[bool] = False 71 | save_top_k: Optional[int] = 1 72 | validation_only: bool = False 73 | debug: bool = False 74 | strict_loading: bool = True 75 | load_confidence_from_trunk: Optional[bool] = False 76 | 77 | 78 | def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR0915 79 | """Run training. 80 | 81 | Parameters 82 | ---------- 83 | raw_config : str 84 | The input yaml configuration. 85 | args : list[str] 86 | Any command line overrides. 87 | 88 | """ 89 | # Load the configuration 90 | raw_config = omegaconf.OmegaConf.load(raw_config) 91 | 92 | # Apply input arguments 93 | args = omegaconf.OmegaConf.from_dotlist(args) 94 | raw_config = omegaconf.OmegaConf.merge(raw_config, args) 95 | 96 | # Instantiate the task 97 | cfg = hydra.utils.instantiate(raw_config) 98 | cfg = TrainConfig(**cfg) 99 | 100 | # Set matmul precision 101 | if cfg.matmul_precision is not None: 102 | torch.set_float32_matmul_precision(cfg.matmul_precision) 103 | 104 | # Create trainer dict 105 | trainer = cfg.trainer 106 | if trainer is None: 107 | trainer = {} 108 | 109 | # Flip some arguments in debug mode 110 | devices = trainer.get("devices", 1) 111 | 112 | wandb = cfg.wandb 113 | if cfg.debug: 114 | if isinstance(devices, int): 115 | devices = 1 116 | elif isinstance(devices, (list, listconfig.ListConfig)): 117 | devices = [devices[0]] 118 | trainer["devices"] = devices 119 | cfg.data.num_workers = 0 120 | if wandb: 121 | wandb = None 122 | 123 | # Create objects 124 | data_config = DataConfig(**cfg.data) 125 | data_module = BoltzTrainingDataModule(data_config) 126 | model_module = cfg.model 127 | 128 | if cfg.pretrained and not cfg.resume: 129 | # Load the pretrained weights into the confidence module 130 | if cfg.load_confidence_from_trunk: 131 | checkpoint = torch.load(cfg.pretrained, map_location="cpu") 132 | 133 | # Modify parameter names in the state_dict 134 | new_state_dict = {} 135 | for key, value in checkpoint["state_dict"].items(): 136 | if not key.startswith("structure_module") and not key.startswith( 137 | "distogram_module" 138 | ): 139 | new_key = "confidence_module." + key 140 | new_state_dict[new_key] = value 141 | new_state_dict.update(checkpoint["state_dict"]) 142 | 143 | # Update the checkpoint with the new state_dict 144 | checkpoint["state_dict"] = new_state_dict 145 | else: 146 | file_path = cfg.pretrained 147 | 148 | print(f"Loading model from {file_path}") 149 | model_module = type(model_module).load_from_checkpoint( 150 | file_path, strict=False, **(model_module.hparams) 151 | ) 152 | 153 | if cfg.load_confidence_from_trunk: 154 | os.remove(file_path) 155 | 156 | # Create checkpoint callback 157 | callbacks = [] 158 | dirpath = cfg.output 159 | if not cfg.disable_checkpoint: 160 | mc = ModelCheckpoint( 161 | monitor="val/lddt", 162 | save_top_k=cfg.save_top_k, 163 | save_last=True, 164 | mode="max", 165 | every_n_epochs=1, 166 | ) 167 | callbacks = [mc] 168 | 169 | # Create wandb logger 170 | loggers = [] 171 | if wandb: 172 | wdb_logger = WandbLogger( 173 | group=wandb["name"], 174 | save_dir=cfg.output, 175 | project=wandb["project"], 176 | entity=wandb["entity"], 177 | log_model=False, 178 | ) 179 | loggers.append(wdb_logger) 180 | # Save the config to wandb 181 | 182 | @rank_zero_only 183 | def save_config_to_wandb() -> None: 184 | config_out = Path(wdb_logger.experiment.dir) / "run.yaml" 185 | with Path.open(config_out, "w") as f: 186 | OmegaConf.save(raw_config, f) 187 | wdb_logger.experiment.save(str(config_out)) 188 | 189 | save_config_to_wandb() 190 | 191 | # Set up trainer 192 | strategy = "auto" 193 | if (isinstance(devices, int) and devices > 1) or ( 194 | isinstance(devices, (list, listconfig.ListConfig)) and len(devices) > 1 195 | ): 196 | strategy = DDPStrategy(find_unused_parameters=cfg.find_unused_parameters) 197 | 198 | trainer = pl.Trainer( 199 | default_root_dir=str(dirpath), 200 | strategy=strategy, 201 | callbacks=callbacks, 202 | logger=loggers, 203 | enable_checkpointing=not cfg.disable_checkpoint, 204 | reload_dataloaders_every_n_epochs=1, 205 | **trainer, 206 | ) 207 | 208 | if not cfg.strict_loading: 209 | model_module.strict_loading = False 210 | 211 | if cfg.validation_only: 212 | trainer.validate( 213 | model_module, 214 | datamodule=data_module, 215 | ckpt_path=cfg.resume, 216 | ) 217 | else: 218 | trainer.fit( 219 | model_module, 220 | datamodule=data_module, 221 | ckpt_path=cfg.resume, 222 | ) 223 | 224 | 225 | if __name__ == "__main__": 226 | arg1 = sys.argv[1] 227 | arg2 = sys.argv[2:] 228 | train(arg1, arg2) 229 | -------------------------------------------------------------------------------- /src/boltz/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import PackageNotFoundError, version 2 | 3 | try: # noqa: SIM105 4 | __version__ = version("boltz") 5 | except PackageNotFoundError: 6 | # package is not installed 7 | pass 8 | -------------------------------------------------------------------------------- /src/boltz/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/__init__.py -------------------------------------------------------------------------------- /src/boltz/data/crop/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/crop/__init__.py -------------------------------------------------------------------------------- /src/boltz/data/crop/cropper.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional 3 | 4 | import numpy as np 5 | 6 | from boltz.data.types import Tokenized 7 | 8 | 9 | class Cropper(ABC): 10 | """Abstract base class for cropper.""" 11 | 12 | @abstractmethod 13 | def crop( 14 | self, 15 | data: Tokenized, 16 | max_tokens: int, 17 | random: np.random.RandomState, 18 | max_atoms: Optional[int] = None, 19 | chain_id: Optional[int] = None, 20 | interface_id: Optional[int] = None, 21 | ) -> Tokenized: 22 | """Crop the data to a maximum number of tokens. 23 | 24 | Parameters 25 | ---------- 26 | data : Tokenized 27 | The tokenized data. 28 | max_tokens : int 29 | The maximum number of tokens to crop. 30 | random : np.random.RandomState 31 | The random state for reproducibility. 32 | max_atoms : Optional[int] 33 | The maximum number of atoms to consider. 34 | chain_id : Optional[int] 35 | The chain ID to crop. 36 | interface_id : Optional[int] 37 | The interface ID to crop. 38 | 39 | Returns 40 | ------- 41 | Tokenized 42 | The cropped data. 43 | 44 | """ 45 | raise NotImplementedError 46 | -------------------------------------------------------------------------------- /src/boltz/data/feature/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/feature/__init__.py -------------------------------------------------------------------------------- /src/boltz/data/feature/pad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn.functional import pad 4 | 5 | 6 | def pad_dim(data: Tensor, dim: int, pad_len: float, value: float = 0) -> Tensor: 7 | """Pad a tensor along a given dimension. 8 | 9 | Parameters 10 | ---------- 11 | data : Tensor 12 | The input tensor. 13 | dim : int 14 | The dimension to pad. 15 | pad_len : float 16 | The padding length. 17 | value : int, optional 18 | The value to pad with. 19 | 20 | Returns 21 | ------- 22 | Tensor 23 | The padded tensor. 24 | 25 | """ 26 | if pad_len == 0: 27 | return data 28 | 29 | total_dims = len(data.shape) 30 | padding = [0] * (2 * (total_dims - dim)) 31 | padding[2 * (total_dims - 1 - dim) + 1] = pad_len 32 | return pad(data, tuple(padding), value=value) 33 | 34 | 35 | def pad_to_max(data: list[Tensor], value: float = 0) -> tuple[Tensor, Tensor]: 36 | """Pad the data in all dimensions to the maximum found. 37 | 38 | Parameters 39 | ---------- 40 | data : List[Tensor] 41 | List of tensors to pad. 42 | value : float 43 | The value to use for padding. 44 | 45 | Returns 46 | ------- 47 | Tensor 48 | The padded tensor. 49 | Tensor 50 | The padding mask. 51 | 52 | """ 53 | if isinstance(data[0], str): 54 | return data, 0 55 | 56 | # Check if all have the same shape 57 | if all(d.shape == data[0].shape for d in data): 58 | return torch.stack(data, dim=0), 0 59 | 60 | # Get the maximum in each dimension 61 | num_dims = len(data[0].shape) 62 | max_dims = [max(d.shape[i] for d in data) for i in range(num_dims)] 63 | 64 | # Get the padding lengths 65 | pad_lengths = [] 66 | for d in data: 67 | dims = [] 68 | for i in range(num_dims): 69 | dims.append(0) 70 | dims.append(max_dims[num_dims - i - 1] - d.shape[num_dims - i - 1]) 71 | pad_lengths.append(dims) 72 | 73 | # Pad the data 74 | padding = [ 75 | pad(torch.ones_like(d), pad_len, value=0) 76 | for d, pad_len in zip(data, pad_lengths) 77 | ] 78 | data = [pad(d, pad_len, value=value) for d, pad_len in zip(data, pad_lengths)] 79 | 80 | # Stack the data 81 | padding = torch.stack(padding, dim=0) 82 | data = torch.stack(data, dim=0) 83 | 84 | return data, padding 85 | -------------------------------------------------------------------------------- /src/boltz/data/filter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/filter/__init__.py -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/filter/dynamic/__init__.py -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/date.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Literal 3 | 4 | from boltz.data.types import Record 5 | from boltz.data.filter.dynamic.filter import DynamicFilter 6 | 7 | 8 | class DateFilter(DynamicFilter): 9 | """A filter that filters complexes based on their date. 10 | 11 | The date can be the deposition, release, or revision date. 12 | If the date is not available, the previous date is used. 13 | 14 | If no date is available, the complex is rejected. 15 | 16 | """ 17 | 18 | def __init__( 19 | self, 20 | date: str, 21 | ref: Literal["deposited", "revised", "released"], 22 | ) -> None: 23 | """Initialize the filter. 24 | 25 | Parameters 26 | ---------- 27 | date : str, optional 28 | The maximum date of PDB entries to filter 29 | ref : Literal["deposited", "revised", "released"] 30 | The reference date to use. 31 | 32 | """ 33 | self.filter_date = datetime.fromisoformat(date) 34 | self.ref = ref 35 | 36 | if ref not in ["deposited", "revised", "released"]: 37 | msg = ( 38 | "Invalid reference date. Must be ", 39 | "deposited, revised, or released", 40 | ) 41 | raise ValueError(msg) 42 | 43 | def filter(self, record: Record) -> bool: 44 | """Filter a record based on its date. 45 | 46 | Parameters 47 | ---------- 48 | record : Record 49 | The record to filter. 50 | 51 | Returns 52 | ------- 53 | bool 54 | Whether the record should be filtered. 55 | 56 | """ 57 | structure = record.structure 58 | 59 | if self.ref == "deposited": 60 | date = structure.deposited 61 | elif self.ref == "released": 62 | date = structure.released 63 | if not date: 64 | date = structure.deposited 65 | elif self.ref == "revised": 66 | date = structure.revised 67 | if not date and structure.released: 68 | date = structure.released 69 | elif not date: 70 | date = structure.deposited 71 | 72 | if date is None or date == "": 73 | return False 74 | 75 | date = datetime.fromisoformat(date) 76 | return date <= self.filter_date 77 | -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/filter.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from boltz.data.types import Record 4 | 5 | 6 | class DynamicFilter(ABC): 7 | """Base class for data filters.""" 8 | 9 | @abstractmethod 10 | def filter(self, record: Record) -> bool: 11 | """Filter a data record. 12 | 13 | Parameters 14 | ---------- 15 | record : Record 16 | The object to consider filtering in / out. 17 | 18 | Returns 19 | ------- 20 | bool 21 | True if the data passes the filter, False otherwise. 22 | 23 | """ 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/max_residues.py: -------------------------------------------------------------------------------- 1 | from boltz.data.types import Record 2 | from boltz.data.filter.dynamic.filter import DynamicFilter 3 | 4 | 5 | class MaxResiduesFilter(DynamicFilter): 6 | """A filter that filters structures based on their size.""" 7 | 8 | def __init__(self, min_residues: int = 1, max_residues: int = 500) -> None: 9 | """Initialize the filter. 10 | 11 | Parameters 12 | ---------- 13 | min_chains : int 14 | The minimum number of chains allowed. 15 | max_chains : int 16 | The maximum number of chains allowed. 17 | 18 | """ 19 | self.min_residues = min_residues 20 | self.max_residues = max_residues 21 | 22 | def filter(self, record: Record) -> bool: 23 | """Filter structures based on their resolution. 24 | 25 | Parameters 26 | ---------- 27 | record : Record 28 | The record to filter. 29 | 30 | Returns 31 | ------- 32 | bool 33 | Whether the record should be filtered. 34 | 35 | """ 36 | num_residues = sum(chain.num_residues for chain in record.chains) 37 | return num_residues <= self.max_residues and num_residues >= self.min_residues 38 | -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/resolution.py: -------------------------------------------------------------------------------- 1 | from boltz.data.types import Record 2 | from boltz.data.filter.dynamic.filter import DynamicFilter 3 | 4 | 5 | class ResolutionFilter(DynamicFilter): 6 | """A filter that filters complexes based on their resolution.""" 7 | 8 | def __init__(self, resolution: float = 9.0) -> None: 9 | """Initialize the filter. 10 | 11 | Parameters 12 | ---------- 13 | resolution : float, optional 14 | The maximum allowed resolution. 15 | 16 | """ 17 | self.resolution = resolution 18 | 19 | def filter(self, record: Record) -> bool: 20 | """Filter complexes based on their resolution. 21 | 22 | Parameters 23 | ---------- 24 | record : Record 25 | The record to filter. 26 | 27 | Returns 28 | ------- 29 | bool 30 | Whether the record should be filtered. 31 | 32 | """ 33 | structure = record.structure 34 | return structure.resolution <= self.resolution 35 | -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/size.py: -------------------------------------------------------------------------------- 1 | from boltz.data.types import Record 2 | from boltz.data.filter.dynamic.filter import DynamicFilter 3 | 4 | 5 | class SizeFilter(DynamicFilter): 6 | """A filter that filters structures based on their size.""" 7 | 8 | def __init__(self, min_chains: int = 1, max_chains: int = 300) -> None: 9 | """Initialize the filter. 10 | 11 | Parameters 12 | ---------- 13 | min_chains : int 14 | The minimum number of chains allowed. 15 | max_chains : int 16 | The maximum number of chains allowed. 17 | 18 | """ 19 | self.min_chains = min_chains 20 | self.max_chains = max_chains 21 | 22 | def filter(self, record: Record) -> bool: 23 | """Filter structures based on their resolution. 24 | 25 | Parameters 26 | ---------- 27 | record : Record 28 | The record to filter. 29 | 30 | Returns 31 | ------- 32 | bool 33 | Whether the record should be filtered. 34 | 35 | """ 36 | num_chains = record.structure.num_chains 37 | num_valid = sum(1 for chain in record.chains if chain.valid) 38 | return num_chains <= self.max_chains and num_valid >= self.min_chains 39 | -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/subset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from boltz.data.types import Record 4 | from boltz.data.filter.dynamic.filter import DynamicFilter 5 | 6 | 7 | class SubsetFilter(DynamicFilter): 8 | """Filter a data record based on a subset of the data.""" 9 | 10 | def __init__(self, subset: str, reverse: bool = False) -> None: 11 | """Initialize the filter. 12 | 13 | Parameters 14 | ---------- 15 | subset : str 16 | The subset of data to consider, one per line. 17 | 18 | """ 19 | with Path(subset).open("r") as f: 20 | subset = f.read().splitlines() 21 | 22 | self.subset = {s.lower() for s in subset} 23 | self.reverse = reverse 24 | 25 | def filter(self, record: Record) -> bool: 26 | """Filter a data record. 27 | 28 | Parameters 29 | ---------- 30 | record : Record 31 | The object to consider filtering in / out. 32 | 33 | Returns 34 | ------- 35 | bool 36 | True if the data passes the filter, False otherwise. 37 | 38 | """ 39 | if self.reverse: 40 | return record.id.lower() not in self.subset 41 | else: # noqa: RET505 42 | return record.id.lower() in self.subset 43 | -------------------------------------------------------------------------------- /src/boltz/data/filter/static/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/filter/static/__init__.py -------------------------------------------------------------------------------- /src/boltz/data/filter/static/filter.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | 5 | from boltz.data.types import Structure 6 | 7 | 8 | class StaticFilter(ABC): 9 | """Base class for structure filters.""" 10 | 11 | @abstractmethod 12 | def filter(self, structure: Structure) -> np.ndarray: 13 | """Filter chains in a structure. 14 | 15 | Parameters 16 | ---------- 17 | structure : Structure 18 | The structure to filter chains from. 19 | 20 | Returns 21 | ------- 22 | np.ndarray 23 | The chains to keep, as a boolean mask. 24 | 25 | """ 26 | raise NotImplementedError 27 | -------------------------------------------------------------------------------- /src/boltz/data/filter/static/ligand.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from boltz.data import const 4 | from boltz.data.types import Structure 5 | from boltz.data.filter.static.filter import StaticFilter 6 | 7 | LIGAND_EXCLUSION = { 8 | "144", 9 | "15P", 10 | "1PE", 11 | "2F2", 12 | "2JC", 13 | "3HR", 14 | "3SY", 15 | "7N5", 16 | "7PE", 17 | "9JE", 18 | "AAE", 19 | "ABA", 20 | "ACE", 21 | "ACN", 22 | "ACT", 23 | "ACY", 24 | "AZI", 25 | "BAM", 26 | "BCN", 27 | "BCT", 28 | "BDN", 29 | "BEN", 30 | "BME", 31 | "BO3", 32 | "BTB", 33 | "BTC", 34 | "BU1", 35 | "C8E", 36 | "CAD", 37 | "CAQ", 38 | "CBM", 39 | "CCN", 40 | "CIT", 41 | "CL", 42 | "CLR", 43 | "CM", 44 | "CMO", 45 | "CO3", 46 | "CPT", 47 | "CXS", 48 | "D10", 49 | "DEP", 50 | "DIO", 51 | "DMS", 52 | "DN", 53 | "DOD", 54 | "DOX", 55 | "EDO", 56 | "EEE", 57 | "EGL", 58 | "EOH", 59 | "EOX", 60 | "EPE", 61 | "ETF", 62 | "FCY", 63 | "FJO", 64 | "FLC", 65 | "FMT", 66 | "FW5", 67 | "GOL", 68 | "GSH", 69 | "GTT", 70 | "GYF", 71 | "HED", 72 | "IHP", 73 | "IHS", 74 | "IMD", 75 | "IOD", 76 | "IPA", 77 | "IPH", 78 | "LDA", 79 | "MB3", 80 | "MEG", 81 | "MES", 82 | "MLA", 83 | "MLI", 84 | "MOH", 85 | "MPD", 86 | "MRD", 87 | "MSE", 88 | "MYR", 89 | "N", 90 | "NA", 91 | "NH2", 92 | "NH4", 93 | "NHE", 94 | "NO3", 95 | "O4B", 96 | "OHE", 97 | "OLA", 98 | "OLC", 99 | "OMB", 100 | "OME", 101 | "OXA", 102 | "P6G", 103 | "PE3", 104 | "PE4", 105 | "PEG", 106 | "PEO", 107 | "PEP", 108 | "PG0", 109 | "PG4", 110 | "PGE", 111 | "PGR", 112 | "PLM", 113 | "PO4", 114 | "POL", 115 | "POP", 116 | "PVO", 117 | "SAR", 118 | "SCN", 119 | "SEO", 120 | "SEP", 121 | "SIN", 122 | "SO4", 123 | "SPD", 124 | "SPM", 125 | "SR", 126 | "STE", 127 | "STO", 128 | "STU", 129 | "TAR", 130 | "TBU", 131 | "TME", 132 | "TPO", 133 | "TRS", 134 | "UNK", 135 | "UNL", 136 | "UNX", 137 | "UPL", 138 | "URE", 139 | } 140 | 141 | 142 | class ExcludedLigands(StaticFilter): 143 | """Filter excluded ligands.""" 144 | 145 | def filter(self, structure: Structure) -> np.ndarray: 146 | """Filter excluded ligands. 147 | 148 | Parameters 149 | ---------- 150 | structure : Structure 151 | The structure to filter chains from. 152 | 153 | Returns 154 | ------- 155 | np.ndarray 156 | The chains to keep, as a boolean mask. 157 | 158 | """ 159 | valid = np.ones(len(structure.chains), dtype=bool) 160 | 161 | for i, chain in enumerate(structure.chains): 162 | if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]: 163 | continue 164 | 165 | res_start = chain["res_idx"] 166 | res_end = res_start + chain["res_num"] 167 | residues = structure.residues[res_start:res_end] 168 | if any(res["name"] in LIGAND_EXCLUSION for res in residues): 169 | valid[i] = 0 170 | 171 | return valid 172 | -------------------------------------------------------------------------------- /src/boltz/data/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/module/__init__.py -------------------------------------------------------------------------------- /src/boltz/data/module/inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | from torch import Tensor 7 | from torch.utils.data import DataLoader 8 | 9 | from boltz.data import const 10 | from boltz.data.feature.featurizer import BoltzFeaturizer 11 | from boltz.data.feature.pad import pad_to_max 12 | from boltz.data.tokenize.boltz import BoltzTokenizer 13 | from boltz.data.types import MSA, Input, Manifest, Record, Structure 14 | 15 | 16 | def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input: 17 | """Load the given input data. 18 | 19 | Parameters 20 | ---------- 21 | record : Record 22 | The record to load. 23 | target_dir : Path 24 | The path to the data directory. 25 | msa_dir : Path 26 | The path to msa directory. 27 | 28 | Returns 29 | ------- 30 | Input 31 | The loaded input. 32 | 33 | """ 34 | # Load the structure 35 | structure = np.load(target_dir / f"{record.id}.npz") 36 | structure = Structure( 37 | atoms=structure["atoms"], 38 | bonds=structure["bonds"], 39 | residues=structure["residues"], 40 | chains=structure["chains"], 41 | connections=structure["connections"], 42 | interfaces=structure["interfaces"], 43 | mask=structure["mask"], 44 | ) 45 | 46 | msas = {} 47 | for chain in record.chains: 48 | msa_id = chain.msa_id 49 | # Load the MSA for this chain, if any 50 | if msa_id != -1: 51 | msa = np.load(msa_dir / f"{msa_id}.npz") 52 | msas[chain.chain_id] = MSA(**msa) 53 | 54 | return Input(structure, msas) 55 | 56 | 57 | def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]: 58 | """Collate the data. 59 | 60 | Parameters 61 | ---------- 62 | data : List[Dict[str, Tensor]] 63 | The data to collate. 64 | 65 | Returns 66 | ------- 67 | Dict[str, Tensor] 68 | The collated data. 69 | 70 | """ 71 | # Get the keys 72 | keys = data[0].keys() 73 | 74 | # Collate the data 75 | collated = {} 76 | for key in keys: 77 | values = [d[key] for d in data] 78 | 79 | if key not in [ 80 | "all_coords", 81 | "all_resolved_mask", 82 | "crop_to_all_atom_map", 83 | "chain_symmetries", 84 | "amino_acids_symmetries", 85 | "ligand_symmetries", 86 | "record", 87 | ]: 88 | # Check if all have the same shape 89 | shape = values[0].shape 90 | if not all(v.shape == shape for v in values): 91 | values, _ = pad_to_max(values, 0) 92 | else: 93 | values = torch.stack(values, dim=0) 94 | 95 | # Stack the values 96 | collated[key] = values 97 | 98 | return collated 99 | 100 | 101 | class PredictionDataset(torch.utils.data.Dataset): 102 | """Base iterable dataset.""" 103 | 104 | def __init__( 105 | self, 106 | manifest: Manifest, 107 | target_dir: Path, 108 | msa_dir: Path, 109 | ) -> None: 110 | """Initialize the training dataset. 111 | 112 | Parameters 113 | ---------- 114 | manifest : Manifest 115 | The manifest to load data from. 116 | target_dir : Path 117 | The path to the target directory. 118 | msa_dir : Path 119 | The path to the msa directory. 120 | 121 | """ 122 | super().__init__() 123 | self.manifest = manifest 124 | self.target_dir = target_dir 125 | self.msa_dir = msa_dir 126 | self.tokenizer = BoltzTokenizer() 127 | self.featurizer = BoltzFeaturizer() 128 | 129 | def __getitem__(self, idx: int) -> dict: 130 | """Get an item from the dataset. 131 | 132 | Returns 133 | ------- 134 | Dict[str, Tensor] 135 | The sampled data features. 136 | 137 | """ 138 | # Get a sample from the dataset 139 | record = self.manifest.records[idx] 140 | 141 | # Get the structure 142 | try: 143 | input_data = load_input(record, self.target_dir, self.msa_dir) 144 | except Exception as e: # noqa: BLE001 145 | print(f"Failed to load input for {record.id} with error {e}. Skipping.") # noqa: T201 146 | return self.__getitem__(0) 147 | 148 | # Tokenize structure 149 | try: 150 | tokenized = self.tokenizer.tokenize(input_data) 151 | except Exception as e: # noqa: BLE001 152 | print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201 153 | return self.__getitem__(0) 154 | 155 | # Inference specific options 156 | options = record.inference_options 157 | if options is None: 158 | binders, pocket = None, None 159 | else: 160 | binders, pocket = options.binders, options.pocket 161 | 162 | # Compute features 163 | try: 164 | features = self.featurizer.process( 165 | tokenized, 166 | training=False, 167 | max_atoms=None, 168 | max_tokens=None, 169 | max_seqs=const.max_msa_seqs, 170 | pad_to_max_seqs=False, 171 | symmetries={}, 172 | compute_symmetries=False, 173 | inference_binder=binders, 174 | inference_pocket=pocket, 175 | ) 176 | except Exception as e: # noqa: BLE001 177 | print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201 178 | return self.__getitem__(0) 179 | 180 | features["record"] = record 181 | return features 182 | 183 | def __len__(self) -> int: 184 | """Get the length of the dataset. 185 | 186 | Returns 187 | ------- 188 | int 189 | The length of the dataset. 190 | 191 | """ 192 | return len(self.manifest.records) 193 | 194 | 195 | class BoltzInferenceDataModule(pl.LightningDataModule): 196 | """DataModule for Boltz inference.""" 197 | 198 | def __init__( 199 | self, 200 | manifest: Manifest, 201 | target_dir: Path, 202 | msa_dir: Path, 203 | num_workers: int, 204 | ) -> None: 205 | """Initialize the DataModule. 206 | 207 | Parameters 208 | ---------- 209 | config : DataConfig 210 | The data configuration. 211 | 212 | """ 213 | super().__init__() 214 | self.num_workers = num_workers 215 | self.manifest = manifest 216 | self.target_dir = target_dir 217 | self.msa_dir = msa_dir 218 | 219 | def predict_dataloader(self) -> DataLoader: 220 | """Get the training dataloader. 221 | 222 | Returns 223 | ------- 224 | DataLoader 225 | The training dataloader. 226 | 227 | """ 228 | dataset = PredictionDataset( 229 | manifest=self.manifest, 230 | target_dir=self.target_dir, 231 | msa_dir=self.msa_dir, 232 | ) 233 | return DataLoader( 234 | dataset, 235 | batch_size=1, 236 | num_workers=self.num_workers, 237 | pin_memory=True, 238 | shuffle=False, 239 | collate_fn=collate, 240 | ) 241 | 242 | def transfer_batch_to_device( 243 | self, 244 | batch: dict, 245 | device: torch.device, 246 | dataloader_idx: int, # noqa: ARG002 247 | ) -> dict: 248 | """Transfer a batch to the given device. 249 | 250 | Parameters 251 | ---------- 252 | batch : Dict 253 | The batch to transfer. 254 | device : torch.device 255 | The device to transfer to. 256 | dataloader_idx : int 257 | The dataloader index. 258 | 259 | Returns 260 | ------- 261 | np.Any 262 | The transferred batch. 263 | 264 | """ 265 | for key in batch: 266 | if key not in [ 267 | "all_coords", 268 | "all_resolved_mask", 269 | "crop_to_all_atom_map", 270 | "chain_symmetries", 271 | "amino_acids_symmetries", 272 | "ligand_symmetries", 273 | "record", 274 | ]: 275 | batch[key] = batch[key].to(device) 276 | return batch 277 | -------------------------------------------------------------------------------- /src/boltz/data/msa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/msa/__init__.py -------------------------------------------------------------------------------- /src/boltz/data/parse/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/parse/__init__.py -------------------------------------------------------------------------------- /src/boltz/data/parse/a3m.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | from pathlib import Path 3 | from typing import Optional, TextIO 4 | 5 | import numpy as np 6 | 7 | from boltz.data import const 8 | from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence 9 | 10 | 11 | def _parse_a3m( # noqa: C901 12 | lines: TextIO, 13 | taxonomy: Optional[dict[str, str]], 14 | max_seqs: Optional[int] = None, 15 | ) -> MSA: 16 | """Process an MSA file. 17 | 18 | Parameters 19 | ---------- 20 | lines : TextIO 21 | The lines of the MSA file. 22 | taxonomy : dict[str, str] 23 | The taxonomy database, if available. 24 | max_seqs : int, optional 25 | The maximum number of sequences. 26 | 27 | Returns 28 | ------- 29 | MSA 30 | The MSA object. 31 | 32 | """ 33 | visited = set() 34 | sequences = [] 35 | deletions = [] 36 | residues = [] 37 | 38 | seq_idx = 0 39 | for line in lines: 40 | line: str 41 | line = line.strip() # noqa: PLW2901 42 | if not line or line.startswith("#"): 43 | continue 44 | 45 | # Get taxonomy, if annotated 46 | if line.startswith(">"): 47 | header = line.split()[0] 48 | if taxonomy and header.startswith(">UniRef100"): 49 | uniref_id = header.split("_")[1] 50 | taxonomy_id = taxonomy.get(uniref_id) 51 | if taxonomy_id is None: 52 | taxonomy_id = -1 53 | else: 54 | taxonomy_id = -1 55 | continue 56 | 57 | # Skip if duplicate sequence 58 | str_seq = line.replace("-", "").upper() 59 | if str_seq not in visited: 60 | visited.add(str_seq) 61 | else: 62 | continue 63 | 64 | # Process sequence 65 | residue = [] 66 | deletion = [] 67 | count = 0 68 | res_idx = 0 69 | for c in line: 70 | if c != "-" and c.islower(): 71 | count += 1 72 | continue 73 | token = const.prot_letter_to_token[c] 74 | token = const.token_ids[token] 75 | residue.append(token) 76 | if count > 0: 77 | deletion.append((res_idx, count)) 78 | count = 0 79 | res_idx += 1 80 | 81 | res_start = len(residues) 82 | res_end = res_start + len(residue) 83 | 84 | del_start = len(deletions) 85 | del_end = del_start + len(deletion) 86 | 87 | sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end)) 88 | residues.extend(residue) 89 | deletions.extend(deletion) 90 | 91 | seq_idx += 1 92 | if (max_seqs is not None) and (seq_idx >= max_seqs): 93 | break 94 | 95 | # Create MSA object 96 | msa = MSA( 97 | residues=np.array(residues, dtype=MSAResidue), 98 | deletions=np.array(deletions, dtype=MSADeletion), 99 | sequences=np.array(sequences, dtype=MSASequence), 100 | ) 101 | return msa 102 | 103 | 104 | def parse_a3m( 105 | path: Path, 106 | taxonomy: Optional[dict[str, str]], 107 | max_seqs: Optional[int] = None, 108 | ) -> MSA: 109 | """Process an A3M file. 110 | 111 | Parameters 112 | ---------- 113 | path : Path 114 | The path to the a3m(.gz) file. 115 | taxonomy : Redis 116 | The taxonomy database. 117 | max_seqs : int, optional 118 | The maximum number of sequences. 119 | 120 | Returns 121 | ------- 122 | MSA 123 | The MSA object. 124 | 125 | """ 126 | # Read the file 127 | if path.suffix == ".gz": 128 | with gzip.open(str(path), "rt") as f: 129 | msa = _parse_a3m(f, taxonomy, max_seqs) 130 | else: 131 | with path.open("r") as f: 132 | msa = _parse_a3m(f, taxonomy, max_seqs) 133 | 134 | return msa 135 | -------------------------------------------------------------------------------- /src/boltz/data/parse/csv.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from boltz.data import const 8 | from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence 9 | 10 | 11 | def parse_csv( 12 | path: Path, 13 | max_seqs: Optional[int] = None, 14 | ) -> MSA: 15 | """Process an A3M file. 16 | 17 | Parameters 18 | ---------- 19 | path : Path 20 | The path to the a3m(.gz) file. 21 | max_seqs : int, optional 22 | The maximum number of sequences. 23 | 24 | Returns 25 | ------- 26 | MSA 27 | The MSA object. 28 | 29 | """ 30 | # Read file 31 | data = pd.read_csv(path) 32 | 33 | # Check columns 34 | if tuple(sorted(data.columns)) != ("key", "sequence"): 35 | msg = "Invalid CSV format, expected columns: ['sequence', 'key']" 36 | raise ValueError(msg) 37 | 38 | # Create taxonomy mapping 39 | visited = set() 40 | sequences = [] 41 | deletions = [] 42 | residues = [] 43 | 44 | seq_idx = 0 45 | for line, key in zip(data["sequence"], data["key"]): 46 | line: str 47 | line = line.strip() # noqa: PLW2901 48 | if not line: 49 | continue 50 | 51 | # Get taxonomy, if annotated 52 | taxonomy_id = -1 53 | if (str(key) != "nan") and (key is not None) and (key != ""): 54 | taxonomy_id = key 55 | 56 | # Skip if duplicate sequence 57 | str_seq = line.replace("-", "").upper() 58 | if str_seq not in visited: 59 | visited.add(str_seq) 60 | else: 61 | continue 62 | 63 | # Process sequence 64 | residue = [] 65 | deletion = [] 66 | count = 0 67 | res_idx = 0 68 | for c in line: 69 | if c != "-" and c.islower(): 70 | count += 1 71 | continue 72 | token = const.prot_letter_to_token[c] 73 | token = const.token_ids[token] 74 | residue.append(token) 75 | if count > 0: 76 | deletion.append((res_idx, count)) 77 | count = 0 78 | res_idx += 1 79 | 80 | res_start = len(residues) 81 | res_end = res_start + len(residue) 82 | 83 | del_start = len(deletions) 84 | del_end = del_start + len(deletion) 85 | 86 | sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end)) 87 | residues.extend(residue) 88 | deletions.extend(deletion) 89 | 90 | seq_idx += 1 91 | if (max_seqs is not None) and (seq_idx >= max_seqs): 92 | break 93 | 94 | # Create MSA object 95 | msa = MSA( 96 | residues=np.array(residues, dtype=MSAResidue), 97 | deletions=np.array(deletions, dtype=MSADeletion), 98 | sequences=np.array(sequences, dtype=MSASequence), 99 | ) 100 | return msa 101 | -------------------------------------------------------------------------------- /src/boltz/data/parse/fasta.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | from pathlib import Path 3 | 4 | from Bio import SeqIO 5 | from rdkit.Chem.rdchem import Mol 6 | 7 | from boltz.data.parse.yaml import parse_boltz_schema 8 | from boltz.data.types import Target 9 | 10 | 11 | def parse_fasta(path: Path, ccd: Mapping[str, Mol]) -> Target: # noqa: C901 12 | """Parse a fasta file. 13 | 14 | The name of the fasta file is used as the name of this job. 15 | We rely on the fasta record id to determine the entity type. 16 | 17 | > CHAIN_ID|ENTITY_TYPE|MSA_ID 18 | SEQUENCE 19 | > CHAIN_ID|ENTITY_TYPE|MSA_ID 20 | ... 21 | 22 | Where ENTITY_TYPE is either protein, rna, dna, ccd or smiles, 23 | and CHAIN_ID is the chain identifier, which should be unique. 24 | The MSA_ID is optional and should only be used on proteins. 25 | 26 | Parameters 27 | ---------- 28 | fasta_file : Path 29 | Path to the fasta file. 30 | ccd : Dict 31 | Dictionary of CCD components. 32 | 33 | Returns 34 | ------- 35 | Target 36 | The parsed target. 37 | 38 | """ 39 | # Read fasta file 40 | with path.open("r") as f: 41 | records = list(SeqIO.parse(f, "fasta")) 42 | 43 | # Make sure all records have a chain id and entity 44 | for seq_record in records: 45 | if "|" not in seq_record.id: 46 | msg = f"Invalid record id: {seq_record.id}" 47 | raise ValueError(msg) 48 | 49 | header = seq_record.id.split("|") 50 | assert len(header) >= 2, f"Invalid record id: {seq_record.id}" 51 | 52 | chain_id, entity_type = header[:2] 53 | if entity_type.lower() not in {"protein", "dna", "rna", "ccd", "smiles"}: 54 | msg = f"Invalid entity type: {entity_type}" 55 | raise ValueError(msg) 56 | if chain_id == "": 57 | msg = "Empty chain id in input fasta!" 58 | raise ValueError(msg) 59 | if entity_type == "": 60 | msg = "Empty entity type in input fasta!" 61 | raise ValueError(msg) 62 | 63 | # Convert to yaml format 64 | sequences = [] 65 | for seq_record in records: 66 | # Get chain id, entity type and sequence 67 | header = seq_record.id.split("|") 68 | chain_id, entity_type = header[:2] 69 | if len(header) == 3 and header[2] != "": 70 | assert ( 71 | entity_type.lower() == "protein" 72 | ), "MSA_ID is only allowed for proteins" 73 | msa_id = header[2] 74 | else: 75 | msa_id = None 76 | 77 | entity_type = entity_type.upper() 78 | seq = str(seq_record.seq) 79 | 80 | if entity_type == "PROTEIN": 81 | molecule = { 82 | "protein": { 83 | "id": chain_id, 84 | "sequence": seq, 85 | "modifications": [], 86 | "msa": msa_id, 87 | }, 88 | } 89 | elif entity_type == "RNA": 90 | molecule = { 91 | "rna": { 92 | "id": chain_id, 93 | "sequence": seq, 94 | "modifications": [], 95 | }, 96 | } 97 | elif entity_type == "DNA": 98 | molecule = { 99 | "dna": { 100 | "id": chain_id, 101 | "sequence": seq, 102 | "modifications": [], 103 | } 104 | } 105 | elif entity_type.upper() == "CCD": 106 | molecule = { 107 | "ligand": { 108 | "id": chain_id, 109 | "ccd": seq, 110 | } 111 | } 112 | elif entity_type.upper() == "SMILES": 113 | molecule = { 114 | "ligand": { 115 | "id": chain_id, 116 | "smiles": seq, 117 | } 118 | } 119 | 120 | sequences.append(molecule) 121 | 122 | data = { 123 | "sequences": sequences, 124 | "bonds": [], 125 | "version": 1, 126 | } 127 | 128 | name = path.stem 129 | return parse_boltz_schema(name, data, ccd) 130 | -------------------------------------------------------------------------------- /src/boltz/data/parse/yaml.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yaml 4 | from rdkit.Chem.rdchem import Mol 5 | 6 | from boltz.data.parse.schema import parse_boltz_schema 7 | from boltz.data.types import Target 8 | 9 | 10 | def parse_yaml(path: Path, ccd: dict[str, Mol]) -> Target: 11 | """Parse a Boltz input yaml / json. 12 | 13 | The input file should be a yaml file with the following format: 14 | 15 | sequences: 16 | - protein: 17 | id: A 18 | sequence: "MADQLTEEQIAEFKEAFSLF" 19 | - protein: 20 | id: [B, C] 21 | sequence: "AKLSILPWGHC" 22 | - rna: 23 | id: D 24 | sequence: "GCAUAGC" 25 | - ligand: 26 | id: E 27 | smiles: "CC1=CC=CC=C1" 28 | - ligand: 29 | id: [F, G] 30 | ccd: [] 31 | constraints: 32 | - bond: 33 | atom1: [A, 1, CA] 34 | atom2: [A, 2, N] 35 | - pocket: 36 | binder: E 37 | contacts: [[B, 1], [B, 2]] 38 | version: 1 39 | 40 | Parameters 41 | ---------- 42 | path : Path 43 | Path to the YAML input format. 44 | components : Dict 45 | Dictionary of CCD components. 46 | 47 | Returns 48 | ------- 49 | Target 50 | The parsed target. 51 | 52 | """ 53 | with path.open("r") as file: 54 | data = yaml.safe_load(file) 55 | 56 | name = path.stem 57 | return parse_boltz_schema(name, data, ccd) 58 | -------------------------------------------------------------------------------- /src/boltz/data/sample/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/sample/__init__.py -------------------------------------------------------------------------------- /src/boltz/data/sample/cluster.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterator, List 2 | 3 | import numpy as np 4 | from numpy.random import RandomState 5 | 6 | from boltz.data import const 7 | from boltz.data.types import ChainInfo, InterfaceInfo, Record 8 | from boltz.data.sample.sampler import Sample, Sampler 9 | 10 | 11 | def get_chain_cluster(chain: ChainInfo, record: Record) -> str: # noqa: ARG001 12 | """Get the cluster id for a chain. 13 | 14 | Parameters 15 | ---------- 16 | chain : ChainInfo 17 | The chain id to get the cluster id for. 18 | record : Record 19 | The record the interface is part of. 20 | 21 | Returns 22 | ------- 23 | str 24 | The cluster id of the chain. 25 | 26 | """ 27 | return chain.cluster_id 28 | 29 | 30 | def get_interface_cluster(interface: InterfaceInfo, record: Record) -> str: 31 | """Get the cluster id for an interface. 32 | 33 | Parameters 34 | ---------- 35 | interface : InterfaceInfo 36 | The interface to get the cluster id for. 37 | record : Record 38 | The record the interface is part of. 39 | 40 | Returns 41 | ------- 42 | str 43 | The cluster id of the interface. 44 | 45 | """ 46 | chain1 = record.chains[interface.chain_1] 47 | chain2 = record.chains[interface.chain_2] 48 | 49 | cluster_1 = str(chain1.cluster_id) 50 | cluster_2 = str(chain2.cluster_id) 51 | 52 | cluster_id = (cluster_1, cluster_2) 53 | cluster_id = tuple(sorted(cluster_id)) 54 | 55 | return cluster_id 56 | 57 | 58 | def get_chain_weight( 59 | chain: ChainInfo, 60 | record: Record, # noqa: ARG001 61 | clusters: Dict[str, int], 62 | beta_chain: float, 63 | alpha_prot: float, 64 | alpha_nucl: float, 65 | alpha_ligand: float, 66 | ) -> float: 67 | """Get the weight of a chain. 68 | 69 | Parameters 70 | ---------- 71 | chain : ChainInfo 72 | The chain to get the weight for. 73 | record : Record 74 | The record the chain is part of. 75 | clusters : Dict[str, int] 76 | The cluster sizes. 77 | beta_chain : float 78 | The beta value for chains. 79 | alpha_prot : float 80 | The alpha value for proteins. 81 | alpha_nucl : float 82 | The alpha value for nucleic acids. 83 | alpha_ligand : float 84 | The alpha value for ligands. 85 | 86 | Returns 87 | ------- 88 | float 89 | The weight of the chain. 90 | 91 | """ 92 | prot_id = const.chain_type_ids["PROTEIN"] 93 | rna_id = const.chain_type_ids["RNA"] 94 | dna_id = const.chain_type_ids["DNA"] 95 | ligand_id = const.chain_type_ids["NONPOLYMER"] 96 | 97 | weight = beta_chain / clusters[chain.cluster_id] 98 | if chain.mol_type == prot_id: 99 | weight *= alpha_prot 100 | elif chain.mol_type in [rna_id, dna_id]: 101 | weight *= alpha_nucl 102 | elif chain.mol_type == ligand_id: 103 | weight *= alpha_ligand 104 | 105 | return weight 106 | 107 | 108 | def get_interface_weight( 109 | interface: InterfaceInfo, 110 | record: Record, 111 | clusters: Dict[str, int], 112 | beta_interface: float, 113 | alpha_prot: float, 114 | alpha_nucl: float, 115 | alpha_ligand: float, 116 | ) -> float: 117 | """Get the weight of an interface. 118 | 119 | Parameters 120 | ---------- 121 | interface : InterfaceInfo 122 | The interface to get the weight for. 123 | record : Record 124 | The record the interface is part of. 125 | clusters : Dict[str, int] 126 | The cluster sizes. 127 | beta_interface : float 128 | The beta value for interfaces. 129 | alpha_prot : float 130 | The alpha value for proteins. 131 | alpha_nucl : float 132 | The alpha value for nucleic acids. 133 | alpha_ligand : float 134 | The alpha value for ligands. 135 | 136 | Returns 137 | ------- 138 | float 139 | The weight of the interface. 140 | 141 | """ 142 | prot_id = const.chain_type_ids["PROTEIN"] 143 | rna_id = const.chain_type_ids["RNA"] 144 | dna_id = const.chain_type_ids["DNA"] 145 | ligand_id = const.chain_type_ids["NONPOLYMER"] 146 | 147 | chain1 = record.chains[interface.chain_1] 148 | chain2 = record.chains[interface.chain_2] 149 | 150 | n_prot = (chain1.mol_type) == prot_id 151 | n_nuc = chain1.mol_type in [rna_id, dna_id] 152 | n_ligand = chain1.mol_type == ligand_id 153 | 154 | n_prot += chain2.mol_type == prot_id 155 | n_nuc += chain2.mol_type in [rna_id, dna_id] 156 | n_ligand += chain2.mol_type == ligand_id 157 | 158 | weight = beta_interface / clusters[get_interface_cluster(interface, record)] 159 | weight *= alpha_prot * n_prot + alpha_nucl * n_nuc + alpha_ligand * n_ligand 160 | return weight 161 | 162 | 163 | class ClusterSampler(Sampler): 164 | """The weighted sampling approach, as described in AF3. 165 | 166 | Each chain / interface is given a weight according 167 | to the following formula, and sampled accordingly: 168 | 169 | w = b / n_clust *(a_prot * n_prot + a_nuc * n_nuc 170 | + a_ligand * n_ligand) 171 | 172 | """ 173 | 174 | def __init__( 175 | self, 176 | alpha_prot: float = 3.0, 177 | alpha_nucl: float = 3.0, 178 | alpha_ligand: float = 1.0, 179 | beta_chain: float = 0.5, 180 | beta_interface: float = 1.0, 181 | ) -> None: 182 | """Initialize the sampler. 183 | 184 | Parameters 185 | ---------- 186 | alpha_prot : float, optional 187 | The alpha value for proteins. 188 | alpha_nucl : float, optional 189 | The alpha value for nucleic acids. 190 | alpha_ligand : float, optional 191 | The alpha value for ligands. 192 | beta_chain : float, optional 193 | The beta value for chains. 194 | beta_interface : float, optional 195 | The beta value for interfaces. 196 | 197 | """ 198 | self.alpha_prot = alpha_prot 199 | self.alpha_nucl = alpha_nucl 200 | self.alpha_ligand = alpha_ligand 201 | self.beta_chain = beta_chain 202 | self.beta_interface = beta_interface 203 | 204 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: # noqa: C901, PLR0912 205 | """Sample a structure from the dataset infinitely. 206 | 207 | Parameters 208 | ---------- 209 | records : List[Record] 210 | The records to sample from. 211 | random : RandomState 212 | The random state for reproducibility. 213 | 214 | Yields 215 | ------ 216 | Sample 217 | A data sample. 218 | 219 | """ 220 | # Compute chain cluster sizes 221 | chain_clusters: Dict[str, int] = {} 222 | for record in records: 223 | for chain in record.chains: 224 | if not chain.valid: 225 | continue 226 | cluster_id = get_chain_cluster(chain, record) 227 | if cluster_id not in chain_clusters: 228 | chain_clusters[cluster_id] = 0 229 | chain_clusters[cluster_id] += 1 230 | 231 | # Compute interface clusters sizes 232 | interface_clusters: Dict[str, int] = {} 233 | for record in records: 234 | for interface in record.interfaces: 235 | if not interface.valid: 236 | continue 237 | cluster_id = get_interface_cluster(interface, record) 238 | if cluster_id not in interface_clusters: 239 | interface_clusters[cluster_id] = 0 240 | interface_clusters[cluster_id] += 1 241 | 242 | # Compute weights 243 | items, weights = [], [] 244 | for record in records: 245 | for chain_id, chain in enumerate(record.chains): 246 | if not chain.valid: 247 | continue 248 | weight = get_chain_weight( 249 | chain, 250 | record, 251 | chain_clusters, 252 | self.beta_chain, 253 | self.alpha_prot, 254 | self.alpha_nucl, 255 | self.alpha_ligand, 256 | ) 257 | items.append((record, 0, chain_id)) 258 | weights.append(weight) 259 | 260 | for int_id, interface in enumerate(record.interfaces): 261 | if not interface.valid: 262 | continue 263 | weight = get_interface_weight( 264 | interface, 265 | record, 266 | interface_clusters, 267 | self.beta_interface, 268 | self.alpha_prot, 269 | self.alpha_nucl, 270 | self.alpha_ligand, 271 | ) 272 | items.append((record, 1, int_id)) 273 | weights.append(weight) 274 | 275 | # Sample infinitely 276 | weights = np.array(weights) / np.sum(weights) 277 | while True: 278 | item_idx = random.choice(len(items), p=weights) 279 | record, kind, index = items[item_idx] 280 | if kind == 0: 281 | yield Sample(record=record, chain_id=index) 282 | else: 283 | yield Sample(record=record, interface_id=index) 284 | -------------------------------------------------------------------------------- /src/boltz/data/sample/distillation.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List 2 | 3 | from numpy.random import RandomState 4 | 5 | from boltz.data.types import Record 6 | from boltz.data.sample.sampler import Sample, Sampler 7 | 8 | 9 | class DistillationSampler(Sampler): 10 | """A sampler for monomer distillation data.""" 11 | 12 | def __init__(self, small_size: int = 200, small_prob: float = 0.01) -> None: 13 | """Initialize the sampler. 14 | 15 | Parameters 16 | ---------- 17 | small_size : int, optional 18 | The maximum size to be considered small. 19 | small_prob : float, optional 20 | The probability of sampling a small item. 21 | 22 | """ 23 | self._size = small_size 24 | self._prob = small_prob 25 | 26 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: 27 | """Sample a structure from the dataset infinitely. 28 | 29 | Parameters 30 | ---------- 31 | records : List[Record] 32 | The records to sample from. 33 | random : RandomState 34 | The random state for reproducibility. 35 | 36 | Yields 37 | ------ 38 | Sample 39 | A data sample. 40 | 41 | """ 42 | # Remove records with invalid chains 43 | records = [r for r in records if r.chains[0].valid] 44 | 45 | # Split in small and large proteins. We assume that there is only 46 | # one chain per record, as is the case for monomer distillation 47 | small = [r for r in records if r.chains[0].num_residues <= self._size] 48 | large = [r for r in records if r.chains[0].num_residues > self._size] 49 | 50 | # Sample infinitely 51 | while True: 52 | # Sample small or large 53 | samples = small if random.rand() < self._prob else large 54 | 55 | # Sample item from the list 56 | index = random.randint(0, len(samples)) 57 | yield Sample(record=samples[index]) 58 | -------------------------------------------------------------------------------- /src/boltz/data/sample/random.py: -------------------------------------------------------------------------------- 1 | from dataclasses import replace 2 | from typing import Iterator, List 3 | 4 | from numpy.random import RandomState 5 | 6 | from boltz.data.types import Record 7 | from boltz.data.sample.sampler import Sample, Sampler 8 | 9 | 10 | class RandomSampler(Sampler): 11 | """A simple random sampler with replacement.""" 12 | 13 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: 14 | """Sample a structure from the dataset infinitely. 15 | 16 | Parameters 17 | ---------- 18 | records : List[Record] 19 | The records to sample from. 20 | random : RandomState 21 | The random state for reproducibility. 22 | 23 | Yields 24 | ------ 25 | Sample 26 | A data sample. 27 | 28 | """ 29 | while True: 30 | # Sample item from the list 31 | index = random.randint(0, len(records)) 32 | record = records[index] 33 | 34 | # Remove invalid chains and interfaces 35 | chains = [c for c in record.chains if c.valid] 36 | interfaces = [i for i in record.interfaces if i.valid] 37 | record = replace(record, chains=chains, interfaces=interfaces) 38 | 39 | yield Sample(record=record) 40 | -------------------------------------------------------------------------------- /src/boltz/data/sample/sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Iterator, List, Optional 4 | 5 | from numpy.random import RandomState 6 | 7 | from boltz.data.types import Record 8 | 9 | 10 | @dataclass 11 | class Sample: 12 | """A sample with optional chain and interface IDs. 13 | 14 | Attributes 15 | ---------- 16 | record : Record 17 | The record. 18 | chain_id : Optional[int] 19 | The chain ID. 20 | interface_id : Optional[int] 21 | The interface ID. 22 | """ 23 | 24 | record: Record 25 | chain_id: Optional[int] = None 26 | interface_id: Optional[int] = None 27 | 28 | 29 | class Sampler(ABC): 30 | """Abstract base class for samplers.""" 31 | 32 | @abstractmethod 33 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: 34 | """Sample a structure from the dataset infinitely. 35 | 36 | Parameters 37 | ---------- 38 | records : List[Record] 39 | The records to sample from. 40 | random : RandomState 41 | The random state for reproducibility. 42 | 43 | Yields 44 | ------ 45 | Sample 46 | A data sample. 47 | 48 | """ 49 | raise NotImplementedError 50 | -------------------------------------------------------------------------------- /src/boltz/data/tokenize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/tokenize/__init__.py -------------------------------------------------------------------------------- /src/boltz/data/tokenize/boltz.py: -------------------------------------------------------------------------------- 1 | from dataclasses import astuple, dataclass 2 | 3 | import numpy as np 4 | 5 | from boltz.data import const 6 | from boltz.data.tokenize.tokenizer import Tokenizer 7 | from boltz.data.types import Input, Token, TokenBond, Tokenized 8 | 9 | 10 | @dataclass 11 | class TokenData: 12 | """TokenData datatype.""" 13 | 14 | token_idx: int 15 | atom_idx: int 16 | atom_num: int 17 | res_idx: int 18 | res_type: int 19 | sym_id: int 20 | asym_id: int 21 | entity_id: int 22 | mol_type: int 23 | center_idx: int 24 | disto_idx: int 25 | center_coords: np.ndarray 26 | disto_coords: np.ndarray 27 | resolved_mask: bool 28 | disto_mask: bool 29 | 30 | 31 | class BoltzTokenizer(Tokenizer): 32 | """Tokenize an input structure for training.""" 33 | 34 | def tokenize(self, data: Input) -> Tokenized: 35 | """Tokenize the input data. 36 | 37 | Parameters 38 | ---------- 39 | data : Input 40 | The input data. 41 | 42 | Returns 43 | ------- 44 | Tokenized 45 | The tokenized data. 46 | 47 | """ 48 | # Get structure data 49 | struct = data.structure 50 | 51 | # Create token data 52 | token_data = [] 53 | 54 | # Keep track of atom_idx to token_idx 55 | token_idx = 0 56 | atom_to_token = {} 57 | 58 | # Filter to valid chains only 59 | chains = struct.chains[struct.mask] 60 | 61 | for chain in chains: 62 | # Get residue indices 63 | res_start = chain["res_idx"] 64 | res_end = chain["res_idx"] + chain["res_num"] 65 | 66 | for res in struct.residues[res_start:res_end]: 67 | # Get atom indices 68 | atom_start = res["atom_idx"] 69 | atom_end = res["atom_idx"] + res["atom_num"] 70 | 71 | # Standard residues are tokens 72 | if res["is_standard"]: 73 | # Get center and disto atoms 74 | center = struct.atoms[res["atom_center"]] 75 | disto = struct.atoms[res["atom_disto"]] 76 | 77 | # Token is present if centers are 78 | is_present = res["is_present"] & center["is_present"] 79 | is_disto_present = res["is_present"] & disto["is_present"] 80 | 81 | # Apply chain transformation 82 | c_coords = center["coords"] 83 | d_coords = disto["coords"] 84 | 85 | # Create token 86 | token = TokenData( 87 | token_idx=token_idx, 88 | atom_idx=res["atom_idx"], 89 | atom_num=res["atom_num"], 90 | res_idx=res["res_idx"], 91 | res_type=res["res_type"], 92 | sym_id=chain["sym_id"], 93 | asym_id=chain["asym_id"], 94 | entity_id=chain["entity_id"], 95 | mol_type=chain["mol_type"], 96 | center_idx=res["atom_center"], 97 | disto_idx=res["atom_disto"], 98 | center_coords=c_coords, 99 | disto_coords=d_coords, 100 | resolved_mask=is_present, 101 | disto_mask=is_disto_present, 102 | ) 103 | token_data.append(astuple(token)) 104 | 105 | # Update atom_idx to token_idx 106 | for atom_idx in range(atom_start, atom_end): 107 | atom_to_token[atom_idx] = token_idx 108 | 109 | token_idx += 1 110 | 111 | # Non-standard are tokenized per atom 112 | else: 113 | # We use the unk protein token as res_type 114 | unk_token = const.unk_token["PROTEIN"] 115 | unk_id = const.token_ids[unk_token] 116 | 117 | # Get atom coordinates 118 | atom_data = struct.atoms[atom_start:atom_end] 119 | atom_coords = atom_data["coords"] 120 | 121 | # Tokenize each atom 122 | for i, atom in enumerate(atom_data): 123 | # Token is present if atom is 124 | is_present = res["is_present"] & atom["is_present"] 125 | index = atom_start + i 126 | 127 | # Create token 128 | token = TokenData( 129 | token_idx=token_idx, 130 | atom_idx=index, 131 | atom_num=1, 132 | res_idx=res["res_idx"], 133 | res_type=unk_id, 134 | sym_id=chain["sym_id"], 135 | asym_id=chain["asym_id"], 136 | entity_id=chain["entity_id"], 137 | mol_type=chain["mol_type"], 138 | center_idx=index, 139 | disto_idx=index, 140 | center_coords=atom_coords[i], 141 | disto_coords=atom_coords[i], 142 | resolved_mask=is_present, 143 | disto_mask=is_present, 144 | ) 145 | token_data.append(astuple(token)) 146 | 147 | # Update atom_idx to token_idx 148 | atom_to_token[index] = token_idx 149 | token_idx += 1 150 | 151 | # Create token bonds 152 | token_bonds = [] 153 | 154 | # Add atom-atom bonds from ligands 155 | for bond in struct.bonds: 156 | if ( 157 | bond["atom_1"] not in atom_to_token 158 | or bond["atom_2"] not in atom_to_token 159 | ): 160 | continue 161 | token_bond = ( 162 | atom_to_token[bond["atom_1"]], 163 | atom_to_token[bond["atom_2"]], 164 | ) 165 | token_bonds.append(token_bond) 166 | 167 | # Add connection bonds (covalent) 168 | for conn in struct.connections: 169 | if ( 170 | conn["atom_1"] not in atom_to_token 171 | or conn["atom_2"] not in atom_to_token 172 | ): 173 | continue 174 | token_bond = ( 175 | atom_to_token[conn["atom_1"]], 176 | atom_to_token[conn["atom_2"]], 177 | ) 178 | token_bonds.append(token_bond) 179 | 180 | token_data = np.array(token_data, dtype=Token) 181 | token_bonds = np.array(token_bonds, dtype=TokenBond) 182 | tokenized = Tokenized( 183 | token_data, 184 | token_bonds, 185 | data.structure, 186 | data.msa, 187 | ) 188 | return tokenized 189 | -------------------------------------------------------------------------------- /src/boltz/data/tokenize/tokenizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from boltz.data.types import Input, Tokenized 4 | 5 | 6 | class Tokenizer(ABC): 7 | """Tokenize an input structure for training.""" 8 | 9 | @abstractmethod 10 | def tokenize(self, data: Input) -> Tokenized: 11 | """Tokenize the input data. 12 | 13 | Parameters 14 | ---------- 15 | data : Input 16 | The input data. 17 | 18 | Returns 19 | ------- 20 | Tokenized 21 | The tokenized data. 22 | 23 | """ 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /src/boltz/data/write/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/write/__init__.py -------------------------------------------------------------------------------- /src/boltz/data/write/mmcif.py: -------------------------------------------------------------------------------- 1 | import io 2 | from collections.abc import Iterator 3 | from typing import Optional 4 | 5 | import ihm 6 | import modelcif 7 | from modelcif import Assembly, AsymUnit, Entity, System, dumper 8 | from modelcif.model import AbInitioModel, Atom, ModelGroup 9 | from rdkit import Chem 10 | from torch import Tensor 11 | 12 | from boltz.data import const 13 | from boltz.data.types import Structure 14 | from boltz.data.write.utils import generate_tags 15 | 16 | 17 | def to_mmcif(structure: Structure, plddts: Optional[Tensor] = None) -> str: # noqa: C901, PLR0915, PLR0912 18 | """Write a structure into an MMCIF file. 19 | 20 | Parameters 21 | ---------- 22 | structure : Structure 23 | The input structure 24 | 25 | Returns 26 | ------- 27 | str 28 | the output MMCIF file 29 | 30 | """ 31 | system = System() 32 | 33 | # Load periodic table for element mapping 34 | periodic_table = Chem.GetPeriodicTable() 35 | 36 | # Map entities to chain_ids 37 | entity_to_chains = {} 38 | entity_to_moltype = {} 39 | 40 | for chain in structure.chains: 41 | entity_id = chain["entity_id"] 42 | mol_type = chain["mol_type"] 43 | entity_to_chains.setdefault(entity_id, []).append(chain) 44 | entity_to_moltype[entity_id] = mol_type 45 | 46 | # Map entities to sequences 47 | sequences = {} 48 | for entity in entity_to_chains: 49 | # Get the first chain 50 | chain = entity_to_chains[entity][0] 51 | 52 | # Get the sequence 53 | res_start = chain["res_idx"] 54 | res_end = chain["res_idx"] + chain["res_num"] 55 | residues = structure.residues[res_start:res_end] 56 | sequence = [str(res["name"]) for res in residues] 57 | sequences[entity] = sequence 58 | 59 | # Create entity objects 60 | lig_entity = None 61 | entities_map = {} 62 | for entity, sequence in sequences.items(): 63 | mol_type = entity_to_moltype[entity] 64 | 65 | if mol_type == const.chain_type_ids["PROTEIN"]: 66 | alphabet = ihm.LPeptideAlphabet() 67 | chem_comp = lambda x: ihm.LPeptideChemComp(id=x, code=x, code_canonical="X") # noqa: E731 68 | elif mol_type == const.chain_type_ids["DNA"]: 69 | alphabet = ihm.DNAAlphabet() 70 | chem_comp = lambda x: ihm.DNAChemComp(id=x, code=x, code_canonical="N") # noqa: E731 71 | elif mol_type == const.chain_type_ids["RNA"]: 72 | alphabet = ihm.RNAAlphabet() 73 | chem_comp = lambda x: ihm.RNAChemComp(id=x, code=x, code_canonical="N") # noqa: E731 74 | elif len(sequence) > 1: 75 | alphabet = {} 76 | chem_comp = lambda x: ihm.SaccharideChemComp(id=x) # noqa: E731 77 | else: 78 | alphabet = {} 79 | chem_comp = lambda x: ihm.NonPolymerChemComp(id=x) # noqa: E731 80 | 81 | # Handle smiles 82 | if len(sequence) == 1 and (sequence[0] == "LIG"): 83 | if lig_entity is None: 84 | seq = [chem_comp(sequence[0])] 85 | lig_entity = Entity(seq) 86 | model_e = lig_entity 87 | else: 88 | seq = [ 89 | alphabet[item] if item in alphabet else chem_comp(item) 90 | for item in sequence 91 | ] 92 | model_e = Entity(seq) 93 | 94 | for chain in entity_to_chains[entity]: 95 | chain_idx = chain["asym_id"] 96 | entities_map[chain_idx] = model_e 97 | 98 | # We don't assume that symmetry is perfect, so we dump everything 99 | # into the asymmetric unit, and produce just a single assembly 100 | chain_tags = generate_tags() 101 | asym_unit_map = {} 102 | for chain in structure.chains: 103 | # Define the model assembly 104 | chain_idx = chain["asym_id"] 105 | chain_tag = next(chain_tags) 106 | asym = AsymUnit( 107 | entities_map[chain_idx], 108 | details="Model subunit %s" % chain_tag, 109 | id=chain_tag, 110 | ) 111 | asym_unit_map[chain_idx] = asym 112 | modeled_assembly = Assembly(asym_unit_map.values(), name="Modeled assembly") 113 | 114 | class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT): 115 | name = "pLDDT" 116 | software = None 117 | description = "Predicted lddt" 118 | 119 | class _MyModel(AbInitioModel): 120 | def get_atoms(self) -> Iterator[Atom]: 121 | # Add all atom sites. 122 | res_num = 0 123 | for chain in structure.chains: 124 | # We rename the chains in alphabetical order 125 | het = chain["mol_type"] == const.chain_type_ids["NONPOLYMER"] 126 | chain_idx = chain["asym_id"] 127 | res_start = chain["res_idx"] 128 | res_end = chain["res_idx"] + chain["res_num"] 129 | 130 | residues = structure.residues[res_start:res_end] 131 | for residue in residues: 132 | atom_start = residue["atom_idx"] 133 | atom_end = residue["atom_idx"] + residue["atom_num"] 134 | atoms = structure.atoms[atom_start:atom_end] 135 | atom_coords = atoms["coords"] 136 | for i, atom in enumerate(atoms): 137 | # This should not happen on predictions, but just in case. 138 | if not atom["is_present"]: 139 | continue 140 | 141 | name = atom["name"] 142 | name = [chr(c + 32) for c in name if c != 0] 143 | name = "".join(name) 144 | element = periodic_table.GetElementSymbol( 145 | atom["element"].item() 146 | ) 147 | element = element.upper() 148 | residue_index = residue["res_idx"] + 1 149 | pos = atom_coords[i] 150 | yield Atom( 151 | asym_unit=asym_unit_map[chain_idx], 152 | type_symbol=element, 153 | seq_id=residue_index, 154 | atom_id=name, 155 | x=f"{pos[0]:.5f}", 156 | y=f"{pos[1]:.5f}", 157 | z=f"{pos[2]:.5f}", 158 | het=het, 159 | biso=1 160 | if plddts is None 161 | else round(plddts[res_num].item(), 2), 162 | occupancy=1, 163 | ) 164 | 165 | res_num += 1 166 | 167 | def add_plddt(self, plddts): 168 | res_num = 0 169 | for chain in structure.chains: 170 | chain_idx = chain["asym_id"] 171 | res_start = chain["res_idx"] 172 | res_end = chain["res_idx"] + chain["res_num"] 173 | residues = structure.residues[res_start:res_end] 174 | # We rename the chains in alphabetical order 175 | for residue in residues: 176 | residue_idx = residue["res_idx"] + 1 177 | self.qa_metrics.append( 178 | _LocalPLDDT( 179 | asym_unit_map[chain_idx].residue(residue_idx), 180 | plddts[res_num].item(), 181 | ) 182 | ) 183 | res_num += 1 184 | 185 | # Add the model and modeling protocol to the file and write them out: 186 | model = _MyModel(assembly=modeled_assembly, name="Model") 187 | if plddts is not None: 188 | model.add_plddt(plddts) 189 | 190 | model_group = ModelGroup([model], name="All models") 191 | system.model_groups.append(model_group) 192 | 193 | fh = io.StringIO() 194 | dumper.write(fh, [system]) 195 | return fh.getvalue() 196 | -------------------------------------------------------------------------------- /src/boltz/data/write/pdb.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from rdkit import Chem 4 | from torch import Tensor 5 | 6 | from boltz.data import const 7 | from boltz.data.types import Structure 8 | from boltz.data.write.utils import generate_tags 9 | 10 | 11 | def to_pdb(structure: Structure, plddts: Optional[Tensor] = None) -> str: # noqa: PLR0915 12 | """Write a structure into a PDB file. 13 | 14 | Parameters 15 | ---------- 16 | structure : Structure 17 | The input structure 18 | 19 | Returns 20 | ------- 21 | str 22 | the output PDB file 23 | 24 | """ 25 | pdb_lines = [] 26 | 27 | atom_index = 1 28 | atom_reindex_ter = [] 29 | chain_tags = generate_tags() 30 | 31 | # Load periodic table for element mapping 32 | periodic_table = Chem.GetPeriodicTable() 33 | 34 | # Add all atom sites. 35 | res_num = 0 36 | for chain in structure.chains: 37 | # We rename the chains in alphabetical order 38 | chain_idx = chain["asym_id"] 39 | chain_tag = next(chain_tags) 40 | 41 | res_start = chain["res_idx"] 42 | res_end = chain["res_idx"] + chain["res_num"] 43 | 44 | residues = structure.residues[res_start:res_end] 45 | for residue in residues: 46 | atom_start = residue["atom_idx"] 47 | atom_end = residue["atom_idx"] + residue["atom_num"] 48 | atoms = structure.atoms[atom_start:atom_end] 49 | atom_coords = atoms["coords"] 50 | for i, atom in enumerate(atoms): 51 | # This should not happen on predictions, but just in case. 52 | if not atom["is_present"]: 53 | continue 54 | 55 | record_type = ( 56 | "ATOM" 57 | if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"] 58 | else "HETATM" 59 | ) 60 | name = atom["name"] 61 | name = [chr(c + 32) for c in name if c != 0] 62 | name = "".join(name) 63 | name = name if len(name) == 4 else f" {name}" # noqa: PLR2004 64 | alt_loc = "" 65 | insertion_code = "" 66 | occupancy = 1.00 67 | element = periodic_table.GetElementSymbol(atom["element"].item()) 68 | element = element.upper() 69 | charge = "" 70 | residue_index = residue["res_idx"] + 1 71 | pos = atom_coords[i] 72 | res_name_3 = ( 73 | "LIG" if record_type == "HETATM" else str(residue["name"][:3]) 74 | ) 75 | b_factor = 1.00 if plddts is None else round(plddts[res_num].item(), 2) 76 | 77 | # PDB is a columnar format, every space matters here! 78 | atom_line = ( 79 | f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" 80 | f"{res_name_3:>3} {chain_tag:>1}" 81 | f"{residue_index:>4}{insertion_code:>1} " 82 | f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" 83 | f"{occupancy:>6.2f}{b_factor:>6.2f} " 84 | f"{element:>2}{charge:>2}" 85 | ) 86 | pdb_lines.append(atom_line) 87 | atom_reindex_ter.append(atom_index) 88 | atom_index += 1 89 | 90 | res_num += 1 91 | 92 | should_terminate = chain_idx < (len(structure.chains) - 1) 93 | if should_terminate: 94 | # Close the chain. 95 | chain_end = "TER" 96 | chain_termination_line = ( 97 | f"{chain_end:<6}{atom_index:>5} " 98 | f"{res_name_3:>3} " 99 | f"{chain_tag:>1}{residue_index:>4}" 100 | ) 101 | pdb_lines.append(chain_termination_line) 102 | atom_index += 1 103 | 104 | # Dump CONECT records. 105 | for bonds in [structure.bonds, structure.connections]: 106 | for bond in bonds: 107 | atom1 = structure.atoms[bond["atom_1"]] 108 | atom2 = structure.atoms[bond["atom_2"]] 109 | if not atom1["is_present"] or not atom2["is_present"]: 110 | continue 111 | atom1_idx = atom_reindex_ter[bond["atom_1"]] 112 | atom2_idx = atom_reindex_ter[bond["atom_2"]] 113 | conect_line = f"CONECT{atom1_idx:>5}{atom2_idx:>5}" 114 | pdb_lines.append(conect_line) 115 | 116 | pdb_lines.append("END") 117 | pdb_lines.append("") 118 | pdb_lines = [line.ljust(80) for line in pdb_lines] 119 | return "\n".join(pdb_lines) 120 | -------------------------------------------------------------------------------- /src/boltz/data/write/utils.py: -------------------------------------------------------------------------------- 1 | import string 2 | from collections.abc import Iterator 3 | 4 | 5 | def generate_tags() -> Iterator[str]: 6 | """Generate chain tags. 7 | 8 | Yields 9 | ------ 10 | str 11 | The next chain tag 12 | 13 | """ 14 | for i in range(1, 4): 15 | for j in range(len(string.ascii_uppercase) ** i): 16 | tag = "" 17 | for k in range(i): 18 | tag += string.ascii_uppercase[ 19 | j 20 | // (len(string.ascii_uppercase) ** k) 21 | % len(string.ascii_uppercase) 22 | ] 23 | yield tag 24 | -------------------------------------------------------------------------------- /src/boltz/data/write/writer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, replace 2 | import json 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | import numpy as np 7 | from pytorch_lightning import LightningModule, Trainer 8 | from pytorch_lightning.callbacks import BasePredictionWriter 9 | import torch 10 | from torch import Tensor 11 | 12 | from boltz.data.types import ( 13 | Interface, 14 | Record, 15 | Structure, 16 | ) 17 | from boltz.data.write.mmcif import to_mmcif 18 | from boltz.data.write.pdb import to_pdb 19 | 20 | 21 | class BoltzWriter(BasePredictionWriter): 22 | """Custom writer for predictions.""" 23 | 24 | def __init__( 25 | self, 26 | data_dir: str, 27 | output_dir: str, 28 | output_format: Literal["pdb", "mmcif"] = "mmcif", 29 | ) -> None: 30 | """Initialize the writer. 31 | 32 | Parameters 33 | ---------- 34 | output_dir : str 35 | The directory to save the predictions. 36 | 37 | """ 38 | super().__init__(write_interval="batch") 39 | if output_format not in ["pdb", "mmcif"]: 40 | msg = f"Invalid output format: {output_format}" 41 | raise ValueError(msg) 42 | 43 | self.data_dir = Path(data_dir) 44 | self.output_dir = Path(output_dir) 45 | self.output_format = output_format 46 | self.failed = 0 47 | 48 | # Create the output directories 49 | self.output_dir.mkdir(parents=True, exist_ok=True) 50 | 51 | def write_on_batch_end( 52 | self, 53 | trainer: Trainer, # noqa: ARG002 54 | pl_module: LightningModule, # noqa: ARG002 55 | prediction: dict[str, Tensor], 56 | batch_indices: list[int], # noqa: ARG002 57 | batch: dict[str, Tensor], 58 | batch_idx: int, # noqa: ARG002 59 | dataloader_idx: int, # noqa: ARG002 60 | ) -> None: 61 | """Write the predictions to disk.""" 62 | if prediction["exception"]: 63 | self.failed += 1 64 | return 65 | 66 | # Get the records 67 | records: list[Record] = batch["record"] 68 | 69 | # Get the predictions 70 | coords = prediction["coords"] 71 | coords = coords.unsqueeze(0) 72 | 73 | pad_masks = prediction["masks"] 74 | 75 | # Get ranking 76 | argsort = torch.argsort(prediction["confidence_score"], descending=True) 77 | idx_to_rank = {idx.item(): rank for rank, idx in enumerate(argsort)} 78 | 79 | # Iterate over the records 80 | for record, coord, pad_mask in zip(records, coords, pad_masks): 81 | # Load the structure 82 | path = self.data_dir / f"{record.id}.npz" 83 | structure: Structure = Structure.load(path) 84 | 85 | # Compute chain map with masked removed, to be used later 86 | chain_map = {} 87 | for i, mask in enumerate(structure.mask): 88 | if mask: 89 | chain_map[len(chain_map)] = i 90 | 91 | # Remove masked chains completely 92 | structure = structure.remove_invalid_chains() 93 | 94 | for model_idx in range(coord.shape[0]): 95 | # Get model coord 96 | model_coord = coord[model_idx] 97 | # Unpad 98 | coord_unpad = model_coord[pad_mask.bool()] 99 | coord_unpad = coord_unpad.cpu().numpy() 100 | 101 | # New atom table 102 | atoms = structure.atoms 103 | atoms["coords"] = coord_unpad 104 | atoms["is_present"] = True 105 | 106 | # Mew residue table 107 | residues = structure.residues 108 | residues["is_present"] = True 109 | 110 | # Update the structure 111 | interfaces = np.array([], dtype=Interface) 112 | new_structure: Structure = replace( 113 | structure, 114 | atoms=atoms, 115 | residues=residues, 116 | interfaces=interfaces, 117 | ) 118 | 119 | # Update chain info 120 | chain_info = [] 121 | for chain in new_structure.chains: 122 | old_chain_idx = chain_map[chain["asym_id"]] 123 | old_chain_info = record.chains[old_chain_idx] 124 | new_chain_info = replace( 125 | old_chain_info, 126 | chain_id=int(chain["asym_id"]), 127 | valid=True, 128 | ) 129 | chain_info.append(new_chain_info) 130 | 131 | # Save the structure 132 | struct_dir = self.output_dir / record.id 133 | struct_dir.mkdir(exist_ok=True) 134 | 135 | # Get plddt's 136 | plddts = None 137 | if "plddt" in prediction: 138 | plddts = prediction["plddt"][model_idx] 139 | 140 | # Create path name 141 | outname = f"{record.id}_model_{idx_to_rank[model_idx]}" 142 | 143 | # Save the structure 144 | if self.output_format == "pdb": 145 | path = struct_dir / f"{outname}.pdb" 146 | with path.open("w") as f: 147 | f.write(to_pdb(new_structure, plddts=plddts)) 148 | elif self.output_format == "mmcif": 149 | path = struct_dir / f"{outname}.cif" 150 | with path.open("w") as f: 151 | f.write(to_mmcif(new_structure, plddts=plddts)) 152 | else: 153 | path = struct_dir / f"{outname}.npz" 154 | np.savez_compressed(path, **asdict(new_structure)) 155 | 156 | # Save confidence summary 157 | if "plddt" in prediction: 158 | path = ( 159 | struct_dir 160 | / f"confidence_{record.id}_model_{idx_to_rank[model_idx]}.json" 161 | ) 162 | confidence_summary_dict = {} 163 | for key in [ 164 | "confidence_score", 165 | "ptm", 166 | "iptm", 167 | "ligand_iptm", 168 | "protein_iptm", 169 | "complex_plddt", 170 | "complex_iplddt", 171 | "complex_pde", 172 | "complex_ipde", 173 | ]: 174 | confidence_summary_dict[key] = prediction[key][model_idx].item() 175 | confidence_summary_dict["chains_ptm"] = { 176 | idx: prediction["pair_chains_iptm"][idx][idx][model_idx].item() 177 | for idx in prediction["pair_chains_iptm"] 178 | } 179 | confidence_summary_dict["pair_chains_iptm"] = { 180 | idx1: { 181 | idx2: prediction["pair_chains_iptm"][idx1][idx2][ 182 | model_idx 183 | ].item() 184 | for idx2 in prediction["pair_chains_iptm"][idx1] 185 | } 186 | for idx1 in prediction["pair_chains_iptm"] 187 | } 188 | with path.open("w") as f: 189 | f.write( 190 | json.dumps( 191 | confidence_summary_dict, 192 | indent=4, 193 | ) 194 | ) 195 | 196 | # Save plddt 197 | plddt = prediction["plddt"][model_idx] 198 | path = ( 199 | struct_dir 200 | / f"plddt_{record.id}_model_{idx_to_rank[model_idx]}.npz" 201 | ) 202 | np.savez_compressed(path, plddt=plddt.cpu().numpy()) 203 | 204 | # Save pae 205 | if "pae" in prediction: 206 | pae = prediction["pae"][model_idx] 207 | path = ( 208 | struct_dir 209 | / f"pae_{record.id}_model_{idx_to_rank[model_idx]}.npz" 210 | ) 211 | np.savez_compressed(path, pae=pae.cpu().numpy()) 212 | 213 | # Save pde 214 | if "pde" in prediction: 215 | pde = prediction["pde"][model_idx] 216 | path = ( 217 | struct_dir 218 | / f"pde_{record.id}_model_{idx_to_rank[model_idx]}.npz" 219 | ) 220 | np.savez_compressed(path, pde=pde.cpu().numpy()) 221 | 222 | def on_predict_epoch_end( 223 | self, 224 | trainer: Trainer, # noqa: ARG002 225 | pl_module: LightningModule, # noqa: ARG002 226 | ) -> None: 227 | """Print the number of failed examples.""" 228 | # Print number of failed examples 229 | print(f"Number of failed examples: {self.failed}") # noqa: T201 230 | -------------------------------------------------------------------------------- /src/boltz/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/model/__init__.py -------------------------------------------------------------------------------- /src/boltz/model/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/model/layers/__init__.py -------------------------------------------------------------------------------- /src/boltz/model/layers/attention.py: -------------------------------------------------------------------------------- 1 | from einops.layers.torch import Rearrange 2 | import torch 3 | from torch import Tensor, nn 4 | 5 | import boltz.model.layers.initialize as init 6 | 7 | 8 | class AttentionPairBias(nn.Module): 9 | """Attention pair bias layer.""" 10 | 11 | def __init__( 12 | self, 13 | c_s: int, 14 | c_z: int, 15 | num_heads: int, 16 | inf: float = 1e6, 17 | initial_norm: bool = True, 18 | ) -> None: 19 | """Initialize the attention pair bias layer. 20 | 21 | Parameters 22 | ---------- 23 | c_s : int 24 | The input sequence dimension. 25 | c_z : int 26 | The input pairwise dimension. 27 | num_heads : int 28 | The number of heads. 29 | inf : float, optional 30 | The inf value, by default 1e6 31 | initial_norm: bool, optional 32 | Whether to apply layer norm to the input, by default True 33 | 34 | """ 35 | super().__init__() 36 | 37 | assert c_s % num_heads == 0 38 | 39 | self.c_s = c_s 40 | self.num_heads = num_heads 41 | self.head_dim = c_s // num_heads 42 | self.inf = inf 43 | 44 | self.initial_norm = initial_norm 45 | if self.initial_norm: 46 | self.norm_s = nn.LayerNorm(c_s) 47 | 48 | self.proj_q = nn.Linear(c_s, c_s) 49 | self.proj_k = nn.Linear(c_s, c_s, bias=False) 50 | self.proj_v = nn.Linear(c_s, c_s, bias=False) 51 | self.proj_g = nn.Linear(c_s, c_s, bias=False) 52 | 53 | self.proj_z = nn.Sequential( 54 | nn.LayerNorm(c_z), 55 | nn.Linear(c_z, num_heads, bias=False), 56 | Rearrange("b ... h -> b h ..."), 57 | ) 58 | 59 | self.proj_o = nn.Linear(c_s, c_s, bias=False) 60 | init.final_init_(self.proj_o.weight) 61 | 62 | def forward( 63 | self, 64 | s: Tensor, 65 | z: Tensor, 66 | mask: Tensor, 67 | multiplicity: int = 1, 68 | to_keys=None, 69 | model_cache=None, 70 | ) -> Tensor: 71 | """Forward pass. 72 | 73 | Parameters 74 | ---------- 75 | s : torch.Tensor 76 | The input sequence tensor (B, S, D) 77 | z : torch.Tensor 78 | The input pairwise tensor (B, N, N, D) 79 | mask : torch.Tensor 80 | The pairwise mask tensor (B, N, N) 81 | multiplicity : int, optional 82 | The diffusion batch size, by default 1 83 | 84 | Returns 85 | ------- 86 | torch.Tensor 87 | The output sequence tensor. 88 | 89 | """ 90 | B = s.shape[0] 91 | 92 | # Layer norms 93 | if self.initial_norm: 94 | s = self.norm_s(s) 95 | 96 | if to_keys is not None: 97 | k_in = to_keys(s) 98 | mask = to_keys(mask.unsqueeze(-1)).squeeze(-1) 99 | else: 100 | k_in = s 101 | 102 | # Compute projections 103 | q = self.proj_q(s).view(B, -1, self.num_heads, self.head_dim) 104 | k = self.proj_k(k_in).view(B, -1, self.num_heads, self.head_dim) 105 | v = self.proj_v(k_in).view(B, -1, self.num_heads, self.head_dim) 106 | 107 | # Caching z projection during diffusion roll-out 108 | if model_cache is None or "z" not in model_cache: 109 | z = self.proj_z(z) 110 | 111 | if model_cache is not None: 112 | model_cache["z"] = z 113 | else: 114 | z = model_cache["z"] 115 | z = z.repeat_interleave(multiplicity, 0) 116 | 117 | g = self.proj_g(s).sigmoid() 118 | 119 | with torch.autocast("cuda", enabled=False): 120 | # Compute attention weights 121 | attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float()) 122 | attn = attn / (self.head_dim**0.5) + z.float() 123 | attn = attn + (1 - mask[:, None, None].float()) * -self.inf 124 | attn = attn.softmax(dim=-1) 125 | 126 | # Compute output 127 | o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype) 128 | o = o.reshape(B, -1, self.c_s) 129 | o = self.proj_o(g * o) 130 | 131 | return o 132 | -------------------------------------------------------------------------------- /src/boltz/model/layers/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def get_dropout_mask( 6 | dropout: float, 7 | z: Tensor, 8 | training: bool, 9 | columnwise: bool = False, 10 | ) -> Tensor: 11 | """Get the dropout mask. 12 | 13 | Parameters 14 | ---------- 15 | dropout : float 16 | The dropout rate 17 | z : torch.Tensor 18 | The tensor to apply dropout to 19 | training : bool 20 | Whether the model is in training mode 21 | columnwise : bool, optional 22 | Whether to apply dropout columnwise 23 | 24 | Returns 25 | ------- 26 | torch.Tensor 27 | The dropout mask 28 | 29 | """ 30 | dropout = dropout * training 31 | v = z[:, 0:1, :, 0:1] if columnwise else z[:, :, 0:1, 0:1] 32 | d = torch.rand_like(v) > dropout 33 | d = d * 1.0 / (1.0 - dropout) 34 | return d 35 | -------------------------------------------------------------------------------- /src/boltz/model/layers/initialize.py: -------------------------------------------------------------------------------- 1 | """Utility functions for initializing weights and biases.""" 2 | 3 | # Copyright 2021 AlQuraishi Laboratory 4 | # Copyright 2021 DeepMind Technologies Limited 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import math 19 | import numpy as np 20 | from scipy.stats import truncnorm 21 | import torch 22 | 23 | 24 | def _prod(nums): 25 | out = 1 26 | for n in nums: 27 | out = out * n 28 | return out 29 | 30 | 31 | def _calculate_fan(linear_weight_shape, fan="fan_in"): 32 | fan_out, fan_in = linear_weight_shape 33 | 34 | if fan == "fan_in": 35 | f = fan_in 36 | elif fan == "fan_out": 37 | f = fan_out 38 | elif fan == "fan_avg": 39 | f = (fan_in + fan_out) / 2 40 | else: 41 | raise ValueError("Invalid fan option") 42 | 43 | return f 44 | 45 | 46 | def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): 47 | shape = weights.shape 48 | f = _calculate_fan(shape, fan) 49 | scale = scale / max(1, f) 50 | a = -2 51 | b = 2 52 | std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) 53 | size = _prod(shape) 54 | samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) 55 | samples = np.reshape(samples, shape) 56 | with torch.no_grad(): 57 | weights.copy_(torch.tensor(samples, device=weights.device)) 58 | 59 | 60 | def lecun_normal_init_(weights): 61 | trunc_normal_init_(weights, scale=1.0) 62 | 63 | 64 | def he_normal_init_(weights): 65 | trunc_normal_init_(weights, scale=2.0) 66 | 67 | 68 | def glorot_uniform_init_(weights): 69 | torch.nn.init.xavier_uniform_(weights, gain=1) 70 | 71 | 72 | def final_init_(weights): 73 | with torch.no_grad(): 74 | weights.fill_(0.0) 75 | 76 | 77 | def gating_init_(weights): 78 | with torch.no_grad(): 79 | weights.fill_(0.0) 80 | 81 | 82 | def bias_init_zero_(bias): 83 | with torch.no_grad(): 84 | bias.fill_(0.0) 85 | 86 | 87 | def bias_init_one_(bias): 88 | with torch.no_grad(): 89 | bias.fill_(1.0) 90 | 91 | 92 | def normal_init_(weights): 93 | torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") 94 | 95 | 96 | def ipa_point_weights_init_(weights): 97 | with torch.no_grad(): 98 | softplus_inverse_1 = 0.541324854612918 99 | weights.fill_(softplus_inverse_1) 100 | -------------------------------------------------------------------------------- /src/boltz/model/layers/outer_product_mean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | import boltz.model.layers.initialize as init 5 | 6 | 7 | class OuterProductMean(nn.Module): 8 | """Outer product mean layer.""" 9 | 10 | def __init__(self, c_in: int, c_hidden: int, c_out: int) -> None: 11 | """Initialize the outer product mean layer. 12 | 13 | Parameters 14 | ---------- 15 | c_in : int 16 | The input dimension. 17 | c_hidden : int 18 | The hidden dimension. 19 | c_out : int 20 | The output dimension. 21 | 22 | """ 23 | super().__init__() 24 | self.c_hidden = c_hidden 25 | self.norm = nn.LayerNorm(c_in) 26 | self.proj_a = nn.Linear(c_in, c_hidden, bias=False) 27 | self.proj_b = nn.Linear(c_in, c_hidden, bias=False) 28 | self.proj_o = nn.Linear(c_hidden * c_hidden, c_out) 29 | init.final_init_(self.proj_o.weight) 30 | init.final_init_(self.proj_o.bias) 31 | 32 | def forward(self, m: Tensor, mask: Tensor, chunk_size: int = None) -> Tensor: 33 | """Forward pass. 34 | 35 | Parameters 36 | ---------- 37 | m : torch.Tensor 38 | The sequence tensor (B, S, N, c_in). 39 | mask : torch.Tensor 40 | The mask tensor (B, S, N). 41 | 42 | Returns 43 | ------- 44 | torch.Tensor 45 | The output tensor (B, N, N, c_out). 46 | 47 | """ 48 | # Expand mask 49 | mask = mask.unsqueeze(-1).to(m) 50 | 51 | # Compute projections 52 | m = self.norm(m) 53 | a = self.proj_a(m) * mask 54 | b = self.proj_b(m) * mask 55 | 56 | # Compute outer product mean 57 | if chunk_size is not None and not self.training: 58 | # Compute pairwise mask 59 | for i in range(0, mask.shape[1], 64): 60 | if i == 0: 61 | num_mask = ( 62 | mask[:, i : i + 64, None, :] * mask[:, i : i + 64, :, None] 63 | ).sum(1) 64 | else: 65 | num_mask += ( 66 | mask[:, i : i + 64, None, :] * mask[:, i : i + 64, :, None] 67 | ).sum(1) 68 | num_mask = num_mask.clamp(min=1) 69 | 70 | # Compute squentially in chunks 71 | for i in range(0, self.c_hidden, chunk_size): 72 | a_chunk = a[:, :, :, i : i + chunk_size] 73 | sliced_weight_proj_o = self.proj_o.weight[ 74 | :, i * self.c_hidden : (i + chunk_size) * self.c_hidden 75 | ] 76 | 77 | z = torch.einsum("bsic,bsjd->bijcd", a_chunk, b) 78 | z = z.reshape(*z.shape[:3], -1) 79 | z = z / num_mask 80 | 81 | # Project to output 82 | if i == 0: 83 | z_out = z.to(m) @ sliced_weight_proj_o.T 84 | else: 85 | z_out = z_out + z.to(m) @ sliced_weight_proj_o.T 86 | return z_out 87 | else: 88 | mask = mask[:, :, None, :] * mask[:, :, :, None] 89 | num_mask = mask.sum(1).clamp(min=1) 90 | z = torch.einsum("bsic,bsjd->bijcd", a.float(), b.float()) 91 | z = z.reshape(*z.shape[:3], -1) 92 | z = z / num_mask 93 | 94 | # Project to output 95 | z = self.proj_o(z.to(m)) 96 | return z 97 | -------------------------------------------------------------------------------- /src/boltz/model/layers/pair_averaging.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | import boltz.model.layers.initialize as init 5 | 6 | 7 | class PairWeightedAveraging(nn.Module): 8 | """Pair weighted averaging layer.""" 9 | 10 | def __init__( 11 | self, 12 | c_m: int, 13 | c_z: int, 14 | c_h: int, 15 | num_heads: int, 16 | inf: float = 1e6, 17 | ) -> None: 18 | """Initialize the pair weighted averaging layer. 19 | 20 | Parameters 21 | ---------- 22 | c_m: int 23 | The dimension of the input sequence. 24 | c_z: int 25 | The dimension of the input pairwise tensor. 26 | c_h: int 27 | The dimension of the hidden. 28 | num_heads: int 29 | The number of heads. 30 | inf: float 31 | The value to use for masking, default 1e6. 32 | 33 | """ 34 | super().__init__() 35 | self.c_m = c_m 36 | self.c_z = c_z 37 | self.c_h = c_h 38 | self.num_heads = num_heads 39 | self.inf = inf 40 | 41 | self.norm_m = nn.LayerNorm(c_m) 42 | self.norm_z = nn.LayerNorm(c_z) 43 | 44 | self.proj_m = nn.Linear(c_m, c_h * num_heads, bias=False) 45 | self.proj_g = nn.Linear(c_m, c_h * num_heads, bias=False) 46 | self.proj_z = nn.Linear(c_z, num_heads, bias=False) 47 | self.proj_o = nn.Linear(c_h * num_heads, c_m, bias=False) 48 | init.final_init_(self.proj_o.weight) 49 | 50 | def forward( 51 | self, m: Tensor, z: Tensor, mask: Tensor, chunk_heads: False = bool 52 | ) -> Tensor: 53 | """Forward pass. 54 | 55 | Parameters 56 | ---------- 57 | m : torch.Tensor 58 | The input sequence tensor (B, S, N, D) 59 | z : torch.Tensor 60 | The input pairwise tensor (B, N, N, D) 61 | mask : torch.Tensor 62 | The pairwise mask tensor (B, N, N) 63 | 64 | Returns 65 | ------- 66 | torch.Tensor 67 | The output sequence tensor (B, S, N, D) 68 | 69 | """ 70 | # Compute layer norms 71 | m = self.norm_m(m) 72 | z = self.norm_z(z) 73 | 74 | if chunk_heads and not self.training: 75 | # Compute heads sequentially 76 | o_chunks = [] 77 | for head_idx in range(self.num_heads): 78 | sliced_weight_proj_m = self.proj_m.weight[ 79 | head_idx * self.c_h : (head_idx + 1) * self.c_h, : 80 | ] 81 | sliced_weight_proj_g = self.proj_g.weight[ 82 | head_idx * self.c_h : (head_idx + 1) * self.c_h, : 83 | ] 84 | sliced_weight_proj_z = self.proj_z.weight[head_idx : (head_idx + 1), :] 85 | sliced_weight_proj_o = self.proj_o.weight[ 86 | :, head_idx * self.c_h : (head_idx + 1) * self.c_h 87 | ] 88 | 89 | # Project input tensors 90 | v: Tensor = m @ sliced_weight_proj_m.T 91 | v = v.reshape(*v.shape[:3], 1, self.c_h) 92 | v = v.permute(0, 3, 1, 2, 4) 93 | 94 | # Compute weights 95 | b: Tensor = z @ sliced_weight_proj_z.T 96 | b = b.permute(0, 3, 1, 2) 97 | b = b + (1 - mask[:, None]) * -self.inf 98 | w = torch.softmax(b, dim=-1) 99 | 100 | # Compute gating 101 | g: Tensor = m @ sliced_weight_proj_g.T 102 | g = g.sigmoid() 103 | 104 | # Compute output 105 | o = torch.einsum("bhij,bhsjd->bhsid", w, v) 106 | o = o.permute(0, 2, 3, 1, 4) 107 | o = o.reshape(*o.shape[:3], 1 * self.c_h) 108 | o_chunks = g * o 109 | if head_idx == 0: 110 | o_out = o_chunks @ sliced_weight_proj_o.T 111 | else: 112 | o_out += o_chunks @ sliced_weight_proj_o.T 113 | return o_out 114 | else: 115 | # Project input tensors 116 | v: Tensor = self.proj_m(m) 117 | v = v.reshape(*v.shape[:3], self.num_heads, self.c_h) 118 | v = v.permute(0, 3, 1, 2, 4) 119 | 120 | # Compute weights 121 | b: Tensor = self.proj_z(z) 122 | b = b.permute(0, 3, 1, 2) 123 | b = b + (1 - mask[:, None]) * -self.inf 124 | w = torch.softmax(b, dim=-1) 125 | 126 | # Compute gating 127 | g: Tensor = self.proj_g(m) 128 | g = g.sigmoid() 129 | 130 | # Compute output 131 | o = torch.einsum("bhij,bhsjd->bhsid", w, v) 132 | o = o.permute(0, 2, 3, 1, 4) 133 | o = o.reshape(*o.shape[:3], self.num_heads * self.c_h) 134 | o = self.proj_o(g * o) 135 | return o 136 | -------------------------------------------------------------------------------- /src/boltz/model/layers/transition.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torch import Tensor, nn 4 | 5 | import boltz.model.layers.initialize as init 6 | 7 | 8 | class Transition(nn.Module): 9 | """Perform a two-layer MLP.""" 10 | 11 | def __init__( 12 | self, 13 | dim: int = 128, 14 | hidden: int = 512, 15 | out_dim: Optional[int] = None, 16 | ) -> None: 17 | """Initialize the TransitionUpdate module. 18 | 19 | Parameters 20 | ---------- 21 | dim: int 22 | The dimension of the input, default 128 23 | hidden: int 24 | The dimension of the hidden, default 512 25 | out_dim: Optional[int] 26 | The dimension of the output, default None 27 | 28 | """ 29 | super().__init__() 30 | if out_dim is None: 31 | out_dim = dim 32 | 33 | self.norm = nn.LayerNorm(dim, eps=1e-5) 34 | self.fc1 = nn.Linear(dim, hidden, bias=False) 35 | self.fc2 = nn.Linear(dim, hidden, bias=False) 36 | self.fc3 = nn.Linear(hidden, out_dim, bias=False) 37 | self.silu = nn.SiLU() 38 | self.hidden = hidden 39 | 40 | init.bias_init_one_(self.norm.weight) 41 | init.bias_init_zero_(self.norm.bias) 42 | 43 | init.lecun_normal_init_(self.fc1.weight) 44 | init.lecun_normal_init_(self.fc2.weight) 45 | init.final_init_(self.fc3.weight) 46 | 47 | def forward(self, x: Tensor, chunk_size: int = None) -> Tensor: 48 | """Perform a forward pass. 49 | 50 | Parameters 51 | ---------- 52 | x: torch.Tensor 53 | The input data of shape (..., D) 54 | 55 | Returns 56 | ------- 57 | x: torch.Tensor 58 | The output data of shape (..., D) 59 | 60 | """ 61 | x = self.norm(x) 62 | 63 | if chunk_size is None or self.training: 64 | x = self.silu(self.fc1(x)) * self.fc2(x) 65 | x = self.fc3(x) 66 | return x 67 | else: 68 | # Compute in chunks 69 | for i in range(0, self.hidden, chunk_size): 70 | fc1_slice = self.fc1.weight[i : i + chunk_size, :] 71 | fc2_slice = self.fc2.weight[i : i + chunk_size, :] 72 | fc3_slice = self.fc3.weight[:, i : i + chunk_size] 73 | x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T) 74 | if i == 0: 75 | x_out = x_chunk @ fc3_slice.T 76 | else: 77 | x_out = x_out + x_chunk @ fc3_slice.T 78 | return x_out 79 | -------------------------------------------------------------------------------- /src/boltz/model/layers/triangular_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/model/layers/triangular_attention/__init__.py -------------------------------------------------------------------------------- /src/boltz/model/layers/triangular_attention/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial, partialmethod 17 | from typing import List, Optional 18 | 19 | import torch 20 | import torch.nn as nn 21 | 22 | from boltz.model.layers.triangular_attention.primitives import ( 23 | Attention, 24 | LayerNorm, 25 | Linear, 26 | ) 27 | from boltz.model.layers.triangular_attention.utils import ( 28 | chunk_layer, 29 | permute_final_dims, 30 | ) 31 | 32 | 33 | class TriangleAttention(nn.Module): 34 | def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9): 35 | """ 36 | Args: 37 | c_in: 38 | Input channel dimension 39 | c_hidden: 40 | Overall hidden channel dimension (not per-head) 41 | no_heads: 42 | Number of attention heads 43 | """ 44 | super(TriangleAttention, self).__init__() 45 | 46 | self.c_in = c_in 47 | self.c_hidden = c_hidden 48 | self.no_heads = no_heads 49 | self.starting = starting 50 | self.inf = inf 51 | 52 | self.layer_norm = LayerNorm(self.c_in) 53 | 54 | self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") 55 | 56 | self.mha = Attention( 57 | self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads 58 | ) 59 | 60 | @torch.jit.ignore 61 | def _chunk( 62 | self, 63 | x: torch.Tensor, 64 | biases: List[torch.Tensor], 65 | chunk_size: int, 66 | use_memory_efficient_kernel: bool = False, 67 | use_deepspeed_evo_attention: bool = False, 68 | use_lma: bool = False, 69 | inplace_safe: bool = False, 70 | ) -> torch.Tensor: 71 | "triangle! triangle!" 72 | mha_inputs = { 73 | "q_x": x, 74 | "kv_x": x, 75 | "biases": biases, 76 | } 77 | 78 | return chunk_layer( 79 | partial( 80 | self.mha, 81 | use_memory_efficient_kernel=use_memory_efficient_kernel, 82 | use_deepspeed_evo_attention=use_deepspeed_evo_attention, 83 | use_lma=use_lma, 84 | ), 85 | mha_inputs, 86 | chunk_size=chunk_size, 87 | no_batch_dims=len(x.shape[:-2]), 88 | _out=x if inplace_safe else None, 89 | ) 90 | 91 | def forward( 92 | self, 93 | x: torch.Tensor, 94 | mask: Optional[torch.Tensor] = None, 95 | chunk_size: Optional[int] = None, 96 | use_memory_efficient_kernel: bool = False, 97 | use_deepspeed_evo_attention: bool = False, 98 | use_lma: bool = False, 99 | inplace_safe: bool = False, 100 | ) -> torch.Tensor: 101 | """ 102 | Args: 103 | x: 104 | [*, I, J, C_in] input tensor (e.g. the pair representation) 105 | Returns: 106 | [*, I, J, C_in] output tensor 107 | """ 108 | if mask is None: 109 | # [*, I, J] 110 | mask = x.new_ones( 111 | x.shape[:-1], 112 | ) 113 | 114 | if not self.starting: 115 | x = x.transpose(-2, -3) 116 | mask = mask.transpose(-1, -2) 117 | 118 | # [*, I, J, C_in] 119 | x = self.layer_norm(x) 120 | 121 | # [*, I, 1, 1, J] 122 | mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] 123 | 124 | # [*, H, I, J] 125 | triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) 126 | 127 | # [*, 1, H, I, J] 128 | triangle_bias = triangle_bias.unsqueeze(-4) 129 | 130 | biases = [mask_bias, triangle_bias] 131 | 132 | if chunk_size is not None: 133 | x = self._chunk( 134 | x, 135 | biases, 136 | chunk_size, 137 | use_memory_efficient_kernel=use_memory_efficient_kernel, 138 | use_deepspeed_evo_attention=use_deepspeed_evo_attention, 139 | use_lma=use_lma, 140 | inplace_safe=inplace_safe, 141 | ) 142 | else: 143 | x = self.mha( 144 | q_x=x, 145 | kv_x=x, 146 | biases=biases, 147 | use_memory_efficient_kernel=use_memory_efficient_kernel, 148 | use_deepspeed_evo_attention=use_deepspeed_evo_attention, 149 | use_lma=use_lma, 150 | ) 151 | 152 | if not self.starting: 153 | x = x.transpose(-2, -3) 154 | 155 | return x 156 | 157 | 158 | # Implements Algorithm 13 159 | TriangleAttentionStartingNode = TriangleAttention 160 | 161 | 162 | class TriangleAttentionEndingNode(TriangleAttention): 163 | """Implement Algorithm 14.""" 164 | 165 | __init__ = partialmethod(TriangleAttention.__init__, starting=False) 166 | -------------------------------------------------------------------------------- /src/boltz/model/layers/triangular_mult.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from boltz.model.layers import initialize as init 5 | 6 | 7 | class TriangleMultiplicationOutgoing(nn.Module): 8 | """TriangleMultiplicationOutgoing.""" 9 | 10 | def __init__(self, dim: int = 128) -> None: 11 | """Initialize the TriangularUpdate module. 12 | 13 | Parameters 14 | ---------- 15 | dim: int 16 | The dimension of the input, default 128 17 | 18 | """ 19 | super().__init__() 20 | 21 | self.norm_in = nn.LayerNorm(dim, eps=1e-5) 22 | self.p_in = nn.Linear(dim, 2 * dim, bias=False) 23 | self.g_in = nn.Linear(dim, 2 * dim, bias=False) 24 | 25 | self.norm_out = nn.LayerNorm(dim) 26 | self.p_out = nn.Linear(dim, dim, bias=False) 27 | self.g_out = nn.Linear(dim, dim, bias=False) 28 | 29 | init.bias_init_one_(self.norm_in.weight) 30 | init.bias_init_zero_(self.norm_in.bias) 31 | 32 | init.lecun_normal_init_(self.p_in.weight) 33 | init.gating_init_(self.g_in.weight) 34 | 35 | init.bias_init_one_(self.norm_out.weight) 36 | init.bias_init_zero_(self.norm_out.bias) 37 | 38 | init.final_init_(self.p_out.weight) 39 | init.gating_init_(self.g_out.weight) 40 | 41 | def forward(self, x: Tensor, mask: Tensor) -> Tensor: 42 | """Perform a forward pass. 43 | 44 | Parameters 45 | ---------- 46 | x: torch.Tensor 47 | The input data of shape (B, N, N, D) 48 | mask: torch.Tensor 49 | The input mask of shape (B, N, N) 50 | 51 | Returns 52 | ------- 53 | x: torch.Tensor 54 | The output data of shape (B, N, N, D) 55 | 56 | """ 57 | # Input gating: D -> D 58 | x = self.norm_in(x) 59 | x_in = x 60 | x = self.p_in(x) * self.g_in(x).sigmoid() 61 | 62 | # Apply mask 63 | x = x * mask.unsqueeze(-1) 64 | 65 | # Split input and cast to float 66 | a, b = torch.chunk(x.float(), 2, dim=-1) 67 | 68 | # Triangular projection 69 | x = torch.einsum("bikd,bjkd->bijd", a, b) 70 | 71 | # Output gating 72 | x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid() 73 | 74 | return x 75 | 76 | 77 | class TriangleMultiplicationIncoming(nn.Module): 78 | """TriangleMultiplicationIncoming.""" 79 | 80 | def __init__(self, dim: int = 128) -> None: 81 | """Initialize the TriangularUpdate module. 82 | 83 | Parameters 84 | ---------- 85 | dim: int 86 | The dimension of the input, default 128 87 | 88 | """ 89 | super().__init__() 90 | 91 | self.norm_in = nn.LayerNorm(dim, eps=1e-5) 92 | self.p_in = nn.Linear(dim, 2 * dim, bias=False) 93 | self.g_in = nn.Linear(dim, 2 * dim, bias=False) 94 | 95 | self.norm_out = nn.LayerNorm(dim) 96 | self.p_out = nn.Linear(dim, dim, bias=False) 97 | self.g_out = nn.Linear(dim, dim, bias=False) 98 | 99 | init.bias_init_one_(self.norm_in.weight) 100 | init.bias_init_zero_(self.norm_in.bias) 101 | 102 | init.lecun_normal_init_(self.p_in.weight) 103 | init.gating_init_(self.g_in.weight) 104 | 105 | init.bias_init_one_(self.norm_out.weight) 106 | init.bias_init_zero_(self.norm_out.bias) 107 | 108 | init.final_init_(self.p_out.weight) 109 | init.gating_init_(self.g_out.weight) 110 | 111 | def forward(self, x: Tensor, mask: Tensor) -> Tensor: 112 | """Perform a forward pass. 113 | 114 | Parameters 115 | ---------- 116 | x: torch.Tensor 117 | The input data of shape (B, N, N, D) 118 | mask: torch.Tensor 119 | The input mask of shape (B, N, N) 120 | 121 | Returns 122 | ------- 123 | x: torch.Tensor 124 | The output data of shape (B, N, N, D) 125 | 126 | """ 127 | # Input gating: D -> D 128 | x = self.norm_in(x) 129 | x_in = x 130 | x = self.p_in(x) * self.g_in(x).sigmoid() 131 | 132 | # Apply mask 133 | x = x * mask.unsqueeze(-1) 134 | 135 | # Split input and cast to float 136 | a, b = torch.chunk(x.float(), 2, dim=-1) 137 | 138 | # Triangular projection 139 | x = torch.einsum("bkid,bkjd->bijd", a, b) 140 | 141 | # Output gating 142 | x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid() 143 | 144 | return x 145 | -------------------------------------------------------------------------------- /src/boltz/model/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/model/loss/__init__.py -------------------------------------------------------------------------------- /src/boltz/model/loss/diffusion.py: -------------------------------------------------------------------------------- 1 | # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang 2 | 3 | from einops import einsum 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def weighted_rigid_align( 9 | true_coords, 10 | pred_coords, 11 | weights, 12 | mask, 13 | ): 14 | """Compute weighted alignment. 15 | 16 | Parameters 17 | ---------- 18 | true_coords: torch.Tensor 19 | The ground truth atom coordinates 20 | pred_coords: torch.Tensor 21 | The predicted atom coordinates 22 | weights: torch.Tensor 23 | The weights for alignment 24 | mask: torch.Tensor 25 | The atoms mask 26 | 27 | Returns 28 | ------- 29 | torch.Tensor 30 | Aligned coordinates 31 | 32 | """ 33 | 34 | batch_size, num_points, dim = true_coords.shape 35 | weights = (mask * weights).unsqueeze(-1) 36 | 37 | # Compute weighted centroids 38 | true_centroid = (true_coords * weights).sum(dim=1, keepdim=True) / weights.sum( 39 | dim=1, keepdim=True 40 | ) 41 | pred_centroid = (pred_coords * weights).sum(dim=1, keepdim=True) / weights.sum( 42 | dim=1, keepdim=True 43 | ) 44 | 45 | # Center the coordinates 46 | true_coords_centered = true_coords - true_centroid 47 | pred_coords_centered = pred_coords - pred_centroid 48 | 49 | if num_points < (dim + 1): 50 | print( 51 | "Warning: The size of one of the point clouds is <= dim+1. " 52 | + "`WeightedRigidAlign` cannot return a unique rotation." 53 | ) 54 | 55 | # Compute the weighted covariance matrix 56 | cov_matrix = einsum( 57 | weights * pred_coords_centered, true_coords_centered, "b n i, b n j -> b i j" 58 | ) 59 | 60 | # Compute the SVD of the covariance matrix, required float32 for svd and determinant 61 | original_dtype = cov_matrix.dtype 62 | cov_matrix_32 = cov_matrix.to(dtype=torch.float32) 63 | U, S, V = torch.linalg.svd( 64 | cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None 65 | ) 66 | V = V.mH 67 | 68 | # Catch ambiguous rotation by checking the magnitude of singular values 69 | if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)): 70 | print( 71 | "Warning: Excessively low rank of " 72 | + "cross-correlation between aligned point clouds. " 73 | + "`WeightedRigidAlign` cannot return a unique rotation." 74 | ) 75 | 76 | # Compute the rotation matrix 77 | rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V).to(dtype=torch.float32) 78 | 79 | # Ensure proper rotation matrix with determinant 1 80 | F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[ 81 | None 82 | ].repeat(batch_size, 1, 1) 83 | F[:, -1, -1] = torch.det(rot_matrix) 84 | rot_matrix = einsum(U, F, V, "b i j, b j k, b l k -> b i l") 85 | rot_matrix = rot_matrix.to(dtype=original_dtype) 86 | 87 | # Apply the rotation and translation 88 | aligned_coords = ( 89 | einsum(true_coords_centered, rot_matrix, "b n i, b j i -> b n j") 90 | + pred_centroid 91 | ) 92 | aligned_coords.detach_() 93 | 94 | return aligned_coords 95 | 96 | 97 | def smooth_lddt_loss( 98 | pred_coords, 99 | true_coords, 100 | is_nucleotide, 101 | coords_mask, 102 | nucleic_acid_cutoff: float = 30.0, 103 | other_cutoff: float = 15.0, 104 | multiplicity: int = 1, 105 | ): 106 | """Compute weighted alignment. 107 | 108 | Parameters 109 | ---------- 110 | pred_coords: torch.Tensor 111 | The predicted atom coordinates 112 | true_coords: torch.Tensor 113 | The ground truth atom coordinates 114 | is_nucleotide: torch.Tensor 115 | The weights for alignment 116 | coords_mask: torch.Tensor 117 | The atoms mask 118 | nucleic_acid_cutoff: float 119 | The nucleic acid cutoff 120 | other_cutoff: float 121 | The non nucleic acid cutoff 122 | multiplicity: int 123 | The multiplicity 124 | Returns 125 | ------- 126 | torch.Tensor 127 | Aligned coordinates 128 | 129 | """ 130 | B, N, _ = true_coords.shape 131 | true_dists = torch.cdist(true_coords, true_coords) 132 | is_nucleotide = is_nucleotide.repeat_interleave(multiplicity, 0) 133 | 134 | coords_mask = coords_mask.repeat_interleave(multiplicity, 0) 135 | is_nucleotide_pair = is_nucleotide.unsqueeze(-1).expand( 136 | -1, -1, is_nucleotide.shape[-1] 137 | ) 138 | 139 | mask = ( 140 | is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float() 141 | + (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float() 142 | ) 143 | mask = mask * (1 - torch.eye(pred_coords.shape[1], device=pred_coords.device)) 144 | mask = mask * (coords_mask.unsqueeze(-1) * coords_mask.unsqueeze(-2)) 145 | 146 | # Compute distances between all pairs of atoms 147 | pred_dists = torch.cdist(pred_coords, pred_coords) 148 | dist_diff = torch.abs(true_dists - pred_dists) 149 | 150 | # Compute epsilon values 151 | eps = ( 152 | ( 153 | ( 154 | F.sigmoid(0.5 - dist_diff) 155 | + F.sigmoid(1.0 - dist_diff) 156 | + F.sigmoid(2.0 - dist_diff) 157 | + F.sigmoid(4.0 - dist_diff) 158 | ) 159 | / 4.0 160 | ) 161 | .view(multiplicity, B // multiplicity, N, N) 162 | .mean(dim=0) 163 | ) 164 | 165 | # Calculate masked averaging 166 | eps = eps.repeat_interleave(multiplicity, 0) 167 | num = (eps * mask).sum(dim=(-1, -2)) 168 | den = mask.sum(dim=(-1, -2)).clamp(min=1) 169 | lddt = num / den 170 | 171 | return 1.0 - lddt.mean() 172 | -------------------------------------------------------------------------------- /src/boltz/model/loss/distogram.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def distogram_loss( 8 | output: Dict[str, Tensor], 9 | feats: Dict[str, Tensor], 10 | ) -> Tuple[Tensor, Tensor]: 11 | """Compute the distogram loss. 12 | 13 | Parameters 14 | ---------- 15 | output : Dict[str, Tensor] 16 | Output of the model 17 | feats : Dict[str, Tensor] 18 | Input features 19 | 20 | Returns 21 | ------- 22 | Tensor 23 | The globally averaged loss. 24 | Tensor 25 | Per example loss. 26 | 27 | """ 28 | # Get predicted distograms 29 | pred = output["pdistogram"] 30 | 31 | # Compute target distogram 32 | target = feats["disto_target"] 33 | 34 | # Combine target mask and padding mask 35 | mask = feats["token_disto_mask"] 36 | mask = mask[:, None, :] * mask[:, :, None] 37 | mask = mask * (1 - torch.eye(mask.shape[1])[None]).to(pred) 38 | 39 | # Compute the distogram loss 40 | errors = -1 * torch.sum( 41 | target * torch.nn.functional.log_softmax(pred, dim=-1), 42 | dim=-1, 43 | ) 44 | denom = 1e-5 + torch.sum(mask, dim=(-1, -2)) 45 | mean = errors * mask 46 | mean = torch.sum(mean, dim=-1) 47 | mean = mean / denom[..., None] 48 | batch_loss = torch.sum(mean, dim=-1) 49 | global_loss = torch.mean(batch_loss) 50 | return global_loss, batch_loss 51 | -------------------------------------------------------------------------------- /src/boltz/model/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/model/modules/__init__.py -------------------------------------------------------------------------------- /src/boltz/model/modules/confidence_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from boltz.data import const 5 | from boltz.model.loss.confidence import compute_frame_pred 6 | 7 | 8 | def compute_aggregated_metric(logits, end=1.0): 9 | """Compute the metric from the logits. 10 | 11 | Parameters 12 | ---------- 13 | logits : torch.Tensor 14 | The logits of the metric 15 | end : float 16 | Max value of the metric, by default 1.0 17 | 18 | Returns 19 | ------- 20 | Tensor 21 | The metric value 22 | 23 | """ 24 | num_bins = logits.shape[-1] 25 | bin_width = end / num_bins 26 | bounds = torch.arange( 27 | start=0.5 * bin_width, end=end, step=bin_width, device=logits.device 28 | ) 29 | probs = nn.functional.softmax(logits, dim=-1) 30 | plddt = torch.sum( 31 | probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape), 32 | dim=-1, 33 | ) 34 | return plddt 35 | 36 | 37 | def tm_function(d, Nres): 38 | """Compute the rescaling function for pTM. 39 | 40 | Parameters 41 | ---------- 42 | d : torch.Tensor 43 | The input 44 | Nres : torch.Tensor 45 | The number of residues 46 | 47 | Returns 48 | ------- 49 | Tensor 50 | Output of the function 51 | 52 | """ 53 | d0 = 1.24 * (torch.clip(Nres, min=19) - 15) ** (1 / 3) - 1.8 54 | return 1 / (1 + (d / d0) ** 2) 55 | 56 | 57 | def compute_ptms(logits, x_preds, feats, multiplicity): 58 | """Compute pTM and ipTM scores. 59 | 60 | Parameters 61 | ---------- 62 | logits : torch.Tensor 63 | pae logits 64 | x_preds : torch.Tensor 65 | The predicted coordinates 66 | feats : Dict[str, torch.Tensor] 67 | The input features 68 | multiplicity : int 69 | The batch size of the diffusion roll-out 70 | 71 | Returns 72 | ------- 73 | Tensor 74 | pTM score 75 | Tensor 76 | ipTM score 77 | Tensor 78 | ligand ipTM score 79 | Tensor 80 | protein ipTM score 81 | 82 | """ 83 | # Compute mask for collinear and overlapping tokens 84 | _, mask_collinear_pred = compute_frame_pred( 85 | x_preds, feats["frames_idx"], feats, multiplicity, inference=True 86 | ) 87 | mask_pad = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) 88 | maski = mask_collinear_pred.reshape(-1, mask_collinear_pred.shape[-1]) 89 | pair_mask_ptm = maski[:, :, None] * mask_pad[:, None, :] * mask_pad[:, :, None] 90 | asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0) 91 | pair_mask_iptm = ( 92 | maski[:, :, None] 93 | * (asym_id[:, None, :] != asym_id[:, :, None]) 94 | * mask_pad[:, None, :] 95 | * mask_pad[:, :, None] 96 | ) 97 | 98 | # Extract pae values 99 | num_bins = logits.shape[-1] 100 | bin_width = 32.0 / num_bins 101 | end = 32.0 102 | pae_value = torch.arange( 103 | start=0.5 * bin_width, end=end, step=bin_width, device=logits.device 104 | ).unsqueeze(0) 105 | N_res = mask_pad.sum(dim=-1, keepdim=True) 106 | 107 | # compute pTM and ipTM 108 | tm_value = tm_function(pae_value, N_res).unsqueeze(1).unsqueeze(2) 109 | probs = nn.functional.softmax(logits, dim=-1) 110 | tm_expected_value = torch.sum( 111 | probs * tm_value, 112 | dim=-1, 113 | ) # shape (B, N, N) 114 | ptm = torch.max( 115 | torch.sum(tm_expected_value * pair_mask_ptm, dim=-1) 116 | / (torch.sum(pair_mask_ptm, dim=-1) + 1e-5), 117 | dim=1, 118 | ).values 119 | iptm = torch.max( 120 | torch.sum(tm_expected_value * pair_mask_iptm, dim=-1) 121 | / (torch.sum(pair_mask_iptm, dim=-1) + 1e-5), 122 | dim=1, 123 | ).values 124 | 125 | # compute ligand and protein ipTM 126 | token_type = feats["mol_type"] 127 | token_type = token_type.repeat_interleave(multiplicity, 0) 128 | is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float() 129 | is_protein_token = (token_type == const.chain_type_ids["PROTEIN"]).float() 130 | 131 | ligand_iptm_mask = ( 132 | maski[:, :, None] 133 | * (asym_id[:, None, :] != asym_id[:, :, None]) 134 | * mask_pad[:, None, :] 135 | * mask_pad[:, :, None] 136 | * ( 137 | (is_ligand_token[:, :, None] * is_protein_token[:, None, :]) 138 | + (is_protein_token[:, :, None] * is_ligand_token[:, None, :]) 139 | ) 140 | ) 141 | protein_ipmt_mask = ( 142 | maski[:, :, None] 143 | * (asym_id[:, None, :] != asym_id[:, :, None]) 144 | * mask_pad[:, None, :] 145 | * mask_pad[:, :, None] 146 | * (is_protein_token[:, :, None] * is_protein_token[:, None, :]) 147 | ) 148 | 149 | ligand_iptm = torch.max( 150 | torch.sum(tm_expected_value * ligand_iptm_mask, dim=-1) 151 | / (torch.sum(ligand_iptm_mask, dim=-1) + 1e-5), 152 | dim=1, 153 | ).values 154 | protein_iptm = torch.max( 155 | torch.sum(tm_expected_value * protein_ipmt_mask, dim=-1) 156 | / (torch.sum(protein_ipmt_mask, dim=-1) + 1e-5), 157 | dim=1, 158 | ).values 159 | 160 | # Compute pair chain ipTM 161 | chain_pair_iptm = {} 162 | asym_ids_list = torch.unique(asym_id).tolist() 163 | for idx1 in asym_ids_list: 164 | chain_iptm = {} 165 | for idx2 in asym_ids_list: 166 | mask_pair_chain = ( 167 | maski[:, :, None] 168 | * (asym_id[:, None, :] == idx1) 169 | * (asym_id[:, :, None] == idx2) 170 | * mask_pad[:, None, :] 171 | * mask_pad[:, :, None] 172 | ) 173 | 174 | chain_iptm[idx2] = torch.max( 175 | torch.sum(tm_expected_value * mask_pair_chain, dim=-1) 176 | / (torch.sum(mask_pair_chain, dim=-1) + 1e-5), 177 | dim=1, 178 | ).values 179 | chain_pair_iptm[idx1] = chain_iptm 180 | 181 | return ptm, iptm, ligand_iptm, protein_iptm, chain_pair_iptm 182 | -------------------------------------------------------------------------------- /src/boltz/model/optim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/model/optim/__init__.py -------------------------------------------------------------------------------- /src/boltz/model/optim/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler): 5 | """Implements the learning rate schedule defined AF3. 6 | 7 | A linear warmup is followed by a plateau at the maximum 8 | learning rate and then exponential decay. Note that the 9 | initial learning rate of the optimizer in question is 10 | ignored; use this class' base_lr parameter to specify 11 | the starting point of the warmup. 12 | 13 | """ 14 | 15 | def __init__( 16 | self, 17 | optimizer: torch.optim.Optimizer, 18 | last_epoch: int = -1, 19 | verbose: bool = False, 20 | base_lr: float = 0.0, 21 | max_lr: float = 1.8e-3, 22 | warmup_no_steps: int = 1000, 23 | start_decay_after_n_steps: int = 50000, 24 | decay_every_n_steps: int = 50000, 25 | decay_factor: float = 0.95, 26 | ) -> None: 27 | """Initialize the learning rate scheduler. 28 | 29 | Parameters 30 | ---------- 31 | optimizer : torch.optim.Optimizer 32 | The optimizer. 33 | last_epoch : int, optional 34 | The last epoch, by default -1 35 | verbose : bool, optional 36 | Whether to print verbose output, by default False 37 | base_lr : float, optional 38 | The base learning rate, by default 0.0 39 | max_lr : float, optional 40 | The maximum learning rate, by default 1.8e-3 41 | warmup_no_steps : int, optional 42 | The number of warmup steps, by default 1000 43 | start_decay_after_n_steps : int, optional 44 | The number of steps after which to start decay, by default 50000 45 | decay_every_n_steps : int, optional 46 | The number of steps after which to decay, by default 50000 47 | decay_factor : float, optional 48 | The decay factor, by default 0.95 49 | 50 | """ 51 | step_counts = { 52 | "warmup_no_steps": warmup_no_steps, 53 | "start_decay_after_n_steps": start_decay_after_n_steps, 54 | } 55 | 56 | for k, v in step_counts.items(): 57 | if v < 0: 58 | msg = f"{k} must be nonnegative" 59 | raise ValueError(msg) 60 | 61 | if warmup_no_steps > start_decay_after_n_steps: 62 | msg = "warmup_no_steps must not exceed start_decay_after_n_steps" 63 | raise ValueError(msg) 64 | 65 | self.optimizer = optimizer 66 | self.last_epoch = last_epoch 67 | self.verbose = verbose 68 | self.base_lr = base_lr 69 | self.max_lr = max_lr 70 | self.warmup_no_steps = warmup_no_steps 71 | self.start_decay_after_n_steps = start_decay_after_n_steps 72 | self.decay_every_n_steps = decay_every_n_steps 73 | self.decay_factor = decay_factor 74 | 75 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) 76 | 77 | def state_dict(self) -> dict: 78 | state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]} 79 | return state_dict 80 | 81 | def load_state_dict(self, state_dict): 82 | self.__dict__.update(state_dict) 83 | 84 | def get_lr(self): 85 | if not self._get_lr_called_within_step: 86 | msg = ( 87 | "To get the last learning rate computed by the scheduler, use " 88 | "get_last_lr()" 89 | ) 90 | raise RuntimeError(msg) 91 | 92 | step_no = self.last_epoch 93 | 94 | if step_no <= self.warmup_no_steps: 95 | lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr 96 | elif step_no > self.start_decay_after_n_steps: 97 | steps_since_decay = step_no - self.start_decay_after_n_steps 98 | exp = (steps_since_decay // self.decay_every_n_steps) + 1 99 | lr = self.max_lr * (self.decay_factor**exp) 100 | else: # plateau 101 | lr = self.max_lr 102 | 103 | return [lr for group in self.optimizer.param_groups] 104 | -------------------------------------------------------------------------------- /tests/model/layers/test_outer_product_mean.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning 2 | import torch 3 | import torch.nn as nn 4 | 5 | import unittest 6 | 7 | from boltz.model.layers.outer_product_mean import OuterProductMean 8 | 9 | 10 | class OuterProductMeanTest(unittest.TestCase): 11 | def setUp(self): 12 | self.c_in = 32 13 | self.c_hidden = 16 14 | self.c_out = 64 15 | 16 | torch.set_grad_enabled(False) 17 | pytorch_lightning.seed_everything(1100) 18 | self.layer = OuterProductMean(self.c_in, self.c_hidden, self.c_out) 19 | 20 | # Initialize layer 21 | for name, param in self.layer.named_parameters(): 22 | nn.init.normal_(param, mean=1., std=1.) 23 | 24 | 25 | def test_chunk(self): 26 | chunk_sizes = [16, 33, 64, 83, 100] 27 | B, S, N = 1, 49, 84 28 | m = torch.randn(size=(B, S, N, self.c_in)) 29 | mask = torch.randint(low=0, high=1, size=(B, S, N)) 30 | 31 | with torch.no_grad(): 32 | exp_output = self.layer(m=m, mask=mask) 33 | for chunk_size in chunk_sizes: 34 | with self.subTest(chunk_size=chunk_size): 35 | act_output = self.layer(m=m, mask=mask, chunk_size=chunk_size) 36 | assert torch.allclose(exp_output, act_output, atol=1e-8) -------------------------------------------------------------------------------- /tests/model/layers/test_triangle_attention.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning 2 | import torch 3 | import torch.nn as nn 4 | 5 | import unittest 6 | 7 | from boltz.model.layers.triangular_attention.attention import TriangleAttention 8 | 9 | 10 | class OuterProductMeanTest(unittest.TestCase): 11 | def setUp(self): 12 | self.c_in = 128 13 | self.c_hidden = 32 14 | self.no_heads = 1 15 | 16 | torch.set_grad_enabled(False) 17 | pytorch_lightning.seed_everything(1100) 18 | self.layer = TriangleAttention(self.c_in, self.c_hidden, self.no_heads) 19 | 20 | # Initialize layer 21 | for name, param in self.layer.named_parameters(): 22 | nn.init.normal_(param, mean=1., std=1.) 23 | 24 | 25 | def test_chunk(self): 26 | chunk_sizes = [16, 33, 64, 100] 27 | B, N = 1, 99 28 | m = torch.randn(size=(B, N, N, self.c_in)) 29 | mask = torch.randint(low=0, high=1, size=(B, N, N)) 30 | 31 | with torch.no_grad(): 32 | exp_output = self.layer(x=m, mask=mask) 33 | for chunk_size in chunk_sizes: 34 | with self.subTest(chunk_size=chunk_size): 35 | act_output = self.layer(x=m, mask=mask, chunk_size=chunk_size) 36 | assert torch.allclose(exp_output, act_output, atol=1e-8) -------------------------------------------------------------------------------- /tests/test_regression.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from dataclasses import asdict 4 | import pprint 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | import pytest 10 | import unittest 11 | 12 | from lightning_fabric import seed_everything 13 | 14 | from boltz.main import MODEL_URL 15 | from boltz.model.model import Boltz1 16 | 17 | import test_utils 18 | 19 | tests_dir = os.path.dirname(os.path.abspath(__file__)) 20 | test_data_dir = os.path.join(tests_dir, 'data') 21 | 22 | @pytest.mark.regression 23 | class RegressionTester(unittest.TestCase): 24 | 25 | @classmethod 26 | def setUpClass(cls): 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | cache = os.path.expanduser('~/.boltz') 29 | checkpoint_url = MODEL_URL 30 | model_name = checkpoint_url.split("/")[-1] 31 | checkpoint = os.path.join(cache, model_name) 32 | if not os.path.exists(checkpoint): 33 | test_utils.download_file(checkpoint_url, checkpoint) 34 | 35 | regression_feats_path = os.path.join(test_data_dir, 'ligand_regression_feats.pkl') 36 | if not os.path.exists(regression_feats_path): 37 | regression_feats_url = "https://www.dropbox.com/scl/fi/1avbcvoor5jcnvpt07tp6/ligand_regression_feats.pkl?rlkey=iwtm9gpxgrbp51jbizq937pqf&st=jnbky253&dl=1" 38 | test_utils.download_file(regression_feats_url, regression_feats_path) 39 | 40 | regression_feats = torch.load(regression_feats_path, map_location=device) 41 | model_module: nn.Module = Boltz1.load_from_checkpoint(checkpoint, map_location=device) 42 | model_module.to(device) 43 | model_module.eval() 44 | 45 | coords = regression_feats["feats"]["coords"] 46 | # Coords should be rank 4 47 | if len(coords.shape) == 3: 48 | coords = coords.unsqueeze(0) 49 | regression_feats["feats"]["coords"] = coords 50 | for key, val in regression_feats["feats"].items(): 51 | if hasattr(val, "to"): 52 | regression_feats["feats"][key] = val.to(device) 53 | 54 | cls.model_module = model_module.to(device) 55 | cls.regression_feats = regression_feats 56 | 57 | def test_input_embedder(self): 58 | exp_s_inputs = self.regression_feats["s_inputs"] 59 | act_s_inputs = self.model_module.input_embedder(self.regression_feats["feats"]) 60 | 61 | assert torch.allclose(exp_s_inputs, act_s_inputs, atol=1e-5) 62 | 63 | def test_rel_pos(self): 64 | exp_rel_pos_encoding = self.regression_feats["relative_position_encoding"] 65 | act_rel_pos_encoding = self.model_module.rel_pos(self.regression_feats["feats"]) 66 | 67 | assert torch.allclose(exp_rel_pos_encoding, act_rel_pos_encoding, atol=1e-5) 68 | 69 | @pytest.mark.slow 70 | def test_structure_output(self): 71 | exp_structure_output = self.regression_feats["structure_output"] 72 | s = self.regression_feats["s"] 73 | z = self.regression_feats["z"] 74 | s_inputs = self.regression_feats["s_inputs"] 75 | feats = self.regression_feats["feats"] 76 | relative_position_encoding = self.regression_feats["relative_position_encoding"] 77 | multiplicity_diffusion_train = self.regression_feats["multiplicity_diffusion_train"] 78 | 79 | self.model_module.structure_module.coordinate_augmentation = False 80 | self.model_module.structure_module.sigma_data = 0.0 81 | 82 | seed_everything(self.regression_feats["seed"]) 83 | act_structure_output = self.model_module.structure_module( 84 | s_trunk=s, 85 | z_trunk=z, 86 | s_inputs=s_inputs, 87 | feats=feats, 88 | relative_position_encoding=relative_position_encoding, 89 | multiplicity=multiplicity_diffusion_train, 90 | ) 91 | 92 | act_keys = act_structure_output.keys() 93 | exp_keys = exp_structure_output.keys() 94 | assert act_keys == exp_keys 95 | 96 | # Other keys have some randomness, so we will only check the keys that 97 | # we can make deterministic with sigma_data = 0.0 (above). 98 | check_keys = ["noised_atom_coords", "aligned_true_atom_coords"] 99 | for key in check_keys: 100 | exp_val = exp_structure_output[key] 101 | act_val = act_structure_output[key] 102 | assert exp_val.shape == act_val.shape, f"Shape mismatch in {key}" 103 | assert torch.allclose(exp_val, act_val, atol=1e-4) 104 | 105 | 106 | if __name__ == '__main__': 107 | unittest.main() 108 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | 5 | def download_file(url, filepath, verbose=True): 6 | if verbose: 7 | print(f"Downloading {url} to {filepath}") 8 | response = requests.get(url) 9 | 10 | target_dir = os.path.dirname(filepath) 11 | if target_dir and not os.path.exists(target_dir): 12 | os.makedirs(target_dir) 13 | 14 | # Check if the request was successful 15 | if response.status_code == 200: 16 | with open(filepath, 'wb') as file: 17 | file.write(response.content) 18 | else: 19 | print(f"Failed to download file. Status code: {response.status_code}") 20 | 21 | return filepath 22 | --------------------------------------------------------------------------------