├── .gitignore
├── LICENSE
├── README.md
├── docs
├── boltz1_pred_figure.png
├── evaluation.md
├── plot_casp.png
├── plot_test.png
├── prediction.md
└── training.md
├── examples
├── ligand.fasta
├── ligand.yaml
├── msa
│ ├── seq1.a3m
│ └── seq2.a3m
├── multimer.yaml
├── pocket.yaml
├── prot.fasta
├── prot.yaml
├── prot_custom_msa.yaml
└── prot_no_msa.yaml
├── pyproject.toml
├── scripts
├── eval
│ ├── aggregate_evals.py
│ └── run_evals.py
├── process
│ ├── README.md
│ ├── ccd.py
│ ├── cluster.py
│ ├── mmcif.py
│ ├── msa.py
│ ├── rcsb.py
│ └── requirements.txt
└── train
│ ├── README.md
│ ├── assets
│ ├── casp15_ids.txt
│ ├── test_ids.txt
│ └── validation_ids.txt
│ ├── configs
│ ├── confidence.yaml
│ ├── full.yaml
│ └── structure.yaml
│ └── train.py
├── src
└── boltz
│ ├── __init__.py
│ ├── data
│ ├── __init__.py
│ ├── const.py
│ ├── crop
│ │ ├── __init__.py
│ │ ├── boltz.py
│ │ └── cropper.py
│ ├── feature
│ │ ├── __init__.py
│ │ ├── featurizer.py
│ │ ├── pad.py
│ │ └── symmetry.py
│ ├── filter
│ │ ├── __init__.py
│ │ ├── dynamic
│ │ │ ├── __init__.py
│ │ │ ├── date.py
│ │ │ ├── filter.py
│ │ │ ├── max_residues.py
│ │ │ ├── resolution.py
│ │ │ ├── size.py
│ │ │ └── subset.py
│ │ └── static
│ │ │ ├── __init__.py
│ │ │ ├── filter.py
│ │ │ ├── ligand.py
│ │ │ └── polymer.py
│ ├── module
│ │ ├── __init__.py
│ │ ├── inference.py
│ │ └── training.py
│ ├── msa
│ │ ├── __init__.py
│ │ └── mmseqs2.py
│ ├── parse
│ │ ├── __init__.py
│ │ ├── a3m.py
│ │ ├── csv.py
│ │ ├── fasta.py
│ │ ├── schema.py
│ │ └── yaml.py
│ ├── sample
│ │ ├── __init__.py
│ │ ├── cluster.py
│ │ ├── distillation.py
│ │ ├── random.py
│ │ └── sampler.py
│ ├── tokenize
│ │ ├── __init__.py
│ │ ├── boltz.py
│ │ └── tokenizer.py
│ ├── types.py
│ └── write
│ │ ├── __init__.py
│ │ ├── mmcif.py
│ │ ├── pdb.py
│ │ ├── utils.py
│ │ └── writer.py
│ ├── main.py
│ └── model
│ ├── __init__.py
│ ├── layers
│ ├── __init__.py
│ ├── attention.py
│ ├── dropout.py
│ ├── initialize.py
│ ├── outer_product_mean.py
│ ├── pair_averaging.py
│ ├── transition.py
│ ├── triangular_attention
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── primitives.py
│ │ └── utils.py
│ └── triangular_mult.py
│ ├── loss
│ ├── __init__.py
│ ├── confidence.py
│ ├── diffusion.py
│ ├── distogram.py
│ └── validation.py
│ ├── model.py
│ ├── modules
│ ├── __init__.py
│ ├── confidence.py
│ ├── confidence_utils.py
│ ├── diffusion.py
│ ├── encoders.py
│ ├── transformers.py
│ ├── trunk.py
│ └── utils.py
│ └── optim
│ ├── __init__.py
│ ├── ema.py
│ └── scheduler.py
└── tests
├── model
└── layers
│ ├── test_outer_product_mean.py
│ └── test_triangle_attention.py
├── test_regression.py
└── test_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Boltz-1:
2 |
3 | Democratizing Biomolecular Interaction Modeling
4 |
5 |
6 | 
7 |
8 | Boltz-1 is the state-of-the-art open-source model that predicts the 3D structure of proteins, RNA, DNA, and small molecules; it handles modified residues, covalent ligands and glycans, as well as condition the generation on pocket residues.
9 |
10 | For more information about the model, see our [technical report](https://doi.org/10.1101/2024.11.19.624167).
11 |
12 | ## Installation
13 | Install boltz with PyPI (recommended):
14 |
15 | ```
16 | pip install boltz -U
17 | ```
18 |
19 | or directly from GitHub for daily updates:
20 |
21 | ```
22 | git clone https://github.com/jwohlwend/boltz.git
23 | cd boltz; pip install -e .
24 | ```
25 | > Note: we recommend installing boltz in a fresh python environment
26 |
27 | ## Inference
28 |
29 | You can run inference using Boltz-1 with:
30 |
31 | ```
32 | boltz predict input_path --use_msa_server
33 | ```
34 |
35 | Boltz currently accepts three input formats:
36 |
37 | 1. Fasta file, for most use cases
38 |
39 | 2. A comprehensive YAML schema, for more complex use cases
40 |
41 | 3. A directory containing files of the above formats, for batched processing
42 |
43 | To see all available options: `boltz predict --help` and for more information on these input formats, see our [prediction instructions](docs/prediction.md).
44 |
45 | ## Evaluation
46 |
47 | To encourage reproducibility and facilitate comparison with other models, we provide the evaluation scripts and predictions for Boltz-1, Chai-1 and AlphaFold3 on our test benchmark dataset as well as CASP15. These datasets are created to contain biomolecules different from the training data and to benchmark the performance of these models we run them with the same input MSAs and same number of recycling and diffusion steps. More details on these evaluations can be found in our [evaluation instructions](docs/evaluation.md).
48 |
49 | 
50 | 
51 |
52 |
53 | ## Training
54 |
55 | If you're interested in retraining the model, see our [training instructions](docs/training.md).
56 |
57 | ## Contributing
58 |
59 | We welcome external contributions and are eager to engage with the community. Connect with us on our [Slack channel](https://join.slack.com/t/boltz-community/shared_invite/zt-2w0bw6dtt-kZU4png9HUgprx9NK2xXZw) to discuss advancements, share insights, and foster collaboration around Boltz-1.
60 |
61 | ## Coming very soon
62 |
63 | - [x] Auto-generated MSAs using MMseqs2
64 | - [x] More examples
65 | - [x] Support for custom paired MSA
66 | - [x] Confidence model checkpoint
67 | - [x] Chunking for lower memory usage
68 | - [x] Pocket conditioning support
69 | - [x] Full data processing pipeline
70 | - [ ] Colab notebook for inference
71 | - [ ] Kernel integration
72 |
73 | ## License
74 |
75 | Our model and code are released under MIT License, and can be freely used for both academic and commercial purposes.
76 |
77 |
78 | ## Cite
79 |
80 | If you use this code or the models in your research, please cite the following paper:
81 |
82 | ```bibtex
83 | @article{wohlwend2024boltz1,
84 | author = {Wohlwend, Jeremy and Corso, Gabriele and Passaro, Saro and Reveiz, Mateo and Leidal, Ken and Swiderski, Wojtek and Portnoi, Tally and Chinn, Itamar and Silterra, Jacob and Jaakkola, Tommi and Barzilay, Regina},
85 | title = {Boltz-1: Democratizing Biomolecular Interaction Modeling},
86 | year = {2024},
87 | doi = {10.1101/2024.11.19.624167},
88 | journal = {bioRxiv}
89 | }
90 | ```
91 |
92 | In addition if you use the automatic MSA generation, please cite:
93 |
94 | ```bibtex
95 | @article{mirdita2022colabfold,
96 | title={ColabFold: making protein folding accessible to all},
97 | author={Mirdita, Milot and Sch{\"u}tze, Konstantin and Moriwaki, Yoshitaka and Heo, Lim and Ovchinnikov, Sergey and Steinegger, Martin},
98 | journal={Nature methods},
99 | year={2022},
100 | }
101 | ```
102 |
--------------------------------------------------------------------------------
/docs/boltz1_pred_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/docs/boltz1_pred_figure.png
--------------------------------------------------------------------------------
/docs/evaluation.md:
--------------------------------------------------------------------------------
1 | # Evaluation
2 |
3 | To encourage reproducibility and facilitate comparison with other models, we provide the evaluation scripts and predictions for Boltz-1, Chai-1, and AlphaFold3 on our test benchmark dataset as well as CASP15. These datasets are created to contain biomolecules different from the training data and to benchmark the performance of these models we run them with the same input MSAs and the same number of recycling and diffusion steps.
4 |
5 | 
6 | 
7 |
8 |
9 | ## Evaluation files
10 |
11 | You can download all the MSAs, input files, output files and evaluation outputs for Boltz-1, Chai-1, and AlphaFold3 from this [Google Drive folder](https://drive.google.com/file/d/1JvHlYUMINOaqPTunI9wBYrfYniKgVmxf/view?usp=sharing).
12 |
13 | The files are organized as follows:
14 |
15 | ```
16 | boltz_results_final/
17 | ├── inputs/ # Input files for every model
18 | ├── casp15/
19 | ├── af3
20 | ├── boltz
21 | ├── chai
22 | └── msa
23 | └── test/
24 | ├── targets/ # Target files from PDB
25 | ├── casp15
26 | └── test
27 | ├── outputs/ # Output files for every model
28 | ├── casp15
29 | └── test
30 | ├── evals/ # Output of evluation script for every
31 | ├── casp15
32 | └── test
33 | ├── results_casp.csv # Summary of evaluation results for CASP15
34 | └── results_test.csv # Summary of evaluation results for test set
35 | ```
36 |
37 | ## Evaluation setup
38 |
39 | We evaluate the model on two datasets:
40 | - PDB test set: 541 targets after our validation cut-off date and at most 40% sequence similarity for proteins, 80% Tanimoto for ligands.
41 | - CASP15: 66 difficult targets from the CASP 2022 competition.
42 |
43 | We benchmark Boltz-1 against Chai-1 and AF3, other state-of-the-art structure prediction models, but much more closed source in terms of model code, training and data pipeline. Note that we remove overlap with our validation set, but we cannot ensure that there is no overlap with AF3 or Chai-1 validation set as those were not published.
44 |
45 | For fair comparison we compare the models with the following setup:
46 | - Same MSA’s.
47 | - Same parameters: 10 recycling steps, 200 sampling steps, 5 samples.
48 | - We compare our oracle and top-1 numbers among the 5 samples.
49 |
50 |
51 | ## Evaluation script
52 |
53 | We also provide the scripts we used to evaluate the models and aggregate results. The evaluations were run through [OpenStructure](https://openstructure.org/docs/2.9.0/) version 2.8.0 (it is important to use the specific version for reproducing the results). You can find these scripts at `scripts/eval/run_evals.py` and `scripts/eval/aggregate_evals.py`.
--------------------------------------------------------------------------------
/docs/plot_casp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/docs/plot_casp.png
--------------------------------------------------------------------------------
/docs/plot_test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/docs/plot_test.png
--------------------------------------------------------------------------------
/examples/ligand.fasta:
--------------------------------------------------------------------------------
1 | >A|protein|./examples/msa/seq1.a3m
2 | MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
3 | >B|protein|./examples/msa/seq1.a3m
4 | MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
5 | >C|ccd
6 | SAH
7 | >D|ccd
8 | SAH
9 | >E|smiles
10 | N[C@@H](Cc1ccc(O)cc1)C(=O)O
11 | >F|smiles
12 | N[C@@H](Cc1ccc(O)cc1)C(=O)O
--------------------------------------------------------------------------------
/examples/ligand.yaml:
--------------------------------------------------------------------------------
1 | version: 1 # Optional, defaults to 1
2 | sequences:
3 | - protein:
4 | id: [A, B]
5 | sequence: MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
6 | msa: ./examples/msa/seq1.a3m
7 | - ligand:
8 | id: [C, D]
9 | ccd: SAH
10 | - ligand:
11 | id: [E, F]
12 | smiles: N[C@@H](Cc1ccc(O)cc1)C(=O)O
13 |
--------------------------------------------------------------------------------
/examples/multimer.yaml:
--------------------------------------------------------------------------------
1 | version: 1 # Optional, defaults to 1
2 | sequences:
3 | - protein:
4 | id: A
5 | sequence: MAHHHHHHVAVDAVSFTLLQDQLQSVLDTLSEREAGVVRLRFGLTDGQPRTLDEIGQVYGVTRERIRQIESKTMSKLRHPSRSQVLRDYLDGSSGSGTPEERLLRAIFGEKA
6 | - protein:
7 | id: B
8 | sequence: MRYAFAAEATTCNAFWRNVDMTVTALYEVPLGVCTQDPDRWTTTPDDEAKTLCRACPRRWLCARDAVESAGAEGLWAGVVIPESGRARAFALGQLRSLAERNGYPVRDHRVSAQSA
9 |
--------------------------------------------------------------------------------
/examples/pocket.yaml:
--------------------------------------------------------------------------------
1 | sequences:
2 | - protein:
3 | id: [A1]
4 | sequence: MYNMRRLSLSPTFSMGFHLLVTVSLLFSHVDHVIAETEMEGEGNETGECTGSYYCKKGVILPIWEPQDPSFGDKIARATVYFVAMVYMFLGVSIIADRFMSSIEVITSQEKEITIKKPNGETTKTTVRIWNETVSNLTLMALGSSAPEILLSVIEVCGHNFTAGDLGPSTIVGSAAFNMFIIIALCVYVVPDGETRKIKHLRVFFVTAAWSIFAYTWLYIILSVISPGVVEVWEGLLTFFFFPICVVFAWVADRRLLFYKYVYKRYRAGKQRGMIIEHEGDRPSSKTEIEMDGKVVNSHVENFLDGALVLEVDERDQDDEEARREMARILKELKQKHPDKEIEQLIELANYQVLSQQQKSRAFYRIQATRLMTGAGNILKRHAADQARKAVSMHEVNTEVTENDPVSKIFFEQGTYQCLENCGTVALTIIRRGGDLTNTVFVDFRTEDGTANAGSDYEFTEGTVVFKPGDTQKEIRVGIIDDDIFEEDENFLVHLSNVKVSSEASEDGILEANHVSTLACLGSPSTATVTIFDDDHAGIFTFEEPVTHVSESIGIMEVKVLRTSGARGNVIVPYKTIEGTARGGGEDFEDTCGELEFQNDEIVKIITIRIFDREEYEKECSFSLVLEEPKWIRRGMKGGFTITDEYDDKQPLTSKEEEERRIAEMGRPILGEHTKLEVIIEESYEFKSTVDKLIKKTNLALVVGTNSWREQFIEAITVSAGEDDDDDECGEEKLPSCFDYVMHFLTVFWKVLFAFVPPTEYWNGWACFIVSILMIGLLTAFIGDLASHFGCTIGLKDSVTAVVFVALGTSVPDTFASKVAATQDQYADASIGNVTGSNAVNVFLGIGVAWSIAAIYHAANGEQFKVSPGTLAFSVTLFTIFAFINVGVLLYRRRPEIGGELGGPRTAKLLTSCLFVLLWLLYIFFSSLEAYCHIKGF
5 | - ligand:
6 | ccd: EKY
7 | id: [B1]
8 | constraints:
9 | - pocket:
10 | binder: B1
11 | contacts: [ [ A1, 829 ], [ A1, 138 ] ]
12 |
13 |
--------------------------------------------------------------------------------
/examples/prot.fasta:
--------------------------------------------------------------------------------
1 | >A|protein|./examples/msa/seq2.a3m
2 | QLEDSEVEAVAKGLEEMYANGVTEDNFKNYVKNNFAQQEISSVEEELNVNISDSCVANKIKDEFFAMISISAIVKAAQKKAWKELAVTVLRFAKANGLKTNAIIVAGQLALWAVQCG
--------------------------------------------------------------------------------
/examples/prot.yaml:
--------------------------------------------------------------------------------
1 | version: 1 # Optional, defaults to 1
2 | sequences:
3 | - protein:
4 | id: A
5 | sequence: QLEDSEVEAVAKGLEEMYANGVTEDNFKNYVKNNFAQQEISSVEEELNVNISDSCVANKIKDEFFAMISISAIVKAAQKKAWKELAVTVLRFAKANGLKTNAIIVAGQLALWAVQCG
6 |
7 |
--------------------------------------------------------------------------------
/examples/prot_custom_msa.yaml:
--------------------------------------------------------------------------------
1 | version: 1 # Optional, defaults to 1
2 | sequences:
3 | - protein:
4 | id: A
5 | sequence: QLEDSEVEAVAKGLEEMYANGVTEDNFKNYVKNNFAQQEISSVEEELNVNISDSCVANKIKDEFFAMISISAIVKAAQKKAWKELAVTVLRFAKANGLKTNAIIVAGQLALWAVQCG
6 | msa: ./examples/msa/seq2.a3m
7 |
8 |
--------------------------------------------------------------------------------
/examples/prot_no_msa.yaml:
--------------------------------------------------------------------------------
1 | version: 1 # Optional, defaults to 1
2 | sequences:
3 | - protein:
4 | id: A
5 | sequence: QLEDSEVEAVAKGLEEMYANGVTEDNFKNYVKNNFAQQEISSVEEELNVNISDSCVANKIKDEFFAMISISAIVKAAQKKAWKELAVTVLRFAKANGLKTNAIIVAGQLALWAVQCG
6 | msa: empty
7 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools >= 61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "boltz"
7 | version = "0.4.0"
8 | requires-python = ">=3.9"
9 | description = "Boltz-1"
10 | readme = "README.md"
11 | dependencies = [
12 | "torch>=2.2",
13 | "numpy==1.26.3",
14 | "hydra-core==1.3.2",
15 | "pytorch-lightning==2.4.0",
16 | "rdkit>=2024.3.2",
17 | "dm-tree==0.1.8",
18 | "requests==2.32.3",
19 | "pandas>=2.2.2",
20 | "types-requests",
21 | "einops==0.8.0",
22 | "einx==0.3.0",
23 | "fairscale==0.4.13",
24 | "mashumaro==3.14",
25 | "modelcif==1.2",
26 | "wandb==0.18.7",
27 | "click==8.1.7",
28 | "pyyaml==6.0.2",
29 | "biopython==1.84",
30 | "scipy==1.13.1",
31 | ]
32 |
33 | [project.scripts]
34 | boltz = "boltz.main:cli"
35 |
36 | [project.optional-dependencies]
37 | lint = ["ruff"]
38 | test = ["pytest", "requests"]
39 |
40 | [tool.ruff]
41 | src = ["src"]
42 | extend-exclude = ["conf.py"]
43 | target-version = "py39"
44 | lint.select = ["ALL"]
45 | lint.ignore = [
46 | "COM812", # Conflicts with the formatter
47 | "ISC001", # Conflicts with the formatter
48 | "ANN101", # "missing-type-self"
49 | "RET504", # Unnecessary assignment to `x` before `return` statementRuff
50 | "S101", # Use of `assert` detected
51 | "D100", # Missing docstring in public module
52 | "D104", # Missing docstring in public package
53 | "PT001", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715
54 | "PT004", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715
55 | "PT005", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715
56 | "PT023", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715
57 | "FBT001",
58 | "FBT002",
59 | "PLR0913", # Too many arguments to init (> 5)
60 | ]
61 |
62 | [tool.ruff.lint.per-file-ignores]
63 | "**/__init__.py" = [
64 | "F401", # Imported but unused
65 | "F403", # Wildcard imports
66 | ]
67 | "docs/**" = [
68 | "INP001", # Requires __init__.py but folder is not a package.
69 | ]
70 | "scripts/**" = [
71 | "INP001", # Requires __init__.py but folder is not a package.
72 | ]
73 |
74 | [tool.ruff.lint.pyupgrade]
75 | # Preserve types, even if a file imports `from __future__ import annotations`(https://github.com/astral-sh/ruff/issues/5434)
76 | keep-runtime-typing = true
77 |
78 | [tool.ruff.lint.pydocstyle]
79 | convention = "numpy"
80 |
81 | [tool.pytest.ini_options]
82 | markers = [
83 | "slow: marks tests as slow (deselect with '-m \"not slow\"')",
84 | "regression",
85 | ]
86 |
--------------------------------------------------------------------------------
/scripts/eval/run_evals.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import concurrent.futures
3 | import subprocess
4 | from pathlib import Path
5 |
6 | from tqdm import tqdm
7 |
8 | OST_COMPARE_STRUCTURE = r"""
9 | #!/bin/bash
10 | # https://openstructure.org/docs/2.7/actions/#ost-compare-structures
11 |
12 | IMAGE_NAME=openstructure-0.2.8
13 |
14 | command="compare-structures \
15 | -m {model_file} \
16 | -r {reference_file} \
17 | --fault-tolerant \
18 | --min-pep-length 4 \
19 | --min-nuc-length 4 \
20 | -o {output_path} \
21 | --lddt --bb-lddt --qs-score --dockq \
22 | --ics --ips --rigid-scores --patch-scores --tm-score"
23 |
24 | sudo docker run -u $(id -u):$(id -g) --rm --volume {mount}:{mount} $IMAGE_NAME $command
25 | """
26 |
27 |
28 | OST_COMPARE_LIGAND = r"""
29 | #!/bin/bash
30 | # https://openstructure.org/docs/2.7/actions/#ost-compare-structures
31 |
32 | IMAGE_NAME=openstructure-0.2.8
33 |
34 | command="compare-ligand-structures \
35 | -m {model_file} \
36 | -r {reference_file} \
37 | --fault-tolerant \
38 | --lddt-pli --rmsd \
39 | --substructure-match \
40 | -o {output_path}"
41 |
42 | sudo docker run -u $(id -u):$(id -g) --rm --volume {mount}:{mount} $IMAGE_NAME $command
43 | """
44 |
45 |
46 | def evaluate_structure(
47 | name: str,
48 | pred: Path,
49 | reference: Path,
50 | outdir: str,
51 | mount: str,
52 | executable: str = "/bin/bash",
53 | ) -> None:
54 | """Evaluate the structure."""
55 | # Evaluate polymer metrics
56 | out_path = Path(outdir) / f"{name}.json"
57 |
58 | if out_path.exists():
59 | print( # noqa: T201
60 | f"Skipping recomputation of {name} as protein json file already exists"
61 | )
62 | else:
63 | subprocess.run(
64 | OST_COMPARE_STRUCTURE.format(
65 | model_file=str(pred),
66 | reference_file=str(reference),
67 | output_path=str(out_path),
68 | mount=mount,
69 | ),
70 | shell=True, # noqa: S602
71 | check=False,
72 | executable=executable,
73 | capture_output=True,
74 | )
75 |
76 | # Evaluate ligand metrics
77 | out_path = Path(outdir) / f"{name}_ligand.json"
78 | if out_path.exists():
79 | print(f"Skipping recomputation of {name} as ligand json file already exists") # noqa: T201
80 | else:
81 | subprocess.run(
82 | OST_COMPARE_LIGAND.format(
83 | model_file=str(pred),
84 | reference_file=str(reference),
85 | output_path=str(out_path),
86 | mount=mount,
87 | ),
88 | shell=True, # noqa: S602
89 | check=False,
90 | executable=executable,
91 | capture_output=True,
92 | )
93 |
94 |
95 | def main(args):
96 | # Aggregate the predictions and references
97 | files = list(args.data.iterdir())
98 | names = {f.stem.lower(): f for f in files}
99 |
100 | # Create the output directory
101 | args.outdir.mkdir(parents=True, exist_ok=True)
102 |
103 | first_item = True
104 | with concurrent.futures.ThreadPoolExecutor(args.max_workers) as executor:
105 | futures = []
106 | for name, folder in names.items():
107 | for model_id in range(5):
108 | # Split the input data
109 | if args.format == "af3":
110 | pred_path = folder / f"seed-1_sample-{model_id}" / "model.cif"
111 | elif args.format == "chai":
112 | pred_path = folder / f"pred.model_idx_{model_id}.cif"
113 | elif args.format == "boltz":
114 | name_file = (
115 | f"{name[0].upper()}{name[1:]}"
116 | if args.testset == "casp"
117 | else name.lower()
118 | )
119 | pred_path = folder / f"{name_file}_model_{model_id}.cif"
120 |
121 | if args.testset == "casp":
122 | ref_path = args.pdb / f"{name[0].upper()}{name[1:]}.cif"
123 | elif args.testset == "test":
124 | ref_path = args.pdb / f"{name.lower()}.cif.gz"
125 |
126 | if first_item:
127 | # Evaluate the first item in the first prediction
128 | # Ensures that the docker image is downloaded
129 | evaluate_structure(
130 | name=f"{name}_model_{model_id}",
131 | pred=str(pred_path),
132 | reference=str(ref_path),
133 | outdir=str(args.outdir),
134 | mount=args.mount,
135 | executable=args.executable,
136 | )
137 | first_item = False
138 | else:
139 | future = executor.submit(
140 | evaluate_structure,
141 | name=f"{name}_model_{model_id}",
142 | pred=str(pred_path),
143 | reference=str(ref_path),
144 | outdir=str(args.outdir),
145 | mount=args.mount,
146 | executable=args.executable,
147 | )
148 | futures.append(future)
149 |
150 | # Wait for all tasks to complete
151 | with tqdm(total=len(futures)) as pbar:
152 | for _ in concurrent.futures.as_completed(futures):
153 | pbar.update(1)
154 |
155 |
156 | if __name__ == "__main__":
157 | parser = argparse.ArgumentParser()
158 | parser.add_argument("data", type=Path)
159 | parser.add_argument("pdb", type=Path)
160 | parser.add_argument("outdir", type=Path)
161 | parser.add_argument("--format", type=str, default="af3")
162 | parser.add_argument("--testset", type=str, default="casp")
163 | parser.add_argument("--mount", type=str)
164 | parser.add_argument("--executable", type=str, default="/bin/bash")
165 | parser.add_argument("--max-workers", type=int, default=32)
166 | args = parser.parse_args()
167 | main(args)
--------------------------------------------------------------------------------
/scripts/process/README.md:
--------------------------------------------------------------------------------
1 | Please see our [data processing instructions](../../docs/training.md).
--------------------------------------------------------------------------------
/scripts/process/ccd.py:
--------------------------------------------------------------------------------
1 | """Compute conformers and symmetries for all the CCD molecules."""
2 |
3 | import argparse
4 | import multiprocessing
5 | import pickle
6 | import sys
7 | from functools import partial
8 | from pathlib import Path
9 |
10 | import pandas as pd
11 | import rdkit
12 | from p_tqdm import p_uimap
13 | from pdbeccdutils.core import ccd_reader
14 | from pdbeccdutils.core.component import ConformerType
15 | from rdkit import rdBase
16 | from rdkit.Chem import AllChem
17 | from rdkit.Chem.rdchem import Conformer, Mol
18 |
19 |
20 | def load_molecules(components: str) -> list[Mol]:
21 | """Load the CCD components file.
22 |
23 | Parameters
24 | ----------
25 | components : str
26 | Path to the CCD components file.
27 |
28 | Returns
29 | -------
30 | list[Mol]
31 |
32 | """
33 | components: dict[str, ccd_reader.CCDReaderResult]
34 | components = ccd_reader.read_pdb_components_file(components)
35 |
36 | mols = []
37 | for name, component in components.items():
38 | mol = component.component.mol
39 | mol.SetProp("PDB_NAME", name)
40 | mols.append(mol)
41 |
42 | return mols
43 |
44 |
45 | def compute_3d(mol: Mol, version: str = "v3") -> bool:
46 | """Generate 3D coordinates using EKTDG method.
47 |
48 | Taken from `pdbeccdutils.core.component.Component`.
49 |
50 | Parameters
51 | ----------
52 | mol: Mol
53 | The RDKit molecule to process
54 | version: str, optional
55 | The ETKDG version, defaults ot v3
56 |
57 | Returns
58 | -------
59 | bool
60 | Whether computation was successful.
61 |
62 | """
63 | if version == "v3":
64 | options = rdkit.Chem.AllChem.ETKDGv3()
65 | elif version == "v2":
66 | options = rdkit.Chem.AllChem.ETKDGv2()
67 | else:
68 | options = rdkit.Chem.AllChem.ETKDGv2()
69 |
70 | options.clearConfs = False
71 | conf_id = -1
72 |
73 | try:
74 | conf_id = rdkit.Chem.AllChem.EmbedMolecule(mol, options)
75 | rdkit.Chem.AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000)
76 |
77 | except RuntimeError:
78 | pass # Force field issue here
79 | except ValueError:
80 | pass # sanitization issue here
81 |
82 | if conf_id != -1:
83 | conformer = mol.GetConformer(conf_id)
84 | conformer.SetProp("name", ConformerType.Computed.name)
85 | conformer.SetProp("coord_generation", f"ETKDG{version}")
86 |
87 | return True
88 |
89 | return False
90 |
91 |
92 | def get_conformer(mol: Mol, c_type: ConformerType) -> Conformer:
93 | """Retrieve an rdkit object for a deemed conformer.
94 |
95 | Taken from `pdbeccdutils.core.component.Component`.
96 |
97 | Parameters
98 | ----------
99 | mol: Mol
100 | The molecule to process.
101 | c_type: ConformerType
102 | The conformer type to extract.
103 |
104 | Returns
105 | -------
106 | Conformer
107 | The desired conformer, if any.
108 |
109 | Raises
110 | ------
111 | ValueError
112 | If there are no conformers of the given tyoe.
113 |
114 | """
115 | for c in mol.GetConformers():
116 | try:
117 | if c.GetProp("name") == c_type.name:
118 | return c
119 | except KeyError: # noqa: PERF203
120 | pass
121 |
122 | msg = f"Conformer {c_type.name} does not exist."
123 | raise ValueError(msg)
124 |
125 |
126 | def compute_symmetries(mol: Mol) -> list[list[int]]:
127 | """Compute the symmetries of a molecule.
128 |
129 | Parameters
130 | ----------
131 | mol : Mol
132 | The molecule to process
133 |
134 | Returns
135 | -------
136 | list[list[int]]
137 | The symmetries as a list of index permutations
138 |
139 | """
140 | mol = AllChem.RemoveHs(mol)
141 | idx_map = {}
142 | atom_idx = 0
143 | for i, atom in enumerate(mol.GetAtoms()):
144 | # Skip if leaving atoms
145 | if int(atom.GetProp("leaving_atom")):
146 | continue
147 | idx_map[i] = atom_idx
148 | atom_idx += 1
149 |
150 | # Calculate self permutations
151 | permutations = []
152 | raw_permutations = mol.GetSubstructMatches(mol, uniquify=False)
153 | for raw_permutation in raw_permutations:
154 | # Filter out permutations with leaving atoms
155 | try:
156 | if {raw_permutation[idx] for idx in idx_map} == set(idx_map.keys()):
157 | permutation = [
158 | idx_map[idx] for idx in raw_permutation if idx in idx_map
159 | ]
160 | permutations.append(permutation)
161 | except Exception: # noqa: S110, PERF203, BLE001
162 | pass
163 | serialized_permutations = pickle.dumps(permutations)
164 | mol.SetProp("symmetries", serialized_permutations.hex())
165 | return permutations
166 |
167 |
168 | def process(mol: Mol, output: str) -> tuple[str, str]:
169 | """Process a CCD component.
170 |
171 | Parameters
172 | ----------
173 | mol : Mol
174 | The molecule to process
175 | output : str
176 | The directory to save the molecules
177 |
178 | Returns
179 | -------
180 | str
181 | The name of the component
182 | str
183 | The result of the conformer generation
184 |
185 | """
186 | # Get name
187 | name = mol.GetProp("PDB_NAME")
188 |
189 | # Check if single atom
190 | if mol.GetNumAtoms() == 1:
191 | result = "single"
192 | else:
193 | # Get the 3D conformer
194 | try:
195 | # Try to generate a 3D conformer with RDKit
196 | success = compute_3d(mol, version="v3")
197 | if success:
198 | _ = get_conformer(mol, ConformerType.Computed)
199 | result = "computed"
200 |
201 | # Otherwise, default to the ideal coordinates
202 | else:
203 | _ = get_conformer(mol, ConformerType.Ideal)
204 | result = "ideal"
205 | except ValueError:
206 | result = "failed"
207 |
208 | # Dump the molecule
209 | path = Path(output) / f"{name}.pkl"
210 | with path.open("wb") as f:
211 | pickle.dump(mol, f)
212 |
213 | # Output the results
214 | return name, result
215 |
216 |
217 | def main(args: argparse.Namespace) -> None:
218 | """Process conformers."""
219 | # Set property saving
220 | rdkit.Chem.SetDefaultPickleProperties(rdkit.Chem.PropertyPickleOptions.AllProps)
221 |
222 | # Load components
223 | print("Loading components") # noqa: T201
224 | molecules = load_molecules(args.components)
225 |
226 | # Reset stdout and stderr, as pdbccdutils messes with them
227 | sys.stdout = sys.__stdout__
228 | sys.stderr = sys.__stderr__
229 |
230 | # Disable rdkit warnings
231 | blocker = rdBase.BlockLogs() # noqa: F841
232 |
233 | # Setup processing function
234 | outdir = Path(args.outdir)
235 | outdir.mkdir(parents=True, exist_ok=True)
236 | mol_output = outdir / "mols"
237 | mol_output.mkdir(parents=True, exist_ok=True)
238 | process_fn = partial(process, output=str(mol_output))
239 |
240 | # Process the files in parallel
241 | print("Processing components") # noqa: T201
242 | metadata = []
243 | num_processes = min(max(1, args.num_processes), multiprocessing.cpu_count())
244 | for name, result in p_uimap(
245 | process_fn,
246 | molecules,
247 | num_cpus=num_processes,
248 | ):
249 | metadata.append({"name": name, "result": result})
250 |
251 | # Load and group outputs
252 | molecules = {}
253 | for item in metadata:
254 | if item["result"] == "failed":
255 | continue
256 |
257 | # Load the mol file
258 | path = mol_output / f"{item['name']}.pkl"
259 | with path.open("rb") as f:
260 | mol = pickle.load(f) # noqa: S301
261 | molecules[item["name"]] = mol
262 |
263 | # Dump metadata
264 | path = outdir / "results.csv"
265 | metadata = pd.DataFrame(metadata)
266 | metadata.to_csv(path)
267 |
268 | # Dump the components
269 | path = outdir / "ccd.pkl"
270 | with path.open("wb") as f:
271 | pickle.dump(molecules, f)
272 |
273 |
274 | if __name__ == "__main__":
275 | parser = argparse.ArgumentParser()
276 | parser.add_argument("--components", type=str)
277 | parser.add_argument("--outdir", type=str)
278 | parser.add_argument(
279 | "--num_processes",
280 | type=int,
281 | default=multiprocessing.cpu_count(),
282 | )
283 | args = parser.parse_args()
284 | main(args)
285 |
--------------------------------------------------------------------------------
/scripts/process/cluster.py:
--------------------------------------------------------------------------------
1 | """Create a mapping from structure and chain ID to MSA indices."""
2 |
3 | import argparse
4 | import hashlib
5 | import json
6 | import pickle
7 | import subprocess
8 | from pathlib import Path
9 |
10 | import pandas as pd
11 | from Bio import SeqIO
12 |
13 |
14 | def hash_sequence(seq: str) -> str:
15 | """Hash a sequence."""
16 | return hashlib.sha256(seq.encode()).hexdigest()
17 |
18 |
19 | def main(args: argparse.Namespace) -> None:
20 | """Create clustering."""
21 | # Set output directory
22 | outdir = Path(args.outdir)
23 | outdir.mkdir(parents=True, exist_ok=True)
24 |
25 | # Split the sequences into proteins and nucleotides
26 | with Path(args.sequences).open("r") as f:
27 | data = list(SeqIO.parse(f, "fasta"))
28 |
29 | proteins = set()
30 | nucleotides = set()
31 |
32 | for seq in data:
33 | if set(seq.seq).issubset({"A", "C", "G", "T", "U"}):
34 | nucleotides.add(str(seq.seq))
35 | else:
36 | proteins.add(str(seq.seq))
37 |
38 | # Run mmseqs on the protein data
39 | proteins = [f">{hash_sequence(seq)}\n{seq}" for seq in proteins]
40 | with (outdir / "proteins.fasta").open("w") as f:
41 | f.write("\n".join(proteins))
42 |
43 | subprocess.run(
44 | f"{args.mmseqs} easy-cluster {outdir / 'proteins.fasta'} {outdir / 'clust_prot'} {outdir / 'tmp'} --min-seq-id 0.4", # noqa: E501
45 | shell=True, # noqa: S602
46 | check=True,
47 | )
48 |
49 | # Load protein clusters
50 | clustering_path = outdir / "clust_prot_cluster.tsv"
51 | protein_data = pd.read_csv(clustering_path, sep="\t", header=None)
52 | clusters = protein_data[0]
53 | items = protein_data[1]
54 | clustering = dict(zip(list(items), list(clusters)))
55 |
56 | # Each unique rna sequence is given an id
57 | visited = {}
58 | for nucl in nucleotides:
59 | nucl_id = hash_sequence(nucl)
60 | if nucl not in visited:
61 | clustering[nucl_id] = nucl_id
62 | visited[nucl] = nucl_id
63 | else:
64 | clustering[nucl_id] = visited[nucl]
65 |
66 | # Load ligand data
67 | with Path(args.ccd).open("rb") as handle:
68 | ligand_data = pickle.load(handle) # noqa: S301
69 |
70 | # Each unique ligand CCD is given an id
71 | visited = {}
72 | for ccd_code in ligand_data:
73 | clustering[ccd_code] = ccd_code
74 |
75 | # Save clustering
76 | with (outdir / "clustering.json").open("w") as handle:
77 | json.dump(clustering, handle)
78 |
79 |
80 | if __name__ == "__main__":
81 | parser = argparse.ArgumentParser()
82 | parser.add_argument(
83 | "--sequences",
84 | type=str,
85 | help="Input to protein fasta.",
86 | required=True,
87 | )
88 | parser.add_argument(
89 | "--ccd",
90 | type=str,
91 | help="Input to rna fasta.",
92 | required=True,
93 | )
94 | parser.add_argument(
95 | "--outdir",
96 | type=str,
97 | help="Output directory.",
98 | required=True,
99 | )
100 | parser.add_argument(
101 | "--mmseqs",
102 | type=str,
103 | help="Path to mmseqs program.",
104 | default="mmseqs",
105 | )
106 | args = parser.parse_args()
107 | main(args)
108 |
--------------------------------------------------------------------------------
/scripts/process/msa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing
3 | from dataclasses import asdict
4 | from functools import partial
5 | from pathlib import Path
6 | from typing import Any
7 |
8 | import numpy as np
9 | from redis import Redis
10 | from tqdm import tqdm
11 |
12 | from boltz.data.parse.a3m import parse_a3m
13 |
14 |
15 | class Resource:
16 | """A shared resource for processing."""
17 |
18 | def __init__(self, host: str, port: int) -> None:
19 | """Initialize the redis database."""
20 | self._redis = Redis(host=host, port=port)
21 |
22 | def get(self, key: str) -> Any: # noqa: ANN401
23 | """Get an item from the Redis database."""
24 | return self._redis.get(key)
25 |
26 | def __getitem__(self, key: str) -> Any: # noqa: ANN401
27 | """Get an item from the resource."""
28 | out = self.get(key)
29 | if out is None:
30 | raise KeyError(key)
31 | return out
32 |
33 |
34 | def process_msa(
35 | data: list,
36 | host: str,
37 | port: int,
38 | outdir: str,
39 | max_seqs: int,
40 | ) -> None:
41 | """Run processing in a worker thread."""
42 | outdir = Path(outdir)
43 | resource = Resource(host=host, port=port)
44 | for path in tqdm(data, total=len(data)):
45 | out_path = outdir / f"{path.stem}.npz"
46 | if not out_path.exists():
47 | msa = parse_a3m(path, resource, max_seqs)
48 | np.savez_compressed(out_path, **asdict(msa))
49 |
50 |
51 | def process(args) -> None:
52 | """Run the data processing task."""
53 | # Create output directory
54 | args.outdir.mkdir(parents=True, exist_ok=True)
55 |
56 | # Check if we can run in parallel
57 | num_processes = min(args.num_processes, multiprocessing.cpu_count())
58 | parallel = num_processes > 1
59 |
60 | # Load shared data from redis
61 | print("Loading shared data from Redis...")
62 | shared_data = Resource(host=args.redis_host, port=args.redis_port)
63 |
64 | # Get data points
65 | print("Fetching data...")
66 | data = args.msadir.rglob("*.a3m*")
67 | print(f"Found {len(data)} MSA's.")
68 |
69 | # Randomly permute the data
70 | random = np.random.RandomState()
71 | permute = random.permutation(len(data))
72 | data = [data[i] for i in permute]
73 |
74 | # Run processing
75 | if parallel:
76 | # Create processing function
77 | fn = partial(
78 | process_msa,
79 | host=args.redis_host,
80 | port=args.redis_port,
81 | outdir=args.outdir,
82 | )
83 |
84 | # Split the data into random chunks
85 | size = len(data) // num_processes
86 | chunks = [data[i : i + size] for i in range(0, len(data), size)]
87 |
88 | # Run processing in parallel
89 | with multiprocessing.Pool(num_processes) as pool: # noqa: SIM117
90 | with tqdm(total=len(chunks)) as pbar:
91 | for _ in pool.imap_unordered(fn, chunks):
92 | pbar.update()
93 | else:
94 | for item in tqdm(data, total=len(data)):
95 | process_msa(item, shared_data, args.outdir)
96 |
97 |
98 | if __name__ == "__main__":
99 | parser = argparse.ArgumentParser(description="Process MSA data.")
100 | parser.add_argument(
101 | "--msadir",
102 | type=Path,
103 | required=True,
104 | help="The MSA data directory.",
105 | )
106 | parser.add_argument(
107 | "--outdir",
108 | type=Path,
109 | default="data",
110 | help="The output directory.",
111 | )
112 | parser.add_argument(
113 | "--num-processes",
114 | type=int,
115 | default=multiprocessing.cpu_count(),
116 | help="The number of processes.",
117 | )
118 | parser.add_argument(
119 | "--redis-host",
120 | type=str,
121 | default="localhost",
122 | help="The Redis host.",
123 | )
124 | parser.add_argument(
125 | "--redis-port",
126 | type=int,
127 | default=7777,
128 | help="The Redis port.",
129 | )
130 | parser.add_argument(
131 | "--max-seqs",
132 | type=int,
133 | default=16384,
134 | help="The maximum number of sequences.",
135 | )
136 | args = parser.parse_args()
137 | process(args)
138 |
--------------------------------------------------------------------------------
/scripts/process/requirements.txt:
--------------------------------------------------------------------------------
1 | gemmi
2 | pdbeccdutils
3 | redis
4 | scikit-learn
5 | p_tqdm
--------------------------------------------------------------------------------
/scripts/train/README.md:
--------------------------------------------------------------------------------
1 | Please see our [training instructions](../../docs/training.md).
--------------------------------------------------------------------------------
/scripts/train/assets/casp15_ids.txt:
--------------------------------------------------------------------------------
1 | T1112
2 | T1118v1
3 | T1154
4 | T1137s1
5 | T1188
6 | T1157s1
7 | T1137s6
8 | R1117
9 | H1106
10 | T1106s2
11 | R1149
12 | T1158
13 | T1137s2
14 | T1145
15 | T1121
16 | T1123
17 | T1113
18 | R1156
19 | T1114s1
20 | T1183
21 | R1107
22 | T1137s7
23 | T1124
24 | T1178
25 | T1147
26 | R1128
27 | T1161
28 | R1108
29 | T1194
30 | T1185s2
31 | T1176
32 | T1158v3
33 | T1137s4
34 | T1160
35 | T1120
36 | H1185
37 | T1134s1
38 | T1119
39 | H1151
40 | T1137s8
41 | T1133
42 | T1187
43 | H1157
44 | T1122
45 | T1104
46 | T1158v2
47 | T1137s5
48 | T1129s2
49 | T1174
50 | T1157s2
51 | T1155
52 | T1158v4
53 | T1152
54 | T1137s9
55 | T1134s2
56 | T1125
57 | R1116
58 | H1134
59 | R1136
60 | T1159
61 | T1137s3
62 | T1185s1
63 | T1179
64 | T1106s1
65 | T1132
66 | T1185s4
67 | T1114s3
68 | T1114s2
69 | T1151s2
70 | T1158v1
71 | R1117v2
72 | T1173
73 |
--------------------------------------------------------------------------------
/scripts/train/assets/test_ids.txt:
--------------------------------------------------------------------------------
1 | 8BZ4
2 | 8URN
3 | 7U71
4 | 7Z64
5 | 7Y3Z
6 | 8SOT
7 | 8GH8
8 | 8IIB
9 | 7U08
10 | 8EB5
11 | 8G49
12 | 8K7Y
13 | 7QQD
14 | 8EIL
15 | 8JQE
16 | 8V1K
17 | 7ZRZ
18 | 7YN2
19 | 8D40
20 | 8RXO
21 | 8SXS
22 | 7UDL
23 | 8ADD
24 | 7Z3I
25 | 7YUK
26 | 7XWY
27 | 8F9Y
28 | 8WO7
29 | 8C27
30 | 8I3J
31 | 8HVC
32 | 8SXU
33 | 8K1I
34 | 8FTV
35 | 8ERC
36 | 8DVQ
37 | 8DTQ
38 | 8J12
39 | 8D0P
40 | 8POG
41 | 8HN0
42 | 7QPK
43 | 8AGR
44 | 8GXR
45 | 8K7X
46 | 8BL6
47 | 8HAW
48 | 8SRO
49 | 8HHM
50 | 8C26
51 | 7SPQ
52 | 8SME
53 | 7XGV
54 | 8GTY
55 | 8Q42
56 | 8BRY
57 | 8HDV
58 | 8B3Z
59 | 7XNJ
60 | 8EEL
61 | 8IOI
62 | 8Q70
63 | 8Y4U
64 | 8ANT
65 | 8IUB
66 | 8D49
67 | 8CPQ
68 | 8BAT
69 | 8E2B
70 | 8IWP
71 | 8IJT
72 | 7Y01
73 | 8CJG
74 | 8HML
75 | 8WU2
76 | 8VRM
77 | 8J1J
78 | 8DAJ
79 | 8SUT
80 | 8PTJ
81 | 8IVZ
82 | 8SDZ
83 | 7YDQ
84 | 8JU7
85 | 8K34
86 | 8B6Q
87 | 8F7N
88 | 8IBZ
89 | 7WOI
90 | 8R7D
91 | 8T65
92 | 8IQC
93 | 8SIU
94 | 8QK8
95 | 8HIG
96 | 7Y43
97 | 8IN8
98 | 8IBW
99 | 8GOY
100 | 7ZAO
101 | 8J9G
102 | 7ZCA
103 | 8HIO
104 | 8EFZ
105 | 8IQ8
106 | 8OQ0
107 | 8HHL
108 | 7XMW
109 | 8GI1
110 | 8AYR
111 | 7ZCB
112 | 8BRD
113 | 8IN6
114 | 8I3F
115 | 8HIU
116 | 8ER5
117 | 8WIL
118 | 7YPR
119 | 8UA2
120 | 8BW6
121 | 8IL8
122 | 8J3R
123 | 8K1F
124 | 8OHI
125 | 8WCT
126 | 8AN0
127 | 8BDQ
128 | 7FCT
129 | 8J69
130 | 8HTX
131 | 8PE3
132 | 8K5U
133 | 8AXT
134 | 8PSO
135 | 8JHR
136 | 8GY0
137 | 8QCW
138 | 8K3D
139 | 8P6J
140 | 8J0Q
141 | 7XS3
142 | 8DHJ
143 | 8EIN
144 | 7WKP
145 | 8GAQ
146 | 7WRN
147 | 8AHD
148 | 7SC4
149 | 8B3E
150 | 8AAS
151 | 8UZ8
152 | 8Q1K
153 | 8K5K
154 | 8B45
155 | 8PT7
156 | 7ZPN
157 | 8UQ9
158 | 8TJG
159 | 8TN8
160 | 8B2E
161 | 7XFZ
162 | 8FW7
163 | 8B3W
164 | 7T4W
165 | 8SVA
166 | 7YL4
167 | 8GLD
168 | 8OEI
169 | 8GMX
170 | 8OWF
171 | 8FNR
172 | 8IRQ
173 | 8JDG
174 | 7UXA
175 | 8TKA
176 | 7YH1
177 | 8HUZ
178 | 8TA2
179 | 8E5D
180 | 7YUN
181 | 7UOI
182 | 7WMY
183 | 8AA9
184 | 8ISZ
185 | 8EXA
186 | 8E7F
187 | 8B2S
188 | 8TP8
189 | 8GSY
190 | 7XRX
191 | 8SY3
192 | 8CIL
193 | 8WBR
194 | 7XF1
195 | 7YPO
196 | 8AXF
197 | 7QNL
198 | 8OYY
199 | 7R1N
200 | 8H5S
201 | 8B6U
202 | 8IBX
203 | 8Q43
204 | 8OW8
205 | 7XSG
206 | 8U0M
207 | 8IOO
208 | 8HR5
209 | 8BVK
210 | 8P0C
211 | 7TL6
212 | 8J48
213 | 8S0U
214 | 8K8A
215 | 8G53
216 | 7XYO
217 | 8POF
218 | 8U1K
219 | 8HF2
220 | 8K4L
221 | 8JAH
222 | 8KGZ
223 | 8BNB
224 | 7UG2
225 | 8A0A
226 | 8Q3Z
227 | 8XBI
228 | 8JNM
229 | 8GPS
230 | 8K1R
231 | 8Q66
232 | 7YLQ
233 | 7YNX
234 | 8IMD
235 | 7Y8H
236 | 8OXU
237 | 8BVE
238 | 8B4E
239 | 8V14
240 | 7R5I
241 | 8IR2
242 | 8UK7
243 | 8EBB
244 | 7XCC
245 | 8AEP
246 | 7YDW
247 | 8XX9
248 | 7VS6
249 | 8K3F
250 | 8CQM
251 | 7XH4
252 | 8BH9
253 | 7VXT
254 | 8SM9
255 | 8HGU
256 | 8PSQ
257 | 8SSU
258 | 8VXA
259 | 8GSX
260 | 8GHZ
261 | 8BJ3
262 | 8C9V
263 | 8T66
264 | 7XPC
265 | 8RH3
266 | 8CMQ
267 | 8AGG
268 | 8ERM
269 | 8P6M
270 | 8BUX
271 | 7S2J
272 | 8G32
273 | 8AXJ
274 | 8CID
275 | 8CPK
276 | 8P5Q
277 | 8HP8
278 | 7YUJ
279 | 8PT2
280 | 7YK3
281 | 7YYG
282 | 8ABV
283 | 7XL7
284 | 7YLZ
285 | 8JWS
286 | 8IW5
287 | 8SM6
288 | 8BBZ
289 | 8EOV
290 | 8PXC
291 | 7UWV
292 | 8A9N
293 | 7YH5
294 | 8DEO
295 | 7X2X
296 | 8W7P
297 | 8B5W
298 | 8CIH
299 | 8RB4
300 | 8HLG
301 | 8J8H
302 | 8UA5
303 | 7YKM
304 | 8S9W
305 | 7YPD
306 | 8GA6
307 | 7YPQ
308 | 8X7X
309 | 8HI8
310 | 8H7A
311 | 8C4D
312 | 8XAT
313 | 8W8S
314 | 8HM4
315 | 8H3Z
316 | 7W91
317 | 8GPP
318 | 8TNM
319 | 7YSI
320 | 8OML
321 | 8BBR
322 | 7YOJ
323 | 8JZX
324 | 8I3X
325 | 8AU6
326 | 8ITO
327 | 7SFY
328 | 8B6P
329 | 7Y8S
330 | 8ESL
331 | 8DSP
332 | 8CLZ
333 | 8F72
334 | 8QLD
335 | 8K86
336 | 8G8E
337 | 8QDO
338 | 8ANU
339 | 8PT6
340 | 8F5D
341 | 8DQ6
342 | 8IFK
343 | 8OJN
344 | 8SSC
345 | 7QRR
346 | 8E55
347 | 7TPU
348 | 7UQU
349 | 8HFP
350 | 7XGT
351 | 8A39
352 | 8CB2
353 | 8ACR
354 | 8G5S
355 | 7TZL
356 | 8T4R
357 | 8H18
358 | 7UI4
359 | 8Q41
360 | 8K76
361 | 7WUY
362 | 8VXC
363 | 8GYG
364 | 8IMS
365 | 8IKS
366 | 8X51
367 | 7Y7O
368 | 8PX4
369 | 8BF8
370 | 7XMJ
371 | 8GDW
372 | 7YTU
373 | 8CH4
374 | 7XHZ
375 | 7YH4
376 | 8PSN
377 | 8A16
378 | 8FBJ
379 | 7Y9G
380 | 8JI2
381 | 7YR9
382 | 8SW0
383 | 8A90
384 | 8X6V
385 | 8H8P
386 | 7WJU
387 | 8PSS
388 | 8HL8
389 | 8FJD
390 | 8PM4
391 | 7UK8
392 | 8DX0
393 | 8PHB
394 | 8FBN
395 | 8FXF
396 | 8GKH
397 | 8ENR
398 | 8PTH
399 | 8CBV
400 | 8GKV
401 | 8CQO
402 | 8OK3
403 | 8GSR
404 | 8TPK
405 | 8H1J
406 | 8QFL
407 | 8CHW
408 | 7V34
409 | 8HE2
410 | 7ZIE
411 | 8A50
412 | 7Z8E
413 | 8ILL
414 | 7WWC
415 | 7XVI
416 | 8Q2A
417 | 8HNO
418 | 8PR6
419 | 7XCA
420 | 7XGS
421 | 8H55
422 | 8FJE
423 | 7UNH
424 | 8AY2
425 | 8ARD
426 | 8HBR
427 | 8EWG
428 | 8D4A
429 | 8FIT
430 | 8E5E
431 | 8PMU
432 | 8F5G
433 | 8AMU
434 | 8CPN
435 | 7QPL
436 | 8EHN
437 | 8SQU
438 | 8F70
439 | 8FX9
440 | 7UR2
441 | 8T1M
442 | 7ZDS
443 | 7YH2
444 | 8B6A
445 | 8CHX
446 | 8G0N
447 | 8GY4
448 | 7YKG
449 | 8BH8
450 | 8BVI
451 | 7XF2
452 | 8BFY
453 | 8IA3
454 | 8JW3
455 | 8OQJ
456 | 8TFS
457 | 7Y1S
458 | 8HBB
459 | 8AF9
460 | 8IP1
461 | 7XZ3
462 | 8T0P
463 | 7Y16
464 | 8BRP
465 | 8JNX
466 | 8JP0
467 | 8EC3
468 | 8PZH
469 | 7URP
470 | 8B4D
471 | 8JFR
472 | 8GYR
473 | 7XFS
474 | 8SMQ
475 | 7WNH
476 | 8H0L
477 | 8OWI
478 | 8HFC
479 | 7X6G
480 | 8FKL
481 | 8PAG
482 | 8UPI
483 | 8D4B
484 | 8BCK
485 | 8JFU
486 | 8FUQ
487 | 8IF8
488 | 8PAQ
489 | 8HDU
490 | 8W9O
491 | 8ACA
492 | 7YIA
493 | 7ZFR
494 | 7Y9A
495 | 8TTO
496 | 7YFX
497 | 8B2H
498 | 8PSU
499 | 8ACC
500 | 8JMR
501 | 8IHA
502 | 7UYX
503 | 8DWJ
504 | 8BY5
505 | 8EZW
506 | 8A82
507 | 8TVL
508 | 8R79
509 | 8R8A
510 | 8AHZ
511 | 8AYV
512 | 8JHU
513 | 8Q44
514 | 8ARE
515 | 8OLJ
516 | 7Y95
517 | 7XP0
518 | 8EX9
519 | 8BID
520 | 8Q40
521 | 7QSJ
522 | 7UBA
523 | 7XFU
524 | 8OU1
525 | 8G2V
526 | 8YA7
527 | 8GMZ
528 | 8T8L
529 | 8CK0
530 | 7Y4H
531 | 8IOM
532 | 7ZLQ
533 | 8BZ2
534 | 8B4C
535 | 8DZJ
536 | 8CEG
537 | 8IBY
538 | 8T3J
539 | 8IVI
540 | 8ITN
541 | 8CR7
542 | 8TGH
543 | 8OKH
544 | 7UI8
545 | 8EHT
546 | 8ADC
547 | 8T4C
548 | 7XBJ
549 | 8CLU
550 | 7QA1
551 |
--------------------------------------------------------------------------------
/scripts/train/assets/validation_ids.txt:
--------------------------------------------------------------------------------
1 | 7UTN
2 | 7F9H
3 | 7TZV
4 | 7ZHH
5 | 7SOV
6 | 7EOF
7 | 7R8H
8 | 8AW3
9 | 7F2F
10 | 8BAO
11 | 7BCB
12 | 7D8T
13 | 7D3T
14 | 7BHY
15 | 7YZ7
16 | 8DC2
17 | 7SOW
18 | 8CTL
19 | 7SOS
20 | 7V6W
21 | 7Z55
22 | 7NQF
23 | 7VTN
24 | 7KSP
25 | 7BJQ
26 | 7YZC
27 | 7Y3L
28 | 7TDX
29 | 7R8I
30 | 7OYK
31 | 7TZ1
32 | 7KIJ
33 | 7T8K
34 | 7KII
35 | 7YZA
36 | 7VP4
37 | 7KIK
38 | 7M5W
39 | 7Q94
40 | 7BCA
41 | 7YZB
42 | 7OG0
43 | 7VTI
44 | 7SOP
45 | 7S03
46 | 7YZG
47 | 7TXC
48 | 7VP5
49 | 7Y3I
50 | 7TDW
51 | 8B0R
52 | 7R8G
53 | 7FEF
54 | 7VP1
55 | 7VP3
56 | 7RGU
57 | 7DV2
58 | 7YZD
59 | 7OFZ
60 | 7Y3K
61 | 7TEC
62 | 7WQ5
63 | 7VP2
64 | 7EDB
65 | 7VP7
66 | 7PDV
67 | 7XHT
68 | 7R6R
69 | 8CSH
70 | 8CSZ
71 | 7V9O
72 | 7Q1C
73 | 8EDC
74 | 7PWI
75 | 7FI1
76 | 7ESI
77 | 7F0Y
78 | 7EYR
79 | 7ZVA
80 | 7WEG
81 | 7E4N
82 | 7U5Q
83 | 7FAV
84 | 7LJ2
85 | 7S6F
86 | 7B3N
87 | 7V4P
88 | 7AJO
89 | 7WH1
90 | 8DQP
91 | 7STT
92 | 7VQ7
93 | 7E4J
94 | 7RIS
95 | 7FH8
96 | 7BMW
97 | 7RD0
98 | 7V54
99 | 7LKC
100 | 7OU1
101 | 7QOD
102 | 7PX1
103 | 7EBY
104 | 7U1V
105 | 7PLP
106 | 7T8N
107 | 7SJK
108 | 7RGB
109 | 7TEM
110 | 7UG9
111 | 7B7A
112 | 7TM2
113 | 7Z74
114 | 7PCM
115 | 7V8G
116 | 7EUU
117 | 7VTL
118 | 7ZEI
119 | 7ZC0
120 | 7DZ9
121 | 8B2M
122 | 7NE9
123 | 7ALV
124 | 7M96
125 | 7O6T
126 | 7SKO
127 | 7Z2V
128 | 7OWX
129 | 7SHW
130 | 7TNI
131 | 7ZQY
132 | 7MDF
133 | 7EXR
134 | 7W6B
135 | 7EQF
136 | 7WWO
137 | 7FBW
138 | 8EHE
139 | 7CLE
140 | 7T80
141 | 7WMV
142 | 7SMG
143 | 7WSJ
144 | 7DBU
145 | 7VHY
146 | 7W5F
147 | 7SHG
148 | 7VU3
149 | 7ATH
150 | 7FGZ
151 | 7ADS
152 | 7REO
153 | 7T7H
154 | 7X0N
155 | 7TCU
156 | 7SKH
157 | 7EF6
158 | 7TBV
159 | 7B29
160 | 7VO5
161 | 7TM1
162 | 7QLD
163 | 7BB9
164 | 7SZ8
165 | 7RLM
166 | 7WWP
167 | 7NBV
168 | 7PLD
169 | 7DNM
170 | 7SFZ
171 | 7EAW
172 | 7QNQ
173 | 7SZX
174 | 7U2S
175 | 7WZX
176 | 7TYG
177 | 7QCE
178 | 7DCN
179 | 7WJL
180 | 7VV6
181 | 7TJ4
182 | 7VI8
183 | 8AKP
184 | 7WAO
185 | 7N7V
186 | 7EYO
187 | 7VTD
188 | 7VEG
189 | 7QY5
190 | 7ELV
191 | 7P0J
192 | 7YX8
193 | 7U4H
194 | 7TBD
195 | 7WME
196 | 7RI3
197 | 7TOH
198 | 7ZVM
199 | 7PUL
200 | 7VBO
201 | 7DM0
202 | 7XN9
203 | 7ALY
204 | 7LTB
205 | 8A28
206 | 7UBZ
207 | 8DTE
208 | 7TA2
209 | 7QST
210 | 7AN1
211 | 7FIB
212 | 8BAL
213 | 7TMJ
214 | 7REV
215 | 7PZJ
216 | 7T9X
217 | 7SUU
218 | 7KJQ
219 | 7V6P
220 | 7QA3
221 | 7ULC
222 | 7Y3X
223 | 7TMU
224 | 7OA7
225 | 7PO9
226 | 7Q20
227 | 8H2C
228 | 7VW1
229 | 7VLJ
230 | 8EP4
231 | 7P57
232 | 7QUL
233 | 7ZQE
234 | 7UJU
235 | 7WG1
236 | 7DMK
237 | 7Y8X
238 | 7EHG
239 | 7W13
240 | 7NL4
241 | 7R4J
242 | 7AOV
243 | 7RFT
244 | 7VUF
245 | 7F72
246 | 8DSR
247 | 7MK3
248 | 7MQQ
249 | 7R55
250 | 7T85
251 | 7NCY
252 | 7ZHL
253 | 7E1N
254 | 7W8F
255 | 7PGK
256 | 8GUN
257 | 7P8D
258 | 7PUK
259 | 7N9D
260 | 7XWN
261 | 7ZHA
262 | 7TVP
263 | 7VI6
264 | 7PW6
265 | 7YM0
266 | 7RWK
267 | 8DKR
268 | 7WGU
269 | 7LJI
270 | 7THW
271 | 7OB6
272 | 7N3Z
273 | 7T3S
274 | 7PAB
275 | 7F9F
276 | 7PPP
277 | 7AD5
278 | 7VGM
279 | 7WBO
280 | 7RWM
281 | 7QFI
282 | 7T91
283 | 7ANU
284 | 7UX0
285 | 7USR
286 | 7RDN
287 | 7VW5
288 | 7Q4T
289 | 7W3R
290 | 8DKQ
291 | 7RCX
292 | 7UOF
293 | 7OKR
294 | 7NX1
295 | 6ZBS
296 | 7VEV
297 | 8E8U
298 | 7WJ6
299 | 7MP4
300 | 7RPY
301 | 7R5Z
302 | 7VLM
303 | 7SNE
304 | 7WDW
305 | 8E19
306 | 7PP2
307 | 7Z5H
308 | 7P7I
309 | 7LJJ
310 | 7QPC
311 | 7VJS
312 | 7QOE
313 | 7KZH
314 | 7F6N
315 | 7TMI
316 | 7POH
317 | 8DKS
318 | 7YMO
319 | 6S5I
320 | 7N6O
321 | 7LYU
322 | 7POK
323 | 7BLK
324 | 7TCY
325 | 7W19
326 | 8B55
327 | 7SMU
328 | 7QFK
329 | 7T5T
330 | 7EPQ
331 | 7DCK
332 | 7S69
333 | 6ZSV
334 | 7ZGT
335 | 7TJ1
336 | 7V09
337 | 7ZHD
338 | 7ALL
339 | 7P1Y
340 | 7T71
341 | 7MNK
342 | 7W5Q
343 | 7PZ2
344 | 7QSQ
345 | 7QI3
346 | 7NZZ
347 | 7Q47
348 | 8D08
349 | 7QH5
350 | 7RXQ
351 | 7F45
352 | 8D07
353 | 8EHC
354 | 7PZT
355 | 7K3C
356 | 7ZGI
357 | 7MC4
358 | 7NPQ
359 | 7VD7
360 | 7XAN
361 | 7FDP
362 | 8A0K
363 | 7TXO
364 | 7ZB1
365 | 7V5V
366 | 7WWS
367 | 7PBK
368 | 8EBG
369 | 7N0J
370 | 7UMA
371 | 7T1S
372 | 8EHB
373 | 7DWC
374 | 7K6W
375 | 7WEJ
376 | 7LRH
377 | 7ZCV
378 | 7RKC
379 | 7X8C
380 | 7PV1
381 | 7UGK
382 | 7ULN
383 | 7A66
384 | 7R7M
385 | 7M0Q
386 | 7BGS
387 | 7UPP
388 | 7O62
389 | 7VKK
390 | 7L6Y
391 | 7VG4
392 | 7V2V
393 | 7ETN
394 | 7ZTB
395 | 7AOO
396 | 7OH2
397 | 7E0M
398 | 7PEG
399 | 8CUK
400 | 7ZP0
401 | 7T6A
402 | 7BTM
403 | 7DOV
404 | 7VVV
405 | 7P22
406 | 7RUO
407 | 7E40
408 | 7O5Y
409 | 7XPK
410 | 7R0K
411 | 8D04
412 | 7TYD
413 | 7LSV
414 | 7XSI
415 | 7RTZ
416 | 7UXR
417 | 7QH3
418 | 8END
419 | 8CYK
420 | 7MRJ
421 | 7DJL
422 | 7S5B
423 | 7XUX
424 | 7EV8
425 | 7R6S
426 | 7UH4
427 | 7R9X
428 | 7F7P
429 | 7ACW
430 | 7SPN
431 | 7W70
432 | 7Q5G
433 | 7DXN
434 | 7DK9
435 | 8DT0
436 | 7FDN
437 | 7DGX
438 | 7UJB
439 | 7X4O
440 | 7F4O
441 | 7T9W
442 | 8AID
443 | 7ERQ
444 | 7EQB
445 | 7YDG
446 | 7ETR
447 | 8D27
448 | 7OUU
449 | 7R5Y
450 | 7T8I
451 | 7UZT
452 | 7X8V
453 | 7QLH
454 | 7SAF
455 | 7EN6
456 | 8D4Y
457 | 7ESJ
458 | 7VWO
459 | 7SBE
460 | 7VYU
461 | 7RVJ
462 | 7FCL
463 | 7WUO
464 | 7WWF
465 | 7VMT
466 | 7SHJ
467 | 7SKP
468 | 7KOU
469 | 6ZSU
470 | 7VGW
471 | 7X45
472 | 8GYZ
473 | 8BFE
474 | 8DGL
475 | 7Z3H
476 | 8BD1
477 | 8A0J
478 | 7JRK
479 | 7QII
480 | 7X39
481 | 7Y6B
482 | 7OIY
483 | 7SBI
484 | 8A3I
485 | 7NLI
486 | 7F4U
487 | 7TVY
488 | 7X0O
489 | 7VMH
490 | 7EPN
491 | 7WBK
492 | 8BFJ
493 | 7XFP
494 | 7LXQ
495 | 7TIL
496 | 7O61
497 | 8B8B
498 | 7W2Q
499 | 8APR
500 | 7WZE
501 | 7NYQ
502 | 7RMX
503 | 7PGE
504 | 8F43
505 | 7N2K
506 | 7UXG
507 | 7SXN
508 | 7T5U
509 | 7R22
510 | 7E3T
511 | 7PTB
512 | 7OA8
513 | 7X5T
514 | 7PL7
515 | 7SQ5
516 | 7VBS
517 | 8D03
518 | 7TAE
519 | 7T69
520 | 7WF6
521 | 7LBU
522 | 8A06
523 | 8DA2
524 | 7QFL
525 | 7KUW
526 | 7X9R
527 | 7XT3
528 | 7RB4
529 | 7PT5
530 | 7RPS
531 | 7RXU
532 | 7TDY
533 | 7W89
534 | 7N9I
535 | 7T1M
536 | 7OBM
537 | 7K3X
538 | 7ZJC
539 | 8BDP
540 | 7V8W
541 | 7DJK
542 | 7W1K
543 | 7QFG
544 | 7DGY
545 | 7ZTQ
546 | 7F8A
547 | 7NEK
548 | 7CG9
549 | 7KOB
550 | 7TN7
551 | 8DYS
552 | 7WVR
553 |
--------------------------------------------------------------------------------
/scripts/train/configs/confidence.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | accelerator: gpu
3 | devices: 1
4 | precision: 32
5 | gradient_clip_val: 10.0
6 | max_epochs: -1
7 |
8 | # Optional set wandb here
9 | # wandb:
10 | # name: boltz
11 | # project: boltz
12 | # entity: boltz
13 |
14 |
15 | output: SET_PATH_HERE
16 | pretrained: PATH_TO_STRUCTURE_CHECKPOINT_FILE
17 | resume: null
18 | disable_checkpoint: false
19 | matmul_precision: null
20 | save_top_k: -1
21 | load_confidence_from_trunk: true
22 |
23 | data:
24 | datasets:
25 | - _target_: boltz.data.module.training.DatasetConfig
26 | target_dir: PATH_TO_TARGETS_DIR
27 | msa_dir: PATH_TO_MSA_DIR
28 | prob: 1.0
29 | sampler:
30 | _target_: boltz.data.sample.cluster.ClusterSampler
31 | cropper:
32 | _target_: boltz.data.crop.boltz.BoltzCropper
33 | min_neighborhood: 0
34 | max_neighborhood: 40
35 | split: ./scripts/train/assets/validation_ids.txt
36 |
37 | filters:
38 | - _target_: boltz.data.filter.dynamic.size.SizeFilter
39 | min_chains: 1
40 | max_chains: 300
41 | - _target_: boltz.data.filter.dynamic.date.DateFilter
42 | date: "2021-09-30"
43 | ref: released
44 | - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter
45 | resolution: 4.0
46 |
47 | tokenizer:
48 | _target_: boltz.data.tokenize.boltz.BoltzTokenizer
49 | featurizer:
50 | _target_: boltz.data.feature.featurizer.BoltzFeaturizer
51 |
52 | symmetries: PATH_TO_SYMMETRY_FILE
53 | max_tokens: 512
54 | max_atoms: 4608
55 | max_seqs: 2048
56 | pad_to_max_tokens: true
57 | pad_to_max_atoms: true
58 | pad_to_max_seqs: true
59 | samples_per_epoch: 100000
60 | batch_size: 1
61 | num_workers: 4
62 | random_seed: 42
63 | pin_memory: true
64 | overfit: null
65 | crop_validation: true
66 | return_train_symmetries: true
67 | return_val_symmetries: true
68 | train_binder_pocket_conditioned_prop: 0.3
69 | val_binder_pocket_conditioned_prop: 0.3
70 | binder_pocket_cutoff: 6.0
71 | binder_pocket_sampling_geometric_p: 0.3
72 | min_dist: 2.0
73 | max_dist: 22.0
74 | num_bins: 64
75 | atoms_per_window_queries: 32
76 |
77 | model:
78 | _target_: boltz.model.model.Boltz1
79 | atom_s: 128
80 | atom_z: 16
81 | token_s: 384
82 | token_z: 128
83 | num_bins: 64
84 | atom_feature_dim: 389
85 | atoms_per_window_queries: 32
86 | atoms_per_window_keys: 128
87 | compile_pairformer: false
88 | nucleotide_rmsd_weight: 5.0
89 | ligand_rmsd_weight: 10.0
90 | ema: true
91 | ema_decay: 0.999
92 |
93 | embedder_args:
94 | atom_encoder_depth: 3
95 | atom_encoder_heads: 4
96 |
97 | msa_args:
98 | msa_s: 64
99 | msa_blocks: 4
100 | msa_dropout: 0.15
101 | z_dropout: 0.25
102 | pairwise_head_width: 32
103 | pairwise_num_heads: 4
104 | activation_checkpointing: true
105 | offload_to_cpu: false
106 |
107 | pairformer_args:
108 | num_blocks: 48
109 | num_heads: 16
110 | dropout: 0.25
111 | activation_checkpointing: true
112 | offload_to_cpu: false
113 |
114 | score_model_args:
115 | sigma_data: 16
116 | dim_fourier: 256
117 | atom_encoder_depth: 3
118 | atom_encoder_heads: 4
119 | token_transformer_depth: 24
120 | token_transformer_heads: 16
121 | atom_decoder_depth: 3
122 | atom_decoder_heads: 4
123 | conditioning_transition_layers: 2
124 | activation_checkpointing: true
125 | offload_to_cpu: false
126 |
127 | structure_prediction_training: false
128 | confidence_prediction: true
129 | alpha_pae: 1
130 | confidence_imitate_trunk: true
131 | confidence_model_args:
132 | num_dist_bins: 64
133 | max_dist: 22
134 | add_s_to_z_prod: true
135 | add_s_input_to_s: true
136 | use_s_diffusion: true
137 | add_z_input_to_z: true
138 |
139 | confidence_args:
140 | num_plddt_bins: 50
141 | num_pde_bins: 64
142 | num_pae_bins: 64
143 |
144 | training_args:
145 | recycling_steps: 3
146 | sampling_steps: 200
147 | diffusion_multiplicity: 16
148 | diffusion_samples: 1
149 | confidence_loss_weight: 3e-3
150 | diffusion_loss_weight: 4.0
151 | distogram_loss_weight: 3e-2
152 | adam_beta_1: 0.9
153 | adam_beta_2: 0.95
154 | adam_eps: 0.00000001
155 | lr_scheduler: af3
156 | base_lr: 0.0
157 | max_lr: 0.0018
158 | lr_warmup_no_steps: 1000
159 | lr_start_decay_after_n_steps: 50000
160 | lr_decay_every_n_steps: 50000
161 | lr_decay_factor: 0.95
162 | symmetry_correction: true
163 |
164 | validation_args:
165 | recycling_steps: 3
166 | sampling_steps: 200
167 | diffusion_samples: 5
168 | symmetry_correction: true
169 | run_confidence_sequentially: false
170 |
171 | diffusion_process_args:
172 | sigma_min: 0.0004
173 | sigma_max: 160.0
174 | sigma_data: 16.0
175 | rho: 7
176 | P_mean: -1.2
177 | P_std: 1.5
178 | gamma_0: 0.8
179 | gamma_min: 1.0
180 | noise_scale: 1.0
181 | step_scale: 1.0
182 | coordinate_augmentation: true
183 | alignment_reverse_diff: true
184 | synchronize_sigmas: true
185 | use_inference_model_cache: true
186 |
187 | diffusion_loss_args:
188 | add_smooth_lddt_loss: true
189 | nucleotide_loss_weight: 5.0
190 | ligand_loss_weight: 10.0
191 |
--------------------------------------------------------------------------------
/scripts/train/configs/full.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | accelerator: gpu
3 | devices: 1
4 | precision: 32
5 | gradient_clip_val: 10.0
6 | max_epochs: -1
7 |
8 | # Optional set wandb here
9 | # wandb:
10 | # name: boltz
11 | # project: boltz
12 | # entity: boltz
13 |
14 |
15 | output: SET_PATH_HERE
16 | pretrained: PATH_TO_STRUCTURE_CHECKPOINT_FILE
17 | resume: null
18 | disable_checkpoint: false
19 | matmul_precision: null
20 | save_top_k: -1
21 |
22 | data:
23 | datasets:
24 | - _target_: boltz.data.module.training.DatasetConfig
25 | target_dir: PATH_TO_TARGETS_DIR
26 | msa_dir: PATH_TO_MSA_DIR
27 | prob: 1.0
28 | sampler:
29 | _target_: boltz.data.sample.cluster.ClusterSampler
30 | cropper:
31 | _target_: boltz.data.crop.boltz.BoltzCropper
32 | min_neighborhood: 0
33 | max_neighborhood: 40
34 | split: ./scripts/train/assets/validation_ids.txt
35 |
36 | filters:
37 | - _target_: boltz.data.filter.dynamic.size.SizeFilter
38 | min_chains: 1
39 | max_chains: 300
40 | - _target_: boltz.data.filter.dynamic.date.DateFilter
41 | date: "2021-09-30"
42 | ref: released
43 | - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter
44 | resolution: 4.0
45 |
46 | tokenizer:
47 | _target_: boltz.data.tokenize.boltz.BoltzTokenizer
48 | featurizer:
49 | _target_: boltz.data.feature.featurizer.BoltzFeaturizer
50 |
51 | symmetries: PATH_TO_SYMMETRY_FILE
52 | max_tokens: 512
53 | max_atoms: 4608
54 | max_seqs: 2048
55 | pad_to_max_tokens: true
56 | pad_to_max_atoms: true
57 | pad_to_max_seqs: true
58 | samples_per_epoch: 100000
59 | batch_size: 1
60 | num_workers: 4
61 | random_seed: 42
62 | pin_memory: true
63 | overfit: null
64 | crop_validation: true
65 | return_train_symmetries: true
66 | return_val_symmetries: true
67 | train_binder_pocket_conditioned_prop: 0.3
68 | val_binder_pocket_conditioned_prop: 0.3
69 | binder_pocket_cutoff: 6.0
70 | binder_pocket_sampling_geometric_p: 0.3
71 | min_dist: 2.0
72 | max_dist: 22.0
73 | num_bins: 64
74 | atoms_per_window_queries: 32
75 |
76 | model:
77 | _target_: boltz.model.model.Boltz1
78 | atom_s: 128
79 | atom_z: 16
80 | token_s: 384
81 | token_z: 128
82 | num_bins: 64
83 | atom_feature_dim: 389
84 | atoms_per_window_queries: 32
85 | atoms_per_window_keys: 128
86 | compile_pairformer: false
87 | nucleotide_rmsd_weight: 5.0
88 | ligand_rmsd_weight: 10.0
89 | ema: true
90 | ema_decay: 0.999
91 |
92 | embedder_args:
93 | atom_encoder_depth: 3
94 | atom_encoder_heads: 4
95 |
96 | msa_args:
97 | msa_s: 64
98 | msa_blocks: 4
99 | msa_dropout: 0.15
100 | z_dropout: 0.25
101 | pairwise_head_width: 32
102 | pairwise_num_heads: 4
103 | activation_checkpointing: true
104 | offload_to_cpu: false
105 |
106 | pairformer_args:
107 | num_blocks: 48
108 | num_heads: 16
109 | dropout: 0.25
110 | activation_checkpointing: true
111 | offload_to_cpu: false
112 |
113 | score_model_args:
114 | sigma_data: 16
115 | dim_fourier: 256
116 | atom_encoder_depth: 3
117 | atom_encoder_heads: 4
118 | token_transformer_depth: 24
119 | token_transformer_heads: 16
120 | atom_decoder_depth: 3
121 | atom_decoder_heads: 4
122 | conditioning_transition_layers: 2
123 | activation_checkpointing: true
124 | offload_to_cpu: false
125 |
126 | structure_prediction_training: true
127 | confidence_prediction: true
128 | alpha_pae: 1
129 | confidence_imitate_trunk: true
130 | confidence_model_args:
131 | num_dist_bins: 64
132 | max_dist: 22
133 | add_s_to_z_prod: true
134 | add_s_input_to_s: true
135 | use_s_diffusion: true
136 | add_z_input_to_z: true
137 |
138 | confidence_args:
139 | num_plddt_bins: 50
140 | num_pde_bins: 64
141 | num_pae_bins: 64
142 |
143 | training_args:
144 | recycling_steps: 3
145 | sampling_steps: 200
146 | diffusion_multiplicity: 16
147 | diffusion_samples: 1
148 | confidence_loss_weight: 3e-3
149 | diffusion_loss_weight: 4.0
150 | distogram_loss_weight: 3e-2
151 | adam_beta_1: 0.9
152 | adam_beta_2: 0.95
153 | adam_eps: 0.00000001
154 | lr_scheduler: af3
155 | base_lr: 0.0
156 | max_lr: 0.0018
157 | lr_warmup_no_steps: 1000
158 | lr_start_decay_after_n_steps: 50000
159 | lr_decay_every_n_steps: 50000
160 | lr_decay_factor: 0.95
161 | symmetry_correction: true
162 | run_confidence_sequentially: false
163 |
164 | validation_args:
165 | recycling_steps: 3
166 | sampling_steps: 200
167 | diffusion_samples: 5
168 | symmetry_correction: true
169 |
170 | diffusion_process_args:
171 | sigma_min: 0.0004
172 | sigma_max: 160.0
173 | sigma_data: 16.0
174 | rho: 7
175 | P_mean: -1.2
176 | P_std: 1.5
177 | gamma_0: 0.8
178 | gamma_min: 1.0
179 | noise_scale: 1.0
180 | step_scale: 1.0
181 | coordinate_augmentation: true
182 | alignment_reverse_diff: true
183 | synchronize_sigmas: true
184 | use_inference_model_cache: true
185 |
186 | diffusion_loss_args:
187 | add_smooth_lddt_loss: true
188 | nucleotide_loss_weight: 5.0
189 | ligand_loss_weight: 10.0
190 |
--------------------------------------------------------------------------------
/scripts/train/configs/structure.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | accelerator: gpu
3 | devices: 1
4 | precision: 32
5 | gradient_clip_val: 10.0
6 | max_epochs: -1
7 |
8 | # Optional set wandb here
9 | # wandb:
10 | # name: boltz
11 | # project: boltz
12 | # entity: boltz
13 |
14 | output: SET_PATH_HERE
15 | resume: PATH_TO_CHECKPOINT_FILE
16 | disable_checkpoint: false
17 | matmul_precision: null
18 | save_top_k: -1
19 |
20 | data:
21 | datasets:
22 | - _target_: boltz.data.module.training.DatasetConfig
23 | target_dir: PATH_TO_TARGETS_DIR
24 | msa_dir: PATH_TO_MSA_DIR
25 | prob: 1.0
26 | sampler:
27 | _target_: boltz.data.sample.cluster.ClusterSampler
28 | cropper:
29 | _target_: boltz.data.crop.boltz.BoltzCropper
30 | min_neighborhood: 0
31 | max_neighborhood: 40
32 | split: ./scripts/train/assets/validation_ids.txt
33 |
34 | filters:
35 | - _target_: boltz.data.filter.dynamic.size.SizeFilter
36 | min_chains: 1
37 | max_chains: 300
38 | - _target_: boltz.data.filter.dynamic.date.DateFilter
39 | date: "2021-09-30"
40 | ref: released
41 | - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter
42 | resolution: 9.0
43 |
44 | tokenizer:
45 | _target_: boltz.data.tokenize.boltz.BoltzTokenizer
46 | featurizer:
47 | _target_: boltz.data.feature.featurizer.BoltzFeaturizer
48 |
49 | symmetries: PATH_TO_SYMMETRY_FILE
50 | max_tokens: 512
51 | max_atoms: 4608
52 | max_seqs: 2048
53 | pad_to_max_tokens: true
54 | pad_to_max_atoms: true
55 | pad_to_max_seqs: true
56 | samples_per_epoch: 100000
57 | batch_size: 1
58 | num_workers: 4
59 | random_seed: 42
60 | pin_memory: true
61 | overfit: null
62 | crop_validation: false
63 | return_train_symmetries: false
64 | return_val_symmetries: true
65 | train_binder_pocket_conditioned_prop: 0.3
66 | val_binder_pocket_conditioned_prop: 0.3
67 | binder_pocket_cutoff: 6.0
68 | binder_pocket_sampling_geometric_p: 0.3
69 | min_dist: 2.0
70 | max_dist: 22.0
71 | num_bins: 64
72 | atoms_per_window_queries: 32
73 |
74 | model:
75 | _target_: boltz.model.model.Boltz1
76 | atom_s: 128
77 | atom_z: 16
78 | token_s: 384
79 | token_z: 128
80 | num_bins: 64
81 | atom_feature_dim: 389
82 | atoms_per_window_queries: 32
83 | atoms_per_window_keys: 128
84 | compile_pairformer: false
85 | nucleotide_rmsd_weight: 5.0
86 | ligand_rmsd_weight: 10.0
87 | ema: true
88 | ema_decay: 0.999
89 |
90 | embedder_args:
91 | atom_encoder_depth: 3
92 | atom_encoder_heads: 4
93 |
94 | msa_args:
95 | msa_s: 64
96 | msa_blocks: 4
97 | msa_dropout: 0.15
98 | z_dropout: 0.25
99 | pairwise_head_width: 32
100 | pairwise_num_heads: 4
101 | activation_checkpointing: true
102 | offload_to_cpu: false
103 |
104 | pairformer_args:
105 | num_blocks: 48
106 | num_heads: 16
107 | dropout: 0.25
108 | activation_checkpointing: true
109 | offload_to_cpu: false
110 |
111 | score_model_args:
112 | sigma_data: 16
113 | dim_fourier: 256
114 | atom_encoder_depth: 3
115 | atom_encoder_heads: 4
116 | token_transformer_depth: 24
117 | token_transformer_heads: 16
118 | atom_decoder_depth: 3
119 | atom_decoder_heads: 4
120 | conditioning_transition_layers: 2
121 | activation_checkpointing: true
122 | offload_to_cpu: false
123 |
124 | confidence_prediction: false
125 | confidence_model_args:
126 | num_dist_bins: 64
127 | max_dist: 22
128 | add_s_to_z_prod: true
129 | add_s_input_to_s: true
130 | use_s_diffusion: true
131 | add_z_input_to_z: true
132 |
133 | confidence_args:
134 | num_plddt_bins: 50
135 | num_pde_bins: 64
136 | num_pae_bins: 64
137 |
138 | training_args:
139 | recycling_steps: 3
140 | sampling_steps: 20
141 | diffusion_multiplicity: 16
142 | diffusion_samples: 2
143 | confidence_loss_weight: 1e-4
144 | diffusion_loss_weight: 4.0
145 | distogram_loss_weight: 3e-2
146 | adam_beta_1: 0.9
147 | adam_beta_2: 0.95
148 | adam_eps: 0.00000001
149 | lr_scheduler: af3
150 | base_lr: 0.0
151 | max_lr: 0.0018
152 | lr_warmup_no_steps: 1000
153 | lr_start_decay_after_n_steps: 50000
154 | lr_decay_every_n_steps: 50000
155 | lr_decay_factor: 0.95
156 |
157 | validation_args:
158 | recycling_steps: 3
159 | sampling_steps: 200
160 | diffusion_samples: 5
161 | symmetry_correction: true
162 | run_confidence_sequentially: false
163 |
164 | diffusion_process_args:
165 | sigma_min: 0.0004
166 | sigma_max: 160.0
167 | sigma_data: 16.0
168 | rho: 7
169 | P_mean: -1.2
170 | P_std: 1.5
171 | gamma_0: 0.8
172 | gamma_min: 1.0
173 | noise_scale: 1.0
174 | step_scale: 1.0
175 | coordinate_augmentation: true
176 | alignment_reverse_diff: true
177 | synchronize_sigmas: true
178 | use_inference_model_cache: true
179 |
180 | diffusion_loss_args:
181 | add_smooth_lddt_loss: true
182 | nucleotide_loss_weight: 5.0
183 | ligand_loss_weight: 10.0
184 |
--------------------------------------------------------------------------------
/scripts/train/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from dataclasses import dataclass
4 | from pathlib import Path
5 | from typing import Optional
6 |
7 | import hydra
8 | import omegaconf
9 | import pytorch_lightning as pl
10 | import torch
11 | import torch.multiprocessing
12 | from omegaconf import OmegaConf, listconfig
13 | from pytorch_lightning import LightningModule
14 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
15 | from pytorch_lightning.loggers import WandbLogger
16 | from pytorch_lightning.strategies import DDPStrategy
17 | from pytorch_lightning.utilities import rank_zero_only
18 |
19 | from boltz.data.module.training import BoltzTrainingDataModule, DataConfig
20 |
21 |
22 | @dataclass
23 | class TrainConfig:
24 | """Train configuration.
25 |
26 | Attributes
27 | ----------
28 | data : DataConfig
29 | The data configuration.
30 | model : ModelConfig
31 | The model configuration.
32 | output : str
33 | The output directory.
34 | trainer : Optional[dict]
35 | The trainer configuration.
36 | resume : Optional[str]
37 | The resume checkpoint.
38 | pretrained : Optional[str]
39 | The pretrained model.
40 | wandb : Optional[dict]
41 | The wandb configuration.
42 | disable_checkpoint : bool
43 | Disable checkpoint.
44 | matmul_precision : Optional[str]
45 | The matmul precision.
46 | find_unused_parameters : Optional[bool]
47 | Find unused parameters.
48 | save_top_k : Optional[int]
49 | Save top k checkpoints.
50 | validation_only : bool
51 | Run validation only.
52 | debug : bool
53 | Debug mode.
54 | strict_loading : bool
55 | Fail on mismatched checkpoint weights.
56 | load_confidence_from_trunk: Optional[bool]
57 | Load pre-trained confidence weights from trunk.
58 |
59 | """
60 |
61 | data: DataConfig
62 | model: LightningModule
63 | output: str
64 | trainer: Optional[dict] = None
65 | resume: Optional[str] = None
66 | pretrained: Optional[str] = None
67 | wandb: Optional[dict] = None
68 | disable_checkpoint: bool = False
69 | matmul_precision: Optional[str] = None
70 | find_unused_parameters: Optional[bool] = False
71 | save_top_k: Optional[int] = 1
72 | validation_only: bool = False
73 | debug: bool = False
74 | strict_loading: bool = True
75 | load_confidence_from_trunk: Optional[bool] = False
76 |
77 |
78 | def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR0915
79 | """Run training.
80 |
81 | Parameters
82 | ----------
83 | raw_config : str
84 | The input yaml configuration.
85 | args : list[str]
86 | Any command line overrides.
87 |
88 | """
89 | # Load the configuration
90 | raw_config = omegaconf.OmegaConf.load(raw_config)
91 |
92 | # Apply input arguments
93 | args = omegaconf.OmegaConf.from_dotlist(args)
94 | raw_config = omegaconf.OmegaConf.merge(raw_config, args)
95 |
96 | # Instantiate the task
97 | cfg = hydra.utils.instantiate(raw_config)
98 | cfg = TrainConfig(**cfg)
99 |
100 | # Set matmul precision
101 | if cfg.matmul_precision is not None:
102 | torch.set_float32_matmul_precision(cfg.matmul_precision)
103 |
104 | # Create trainer dict
105 | trainer = cfg.trainer
106 | if trainer is None:
107 | trainer = {}
108 |
109 | # Flip some arguments in debug mode
110 | devices = trainer.get("devices", 1)
111 |
112 | wandb = cfg.wandb
113 | if cfg.debug:
114 | if isinstance(devices, int):
115 | devices = 1
116 | elif isinstance(devices, (list, listconfig.ListConfig)):
117 | devices = [devices[0]]
118 | trainer["devices"] = devices
119 | cfg.data.num_workers = 0
120 | if wandb:
121 | wandb = None
122 |
123 | # Create objects
124 | data_config = DataConfig(**cfg.data)
125 | data_module = BoltzTrainingDataModule(data_config)
126 | model_module = cfg.model
127 |
128 | if cfg.pretrained and not cfg.resume:
129 | # Load the pretrained weights into the confidence module
130 | if cfg.load_confidence_from_trunk:
131 | checkpoint = torch.load(cfg.pretrained, map_location="cpu")
132 |
133 | # Modify parameter names in the state_dict
134 | new_state_dict = {}
135 | for key, value in checkpoint["state_dict"].items():
136 | if not key.startswith("structure_module") and not key.startswith(
137 | "distogram_module"
138 | ):
139 | new_key = "confidence_module." + key
140 | new_state_dict[new_key] = value
141 | new_state_dict.update(checkpoint["state_dict"])
142 |
143 | # Update the checkpoint with the new state_dict
144 | checkpoint["state_dict"] = new_state_dict
145 | else:
146 | file_path = cfg.pretrained
147 |
148 | print(f"Loading model from {file_path}")
149 | model_module = type(model_module).load_from_checkpoint(
150 | file_path, strict=False, **(model_module.hparams)
151 | )
152 |
153 | if cfg.load_confidence_from_trunk:
154 | os.remove(file_path)
155 |
156 | # Create checkpoint callback
157 | callbacks = []
158 | dirpath = cfg.output
159 | if not cfg.disable_checkpoint:
160 | mc = ModelCheckpoint(
161 | monitor="val/lddt",
162 | save_top_k=cfg.save_top_k,
163 | save_last=True,
164 | mode="max",
165 | every_n_epochs=1,
166 | )
167 | callbacks = [mc]
168 |
169 | # Create wandb logger
170 | loggers = []
171 | if wandb:
172 | wdb_logger = WandbLogger(
173 | group=wandb["name"],
174 | save_dir=cfg.output,
175 | project=wandb["project"],
176 | entity=wandb["entity"],
177 | log_model=False,
178 | )
179 | loggers.append(wdb_logger)
180 | # Save the config to wandb
181 |
182 | @rank_zero_only
183 | def save_config_to_wandb() -> None:
184 | config_out = Path(wdb_logger.experiment.dir) / "run.yaml"
185 | with Path.open(config_out, "w") as f:
186 | OmegaConf.save(raw_config, f)
187 | wdb_logger.experiment.save(str(config_out))
188 |
189 | save_config_to_wandb()
190 |
191 | # Set up trainer
192 | strategy = "auto"
193 | if (isinstance(devices, int) and devices > 1) or (
194 | isinstance(devices, (list, listconfig.ListConfig)) and len(devices) > 1
195 | ):
196 | strategy = DDPStrategy(find_unused_parameters=cfg.find_unused_parameters)
197 |
198 | trainer = pl.Trainer(
199 | default_root_dir=str(dirpath),
200 | strategy=strategy,
201 | callbacks=callbacks,
202 | logger=loggers,
203 | enable_checkpointing=not cfg.disable_checkpoint,
204 | reload_dataloaders_every_n_epochs=1,
205 | **trainer,
206 | )
207 |
208 | if not cfg.strict_loading:
209 | model_module.strict_loading = False
210 |
211 | if cfg.validation_only:
212 | trainer.validate(
213 | model_module,
214 | datamodule=data_module,
215 | ckpt_path=cfg.resume,
216 | )
217 | else:
218 | trainer.fit(
219 | model_module,
220 | datamodule=data_module,
221 | ckpt_path=cfg.resume,
222 | )
223 |
224 |
225 | if __name__ == "__main__":
226 | arg1 = sys.argv[1]
227 | arg2 = sys.argv[2:]
228 | train(arg1, arg2)
229 |
--------------------------------------------------------------------------------
/src/boltz/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import PackageNotFoundError, version
2 |
3 | try: # noqa: SIM105
4 | __version__ = version("boltz")
5 | except PackageNotFoundError:
6 | # package is not installed
7 | pass
8 |
--------------------------------------------------------------------------------
/src/boltz/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/__init__.py
--------------------------------------------------------------------------------
/src/boltz/data/crop/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/crop/__init__.py
--------------------------------------------------------------------------------
/src/boltz/data/crop/cropper.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Optional
3 |
4 | import numpy as np
5 |
6 | from boltz.data.types import Tokenized
7 |
8 |
9 | class Cropper(ABC):
10 | """Abstract base class for cropper."""
11 |
12 | @abstractmethod
13 | def crop(
14 | self,
15 | data: Tokenized,
16 | max_tokens: int,
17 | random: np.random.RandomState,
18 | max_atoms: Optional[int] = None,
19 | chain_id: Optional[int] = None,
20 | interface_id: Optional[int] = None,
21 | ) -> Tokenized:
22 | """Crop the data to a maximum number of tokens.
23 |
24 | Parameters
25 | ----------
26 | data : Tokenized
27 | The tokenized data.
28 | max_tokens : int
29 | The maximum number of tokens to crop.
30 | random : np.random.RandomState
31 | The random state for reproducibility.
32 | max_atoms : Optional[int]
33 | The maximum number of atoms to consider.
34 | chain_id : Optional[int]
35 | The chain ID to crop.
36 | interface_id : Optional[int]
37 | The interface ID to crop.
38 |
39 | Returns
40 | -------
41 | Tokenized
42 | The cropped data.
43 |
44 | """
45 | raise NotImplementedError
46 |
--------------------------------------------------------------------------------
/src/boltz/data/feature/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/feature/__init__.py
--------------------------------------------------------------------------------
/src/boltz/data/feature/pad.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn.functional import pad
4 |
5 |
6 | def pad_dim(data: Tensor, dim: int, pad_len: float, value: float = 0) -> Tensor:
7 | """Pad a tensor along a given dimension.
8 |
9 | Parameters
10 | ----------
11 | data : Tensor
12 | The input tensor.
13 | dim : int
14 | The dimension to pad.
15 | pad_len : float
16 | The padding length.
17 | value : int, optional
18 | The value to pad with.
19 |
20 | Returns
21 | -------
22 | Tensor
23 | The padded tensor.
24 |
25 | """
26 | if pad_len == 0:
27 | return data
28 |
29 | total_dims = len(data.shape)
30 | padding = [0] * (2 * (total_dims - dim))
31 | padding[2 * (total_dims - 1 - dim) + 1] = pad_len
32 | return pad(data, tuple(padding), value=value)
33 |
34 |
35 | def pad_to_max(data: list[Tensor], value: float = 0) -> tuple[Tensor, Tensor]:
36 | """Pad the data in all dimensions to the maximum found.
37 |
38 | Parameters
39 | ----------
40 | data : List[Tensor]
41 | List of tensors to pad.
42 | value : float
43 | The value to use for padding.
44 |
45 | Returns
46 | -------
47 | Tensor
48 | The padded tensor.
49 | Tensor
50 | The padding mask.
51 |
52 | """
53 | if isinstance(data[0], str):
54 | return data, 0
55 |
56 | # Check if all have the same shape
57 | if all(d.shape == data[0].shape for d in data):
58 | return torch.stack(data, dim=0), 0
59 |
60 | # Get the maximum in each dimension
61 | num_dims = len(data[0].shape)
62 | max_dims = [max(d.shape[i] for d in data) for i in range(num_dims)]
63 |
64 | # Get the padding lengths
65 | pad_lengths = []
66 | for d in data:
67 | dims = []
68 | for i in range(num_dims):
69 | dims.append(0)
70 | dims.append(max_dims[num_dims - i - 1] - d.shape[num_dims - i - 1])
71 | pad_lengths.append(dims)
72 |
73 | # Pad the data
74 | padding = [
75 | pad(torch.ones_like(d), pad_len, value=0)
76 | for d, pad_len in zip(data, pad_lengths)
77 | ]
78 | data = [pad(d, pad_len, value=value) for d, pad_len in zip(data, pad_lengths)]
79 |
80 | # Stack the data
81 | padding = torch.stack(padding, dim=0)
82 | data = torch.stack(data, dim=0)
83 |
84 | return data, padding
85 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/filter/__init__.py
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/filter/dynamic/__init__.py
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/date.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import Literal
3 |
4 | from boltz.data.types import Record
5 | from boltz.data.filter.dynamic.filter import DynamicFilter
6 |
7 |
8 | class DateFilter(DynamicFilter):
9 | """A filter that filters complexes based on their date.
10 |
11 | The date can be the deposition, release, or revision date.
12 | If the date is not available, the previous date is used.
13 |
14 | If no date is available, the complex is rejected.
15 |
16 | """
17 |
18 | def __init__(
19 | self,
20 | date: str,
21 | ref: Literal["deposited", "revised", "released"],
22 | ) -> None:
23 | """Initialize the filter.
24 |
25 | Parameters
26 | ----------
27 | date : str, optional
28 | The maximum date of PDB entries to filter
29 | ref : Literal["deposited", "revised", "released"]
30 | The reference date to use.
31 |
32 | """
33 | self.filter_date = datetime.fromisoformat(date)
34 | self.ref = ref
35 |
36 | if ref not in ["deposited", "revised", "released"]:
37 | msg = (
38 | "Invalid reference date. Must be ",
39 | "deposited, revised, or released",
40 | )
41 | raise ValueError(msg)
42 |
43 | def filter(self, record: Record) -> bool:
44 | """Filter a record based on its date.
45 |
46 | Parameters
47 | ----------
48 | record : Record
49 | The record to filter.
50 |
51 | Returns
52 | -------
53 | bool
54 | Whether the record should be filtered.
55 |
56 | """
57 | structure = record.structure
58 |
59 | if self.ref == "deposited":
60 | date = structure.deposited
61 | elif self.ref == "released":
62 | date = structure.released
63 | if not date:
64 | date = structure.deposited
65 | elif self.ref == "revised":
66 | date = structure.revised
67 | if not date and structure.released:
68 | date = structure.released
69 | elif not date:
70 | date = structure.deposited
71 |
72 | if date is None or date == "":
73 | return False
74 |
75 | date = datetime.fromisoformat(date)
76 | return date <= self.filter_date
77 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/filter.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from boltz.data.types import Record
4 |
5 |
6 | class DynamicFilter(ABC):
7 | """Base class for data filters."""
8 |
9 | @abstractmethod
10 | def filter(self, record: Record) -> bool:
11 | """Filter a data record.
12 |
13 | Parameters
14 | ----------
15 | record : Record
16 | The object to consider filtering in / out.
17 |
18 | Returns
19 | -------
20 | bool
21 | True if the data passes the filter, False otherwise.
22 |
23 | """
24 | raise NotImplementedError
25 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/max_residues.py:
--------------------------------------------------------------------------------
1 | from boltz.data.types import Record
2 | from boltz.data.filter.dynamic.filter import DynamicFilter
3 |
4 |
5 | class MaxResiduesFilter(DynamicFilter):
6 | """A filter that filters structures based on their size."""
7 |
8 | def __init__(self, min_residues: int = 1, max_residues: int = 500) -> None:
9 | """Initialize the filter.
10 |
11 | Parameters
12 | ----------
13 | min_chains : int
14 | The minimum number of chains allowed.
15 | max_chains : int
16 | The maximum number of chains allowed.
17 |
18 | """
19 | self.min_residues = min_residues
20 | self.max_residues = max_residues
21 |
22 | def filter(self, record: Record) -> bool:
23 | """Filter structures based on their resolution.
24 |
25 | Parameters
26 | ----------
27 | record : Record
28 | The record to filter.
29 |
30 | Returns
31 | -------
32 | bool
33 | Whether the record should be filtered.
34 |
35 | """
36 | num_residues = sum(chain.num_residues for chain in record.chains)
37 | return num_residues <= self.max_residues and num_residues >= self.min_residues
38 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/resolution.py:
--------------------------------------------------------------------------------
1 | from boltz.data.types import Record
2 | from boltz.data.filter.dynamic.filter import DynamicFilter
3 |
4 |
5 | class ResolutionFilter(DynamicFilter):
6 | """A filter that filters complexes based on their resolution."""
7 |
8 | def __init__(self, resolution: float = 9.0) -> None:
9 | """Initialize the filter.
10 |
11 | Parameters
12 | ----------
13 | resolution : float, optional
14 | The maximum allowed resolution.
15 |
16 | """
17 | self.resolution = resolution
18 |
19 | def filter(self, record: Record) -> bool:
20 | """Filter complexes based on their resolution.
21 |
22 | Parameters
23 | ----------
24 | record : Record
25 | The record to filter.
26 |
27 | Returns
28 | -------
29 | bool
30 | Whether the record should be filtered.
31 |
32 | """
33 | structure = record.structure
34 | return structure.resolution <= self.resolution
35 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/size.py:
--------------------------------------------------------------------------------
1 | from boltz.data.types import Record
2 | from boltz.data.filter.dynamic.filter import DynamicFilter
3 |
4 |
5 | class SizeFilter(DynamicFilter):
6 | """A filter that filters structures based on their size."""
7 |
8 | def __init__(self, min_chains: int = 1, max_chains: int = 300) -> None:
9 | """Initialize the filter.
10 |
11 | Parameters
12 | ----------
13 | min_chains : int
14 | The minimum number of chains allowed.
15 | max_chains : int
16 | The maximum number of chains allowed.
17 |
18 | """
19 | self.min_chains = min_chains
20 | self.max_chains = max_chains
21 |
22 | def filter(self, record: Record) -> bool:
23 | """Filter structures based on their resolution.
24 |
25 | Parameters
26 | ----------
27 | record : Record
28 | The record to filter.
29 |
30 | Returns
31 | -------
32 | bool
33 | Whether the record should be filtered.
34 |
35 | """
36 | num_chains = record.structure.num_chains
37 | num_valid = sum(1 for chain in record.chains if chain.valid)
38 | return num_chains <= self.max_chains and num_valid >= self.min_chains
39 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/subset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from boltz.data.types import Record
4 | from boltz.data.filter.dynamic.filter import DynamicFilter
5 |
6 |
7 | class SubsetFilter(DynamicFilter):
8 | """Filter a data record based on a subset of the data."""
9 |
10 | def __init__(self, subset: str, reverse: bool = False) -> None:
11 | """Initialize the filter.
12 |
13 | Parameters
14 | ----------
15 | subset : str
16 | The subset of data to consider, one per line.
17 |
18 | """
19 | with Path(subset).open("r") as f:
20 | subset = f.read().splitlines()
21 |
22 | self.subset = {s.lower() for s in subset}
23 | self.reverse = reverse
24 |
25 | def filter(self, record: Record) -> bool:
26 | """Filter a data record.
27 |
28 | Parameters
29 | ----------
30 | record : Record
31 | The object to consider filtering in / out.
32 |
33 | Returns
34 | -------
35 | bool
36 | True if the data passes the filter, False otherwise.
37 |
38 | """
39 | if self.reverse:
40 | return record.id.lower() not in self.subset
41 | else: # noqa: RET505
42 | return record.id.lower() in self.subset
43 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/static/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/filter/static/__init__.py
--------------------------------------------------------------------------------
/src/boltz/data/filter/static/filter.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 |
5 | from boltz.data.types import Structure
6 |
7 |
8 | class StaticFilter(ABC):
9 | """Base class for structure filters."""
10 |
11 | @abstractmethod
12 | def filter(self, structure: Structure) -> np.ndarray:
13 | """Filter chains in a structure.
14 |
15 | Parameters
16 | ----------
17 | structure : Structure
18 | The structure to filter chains from.
19 |
20 | Returns
21 | -------
22 | np.ndarray
23 | The chains to keep, as a boolean mask.
24 |
25 | """
26 | raise NotImplementedError
27 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/static/ligand.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from boltz.data import const
4 | from boltz.data.types import Structure
5 | from boltz.data.filter.static.filter import StaticFilter
6 |
7 | LIGAND_EXCLUSION = {
8 | "144",
9 | "15P",
10 | "1PE",
11 | "2F2",
12 | "2JC",
13 | "3HR",
14 | "3SY",
15 | "7N5",
16 | "7PE",
17 | "9JE",
18 | "AAE",
19 | "ABA",
20 | "ACE",
21 | "ACN",
22 | "ACT",
23 | "ACY",
24 | "AZI",
25 | "BAM",
26 | "BCN",
27 | "BCT",
28 | "BDN",
29 | "BEN",
30 | "BME",
31 | "BO3",
32 | "BTB",
33 | "BTC",
34 | "BU1",
35 | "C8E",
36 | "CAD",
37 | "CAQ",
38 | "CBM",
39 | "CCN",
40 | "CIT",
41 | "CL",
42 | "CLR",
43 | "CM",
44 | "CMO",
45 | "CO3",
46 | "CPT",
47 | "CXS",
48 | "D10",
49 | "DEP",
50 | "DIO",
51 | "DMS",
52 | "DN",
53 | "DOD",
54 | "DOX",
55 | "EDO",
56 | "EEE",
57 | "EGL",
58 | "EOH",
59 | "EOX",
60 | "EPE",
61 | "ETF",
62 | "FCY",
63 | "FJO",
64 | "FLC",
65 | "FMT",
66 | "FW5",
67 | "GOL",
68 | "GSH",
69 | "GTT",
70 | "GYF",
71 | "HED",
72 | "IHP",
73 | "IHS",
74 | "IMD",
75 | "IOD",
76 | "IPA",
77 | "IPH",
78 | "LDA",
79 | "MB3",
80 | "MEG",
81 | "MES",
82 | "MLA",
83 | "MLI",
84 | "MOH",
85 | "MPD",
86 | "MRD",
87 | "MSE",
88 | "MYR",
89 | "N",
90 | "NA",
91 | "NH2",
92 | "NH4",
93 | "NHE",
94 | "NO3",
95 | "O4B",
96 | "OHE",
97 | "OLA",
98 | "OLC",
99 | "OMB",
100 | "OME",
101 | "OXA",
102 | "P6G",
103 | "PE3",
104 | "PE4",
105 | "PEG",
106 | "PEO",
107 | "PEP",
108 | "PG0",
109 | "PG4",
110 | "PGE",
111 | "PGR",
112 | "PLM",
113 | "PO4",
114 | "POL",
115 | "POP",
116 | "PVO",
117 | "SAR",
118 | "SCN",
119 | "SEO",
120 | "SEP",
121 | "SIN",
122 | "SO4",
123 | "SPD",
124 | "SPM",
125 | "SR",
126 | "STE",
127 | "STO",
128 | "STU",
129 | "TAR",
130 | "TBU",
131 | "TME",
132 | "TPO",
133 | "TRS",
134 | "UNK",
135 | "UNL",
136 | "UNX",
137 | "UPL",
138 | "URE",
139 | }
140 |
141 |
142 | class ExcludedLigands(StaticFilter):
143 | """Filter excluded ligands."""
144 |
145 | def filter(self, structure: Structure) -> np.ndarray:
146 | """Filter excluded ligands.
147 |
148 | Parameters
149 | ----------
150 | structure : Structure
151 | The structure to filter chains from.
152 |
153 | Returns
154 | -------
155 | np.ndarray
156 | The chains to keep, as a boolean mask.
157 |
158 | """
159 | valid = np.ones(len(structure.chains), dtype=bool)
160 |
161 | for i, chain in enumerate(structure.chains):
162 | if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]:
163 | continue
164 |
165 | res_start = chain["res_idx"]
166 | res_end = res_start + chain["res_num"]
167 | residues = structure.residues[res_start:res_end]
168 | if any(res["name"] in LIGAND_EXCLUSION for res in residues):
169 | valid[i] = 0
170 |
171 | return valid
172 |
--------------------------------------------------------------------------------
/src/boltz/data/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/module/__init__.py
--------------------------------------------------------------------------------
/src/boltz/data/module/inference.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | import pytorch_lightning as pl
5 | import torch
6 | from torch import Tensor
7 | from torch.utils.data import DataLoader
8 |
9 | from boltz.data import const
10 | from boltz.data.feature.featurizer import BoltzFeaturizer
11 | from boltz.data.feature.pad import pad_to_max
12 | from boltz.data.tokenize.boltz import BoltzTokenizer
13 | from boltz.data.types import MSA, Input, Manifest, Record, Structure
14 |
15 |
16 | def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input:
17 | """Load the given input data.
18 |
19 | Parameters
20 | ----------
21 | record : Record
22 | The record to load.
23 | target_dir : Path
24 | The path to the data directory.
25 | msa_dir : Path
26 | The path to msa directory.
27 |
28 | Returns
29 | -------
30 | Input
31 | The loaded input.
32 |
33 | """
34 | # Load the structure
35 | structure = np.load(target_dir / f"{record.id}.npz")
36 | structure = Structure(
37 | atoms=structure["atoms"],
38 | bonds=structure["bonds"],
39 | residues=structure["residues"],
40 | chains=structure["chains"],
41 | connections=structure["connections"],
42 | interfaces=structure["interfaces"],
43 | mask=structure["mask"],
44 | )
45 |
46 | msas = {}
47 | for chain in record.chains:
48 | msa_id = chain.msa_id
49 | # Load the MSA for this chain, if any
50 | if msa_id != -1:
51 | msa = np.load(msa_dir / f"{msa_id}.npz")
52 | msas[chain.chain_id] = MSA(**msa)
53 |
54 | return Input(structure, msas)
55 |
56 |
57 | def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
58 | """Collate the data.
59 |
60 | Parameters
61 | ----------
62 | data : List[Dict[str, Tensor]]
63 | The data to collate.
64 |
65 | Returns
66 | -------
67 | Dict[str, Tensor]
68 | The collated data.
69 |
70 | """
71 | # Get the keys
72 | keys = data[0].keys()
73 |
74 | # Collate the data
75 | collated = {}
76 | for key in keys:
77 | values = [d[key] for d in data]
78 |
79 | if key not in [
80 | "all_coords",
81 | "all_resolved_mask",
82 | "crop_to_all_atom_map",
83 | "chain_symmetries",
84 | "amino_acids_symmetries",
85 | "ligand_symmetries",
86 | "record",
87 | ]:
88 | # Check if all have the same shape
89 | shape = values[0].shape
90 | if not all(v.shape == shape for v in values):
91 | values, _ = pad_to_max(values, 0)
92 | else:
93 | values = torch.stack(values, dim=0)
94 |
95 | # Stack the values
96 | collated[key] = values
97 |
98 | return collated
99 |
100 |
101 | class PredictionDataset(torch.utils.data.Dataset):
102 | """Base iterable dataset."""
103 |
104 | def __init__(
105 | self,
106 | manifest: Manifest,
107 | target_dir: Path,
108 | msa_dir: Path,
109 | ) -> None:
110 | """Initialize the training dataset.
111 |
112 | Parameters
113 | ----------
114 | manifest : Manifest
115 | The manifest to load data from.
116 | target_dir : Path
117 | The path to the target directory.
118 | msa_dir : Path
119 | The path to the msa directory.
120 |
121 | """
122 | super().__init__()
123 | self.manifest = manifest
124 | self.target_dir = target_dir
125 | self.msa_dir = msa_dir
126 | self.tokenizer = BoltzTokenizer()
127 | self.featurizer = BoltzFeaturizer()
128 |
129 | def __getitem__(self, idx: int) -> dict:
130 | """Get an item from the dataset.
131 |
132 | Returns
133 | -------
134 | Dict[str, Tensor]
135 | The sampled data features.
136 |
137 | """
138 | # Get a sample from the dataset
139 | record = self.manifest.records[idx]
140 |
141 | # Get the structure
142 | try:
143 | input_data = load_input(record, self.target_dir, self.msa_dir)
144 | except Exception as e: # noqa: BLE001
145 | print(f"Failed to load input for {record.id} with error {e}. Skipping.") # noqa: T201
146 | return self.__getitem__(0)
147 |
148 | # Tokenize structure
149 | try:
150 | tokenized = self.tokenizer.tokenize(input_data)
151 | except Exception as e: # noqa: BLE001
152 | print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
153 | return self.__getitem__(0)
154 |
155 | # Inference specific options
156 | options = record.inference_options
157 | if options is None:
158 | binders, pocket = None, None
159 | else:
160 | binders, pocket = options.binders, options.pocket
161 |
162 | # Compute features
163 | try:
164 | features = self.featurizer.process(
165 | tokenized,
166 | training=False,
167 | max_atoms=None,
168 | max_tokens=None,
169 | max_seqs=const.max_msa_seqs,
170 | pad_to_max_seqs=False,
171 | symmetries={},
172 | compute_symmetries=False,
173 | inference_binder=binders,
174 | inference_pocket=pocket,
175 | )
176 | except Exception as e: # noqa: BLE001
177 | print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
178 | return self.__getitem__(0)
179 |
180 | features["record"] = record
181 | return features
182 |
183 | def __len__(self) -> int:
184 | """Get the length of the dataset.
185 |
186 | Returns
187 | -------
188 | int
189 | The length of the dataset.
190 |
191 | """
192 | return len(self.manifest.records)
193 |
194 |
195 | class BoltzInferenceDataModule(pl.LightningDataModule):
196 | """DataModule for Boltz inference."""
197 |
198 | def __init__(
199 | self,
200 | manifest: Manifest,
201 | target_dir: Path,
202 | msa_dir: Path,
203 | num_workers: int,
204 | ) -> None:
205 | """Initialize the DataModule.
206 |
207 | Parameters
208 | ----------
209 | config : DataConfig
210 | The data configuration.
211 |
212 | """
213 | super().__init__()
214 | self.num_workers = num_workers
215 | self.manifest = manifest
216 | self.target_dir = target_dir
217 | self.msa_dir = msa_dir
218 |
219 | def predict_dataloader(self) -> DataLoader:
220 | """Get the training dataloader.
221 |
222 | Returns
223 | -------
224 | DataLoader
225 | The training dataloader.
226 |
227 | """
228 | dataset = PredictionDataset(
229 | manifest=self.manifest,
230 | target_dir=self.target_dir,
231 | msa_dir=self.msa_dir,
232 | )
233 | return DataLoader(
234 | dataset,
235 | batch_size=1,
236 | num_workers=self.num_workers,
237 | pin_memory=True,
238 | shuffle=False,
239 | collate_fn=collate,
240 | )
241 |
242 | def transfer_batch_to_device(
243 | self,
244 | batch: dict,
245 | device: torch.device,
246 | dataloader_idx: int, # noqa: ARG002
247 | ) -> dict:
248 | """Transfer a batch to the given device.
249 |
250 | Parameters
251 | ----------
252 | batch : Dict
253 | The batch to transfer.
254 | device : torch.device
255 | The device to transfer to.
256 | dataloader_idx : int
257 | The dataloader index.
258 |
259 | Returns
260 | -------
261 | np.Any
262 | The transferred batch.
263 |
264 | """
265 | for key in batch:
266 | if key not in [
267 | "all_coords",
268 | "all_resolved_mask",
269 | "crop_to_all_atom_map",
270 | "chain_symmetries",
271 | "amino_acids_symmetries",
272 | "ligand_symmetries",
273 | "record",
274 | ]:
275 | batch[key] = batch[key].to(device)
276 | return batch
277 |
--------------------------------------------------------------------------------
/src/boltz/data/msa/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/msa/__init__.py
--------------------------------------------------------------------------------
/src/boltz/data/parse/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/parse/__init__.py
--------------------------------------------------------------------------------
/src/boltz/data/parse/a3m.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | from pathlib import Path
3 | from typing import Optional, TextIO
4 |
5 | import numpy as np
6 |
7 | from boltz.data import const
8 | from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence
9 |
10 |
11 | def _parse_a3m( # noqa: C901
12 | lines: TextIO,
13 | taxonomy: Optional[dict[str, str]],
14 | max_seqs: Optional[int] = None,
15 | ) -> MSA:
16 | """Process an MSA file.
17 |
18 | Parameters
19 | ----------
20 | lines : TextIO
21 | The lines of the MSA file.
22 | taxonomy : dict[str, str]
23 | The taxonomy database, if available.
24 | max_seqs : int, optional
25 | The maximum number of sequences.
26 |
27 | Returns
28 | -------
29 | MSA
30 | The MSA object.
31 |
32 | """
33 | visited = set()
34 | sequences = []
35 | deletions = []
36 | residues = []
37 |
38 | seq_idx = 0
39 | for line in lines:
40 | line: str
41 | line = line.strip() # noqa: PLW2901
42 | if not line or line.startswith("#"):
43 | continue
44 |
45 | # Get taxonomy, if annotated
46 | if line.startswith(">"):
47 | header = line.split()[0]
48 | if taxonomy and header.startswith(">UniRef100"):
49 | uniref_id = header.split("_")[1]
50 | taxonomy_id = taxonomy.get(uniref_id)
51 | if taxonomy_id is None:
52 | taxonomy_id = -1
53 | else:
54 | taxonomy_id = -1
55 | continue
56 |
57 | # Skip if duplicate sequence
58 | str_seq = line.replace("-", "").upper()
59 | if str_seq not in visited:
60 | visited.add(str_seq)
61 | else:
62 | continue
63 |
64 | # Process sequence
65 | residue = []
66 | deletion = []
67 | count = 0
68 | res_idx = 0
69 | for c in line:
70 | if c != "-" and c.islower():
71 | count += 1
72 | continue
73 | token = const.prot_letter_to_token[c]
74 | token = const.token_ids[token]
75 | residue.append(token)
76 | if count > 0:
77 | deletion.append((res_idx, count))
78 | count = 0
79 | res_idx += 1
80 |
81 | res_start = len(residues)
82 | res_end = res_start + len(residue)
83 |
84 | del_start = len(deletions)
85 | del_end = del_start + len(deletion)
86 |
87 | sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end))
88 | residues.extend(residue)
89 | deletions.extend(deletion)
90 |
91 | seq_idx += 1
92 | if (max_seqs is not None) and (seq_idx >= max_seqs):
93 | break
94 |
95 | # Create MSA object
96 | msa = MSA(
97 | residues=np.array(residues, dtype=MSAResidue),
98 | deletions=np.array(deletions, dtype=MSADeletion),
99 | sequences=np.array(sequences, dtype=MSASequence),
100 | )
101 | return msa
102 |
103 |
104 | def parse_a3m(
105 | path: Path,
106 | taxonomy: Optional[dict[str, str]],
107 | max_seqs: Optional[int] = None,
108 | ) -> MSA:
109 | """Process an A3M file.
110 |
111 | Parameters
112 | ----------
113 | path : Path
114 | The path to the a3m(.gz) file.
115 | taxonomy : Redis
116 | The taxonomy database.
117 | max_seqs : int, optional
118 | The maximum number of sequences.
119 |
120 | Returns
121 | -------
122 | MSA
123 | The MSA object.
124 |
125 | """
126 | # Read the file
127 | if path.suffix == ".gz":
128 | with gzip.open(str(path), "rt") as f:
129 | msa = _parse_a3m(f, taxonomy, max_seqs)
130 | else:
131 | with path.open("r") as f:
132 | msa = _parse_a3m(f, taxonomy, max_seqs)
133 |
134 | return msa
135 |
--------------------------------------------------------------------------------
/src/boltz/data/parse/csv.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional
3 |
4 | import numpy as np
5 | import pandas as pd
6 |
7 | from boltz.data import const
8 | from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence
9 |
10 |
11 | def parse_csv(
12 | path: Path,
13 | max_seqs: Optional[int] = None,
14 | ) -> MSA:
15 | """Process an A3M file.
16 |
17 | Parameters
18 | ----------
19 | path : Path
20 | The path to the a3m(.gz) file.
21 | max_seqs : int, optional
22 | The maximum number of sequences.
23 |
24 | Returns
25 | -------
26 | MSA
27 | The MSA object.
28 |
29 | """
30 | # Read file
31 | data = pd.read_csv(path)
32 |
33 | # Check columns
34 | if tuple(sorted(data.columns)) != ("key", "sequence"):
35 | msg = "Invalid CSV format, expected columns: ['sequence', 'key']"
36 | raise ValueError(msg)
37 |
38 | # Create taxonomy mapping
39 | visited = set()
40 | sequences = []
41 | deletions = []
42 | residues = []
43 |
44 | seq_idx = 0
45 | for line, key in zip(data["sequence"], data["key"]):
46 | line: str
47 | line = line.strip() # noqa: PLW2901
48 | if not line:
49 | continue
50 |
51 | # Get taxonomy, if annotated
52 | taxonomy_id = -1
53 | if (str(key) != "nan") and (key is not None) and (key != ""):
54 | taxonomy_id = key
55 |
56 | # Skip if duplicate sequence
57 | str_seq = line.replace("-", "").upper()
58 | if str_seq not in visited:
59 | visited.add(str_seq)
60 | else:
61 | continue
62 |
63 | # Process sequence
64 | residue = []
65 | deletion = []
66 | count = 0
67 | res_idx = 0
68 | for c in line:
69 | if c != "-" and c.islower():
70 | count += 1
71 | continue
72 | token = const.prot_letter_to_token[c]
73 | token = const.token_ids[token]
74 | residue.append(token)
75 | if count > 0:
76 | deletion.append((res_idx, count))
77 | count = 0
78 | res_idx += 1
79 |
80 | res_start = len(residues)
81 | res_end = res_start + len(residue)
82 |
83 | del_start = len(deletions)
84 | del_end = del_start + len(deletion)
85 |
86 | sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end))
87 | residues.extend(residue)
88 | deletions.extend(deletion)
89 |
90 | seq_idx += 1
91 | if (max_seqs is not None) and (seq_idx >= max_seqs):
92 | break
93 |
94 | # Create MSA object
95 | msa = MSA(
96 | residues=np.array(residues, dtype=MSAResidue),
97 | deletions=np.array(deletions, dtype=MSADeletion),
98 | sequences=np.array(sequences, dtype=MSASequence),
99 | )
100 | return msa
101 |
--------------------------------------------------------------------------------
/src/boltz/data/parse/fasta.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Mapping
2 | from pathlib import Path
3 |
4 | from Bio import SeqIO
5 | from rdkit.Chem.rdchem import Mol
6 |
7 | from boltz.data.parse.yaml import parse_boltz_schema
8 | from boltz.data.types import Target
9 |
10 |
11 | def parse_fasta(path: Path, ccd: Mapping[str, Mol]) -> Target: # noqa: C901
12 | """Parse a fasta file.
13 |
14 | The name of the fasta file is used as the name of this job.
15 | We rely on the fasta record id to determine the entity type.
16 |
17 | > CHAIN_ID|ENTITY_TYPE|MSA_ID
18 | SEQUENCE
19 | > CHAIN_ID|ENTITY_TYPE|MSA_ID
20 | ...
21 |
22 | Where ENTITY_TYPE is either protein, rna, dna, ccd or smiles,
23 | and CHAIN_ID is the chain identifier, which should be unique.
24 | The MSA_ID is optional and should only be used on proteins.
25 |
26 | Parameters
27 | ----------
28 | fasta_file : Path
29 | Path to the fasta file.
30 | ccd : Dict
31 | Dictionary of CCD components.
32 |
33 | Returns
34 | -------
35 | Target
36 | The parsed target.
37 |
38 | """
39 | # Read fasta file
40 | with path.open("r") as f:
41 | records = list(SeqIO.parse(f, "fasta"))
42 |
43 | # Make sure all records have a chain id and entity
44 | for seq_record in records:
45 | if "|" not in seq_record.id:
46 | msg = f"Invalid record id: {seq_record.id}"
47 | raise ValueError(msg)
48 |
49 | header = seq_record.id.split("|")
50 | assert len(header) >= 2, f"Invalid record id: {seq_record.id}"
51 |
52 | chain_id, entity_type = header[:2]
53 | if entity_type.lower() not in {"protein", "dna", "rna", "ccd", "smiles"}:
54 | msg = f"Invalid entity type: {entity_type}"
55 | raise ValueError(msg)
56 | if chain_id == "":
57 | msg = "Empty chain id in input fasta!"
58 | raise ValueError(msg)
59 | if entity_type == "":
60 | msg = "Empty entity type in input fasta!"
61 | raise ValueError(msg)
62 |
63 | # Convert to yaml format
64 | sequences = []
65 | for seq_record in records:
66 | # Get chain id, entity type and sequence
67 | header = seq_record.id.split("|")
68 | chain_id, entity_type = header[:2]
69 | if len(header) == 3 and header[2] != "":
70 | assert (
71 | entity_type.lower() == "protein"
72 | ), "MSA_ID is only allowed for proteins"
73 | msa_id = header[2]
74 | else:
75 | msa_id = None
76 |
77 | entity_type = entity_type.upper()
78 | seq = str(seq_record.seq)
79 |
80 | if entity_type == "PROTEIN":
81 | molecule = {
82 | "protein": {
83 | "id": chain_id,
84 | "sequence": seq,
85 | "modifications": [],
86 | "msa": msa_id,
87 | },
88 | }
89 | elif entity_type == "RNA":
90 | molecule = {
91 | "rna": {
92 | "id": chain_id,
93 | "sequence": seq,
94 | "modifications": [],
95 | },
96 | }
97 | elif entity_type == "DNA":
98 | molecule = {
99 | "dna": {
100 | "id": chain_id,
101 | "sequence": seq,
102 | "modifications": [],
103 | }
104 | }
105 | elif entity_type.upper() == "CCD":
106 | molecule = {
107 | "ligand": {
108 | "id": chain_id,
109 | "ccd": seq,
110 | }
111 | }
112 | elif entity_type.upper() == "SMILES":
113 | molecule = {
114 | "ligand": {
115 | "id": chain_id,
116 | "smiles": seq,
117 | }
118 | }
119 |
120 | sequences.append(molecule)
121 |
122 | data = {
123 | "sequences": sequences,
124 | "bonds": [],
125 | "version": 1,
126 | }
127 |
128 | name = path.stem
129 | return parse_boltz_schema(name, data, ccd)
130 |
--------------------------------------------------------------------------------
/src/boltz/data/parse/yaml.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import yaml
4 | from rdkit.Chem.rdchem import Mol
5 |
6 | from boltz.data.parse.schema import parse_boltz_schema
7 | from boltz.data.types import Target
8 |
9 |
10 | def parse_yaml(path: Path, ccd: dict[str, Mol]) -> Target:
11 | """Parse a Boltz input yaml / json.
12 |
13 | The input file should be a yaml file with the following format:
14 |
15 | sequences:
16 | - protein:
17 | id: A
18 | sequence: "MADQLTEEQIAEFKEAFSLF"
19 | - protein:
20 | id: [B, C]
21 | sequence: "AKLSILPWGHC"
22 | - rna:
23 | id: D
24 | sequence: "GCAUAGC"
25 | - ligand:
26 | id: E
27 | smiles: "CC1=CC=CC=C1"
28 | - ligand:
29 | id: [F, G]
30 | ccd: []
31 | constraints:
32 | - bond:
33 | atom1: [A, 1, CA]
34 | atom2: [A, 2, N]
35 | - pocket:
36 | binder: E
37 | contacts: [[B, 1], [B, 2]]
38 | version: 1
39 |
40 | Parameters
41 | ----------
42 | path : Path
43 | Path to the YAML input format.
44 | components : Dict
45 | Dictionary of CCD components.
46 |
47 | Returns
48 | -------
49 | Target
50 | The parsed target.
51 |
52 | """
53 | with path.open("r") as file:
54 | data = yaml.safe_load(file)
55 |
56 | name = path.stem
57 | return parse_boltz_schema(name, data, ccd)
58 |
--------------------------------------------------------------------------------
/src/boltz/data/sample/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/sample/__init__.py
--------------------------------------------------------------------------------
/src/boltz/data/sample/cluster.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Iterator, List
2 |
3 | import numpy as np
4 | from numpy.random import RandomState
5 |
6 | from boltz.data import const
7 | from boltz.data.types import ChainInfo, InterfaceInfo, Record
8 | from boltz.data.sample.sampler import Sample, Sampler
9 |
10 |
11 | def get_chain_cluster(chain: ChainInfo, record: Record) -> str: # noqa: ARG001
12 | """Get the cluster id for a chain.
13 |
14 | Parameters
15 | ----------
16 | chain : ChainInfo
17 | The chain id to get the cluster id for.
18 | record : Record
19 | The record the interface is part of.
20 |
21 | Returns
22 | -------
23 | str
24 | The cluster id of the chain.
25 |
26 | """
27 | return chain.cluster_id
28 |
29 |
30 | def get_interface_cluster(interface: InterfaceInfo, record: Record) -> str:
31 | """Get the cluster id for an interface.
32 |
33 | Parameters
34 | ----------
35 | interface : InterfaceInfo
36 | The interface to get the cluster id for.
37 | record : Record
38 | The record the interface is part of.
39 |
40 | Returns
41 | -------
42 | str
43 | The cluster id of the interface.
44 |
45 | """
46 | chain1 = record.chains[interface.chain_1]
47 | chain2 = record.chains[interface.chain_2]
48 |
49 | cluster_1 = str(chain1.cluster_id)
50 | cluster_2 = str(chain2.cluster_id)
51 |
52 | cluster_id = (cluster_1, cluster_2)
53 | cluster_id = tuple(sorted(cluster_id))
54 |
55 | return cluster_id
56 |
57 |
58 | def get_chain_weight(
59 | chain: ChainInfo,
60 | record: Record, # noqa: ARG001
61 | clusters: Dict[str, int],
62 | beta_chain: float,
63 | alpha_prot: float,
64 | alpha_nucl: float,
65 | alpha_ligand: float,
66 | ) -> float:
67 | """Get the weight of a chain.
68 |
69 | Parameters
70 | ----------
71 | chain : ChainInfo
72 | The chain to get the weight for.
73 | record : Record
74 | The record the chain is part of.
75 | clusters : Dict[str, int]
76 | The cluster sizes.
77 | beta_chain : float
78 | The beta value for chains.
79 | alpha_prot : float
80 | The alpha value for proteins.
81 | alpha_nucl : float
82 | The alpha value for nucleic acids.
83 | alpha_ligand : float
84 | The alpha value for ligands.
85 |
86 | Returns
87 | -------
88 | float
89 | The weight of the chain.
90 |
91 | """
92 | prot_id = const.chain_type_ids["PROTEIN"]
93 | rna_id = const.chain_type_ids["RNA"]
94 | dna_id = const.chain_type_ids["DNA"]
95 | ligand_id = const.chain_type_ids["NONPOLYMER"]
96 |
97 | weight = beta_chain / clusters[chain.cluster_id]
98 | if chain.mol_type == prot_id:
99 | weight *= alpha_prot
100 | elif chain.mol_type in [rna_id, dna_id]:
101 | weight *= alpha_nucl
102 | elif chain.mol_type == ligand_id:
103 | weight *= alpha_ligand
104 |
105 | return weight
106 |
107 |
108 | def get_interface_weight(
109 | interface: InterfaceInfo,
110 | record: Record,
111 | clusters: Dict[str, int],
112 | beta_interface: float,
113 | alpha_prot: float,
114 | alpha_nucl: float,
115 | alpha_ligand: float,
116 | ) -> float:
117 | """Get the weight of an interface.
118 |
119 | Parameters
120 | ----------
121 | interface : InterfaceInfo
122 | The interface to get the weight for.
123 | record : Record
124 | The record the interface is part of.
125 | clusters : Dict[str, int]
126 | The cluster sizes.
127 | beta_interface : float
128 | The beta value for interfaces.
129 | alpha_prot : float
130 | The alpha value for proteins.
131 | alpha_nucl : float
132 | The alpha value for nucleic acids.
133 | alpha_ligand : float
134 | The alpha value for ligands.
135 |
136 | Returns
137 | -------
138 | float
139 | The weight of the interface.
140 |
141 | """
142 | prot_id = const.chain_type_ids["PROTEIN"]
143 | rna_id = const.chain_type_ids["RNA"]
144 | dna_id = const.chain_type_ids["DNA"]
145 | ligand_id = const.chain_type_ids["NONPOLYMER"]
146 |
147 | chain1 = record.chains[interface.chain_1]
148 | chain2 = record.chains[interface.chain_2]
149 |
150 | n_prot = (chain1.mol_type) == prot_id
151 | n_nuc = chain1.mol_type in [rna_id, dna_id]
152 | n_ligand = chain1.mol_type == ligand_id
153 |
154 | n_prot += chain2.mol_type == prot_id
155 | n_nuc += chain2.mol_type in [rna_id, dna_id]
156 | n_ligand += chain2.mol_type == ligand_id
157 |
158 | weight = beta_interface / clusters[get_interface_cluster(interface, record)]
159 | weight *= alpha_prot * n_prot + alpha_nucl * n_nuc + alpha_ligand * n_ligand
160 | return weight
161 |
162 |
163 | class ClusterSampler(Sampler):
164 | """The weighted sampling approach, as described in AF3.
165 |
166 | Each chain / interface is given a weight according
167 | to the following formula, and sampled accordingly:
168 |
169 | w = b / n_clust *(a_prot * n_prot + a_nuc * n_nuc
170 | + a_ligand * n_ligand)
171 |
172 | """
173 |
174 | def __init__(
175 | self,
176 | alpha_prot: float = 3.0,
177 | alpha_nucl: float = 3.0,
178 | alpha_ligand: float = 1.0,
179 | beta_chain: float = 0.5,
180 | beta_interface: float = 1.0,
181 | ) -> None:
182 | """Initialize the sampler.
183 |
184 | Parameters
185 | ----------
186 | alpha_prot : float, optional
187 | The alpha value for proteins.
188 | alpha_nucl : float, optional
189 | The alpha value for nucleic acids.
190 | alpha_ligand : float, optional
191 | The alpha value for ligands.
192 | beta_chain : float, optional
193 | The beta value for chains.
194 | beta_interface : float, optional
195 | The beta value for interfaces.
196 |
197 | """
198 | self.alpha_prot = alpha_prot
199 | self.alpha_nucl = alpha_nucl
200 | self.alpha_ligand = alpha_ligand
201 | self.beta_chain = beta_chain
202 | self.beta_interface = beta_interface
203 |
204 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: # noqa: C901, PLR0912
205 | """Sample a structure from the dataset infinitely.
206 |
207 | Parameters
208 | ----------
209 | records : List[Record]
210 | The records to sample from.
211 | random : RandomState
212 | The random state for reproducibility.
213 |
214 | Yields
215 | ------
216 | Sample
217 | A data sample.
218 |
219 | """
220 | # Compute chain cluster sizes
221 | chain_clusters: Dict[str, int] = {}
222 | for record in records:
223 | for chain in record.chains:
224 | if not chain.valid:
225 | continue
226 | cluster_id = get_chain_cluster(chain, record)
227 | if cluster_id not in chain_clusters:
228 | chain_clusters[cluster_id] = 0
229 | chain_clusters[cluster_id] += 1
230 |
231 | # Compute interface clusters sizes
232 | interface_clusters: Dict[str, int] = {}
233 | for record in records:
234 | for interface in record.interfaces:
235 | if not interface.valid:
236 | continue
237 | cluster_id = get_interface_cluster(interface, record)
238 | if cluster_id not in interface_clusters:
239 | interface_clusters[cluster_id] = 0
240 | interface_clusters[cluster_id] += 1
241 |
242 | # Compute weights
243 | items, weights = [], []
244 | for record in records:
245 | for chain_id, chain in enumerate(record.chains):
246 | if not chain.valid:
247 | continue
248 | weight = get_chain_weight(
249 | chain,
250 | record,
251 | chain_clusters,
252 | self.beta_chain,
253 | self.alpha_prot,
254 | self.alpha_nucl,
255 | self.alpha_ligand,
256 | )
257 | items.append((record, 0, chain_id))
258 | weights.append(weight)
259 |
260 | for int_id, interface in enumerate(record.interfaces):
261 | if not interface.valid:
262 | continue
263 | weight = get_interface_weight(
264 | interface,
265 | record,
266 | interface_clusters,
267 | self.beta_interface,
268 | self.alpha_prot,
269 | self.alpha_nucl,
270 | self.alpha_ligand,
271 | )
272 | items.append((record, 1, int_id))
273 | weights.append(weight)
274 |
275 | # Sample infinitely
276 | weights = np.array(weights) / np.sum(weights)
277 | while True:
278 | item_idx = random.choice(len(items), p=weights)
279 | record, kind, index = items[item_idx]
280 | if kind == 0:
281 | yield Sample(record=record, chain_id=index)
282 | else:
283 | yield Sample(record=record, interface_id=index)
284 |
--------------------------------------------------------------------------------
/src/boltz/data/sample/distillation.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator, List
2 |
3 | from numpy.random import RandomState
4 |
5 | from boltz.data.types import Record
6 | from boltz.data.sample.sampler import Sample, Sampler
7 |
8 |
9 | class DistillationSampler(Sampler):
10 | """A sampler for monomer distillation data."""
11 |
12 | def __init__(self, small_size: int = 200, small_prob: float = 0.01) -> None:
13 | """Initialize the sampler.
14 |
15 | Parameters
16 | ----------
17 | small_size : int, optional
18 | The maximum size to be considered small.
19 | small_prob : float, optional
20 | The probability of sampling a small item.
21 |
22 | """
23 | self._size = small_size
24 | self._prob = small_prob
25 |
26 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
27 | """Sample a structure from the dataset infinitely.
28 |
29 | Parameters
30 | ----------
31 | records : List[Record]
32 | The records to sample from.
33 | random : RandomState
34 | The random state for reproducibility.
35 |
36 | Yields
37 | ------
38 | Sample
39 | A data sample.
40 |
41 | """
42 | # Remove records with invalid chains
43 | records = [r for r in records if r.chains[0].valid]
44 |
45 | # Split in small and large proteins. We assume that there is only
46 | # one chain per record, as is the case for monomer distillation
47 | small = [r for r in records if r.chains[0].num_residues <= self._size]
48 | large = [r for r in records if r.chains[0].num_residues > self._size]
49 |
50 | # Sample infinitely
51 | while True:
52 | # Sample small or large
53 | samples = small if random.rand() < self._prob else large
54 |
55 | # Sample item from the list
56 | index = random.randint(0, len(samples))
57 | yield Sample(record=samples[index])
58 |
--------------------------------------------------------------------------------
/src/boltz/data/sample/random.py:
--------------------------------------------------------------------------------
1 | from dataclasses import replace
2 | from typing import Iterator, List
3 |
4 | from numpy.random import RandomState
5 |
6 | from boltz.data.types import Record
7 | from boltz.data.sample.sampler import Sample, Sampler
8 |
9 |
10 | class RandomSampler(Sampler):
11 | """A simple random sampler with replacement."""
12 |
13 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
14 | """Sample a structure from the dataset infinitely.
15 |
16 | Parameters
17 | ----------
18 | records : List[Record]
19 | The records to sample from.
20 | random : RandomState
21 | The random state for reproducibility.
22 |
23 | Yields
24 | ------
25 | Sample
26 | A data sample.
27 |
28 | """
29 | while True:
30 | # Sample item from the list
31 | index = random.randint(0, len(records))
32 | record = records[index]
33 |
34 | # Remove invalid chains and interfaces
35 | chains = [c for c in record.chains if c.valid]
36 | interfaces = [i for i in record.interfaces if i.valid]
37 | record = replace(record, chains=chains, interfaces=interfaces)
38 |
39 | yield Sample(record=record)
40 |
--------------------------------------------------------------------------------
/src/boltz/data/sample/sampler.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from dataclasses import dataclass
3 | from typing import Iterator, List, Optional
4 |
5 | from numpy.random import RandomState
6 |
7 | from boltz.data.types import Record
8 |
9 |
10 | @dataclass
11 | class Sample:
12 | """A sample with optional chain and interface IDs.
13 |
14 | Attributes
15 | ----------
16 | record : Record
17 | The record.
18 | chain_id : Optional[int]
19 | The chain ID.
20 | interface_id : Optional[int]
21 | The interface ID.
22 | """
23 |
24 | record: Record
25 | chain_id: Optional[int] = None
26 | interface_id: Optional[int] = None
27 |
28 |
29 | class Sampler(ABC):
30 | """Abstract base class for samplers."""
31 |
32 | @abstractmethod
33 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
34 | """Sample a structure from the dataset infinitely.
35 |
36 | Parameters
37 | ----------
38 | records : List[Record]
39 | The records to sample from.
40 | random : RandomState
41 | The random state for reproducibility.
42 |
43 | Yields
44 | ------
45 | Sample
46 | A data sample.
47 |
48 | """
49 | raise NotImplementedError
50 |
--------------------------------------------------------------------------------
/src/boltz/data/tokenize/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/tokenize/__init__.py
--------------------------------------------------------------------------------
/src/boltz/data/tokenize/boltz.py:
--------------------------------------------------------------------------------
1 | from dataclasses import astuple, dataclass
2 |
3 | import numpy as np
4 |
5 | from boltz.data import const
6 | from boltz.data.tokenize.tokenizer import Tokenizer
7 | from boltz.data.types import Input, Token, TokenBond, Tokenized
8 |
9 |
10 | @dataclass
11 | class TokenData:
12 | """TokenData datatype."""
13 |
14 | token_idx: int
15 | atom_idx: int
16 | atom_num: int
17 | res_idx: int
18 | res_type: int
19 | sym_id: int
20 | asym_id: int
21 | entity_id: int
22 | mol_type: int
23 | center_idx: int
24 | disto_idx: int
25 | center_coords: np.ndarray
26 | disto_coords: np.ndarray
27 | resolved_mask: bool
28 | disto_mask: bool
29 |
30 |
31 | class BoltzTokenizer(Tokenizer):
32 | """Tokenize an input structure for training."""
33 |
34 | def tokenize(self, data: Input) -> Tokenized:
35 | """Tokenize the input data.
36 |
37 | Parameters
38 | ----------
39 | data : Input
40 | The input data.
41 |
42 | Returns
43 | -------
44 | Tokenized
45 | The tokenized data.
46 |
47 | """
48 | # Get structure data
49 | struct = data.structure
50 |
51 | # Create token data
52 | token_data = []
53 |
54 | # Keep track of atom_idx to token_idx
55 | token_idx = 0
56 | atom_to_token = {}
57 |
58 | # Filter to valid chains only
59 | chains = struct.chains[struct.mask]
60 |
61 | for chain in chains:
62 | # Get residue indices
63 | res_start = chain["res_idx"]
64 | res_end = chain["res_idx"] + chain["res_num"]
65 |
66 | for res in struct.residues[res_start:res_end]:
67 | # Get atom indices
68 | atom_start = res["atom_idx"]
69 | atom_end = res["atom_idx"] + res["atom_num"]
70 |
71 | # Standard residues are tokens
72 | if res["is_standard"]:
73 | # Get center and disto atoms
74 | center = struct.atoms[res["atom_center"]]
75 | disto = struct.atoms[res["atom_disto"]]
76 |
77 | # Token is present if centers are
78 | is_present = res["is_present"] & center["is_present"]
79 | is_disto_present = res["is_present"] & disto["is_present"]
80 |
81 | # Apply chain transformation
82 | c_coords = center["coords"]
83 | d_coords = disto["coords"]
84 |
85 | # Create token
86 | token = TokenData(
87 | token_idx=token_idx,
88 | atom_idx=res["atom_idx"],
89 | atom_num=res["atom_num"],
90 | res_idx=res["res_idx"],
91 | res_type=res["res_type"],
92 | sym_id=chain["sym_id"],
93 | asym_id=chain["asym_id"],
94 | entity_id=chain["entity_id"],
95 | mol_type=chain["mol_type"],
96 | center_idx=res["atom_center"],
97 | disto_idx=res["atom_disto"],
98 | center_coords=c_coords,
99 | disto_coords=d_coords,
100 | resolved_mask=is_present,
101 | disto_mask=is_disto_present,
102 | )
103 | token_data.append(astuple(token))
104 |
105 | # Update atom_idx to token_idx
106 | for atom_idx in range(atom_start, atom_end):
107 | atom_to_token[atom_idx] = token_idx
108 |
109 | token_idx += 1
110 |
111 | # Non-standard are tokenized per atom
112 | else:
113 | # We use the unk protein token as res_type
114 | unk_token = const.unk_token["PROTEIN"]
115 | unk_id = const.token_ids[unk_token]
116 |
117 | # Get atom coordinates
118 | atom_data = struct.atoms[atom_start:atom_end]
119 | atom_coords = atom_data["coords"]
120 |
121 | # Tokenize each atom
122 | for i, atom in enumerate(atom_data):
123 | # Token is present if atom is
124 | is_present = res["is_present"] & atom["is_present"]
125 | index = atom_start + i
126 |
127 | # Create token
128 | token = TokenData(
129 | token_idx=token_idx,
130 | atom_idx=index,
131 | atom_num=1,
132 | res_idx=res["res_idx"],
133 | res_type=unk_id,
134 | sym_id=chain["sym_id"],
135 | asym_id=chain["asym_id"],
136 | entity_id=chain["entity_id"],
137 | mol_type=chain["mol_type"],
138 | center_idx=index,
139 | disto_idx=index,
140 | center_coords=atom_coords[i],
141 | disto_coords=atom_coords[i],
142 | resolved_mask=is_present,
143 | disto_mask=is_present,
144 | )
145 | token_data.append(astuple(token))
146 |
147 | # Update atom_idx to token_idx
148 | atom_to_token[index] = token_idx
149 | token_idx += 1
150 |
151 | # Create token bonds
152 | token_bonds = []
153 |
154 | # Add atom-atom bonds from ligands
155 | for bond in struct.bonds:
156 | if (
157 | bond["atom_1"] not in atom_to_token
158 | or bond["atom_2"] not in atom_to_token
159 | ):
160 | continue
161 | token_bond = (
162 | atom_to_token[bond["atom_1"]],
163 | atom_to_token[bond["atom_2"]],
164 | )
165 | token_bonds.append(token_bond)
166 |
167 | # Add connection bonds (covalent)
168 | for conn in struct.connections:
169 | if (
170 | conn["atom_1"] not in atom_to_token
171 | or conn["atom_2"] not in atom_to_token
172 | ):
173 | continue
174 | token_bond = (
175 | atom_to_token[conn["atom_1"]],
176 | atom_to_token[conn["atom_2"]],
177 | )
178 | token_bonds.append(token_bond)
179 |
180 | token_data = np.array(token_data, dtype=Token)
181 | token_bonds = np.array(token_bonds, dtype=TokenBond)
182 | tokenized = Tokenized(
183 | token_data,
184 | token_bonds,
185 | data.structure,
186 | data.msa,
187 | )
188 | return tokenized
189 |
--------------------------------------------------------------------------------
/src/boltz/data/tokenize/tokenizer.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from boltz.data.types import Input, Tokenized
4 |
5 |
6 | class Tokenizer(ABC):
7 | """Tokenize an input structure for training."""
8 |
9 | @abstractmethod
10 | def tokenize(self, data: Input) -> Tokenized:
11 | """Tokenize the input data.
12 |
13 | Parameters
14 | ----------
15 | data : Input
16 | The input data.
17 |
18 | Returns
19 | -------
20 | Tokenized
21 | The tokenized data.
22 |
23 | """
24 | raise NotImplementedError
25 |
--------------------------------------------------------------------------------
/src/boltz/data/write/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/data/write/__init__.py
--------------------------------------------------------------------------------
/src/boltz/data/write/mmcif.py:
--------------------------------------------------------------------------------
1 | import io
2 | from collections.abc import Iterator
3 | from typing import Optional
4 |
5 | import ihm
6 | import modelcif
7 | from modelcif import Assembly, AsymUnit, Entity, System, dumper
8 | from modelcif.model import AbInitioModel, Atom, ModelGroup
9 | from rdkit import Chem
10 | from torch import Tensor
11 |
12 | from boltz.data import const
13 | from boltz.data.types import Structure
14 | from boltz.data.write.utils import generate_tags
15 |
16 |
17 | def to_mmcif(structure: Structure, plddts: Optional[Tensor] = None) -> str: # noqa: C901, PLR0915, PLR0912
18 | """Write a structure into an MMCIF file.
19 |
20 | Parameters
21 | ----------
22 | structure : Structure
23 | The input structure
24 |
25 | Returns
26 | -------
27 | str
28 | the output MMCIF file
29 |
30 | """
31 | system = System()
32 |
33 | # Load periodic table for element mapping
34 | periodic_table = Chem.GetPeriodicTable()
35 |
36 | # Map entities to chain_ids
37 | entity_to_chains = {}
38 | entity_to_moltype = {}
39 |
40 | for chain in structure.chains:
41 | entity_id = chain["entity_id"]
42 | mol_type = chain["mol_type"]
43 | entity_to_chains.setdefault(entity_id, []).append(chain)
44 | entity_to_moltype[entity_id] = mol_type
45 |
46 | # Map entities to sequences
47 | sequences = {}
48 | for entity in entity_to_chains:
49 | # Get the first chain
50 | chain = entity_to_chains[entity][0]
51 |
52 | # Get the sequence
53 | res_start = chain["res_idx"]
54 | res_end = chain["res_idx"] + chain["res_num"]
55 | residues = structure.residues[res_start:res_end]
56 | sequence = [str(res["name"]) for res in residues]
57 | sequences[entity] = sequence
58 |
59 | # Create entity objects
60 | lig_entity = None
61 | entities_map = {}
62 | for entity, sequence in sequences.items():
63 | mol_type = entity_to_moltype[entity]
64 |
65 | if mol_type == const.chain_type_ids["PROTEIN"]:
66 | alphabet = ihm.LPeptideAlphabet()
67 | chem_comp = lambda x: ihm.LPeptideChemComp(id=x, code=x, code_canonical="X") # noqa: E731
68 | elif mol_type == const.chain_type_ids["DNA"]:
69 | alphabet = ihm.DNAAlphabet()
70 | chem_comp = lambda x: ihm.DNAChemComp(id=x, code=x, code_canonical="N") # noqa: E731
71 | elif mol_type == const.chain_type_ids["RNA"]:
72 | alphabet = ihm.RNAAlphabet()
73 | chem_comp = lambda x: ihm.RNAChemComp(id=x, code=x, code_canonical="N") # noqa: E731
74 | elif len(sequence) > 1:
75 | alphabet = {}
76 | chem_comp = lambda x: ihm.SaccharideChemComp(id=x) # noqa: E731
77 | else:
78 | alphabet = {}
79 | chem_comp = lambda x: ihm.NonPolymerChemComp(id=x) # noqa: E731
80 |
81 | # Handle smiles
82 | if len(sequence) == 1 and (sequence[0] == "LIG"):
83 | if lig_entity is None:
84 | seq = [chem_comp(sequence[0])]
85 | lig_entity = Entity(seq)
86 | model_e = lig_entity
87 | else:
88 | seq = [
89 | alphabet[item] if item in alphabet else chem_comp(item)
90 | for item in sequence
91 | ]
92 | model_e = Entity(seq)
93 |
94 | for chain in entity_to_chains[entity]:
95 | chain_idx = chain["asym_id"]
96 | entities_map[chain_idx] = model_e
97 |
98 | # We don't assume that symmetry is perfect, so we dump everything
99 | # into the asymmetric unit, and produce just a single assembly
100 | chain_tags = generate_tags()
101 | asym_unit_map = {}
102 | for chain in structure.chains:
103 | # Define the model assembly
104 | chain_idx = chain["asym_id"]
105 | chain_tag = next(chain_tags)
106 | asym = AsymUnit(
107 | entities_map[chain_idx],
108 | details="Model subunit %s" % chain_tag,
109 | id=chain_tag,
110 | )
111 | asym_unit_map[chain_idx] = asym
112 | modeled_assembly = Assembly(asym_unit_map.values(), name="Modeled assembly")
113 |
114 | class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
115 | name = "pLDDT"
116 | software = None
117 | description = "Predicted lddt"
118 |
119 | class _MyModel(AbInitioModel):
120 | def get_atoms(self) -> Iterator[Atom]:
121 | # Add all atom sites.
122 | res_num = 0
123 | for chain in structure.chains:
124 | # We rename the chains in alphabetical order
125 | het = chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]
126 | chain_idx = chain["asym_id"]
127 | res_start = chain["res_idx"]
128 | res_end = chain["res_idx"] + chain["res_num"]
129 |
130 | residues = structure.residues[res_start:res_end]
131 | for residue in residues:
132 | atom_start = residue["atom_idx"]
133 | atom_end = residue["atom_idx"] + residue["atom_num"]
134 | atoms = structure.atoms[atom_start:atom_end]
135 | atom_coords = atoms["coords"]
136 | for i, atom in enumerate(atoms):
137 | # This should not happen on predictions, but just in case.
138 | if not atom["is_present"]:
139 | continue
140 |
141 | name = atom["name"]
142 | name = [chr(c + 32) for c in name if c != 0]
143 | name = "".join(name)
144 | element = periodic_table.GetElementSymbol(
145 | atom["element"].item()
146 | )
147 | element = element.upper()
148 | residue_index = residue["res_idx"] + 1
149 | pos = atom_coords[i]
150 | yield Atom(
151 | asym_unit=asym_unit_map[chain_idx],
152 | type_symbol=element,
153 | seq_id=residue_index,
154 | atom_id=name,
155 | x=f"{pos[0]:.5f}",
156 | y=f"{pos[1]:.5f}",
157 | z=f"{pos[2]:.5f}",
158 | het=het,
159 | biso=1
160 | if plddts is None
161 | else round(plddts[res_num].item(), 2),
162 | occupancy=1,
163 | )
164 |
165 | res_num += 1
166 |
167 | def add_plddt(self, plddts):
168 | res_num = 0
169 | for chain in structure.chains:
170 | chain_idx = chain["asym_id"]
171 | res_start = chain["res_idx"]
172 | res_end = chain["res_idx"] + chain["res_num"]
173 | residues = structure.residues[res_start:res_end]
174 | # We rename the chains in alphabetical order
175 | for residue in residues:
176 | residue_idx = residue["res_idx"] + 1
177 | self.qa_metrics.append(
178 | _LocalPLDDT(
179 | asym_unit_map[chain_idx].residue(residue_idx),
180 | plddts[res_num].item(),
181 | )
182 | )
183 | res_num += 1
184 |
185 | # Add the model and modeling protocol to the file and write them out:
186 | model = _MyModel(assembly=modeled_assembly, name="Model")
187 | if plddts is not None:
188 | model.add_plddt(plddts)
189 |
190 | model_group = ModelGroup([model], name="All models")
191 | system.model_groups.append(model_group)
192 |
193 | fh = io.StringIO()
194 | dumper.write(fh, [system])
195 | return fh.getvalue()
196 |
--------------------------------------------------------------------------------
/src/boltz/data/write/pdb.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from rdkit import Chem
4 | from torch import Tensor
5 |
6 | from boltz.data import const
7 | from boltz.data.types import Structure
8 | from boltz.data.write.utils import generate_tags
9 |
10 |
11 | def to_pdb(structure: Structure, plddts: Optional[Tensor] = None) -> str: # noqa: PLR0915
12 | """Write a structure into a PDB file.
13 |
14 | Parameters
15 | ----------
16 | structure : Structure
17 | The input structure
18 |
19 | Returns
20 | -------
21 | str
22 | the output PDB file
23 |
24 | """
25 | pdb_lines = []
26 |
27 | atom_index = 1
28 | atom_reindex_ter = []
29 | chain_tags = generate_tags()
30 |
31 | # Load periodic table for element mapping
32 | periodic_table = Chem.GetPeriodicTable()
33 |
34 | # Add all atom sites.
35 | res_num = 0
36 | for chain in structure.chains:
37 | # We rename the chains in alphabetical order
38 | chain_idx = chain["asym_id"]
39 | chain_tag = next(chain_tags)
40 |
41 | res_start = chain["res_idx"]
42 | res_end = chain["res_idx"] + chain["res_num"]
43 |
44 | residues = structure.residues[res_start:res_end]
45 | for residue in residues:
46 | atom_start = residue["atom_idx"]
47 | atom_end = residue["atom_idx"] + residue["atom_num"]
48 | atoms = structure.atoms[atom_start:atom_end]
49 | atom_coords = atoms["coords"]
50 | for i, atom in enumerate(atoms):
51 | # This should not happen on predictions, but just in case.
52 | if not atom["is_present"]:
53 | continue
54 |
55 | record_type = (
56 | "ATOM"
57 | if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]
58 | else "HETATM"
59 | )
60 | name = atom["name"]
61 | name = [chr(c + 32) for c in name if c != 0]
62 | name = "".join(name)
63 | name = name if len(name) == 4 else f" {name}" # noqa: PLR2004
64 | alt_loc = ""
65 | insertion_code = ""
66 | occupancy = 1.00
67 | element = periodic_table.GetElementSymbol(atom["element"].item())
68 | element = element.upper()
69 | charge = ""
70 | residue_index = residue["res_idx"] + 1
71 | pos = atom_coords[i]
72 | res_name_3 = (
73 | "LIG" if record_type == "HETATM" else str(residue["name"][:3])
74 | )
75 | b_factor = 1.00 if plddts is None else round(plddts[res_num].item(), 2)
76 |
77 | # PDB is a columnar format, every space matters here!
78 | atom_line = (
79 | f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
80 | f"{res_name_3:>3} {chain_tag:>1}"
81 | f"{residue_index:>4}{insertion_code:>1} "
82 | f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
83 | f"{occupancy:>6.2f}{b_factor:>6.2f} "
84 | f"{element:>2}{charge:>2}"
85 | )
86 | pdb_lines.append(atom_line)
87 | atom_reindex_ter.append(atom_index)
88 | atom_index += 1
89 |
90 | res_num += 1
91 |
92 | should_terminate = chain_idx < (len(structure.chains) - 1)
93 | if should_terminate:
94 | # Close the chain.
95 | chain_end = "TER"
96 | chain_termination_line = (
97 | f"{chain_end:<6}{atom_index:>5} "
98 | f"{res_name_3:>3} "
99 | f"{chain_tag:>1}{residue_index:>4}"
100 | )
101 | pdb_lines.append(chain_termination_line)
102 | atom_index += 1
103 |
104 | # Dump CONECT records.
105 | for bonds in [structure.bonds, structure.connections]:
106 | for bond in bonds:
107 | atom1 = structure.atoms[bond["atom_1"]]
108 | atom2 = structure.atoms[bond["atom_2"]]
109 | if not atom1["is_present"] or not atom2["is_present"]:
110 | continue
111 | atom1_idx = atom_reindex_ter[bond["atom_1"]]
112 | atom2_idx = atom_reindex_ter[bond["atom_2"]]
113 | conect_line = f"CONECT{atom1_idx:>5}{atom2_idx:>5}"
114 | pdb_lines.append(conect_line)
115 |
116 | pdb_lines.append("END")
117 | pdb_lines.append("")
118 | pdb_lines = [line.ljust(80) for line in pdb_lines]
119 | return "\n".join(pdb_lines)
120 |
--------------------------------------------------------------------------------
/src/boltz/data/write/utils.py:
--------------------------------------------------------------------------------
1 | import string
2 | from collections.abc import Iterator
3 |
4 |
5 | def generate_tags() -> Iterator[str]:
6 | """Generate chain tags.
7 |
8 | Yields
9 | ------
10 | str
11 | The next chain tag
12 |
13 | """
14 | for i in range(1, 4):
15 | for j in range(len(string.ascii_uppercase) ** i):
16 | tag = ""
17 | for k in range(i):
18 | tag += string.ascii_uppercase[
19 | j
20 | // (len(string.ascii_uppercase) ** k)
21 | % len(string.ascii_uppercase)
22 | ]
23 | yield tag
24 |
--------------------------------------------------------------------------------
/src/boltz/data/write/writer.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict, replace
2 | import json
3 | from pathlib import Path
4 | from typing import Literal
5 |
6 | import numpy as np
7 | from pytorch_lightning import LightningModule, Trainer
8 | from pytorch_lightning.callbacks import BasePredictionWriter
9 | import torch
10 | from torch import Tensor
11 |
12 | from boltz.data.types import (
13 | Interface,
14 | Record,
15 | Structure,
16 | )
17 | from boltz.data.write.mmcif import to_mmcif
18 | from boltz.data.write.pdb import to_pdb
19 |
20 |
21 | class BoltzWriter(BasePredictionWriter):
22 | """Custom writer for predictions."""
23 |
24 | def __init__(
25 | self,
26 | data_dir: str,
27 | output_dir: str,
28 | output_format: Literal["pdb", "mmcif"] = "mmcif",
29 | ) -> None:
30 | """Initialize the writer.
31 |
32 | Parameters
33 | ----------
34 | output_dir : str
35 | The directory to save the predictions.
36 |
37 | """
38 | super().__init__(write_interval="batch")
39 | if output_format not in ["pdb", "mmcif"]:
40 | msg = f"Invalid output format: {output_format}"
41 | raise ValueError(msg)
42 |
43 | self.data_dir = Path(data_dir)
44 | self.output_dir = Path(output_dir)
45 | self.output_format = output_format
46 | self.failed = 0
47 |
48 | # Create the output directories
49 | self.output_dir.mkdir(parents=True, exist_ok=True)
50 |
51 | def write_on_batch_end(
52 | self,
53 | trainer: Trainer, # noqa: ARG002
54 | pl_module: LightningModule, # noqa: ARG002
55 | prediction: dict[str, Tensor],
56 | batch_indices: list[int], # noqa: ARG002
57 | batch: dict[str, Tensor],
58 | batch_idx: int, # noqa: ARG002
59 | dataloader_idx: int, # noqa: ARG002
60 | ) -> None:
61 | """Write the predictions to disk."""
62 | if prediction["exception"]:
63 | self.failed += 1
64 | return
65 |
66 | # Get the records
67 | records: list[Record] = batch["record"]
68 |
69 | # Get the predictions
70 | coords = prediction["coords"]
71 | coords = coords.unsqueeze(0)
72 |
73 | pad_masks = prediction["masks"]
74 |
75 | # Get ranking
76 | argsort = torch.argsort(prediction["confidence_score"], descending=True)
77 | idx_to_rank = {idx.item(): rank for rank, idx in enumerate(argsort)}
78 |
79 | # Iterate over the records
80 | for record, coord, pad_mask in zip(records, coords, pad_masks):
81 | # Load the structure
82 | path = self.data_dir / f"{record.id}.npz"
83 | structure: Structure = Structure.load(path)
84 |
85 | # Compute chain map with masked removed, to be used later
86 | chain_map = {}
87 | for i, mask in enumerate(structure.mask):
88 | if mask:
89 | chain_map[len(chain_map)] = i
90 |
91 | # Remove masked chains completely
92 | structure = structure.remove_invalid_chains()
93 |
94 | for model_idx in range(coord.shape[0]):
95 | # Get model coord
96 | model_coord = coord[model_idx]
97 | # Unpad
98 | coord_unpad = model_coord[pad_mask.bool()]
99 | coord_unpad = coord_unpad.cpu().numpy()
100 |
101 | # New atom table
102 | atoms = structure.atoms
103 | atoms["coords"] = coord_unpad
104 | atoms["is_present"] = True
105 |
106 | # Mew residue table
107 | residues = structure.residues
108 | residues["is_present"] = True
109 |
110 | # Update the structure
111 | interfaces = np.array([], dtype=Interface)
112 | new_structure: Structure = replace(
113 | structure,
114 | atoms=atoms,
115 | residues=residues,
116 | interfaces=interfaces,
117 | )
118 |
119 | # Update chain info
120 | chain_info = []
121 | for chain in new_structure.chains:
122 | old_chain_idx = chain_map[chain["asym_id"]]
123 | old_chain_info = record.chains[old_chain_idx]
124 | new_chain_info = replace(
125 | old_chain_info,
126 | chain_id=int(chain["asym_id"]),
127 | valid=True,
128 | )
129 | chain_info.append(new_chain_info)
130 |
131 | # Save the structure
132 | struct_dir = self.output_dir / record.id
133 | struct_dir.mkdir(exist_ok=True)
134 |
135 | # Get plddt's
136 | plddts = None
137 | if "plddt" in prediction:
138 | plddts = prediction["plddt"][model_idx]
139 |
140 | # Create path name
141 | outname = f"{record.id}_model_{idx_to_rank[model_idx]}"
142 |
143 | # Save the structure
144 | if self.output_format == "pdb":
145 | path = struct_dir / f"{outname}.pdb"
146 | with path.open("w") as f:
147 | f.write(to_pdb(new_structure, plddts=plddts))
148 | elif self.output_format == "mmcif":
149 | path = struct_dir / f"{outname}.cif"
150 | with path.open("w") as f:
151 | f.write(to_mmcif(new_structure, plddts=plddts))
152 | else:
153 | path = struct_dir / f"{outname}.npz"
154 | np.savez_compressed(path, **asdict(new_structure))
155 |
156 | # Save confidence summary
157 | if "plddt" in prediction:
158 | path = (
159 | struct_dir
160 | / f"confidence_{record.id}_model_{idx_to_rank[model_idx]}.json"
161 | )
162 | confidence_summary_dict = {}
163 | for key in [
164 | "confidence_score",
165 | "ptm",
166 | "iptm",
167 | "ligand_iptm",
168 | "protein_iptm",
169 | "complex_plddt",
170 | "complex_iplddt",
171 | "complex_pde",
172 | "complex_ipde",
173 | ]:
174 | confidence_summary_dict[key] = prediction[key][model_idx].item()
175 | confidence_summary_dict["chains_ptm"] = {
176 | idx: prediction["pair_chains_iptm"][idx][idx][model_idx].item()
177 | for idx in prediction["pair_chains_iptm"]
178 | }
179 | confidence_summary_dict["pair_chains_iptm"] = {
180 | idx1: {
181 | idx2: prediction["pair_chains_iptm"][idx1][idx2][
182 | model_idx
183 | ].item()
184 | for idx2 in prediction["pair_chains_iptm"][idx1]
185 | }
186 | for idx1 in prediction["pair_chains_iptm"]
187 | }
188 | with path.open("w") as f:
189 | f.write(
190 | json.dumps(
191 | confidence_summary_dict,
192 | indent=4,
193 | )
194 | )
195 |
196 | # Save plddt
197 | plddt = prediction["plddt"][model_idx]
198 | path = (
199 | struct_dir
200 | / f"plddt_{record.id}_model_{idx_to_rank[model_idx]}.npz"
201 | )
202 | np.savez_compressed(path, plddt=plddt.cpu().numpy())
203 |
204 | # Save pae
205 | if "pae" in prediction:
206 | pae = prediction["pae"][model_idx]
207 | path = (
208 | struct_dir
209 | / f"pae_{record.id}_model_{idx_to_rank[model_idx]}.npz"
210 | )
211 | np.savez_compressed(path, pae=pae.cpu().numpy())
212 |
213 | # Save pde
214 | if "pde" in prediction:
215 | pde = prediction["pde"][model_idx]
216 | path = (
217 | struct_dir
218 | / f"pde_{record.id}_model_{idx_to_rank[model_idx]}.npz"
219 | )
220 | np.savez_compressed(path, pde=pde.cpu().numpy())
221 |
222 | def on_predict_epoch_end(
223 | self,
224 | trainer: Trainer, # noqa: ARG002
225 | pl_module: LightningModule, # noqa: ARG002
226 | ) -> None:
227 | """Print the number of failed examples."""
228 | # Print number of failed examples
229 | print(f"Number of failed examples: {self.failed}") # noqa: T201
230 |
--------------------------------------------------------------------------------
/src/boltz/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/model/__init__.py
--------------------------------------------------------------------------------
/src/boltz/model/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/model/layers/__init__.py
--------------------------------------------------------------------------------
/src/boltz/model/layers/attention.py:
--------------------------------------------------------------------------------
1 | from einops.layers.torch import Rearrange
2 | import torch
3 | from torch import Tensor, nn
4 |
5 | import boltz.model.layers.initialize as init
6 |
7 |
8 | class AttentionPairBias(nn.Module):
9 | """Attention pair bias layer."""
10 |
11 | def __init__(
12 | self,
13 | c_s: int,
14 | c_z: int,
15 | num_heads: int,
16 | inf: float = 1e6,
17 | initial_norm: bool = True,
18 | ) -> None:
19 | """Initialize the attention pair bias layer.
20 |
21 | Parameters
22 | ----------
23 | c_s : int
24 | The input sequence dimension.
25 | c_z : int
26 | The input pairwise dimension.
27 | num_heads : int
28 | The number of heads.
29 | inf : float, optional
30 | The inf value, by default 1e6
31 | initial_norm: bool, optional
32 | Whether to apply layer norm to the input, by default True
33 |
34 | """
35 | super().__init__()
36 |
37 | assert c_s % num_heads == 0
38 |
39 | self.c_s = c_s
40 | self.num_heads = num_heads
41 | self.head_dim = c_s // num_heads
42 | self.inf = inf
43 |
44 | self.initial_norm = initial_norm
45 | if self.initial_norm:
46 | self.norm_s = nn.LayerNorm(c_s)
47 |
48 | self.proj_q = nn.Linear(c_s, c_s)
49 | self.proj_k = nn.Linear(c_s, c_s, bias=False)
50 | self.proj_v = nn.Linear(c_s, c_s, bias=False)
51 | self.proj_g = nn.Linear(c_s, c_s, bias=False)
52 |
53 | self.proj_z = nn.Sequential(
54 | nn.LayerNorm(c_z),
55 | nn.Linear(c_z, num_heads, bias=False),
56 | Rearrange("b ... h -> b h ..."),
57 | )
58 |
59 | self.proj_o = nn.Linear(c_s, c_s, bias=False)
60 | init.final_init_(self.proj_o.weight)
61 |
62 | def forward(
63 | self,
64 | s: Tensor,
65 | z: Tensor,
66 | mask: Tensor,
67 | multiplicity: int = 1,
68 | to_keys=None,
69 | model_cache=None,
70 | ) -> Tensor:
71 | """Forward pass.
72 |
73 | Parameters
74 | ----------
75 | s : torch.Tensor
76 | The input sequence tensor (B, S, D)
77 | z : torch.Tensor
78 | The input pairwise tensor (B, N, N, D)
79 | mask : torch.Tensor
80 | The pairwise mask tensor (B, N, N)
81 | multiplicity : int, optional
82 | The diffusion batch size, by default 1
83 |
84 | Returns
85 | -------
86 | torch.Tensor
87 | The output sequence tensor.
88 |
89 | """
90 | B = s.shape[0]
91 |
92 | # Layer norms
93 | if self.initial_norm:
94 | s = self.norm_s(s)
95 |
96 | if to_keys is not None:
97 | k_in = to_keys(s)
98 | mask = to_keys(mask.unsqueeze(-1)).squeeze(-1)
99 | else:
100 | k_in = s
101 |
102 | # Compute projections
103 | q = self.proj_q(s).view(B, -1, self.num_heads, self.head_dim)
104 | k = self.proj_k(k_in).view(B, -1, self.num_heads, self.head_dim)
105 | v = self.proj_v(k_in).view(B, -1, self.num_heads, self.head_dim)
106 |
107 | # Caching z projection during diffusion roll-out
108 | if model_cache is None or "z" not in model_cache:
109 | z = self.proj_z(z)
110 |
111 | if model_cache is not None:
112 | model_cache["z"] = z
113 | else:
114 | z = model_cache["z"]
115 | z = z.repeat_interleave(multiplicity, 0)
116 |
117 | g = self.proj_g(s).sigmoid()
118 |
119 | with torch.autocast("cuda", enabled=False):
120 | # Compute attention weights
121 | attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float())
122 | attn = attn / (self.head_dim**0.5) + z.float()
123 | attn = attn + (1 - mask[:, None, None].float()) * -self.inf
124 | attn = attn.softmax(dim=-1)
125 |
126 | # Compute output
127 | o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype)
128 | o = o.reshape(B, -1, self.c_s)
129 | o = self.proj_o(g * o)
130 |
131 | return o
132 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/dropout.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 |
5 | def get_dropout_mask(
6 | dropout: float,
7 | z: Tensor,
8 | training: bool,
9 | columnwise: bool = False,
10 | ) -> Tensor:
11 | """Get the dropout mask.
12 |
13 | Parameters
14 | ----------
15 | dropout : float
16 | The dropout rate
17 | z : torch.Tensor
18 | The tensor to apply dropout to
19 | training : bool
20 | Whether the model is in training mode
21 | columnwise : bool, optional
22 | Whether to apply dropout columnwise
23 |
24 | Returns
25 | -------
26 | torch.Tensor
27 | The dropout mask
28 |
29 | """
30 | dropout = dropout * training
31 | v = z[:, 0:1, :, 0:1] if columnwise else z[:, :, 0:1, 0:1]
32 | d = torch.rand_like(v) > dropout
33 | d = d * 1.0 / (1.0 - dropout)
34 | return d
35 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/initialize.py:
--------------------------------------------------------------------------------
1 | """Utility functions for initializing weights and biases."""
2 |
3 | # Copyright 2021 AlQuraishi Laboratory
4 | # Copyright 2021 DeepMind Technologies Limited
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import math
19 | import numpy as np
20 | from scipy.stats import truncnorm
21 | import torch
22 |
23 |
24 | def _prod(nums):
25 | out = 1
26 | for n in nums:
27 | out = out * n
28 | return out
29 |
30 |
31 | def _calculate_fan(linear_weight_shape, fan="fan_in"):
32 | fan_out, fan_in = linear_weight_shape
33 |
34 | if fan == "fan_in":
35 | f = fan_in
36 | elif fan == "fan_out":
37 | f = fan_out
38 | elif fan == "fan_avg":
39 | f = (fan_in + fan_out) / 2
40 | else:
41 | raise ValueError("Invalid fan option")
42 |
43 | return f
44 |
45 |
46 | def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
47 | shape = weights.shape
48 | f = _calculate_fan(shape, fan)
49 | scale = scale / max(1, f)
50 | a = -2
51 | b = 2
52 | std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
53 | size = _prod(shape)
54 | samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
55 | samples = np.reshape(samples, shape)
56 | with torch.no_grad():
57 | weights.copy_(torch.tensor(samples, device=weights.device))
58 |
59 |
60 | def lecun_normal_init_(weights):
61 | trunc_normal_init_(weights, scale=1.0)
62 |
63 |
64 | def he_normal_init_(weights):
65 | trunc_normal_init_(weights, scale=2.0)
66 |
67 |
68 | def glorot_uniform_init_(weights):
69 | torch.nn.init.xavier_uniform_(weights, gain=1)
70 |
71 |
72 | def final_init_(weights):
73 | with torch.no_grad():
74 | weights.fill_(0.0)
75 |
76 |
77 | def gating_init_(weights):
78 | with torch.no_grad():
79 | weights.fill_(0.0)
80 |
81 |
82 | def bias_init_zero_(bias):
83 | with torch.no_grad():
84 | bias.fill_(0.0)
85 |
86 |
87 | def bias_init_one_(bias):
88 | with torch.no_grad():
89 | bias.fill_(1.0)
90 |
91 |
92 | def normal_init_(weights):
93 | torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
94 |
95 |
96 | def ipa_point_weights_init_(weights):
97 | with torch.no_grad():
98 | softplus_inverse_1 = 0.541324854612918
99 | weights.fill_(softplus_inverse_1)
100 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/outer_product_mean.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor, nn
3 |
4 | import boltz.model.layers.initialize as init
5 |
6 |
7 | class OuterProductMean(nn.Module):
8 | """Outer product mean layer."""
9 |
10 | def __init__(self, c_in: int, c_hidden: int, c_out: int) -> None:
11 | """Initialize the outer product mean layer.
12 |
13 | Parameters
14 | ----------
15 | c_in : int
16 | The input dimension.
17 | c_hidden : int
18 | The hidden dimension.
19 | c_out : int
20 | The output dimension.
21 |
22 | """
23 | super().__init__()
24 | self.c_hidden = c_hidden
25 | self.norm = nn.LayerNorm(c_in)
26 | self.proj_a = nn.Linear(c_in, c_hidden, bias=False)
27 | self.proj_b = nn.Linear(c_in, c_hidden, bias=False)
28 | self.proj_o = nn.Linear(c_hidden * c_hidden, c_out)
29 | init.final_init_(self.proj_o.weight)
30 | init.final_init_(self.proj_o.bias)
31 |
32 | def forward(self, m: Tensor, mask: Tensor, chunk_size: int = None) -> Tensor:
33 | """Forward pass.
34 |
35 | Parameters
36 | ----------
37 | m : torch.Tensor
38 | The sequence tensor (B, S, N, c_in).
39 | mask : torch.Tensor
40 | The mask tensor (B, S, N).
41 |
42 | Returns
43 | -------
44 | torch.Tensor
45 | The output tensor (B, N, N, c_out).
46 |
47 | """
48 | # Expand mask
49 | mask = mask.unsqueeze(-1).to(m)
50 |
51 | # Compute projections
52 | m = self.norm(m)
53 | a = self.proj_a(m) * mask
54 | b = self.proj_b(m) * mask
55 |
56 | # Compute outer product mean
57 | if chunk_size is not None and not self.training:
58 | # Compute pairwise mask
59 | for i in range(0, mask.shape[1], 64):
60 | if i == 0:
61 | num_mask = (
62 | mask[:, i : i + 64, None, :] * mask[:, i : i + 64, :, None]
63 | ).sum(1)
64 | else:
65 | num_mask += (
66 | mask[:, i : i + 64, None, :] * mask[:, i : i + 64, :, None]
67 | ).sum(1)
68 | num_mask = num_mask.clamp(min=1)
69 |
70 | # Compute squentially in chunks
71 | for i in range(0, self.c_hidden, chunk_size):
72 | a_chunk = a[:, :, :, i : i + chunk_size]
73 | sliced_weight_proj_o = self.proj_o.weight[
74 | :, i * self.c_hidden : (i + chunk_size) * self.c_hidden
75 | ]
76 |
77 | z = torch.einsum("bsic,bsjd->bijcd", a_chunk, b)
78 | z = z.reshape(*z.shape[:3], -1)
79 | z = z / num_mask
80 |
81 | # Project to output
82 | if i == 0:
83 | z_out = z.to(m) @ sliced_weight_proj_o.T
84 | else:
85 | z_out = z_out + z.to(m) @ sliced_weight_proj_o.T
86 | return z_out
87 | else:
88 | mask = mask[:, :, None, :] * mask[:, :, :, None]
89 | num_mask = mask.sum(1).clamp(min=1)
90 | z = torch.einsum("bsic,bsjd->bijcd", a.float(), b.float())
91 | z = z.reshape(*z.shape[:3], -1)
92 | z = z / num_mask
93 |
94 | # Project to output
95 | z = self.proj_o(z.to(m))
96 | return z
97 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/pair_averaging.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor, nn
3 |
4 | import boltz.model.layers.initialize as init
5 |
6 |
7 | class PairWeightedAveraging(nn.Module):
8 | """Pair weighted averaging layer."""
9 |
10 | def __init__(
11 | self,
12 | c_m: int,
13 | c_z: int,
14 | c_h: int,
15 | num_heads: int,
16 | inf: float = 1e6,
17 | ) -> None:
18 | """Initialize the pair weighted averaging layer.
19 |
20 | Parameters
21 | ----------
22 | c_m: int
23 | The dimension of the input sequence.
24 | c_z: int
25 | The dimension of the input pairwise tensor.
26 | c_h: int
27 | The dimension of the hidden.
28 | num_heads: int
29 | The number of heads.
30 | inf: float
31 | The value to use for masking, default 1e6.
32 |
33 | """
34 | super().__init__()
35 | self.c_m = c_m
36 | self.c_z = c_z
37 | self.c_h = c_h
38 | self.num_heads = num_heads
39 | self.inf = inf
40 |
41 | self.norm_m = nn.LayerNorm(c_m)
42 | self.norm_z = nn.LayerNorm(c_z)
43 |
44 | self.proj_m = nn.Linear(c_m, c_h * num_heads, bias=False)
45 | self.proj_g = nn.Linear(c_m, c_h * num_heads, bias=False)
46 | self.proj_z = nn.Linear(c_z, num_heads, bias=False)
47 | self.proj_o = nn.Linear(c_h * num_heads, c_m, bias=False)
48 | init.final_init_(self.proj_o.weight)
49 |
50 | def forward(
51 | self, m: Tensor, z: Tensor, mask: Tensor, chunk_heads: False = bool
52 | ) -> Tensor:
53 | """Forward pass.
54 |
55 | Parameters
56 | ----------
57 | m : torch.Tensor
58 | The input sequence tensor (B, S, N, D)
59 | z : torch.Tensor
60 | The input pairwise tensor (B, N, N, D)
61 | mask : torch.Tensor
62 | The pairwise mask tensor (B, N, N)
63 |
64 | Returns
65 | -------
66 | torch.Tensor
67 | The output sequence tensor (B, S, N, D)
68 |
69 | """
70 | # Compute layer norms
71 | m = self.norm_m(m)
72 | z = self.norm_z(z)
73 |
74 | if chunk_heads and not self.training:
75 | # Compute heads sequentially
76 | o_chunks = []
77 | for head_idx in range(self.num_heads):
78 | sliced_weight_proj_m = self.proj_m.weight[
79 | head_idx * self.c_h : (head_idx + 1) * self.c_h, :
80 | ]
81 | sliced_weight_proj_g = self.proj_g.weight[
82 | head_idx * self.c_h : (head_idx + 1) * self.c_h, :
83 | ]
84 | sliced_weight_proj_z = self.proj_z.weight[head_idx : (head_idx + 1), :]
85 | sliced_weight_proj_o = self.proj_o.weight[
86 | :, head_idx * self.c_h : (head_idx + 1) * self.c_h
87 | ]
88 |
89 | # Project input tensors
90 | v: Tensor = m @ sliced_weight_proj_m.T
91 | v = v.reshape(*v.shape[:3], 1, self.c_h)
92 | v = v.permute(0, 3, 1, 2, 4)
93 |
94 | # Compute weights
95 | b: Tensor = z @ sliced_weight_proj_z.T
96 | b = b.permute(0, 3, 1, 2)
97 | b = b + (1 - mask[:, None]) * -self.inf
98 | w = torch.softmax(b, dim=-1)
99 |
100 | # Compute gating
101 | g: Tensor = m @ sliced_weight_proj_g.T
102 | g = g.sigmoid()
103 |
104 | # Compute output
105 | o = torch.einsum("bhij,bhsjd->bhsid", w, v)
106 | o = o.permute(0, 2, 3, 1, 4)
107 | o = o.reshape(*o.shape[:3], 1 * self.c_h)
108 | o_chunks = g * o
109 | if head_idx == 0:
110 | o_out = o_chunks @ sliced_weight_proj_o.T
111 | else:
112 | o_out += o_chunks @ sliced_weight_proj_o.T
113 | return o_out
114 | else:
115 | # Project input tensors
116 | v: Tensor = self.proj_m(m)
117 | v = v.reshape(*v.shape[:3], self.num_heads, self.c_h)
118 | v = v.permute(0, 3, 1, 2, 4)
119 |
120 | # Compute weights
121 | b: Tensor = self.proj_z(z)
122 | b = b.permute(0, 3, 1, 2)
123 | b = b + (1 - mask[:, None]) * -self.inf
124 | w = torch.softmax(b, dim=-1)
125 |
126 | # Compute gating
127 | g: Tensor = self.proj_g(m)
128 | g = g.sigmoid()
129 |
130 | # Compute output
131 | o = torch.einsum("bhij,bhsjd->bhsid", w, v)
132 | o = o.permute(0, 2, 3, 1, 4)
133 | o = o.reshape(*o.shape[:3], self.num_heads * self.c_h)
134 | o = self.proj_o(g * o)
135 | return o
136 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/transition.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from torch import Tensor, nn
4 |
5 | import boltz.model.layers.initialize as init
6 |
7 |
8 | class Transition(nn.Module):
9 | """Perform a two-layer MLP."""
10 |
11 | def __init__(
12 | self,
13 | dim: int = 128,
14 | hidden: int = 512,
15 | out_dim: Optional[int] = None,
16 | ) -> None:
17 | """Initialize the TransitionUpdate module.
18 |
19 | Parameters
20 | ----------
21 | dim: int
22 | The dimension of the input, default 128
23 | hidden: int
24 | The dimension of the hidden, default 512
25 | out_dim: Optional[int]
26 | The dimension of the output, default None
27 |
28 | """
29 | super().__init__()
30 | if out_dim is None:
31 | out_dim = dim
32 |
33 | self.norm = nn.LayerNorm(dim, eps=1e-5)
34 | self.fc1 = nn.Linear(dim, hidden, bias=False)
35 | self.fc2 = nn.Linear(dim, hidden, bias=False)
36 | self.fc3 = nn.Linear(hidden, out_dim, bias=False)
37 | self.silu = nn.SiLU()
38 | self.hidden = hidden
39 |
40 | init.bias_init_one_(self.norm.weight)
41 | init.bias_init_zero_(self.norm.bias)
42 |
43 | init.lecun_normal_init_(self.fc1.weight)
44 | init.lecun_normal_init_(self.fc2.weight)
45 | init.final_init_(self.fc3.weight)
46 |
47 | def forward(self, x: Tensor, chunk_size: int = None) -> Tensor:
48 | """Perform a forward pass.
49 |
50 | Parameters
51 | ----------
52 | x: torch.Tensor
53 | The input data of shape (..., D)
54 |
55 | Returns
56 | -------
57 | x: torch.Tensor
58 | The output data of shape (..., D)
59 |
60 | """
61 | x = self.norm(x)
62 |
63 | if chunk_size is None or self.training:
64 | x = self.silu(self.fc1(x)) * self.fc2(x)
65 | x = self.fc3(x)
66 | return x
67 | else:
68 | # Compute in chunks
69 | for i in range(0, self.hidden, chunk_size):
70 | fc1_slice = self.fc1.weight[i : i + chunk_size, :]
71 | fc2_slice = self.fc2.weight[i : i + chunk_size, :]
72 | fc3_slice = self.fc3.weight[:, i : i + chunk_size]
73 | x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T)
74 | if i == 0:
75 | x_out = x_chunk @ fc3_slice.T
76 | else:
77 | x_out = x_out + x_chunk @ fc3_slice.T
78 | return x_out
79 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/triangular_attention/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/model/layers/triangular_attention/__init__.py
--------------------------------------------------------------------------------
/src/boltz/model/layers/triangular_attention/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 AlQuraishi Laboratory
2 | # Copyright 2021 DeepMind Technologies Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from functools import partial, partialmethod
17 | from typing import List, Optional
18 |
19 | import torch
20 | import torch.nn as nn
21 |
22 | from boltz.model.layers.triangular_attention.primitives import (
23 | Attention,
24 | LayerNorm,
25 | Linear,
26 | )
27 | from boltz.model.layers.triangular_attention.utils import (
28 | chunk_layer,
29 | permute_final_dims,
30 | )
31 |
32 |
33 | class TriangleAttention(nn.Module):
34 | def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
35 | """
36 | Args:
37 | c_in:
38 | Input channel dimension
39 | c_hidden:
40 | Overall hidden channel dimension (not per-head)
41 | no_heads:
42 | Number of attention heads
43 | """
44 | super(TriangleAttention, self).__init__()
45 |
46 | self.c_in = c_in
47 | self.c_hidden = c_hidden
48 | self.no_heads = no_heads
49 | self.starting = starting
50 | self.inf = inf
51 |
52 | self.layer_norm = LayerNorm(self.c_in)
53 |
54 | self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
55 |
56 | self.mha = Attention(
57 | self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
58 | )
59 |
60 | @torch.jit.ignore
61 | def _chunk(
62 | self,
63 | x: torch.Tensor,
64 | biases: List[torch.Tensor],
65 | chunk_size: int,
66 | use_memory_efficient_kernel: bool = False,
67 | use_deepspeed_evo_attention: bool = False,
68 | use_lma: bool = False,
69 | inplace_safe: bool = False,
70 | ) -> torch.Tensor:
71 | "triangle! triangle!"
72 | mha_inputs = {
73 | "q_x": x,
74 | "kv_x": x,
75 | "biases": biases,
76 | }
77 |
78 | return chunk_layer(
79 | partial(
80 | self.mha,
81 | use_memory_efficient_kernel=use_memory_efficient_kernel,
82 | use_deepspeed_evo_attention=use_deepspeed_evo_attention,
83 | use_lma=use_lma,
84 | ),
85 | mha_inputs,
86 | chunk_size=chunk_size,
87 | no_batch_dims=len(x.shape[:-2]),
88 | _out=x if inplace_safe else None,
89 | )
90 |
91 | def forward(
92 | self,
93 | x: torch.Tensor,
94 | mask: Optional[torch.Tensor] = None,
95 | chunk_size: Optional[int] = None,
96 | use_memory_efficient_kernel: bool = False,
97 | use_deepspeed_evo_attention: bool = False,
98 | use_lma: bool = False,
99 | inplace_safe: bool = False,
100 | ) -> torch.Tensor:
101 | """
102 | Args:
103 | x:
104 | [*, I, J, C_in] input tensor (e.g. the pair representation)
105 | Returns:
106 | [*, I, J, C_in] output tensor
107 | """
108 | if mask is None:
109 | # [*, I, J]
110 | mask = x.new_ones(
111 | x.shape[:-1],
112 | )
113 |
114 | if not self.starting:
115 | x = x.transpose(-2, -3)
116 | mask = mask.transpose(-1, -2)
117 |
118 | # [*, I, J, C_in]
119 | x = self.layer_norm(x)
120 |
121 | # [*, I, 1, 1, J]
122 | mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
123 |
124 | # [*, H, I, J]
125 | triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
126 |
127 | # [*, 1, H, I, J]
128 | triangle_bias = triangle_bias.unsqueeze(-4)
129 |
130 | biases = [mask_bias, triangle_bias]
131 |
132 | if chunk_size is not None:
133 | x = self._chunk(
134 | x,
135 | biases,
136 | chunk_size,
137 | use_memory_efficient_kernel=use_memory_efficient_kernel,
138 | use_deepspeed_evo_attention=use_deepspeed_evo_attention,
139 | use_lma=use_lma,
140 | inplace_safe=inplace_safe,
141 | )
142 | else:
143 | x = self.mha(
144 | q_x=x,
145 | kv_x=x,
146 | biases=biases,
147 | use_memory_efficient_kernel=use_memory_efficient_kernel,
148 | use_deepspeed_evo_attention=use_deepspeed_evo_attention,
149 | use_lma=use_lma,
150 | )
151 |
152 | if not self.starting:
153 | x = x.transpose(-2, -3)
154 |
155 | return x
156 |
157 |
158 | # Implements Algorithm 13
159 | TriangleAttentionStartingNode = TriangleAttention
160 |
161 |
162 | class TriangleAttentionEndingNode(TriangleAttention):
163 | """Implement Algorithm 14."""
164 |
165 | __init__ = partialmethod(TriangleAttention.__init__, starting=False)
166 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/triangular_mult.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor, nn
3 |
4 | from boltz.model.layers import initialize as init
5 |
6 |
7 | class TriangleMultiplicationOutgoing(nn.Module):
8 | """TriangleMultiplicationOutgoing."""
9 |
10 | def __init__(self, dim: int = 128) -> None:
11 | """Initialize the TriangularUpdate module.
12 |
13 | Parameters
14 | ----------
15 | dim: int
16 | The dimension of the input, default 128
17 |
18 | """
19 | super().__init__()
20 |
21 | self.norm_in = nn.LayerNorm(dim, eps=1e-5)
22 | self.p_in = nn.Linear(dim, 2 * dim, bias=False)
23 | self.g_in = nn.Linear(dim, 2 * dim, bias=False)
24 |
25 | self.norm_out = nn.LayerNorm(dim)
26 | self.p_out = nn.Linear(dim, dim, bias=False)
27 | self.g_out = nn.Linear(dim, dim, bias=False)
28 |
29 | init.bias_init_one_(self.norm_in.weight)
30 | init.bias_init_zero_(self.norm_in.bias)
31 |
32 | init.lecun_normal_init_(self.p_in.weight)
33 | init.gating_init_(self.g_in.weight)
34 |
35 | init.bias_init_one_(self.norm_out.weight)
36 | init.bias_init_zero_(self.norm_out.bias)
37 |
38 | init.final_init_(self.p_out.weight)
39 | init.gating_init_(self.g_out.weight)
40 |
41 | def forward(self, x: Tensor, mask: Tensor) -> Tensor:
42 | """Perform a forward pass.
43 |
44 | Parameters
45 | ----------
46 | x: torch.Tensor
47 | The input data of shape (B, N, N, D)
48 | mask: torch.Tensor
49 | The input mask of shape (B, N, N)
50 |
51 | Returns
52 | -------
53 | x: torch.Tensor
54 | The output data of shape (B, N, N, D)
55 |
56 | """
57 | # Input gating: D -> D
58 | x = self.norm_in(x)
59 | x_in = x
60 | x = self.p_in(x) * self.g_in(x).sigmoid()
61 |
62 | # Apply mask
63 | x = x * mask.unsqueeze(-1)
64 |
65 | # Split input and cast to float
66 | a, b = torch.chunk(x.float(), 2, dim=-1)
67 |
68 | # Triangular projection
69 | x = torch.einsum("bikd,bjkd->bijd", a, b)
70 |
71 | # Output gating
72 | x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
73 |
74 | return x
75 |
76 |
77 | class TriangleMultiplicationIncoming(nn.Module):
78 | """TriangleMultiplicationIncoming."""
79 |
80 | def __init__(self, dim: int = 128) -> None:
81 | """Initialize the TriangularUpdate module.
82 |
83 | Parameters
84 | ----------
85 | dim: int
86 | The dimension of the input, default 128
87 |
88 | """
89 | super().__init__()
90 |
91 | self.norm_in = nn.LayerNorm(dim, eps=1e-5)
92 | self.p_in = nn.Linear(dim, 2 * dim, bias=False)
93 | self.g_in = nn.Linear(dim, 2 * dim, bias=False)
94 |
95 | self.norm_out = nn.LayerNorm(dim)
96 | self.p_out = nn.Linear(dim, dim, bias=False)
97 | self.g_out = nn.Linear(dim, dim, bias=False)
98 |
99 | init.bias_init_one_(self.norm_in.weight)
100 | init.bias_init_zero_(self.norm_in.bias)
101 |
102 | init.lecun_normal_init_(self.p_in.weight)
103 | init.gating_init_(self.g_in.weight)
104 |
105 | init.bias_init_one_(self.norm_out.weight)
106 | init.bias_init_zero_(self.norm_out.bias)
107 |
108 | init.final_init_(self.p_out.weight)
109 | init.gating_init_(self.g_out.weight)
110 |
111 | def forward(self, x: Tensor, mask: Tensor) -> Tensor:
112 | """Perform a forward pass.
113 |
114 | Parameters
115 | ----------
116 | x: torch.Tensor
117 | The input data of shape (B, N, N, D)
118 | mask: torch.Tensor
119 | The input mask of shape (B, N, N)
120 |
121 | Returns
122 | -------
123 | x: torch.Tensor
124 | The output data of shape (B, N, N, D)
125 |
126 | """
127 | # Input gating: D -> D
128 | x = self.norm_in(x)
129 | x_in = x
130 | x = self.p_in(x) * self.g_in(x).sigmoid()
131 |
132 | # Apply mask
133 | x = x * mask.unsqueeze(-1)
134 |
135 | # Split input and cast to float
136 | a, b = torch.chunk(x.float(), 2, dim=-1)
137 |
138 | # Triangular projection
139 | x = torch.einsum("bkid,bkjd->bijd", a, b)
140 |
141 | # Output gating
142 | x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
143 |
144 | return x
145 |
--------------------------------------------------------------------------------
/src/boltz/model/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/model/loss/__init__.py
--------------------------------------------------------------------------------
/src/boltz/model/loss/diffusion.py:
--------------------------------------------------------------------------------
1 | # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2 |
3 | from einops import einsum
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | def weighted_rigid_align(
9 | true_coords,
10 | pred_coords,
11 | weights,
12 | mask,
13 | ):
14 | """Compute weighted alignment.
15 |
16 | Parameters
17 | ----------
18 | true_coords: torch.Tensor
19 | The ground truth atom coordinates
20 | pred_coords: torch.Tensor
21 | The predicted atom coordinates
22 | weights: torch.Tensor
23 | The weights for alignment
24 | mask: torch.Tensor
25 | The atoms mask
26 |
27 | Returns
28 | -------
29 | torch.Tensor
30 | Aligned coordinates
31 |
32 | """
33 |
34 | batch_size, num_points, dim = true_coords.shape
35 | weights = (mask * weights).unsqueeze(-1)
36 |
37 | # Compute weighted centroids
38 | true_centroid = (true_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
39 | dim=1, keepdim=True
40 | )
41 | pred_centroid = (pred_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
42 | dim=1, keepdim=True
43 | )
44 |
45 | # Center the coordinates
46 | true_coords_centered = true_coords - true_centroid
47 | pred_coords_centered = pred_coords - pred_centroid
48 |
49 | if num_points < (dim + 1):
50 | print(
51 | "Warning: The size of one of the point clouds is <= dim+1. "
52 | + "`WeightedRigidAlign` cannot return a unique rotation."
53 | )
54 |
55 | # Compute the weighted covariance matrix
56 | cov_matrix = einsum(
57 | weights * pred_coords_centered, true_coords_centered, "b n i, b n j -> b i j"
58 | )
59 |
60 | # Compute the SVD of the covariance matrix, required float32 for svd and determinant
61 | original_dtype = cov_matrix.dtype
62 | cov_matrix_32 = cov_matrix.to(dtype=torch.float32)
63 | U, S, V = torch.linalg.svd(
64 | cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None
65 | )
66 | V = V.mH
67 |
68 | # Catch ambiguous rotation by checking the magnitude of singular values
69 | if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)):
70 | print(
71 | "Warning: Excessively low rank of "
72 | + "cross-correlation between aligned point clouds. "
73 | + "`WeightedRigidAlign` cannot return a unique rotation."
74 | )
75 |
76 | # Compute the rotation matrix
77 | rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V).to(dtype=torch.float32)
78 |
79 | # Ensure proper rotation matrix with determinant 1
80 | F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[
81 | None
82 | ].repeat(batch_size, 1, 1)
83 | F[:, -1, -1] = torch.det(rot_matrix)
84 | rot_matrix = einsum(U, F, V, "b i j, b j k, b l k -> b i l")
85 | rot_matrix = rot_matrix.to(dtype=original_dtype)
86 |
87 | # Apply the rotation and translation
88 | aligned_coords = (
89 | einsum(true_coords_centered, rot_matrix, "b n i, b j i -> b n j")
90 | + pred_centroid
91 | )
92 | aligned_coords.detach_()
93 |
94 | return aligned_coords
95 |
96 |
97 | def smooth_lddt_loss(
98 | pred_coords,
99 | true_coords,
100 | is_nucleotide,
101 | coords_mask,
102 | nucleic_acid_cutoff: float = 30.0,
103 | other_cutoff: float = 15.0,
104 | multiplicity: int = 1,
105 | ):
106 | """Compute weighted alignment.
107 |
108 | Parameters
109 | ----------
110 | pred_coords: torch.Tensor
111 | The predicted atom coordinates
112 | true_coords: torch.Tensor
113 | The ground truth atom coordinates
114 | is_nucleotide: torch.Tensor
115 | The weights for alignment
116 | coords_mask: torch.Tensor
117 | The atoms mask
118 | nucleic_acid_cutoff: float
119 | The nucleic acid cutoff
120 | other_cutoff: float
121 | The non nucleic acid cutoff
122 | multiplicity: int
123 | The multiplicity
124 | Returns
125 | -------
126 | torch.Tensor
127 | Aligned coordinates
128 |
129 | """
130 | B, N, _ = true_coords.shape
131 | true_dists = torch.cdist(true_coords, true_coords)
132 | is_nucleotide = is_nucleotide.repeat_interleave(multiplicity, 0)
133 |
134 | coords_mask = coords_mask.repeat_interleave(multiplicity, 0)
135 | is_nucleotide_pair = is_nucleotide.unsqueeze(-1).expand(
136 | -1, -1, is_nucleotide.shape[-1]
137 | )
138 |
139 | mask = (
140 | is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float()
141 | + (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float()
142 | )
143 | mask = mask * (1 - torch.eye(pred_coords.shape[1], device=pred_coords.device))
144 | mask = mask * (coords_mask.unsqueeze(-1) * coords_mask.unsqueeze(-2))
145 |
146 | # Compute distances between all pairs of atoms
147 | pred_dists = torch.cdist(pred_coords, pred_coords)
148 | dist_diff = torch.abs(true_dists - pred_dists)
149 |
150 | # Compute epsilon values
151 | eps = (
152 | (
153 | (
154 | F.sigmoid(0.5 - dist_diff)
155 | + F.sigmoid(1.0 - dist_diff)
156 | + F.sigmoid(2.0 - dist_diff)
157 | + F.sigmoid(4.0 - dist_diff)
158 | )
159 | / 4.0
160 | )
161 | .view(multiplicity, B // multiplicity, N, N)
162 | .mean(dim=0)
163 | )
164 |
165 | # Calculate masked averaging
166 | eps = eps.repeat_interleave(multiplicity, 0)
167 | num = (eps * mask).sum(dim=(-1, -2))
168 | den = mask.sum(dim=(-1, -2)).clamp(min=1)
169 | lddt = num / den
170 |
171 | return 1.0 - lddt.mean()
172 |
--------------------------------------------------------------------------------
/src/boltz/model/loss/distogram.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 |
7 | def distogram_loss(
8 | output: Dict[str, Tensor],
9 | feats: Dict[str, Tensor],
10 | ) -> Tuple[Tensor, Tensor]:
11 | """Compute the distogram loss.
12 |
13 | Parameters
14 | ----------
15 | output : Dict[str, Tensor]
16 | Output of the model
17 | feats : Dict[str, Tensor]
18 | Input features
19 |
20 | Returns
21 | -------
22 | Tensor
23 | The globally averaged loss.
24 | Tensor
25 | Per example loss.
26 |
27 | """
28 | # Get predicted distograms
29 | pred = output["pdistogram"]
30 |
31 | # Compute target distogram
32 | target = feats["disto_target"]
33 |
34 | # Combine target mask and padding mask
35 | mask = feats["token_disto_mask"]
36 | mask = mask[:, None, :] * mask[:, :, None]
37 | mask = mask * (1 - torch.eye(mask.shape[1])[None]).to(pred)
38 |
39 | # Compute the distogram loss
40 | errors = -1 * torch.sum(
41 | target * torch.nn.functional.log_softmax(pred, dim=-1),
42 | dim=-1,
43 | )
44 | denom = 1e-5 + torch.sum(mask, dim=(-1, -2))
45 | mean = errors * mask
46 | mean = torch.sum(mean, dim=-1)
47 | mean = mean / denom[..., None]
48 | batch_loss = torch.sum(mean, dim=-1)
49 | global_loss = torch.mean(batch_loss)
50 | return global_loss, batch_loss
51 |
--------------------------------------------------------------------------------
/src/boltz/model/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/model/modules/__init__.py
--------------------------------------------------------------------------------
/src/boltz/model/modules/confidence_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from boltz.data import const
5 | from boltz.model.loss.confidence import compute_frame_pred
6 |
7 |
8 | def compute_aggregated_metric(logits, end=1.0):
9 | """Compute the metric from the logits.
10 |
11 | Parameters
12 | ----------
13 | logits : torch.Tensor
14 | The logits of the metric
15 | end : float
16 | Max value of the metric, by default 1.0
17 |
18 | Returns
19 | -------
20 | Tensor
21 | The metric value
22 |
23 | """
24 | num_bins = logits.shape[-1]
25 | bin_width = end / num_bins
26 | bounds = torch.arange(
27 | start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
28 | )
29 | probs = nn.functional.softmax(logits, dim=-1)
30 | plddt = torch.sum(
31 | probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
32 | dim=-1,
33 | )
34 | return plddt
35 |
36 |
37 | def tm_function(d, Nres):
38 | """Compute the rescaling function for pTM.
39 |
40 | Parameters
41 | ----------
42 | d : torch.Tensor
43 | The input
44 | Nres : torch.Tensor
45 | The number of residues
46 |
47 | Returns
48 | -------
49 | Tensor
50 | Output of the function
51 |
52 | """
53 | d0 = 1.24 * (torch.clip(Nres, min=19) - 15) ** (1 / 3) - 1.8
54 | return 1 / (1 + (d / d0) ** 2)
55 |
56 |
57 | def compute_ptms(logits, x_preds, feats, multiplicity):
58 | """Compute pTM and ipTM scores.
59 |
60 | Parameters
61 | ----------
62 | logits : torch.Tensor
63 | pae logits
64 | x_preds : torch.Tensor
65 | The predicted coordinates
66 | feats : Dict[str, torch.Tensor]
67 | The input features
68 | multiplicity : int
69 | The batch size of the diffusion roll-out
70 |
71 | Returns
72 | -------
73 | Tensor
74 | pTM score
75 | Tensor
76 | ipTM score
77 | Tensor
78 | ligand ipTM score
79 | Tensor
80 | protein ipTM score
81 |
82 | """
83 | # Compute mask for collinear and overlapping tokens
84 | _, mask_collinear_pred = compute_frame_pred(
85 | x_preds, feats["frames_idx"], feats, multiplicity, inference=True
86 | )
87 | mask_pad = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
88 | maski = mask_collinear_pred.reshape(-1, mask_collinear_pred.shape[-1])
89 | pair_mask_ptm = maski[:, :, None] * mask_pad[:, None, :] * mask_pad[:, :, None]
90 | asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
91 | pair_mask_iptm = (
92 | maski[:, :, None]
93 | * (asym_id[:, None, :] != asym_id[:, :, None])
94 | * mask_pad[:, None, :]
95 | * mask_pad[:, :, None]
96 | )
97 |
98 | # Extract pae values
99 | num_bins = logits.shape[-1]
100 | bin_width = 32.0 / num_bins
101 | end = 32.0
102 | pae_value = torch.arange(
103 | start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
104 | ).unsqueeze(0)
105 | N_res = mask_pad.sum(dim=-1, keepdim=True)
106 |
107 | # compute pTM and ipTM
108 | tm_value = tm_function(pae_value, N_res).unsqueeze(1).unsqueeze(2)
109 | probs = nn.functional.softmax(logits, dim=-1)
110 | tm_expected_value = torch.sum(
111 | probs * tm_value,
112 | dim=-1,
113 | ) # shape (B, N, N)
114 | ptm = torch.max(
115 | torch.sum(tm_expected_value * pair_mask_ptm, dim=-1)
116 | / (torch.sum(pair_mask_ptm, dim=-1) + 1e-5),
117 | dim=1,
118 | ).values
119 | iptm = torch.max(
120 | torch.sum(tm_expected_value * pair_mask_iptm, dim=-1)
121 | / (torch.sum(pair_mask_iptm, dim=-1) + 1e-5),
122 | dim=1,
123 | ).values
124 |
125 | # compute ligand and protein ipTM
126 | token_type = feats["mol_type"]
127 | token_type = token_type.repeat_interleave(multiplicity, 0)
128 | is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
129 | is_protein_token = (token_type == const.chain_type_ids["PROTEIN"]).float()
130 |
131 | ligand_iptm_mask = (
132 | maski[:, :, None]
133 | * (asym_id[:, None, :] != asym_id[:, :, None])
134 | * mask_pad[:, None, :]
135 | * mask_pad[:, :, None]
136 | * (
137 | (is_ligand_token[:, :, None] * is_protein_token[:, None, :])
138 | + (is_protein_token[:, :, None] * is_ligand_token[:, None, :])
139 | )
140 | )
141 | protein_ipmt_mask = (
142 | maski[:, :, None]
143 | * (asym_id[:, None, :] != asym_id[:, :, None])
144 | * mask_pad[:, None, :]
145 | * mask_pad[:, :, None]
146 | * (is_protein_token[:, :, None] * is_protein_token[:, None, :])
147 | )
148 |
149 | ligand_iptm = torch.max(
150 | torch.sum(tm_expected_value * ligand_iptm_mask, dim=-1)
151 | / (torch.sum(ligand_iptm_mask, dim=-1) + 1e-5),
152 | dim=1,
153 | ).values
154 | protein_iptm = torch.max(
155 | torch.sum(tm_expected_value * protein_ipmt_mask, dim=-1)
156 | / (torch.sum(protein_ipmt_mask, dim=-1) + 1e-5),
157 | dim=1,
158 | ).values
159 |
160 | # Compute pair chain ipTM
161 | chain_pair_iptm = {}
162 | asym_ids_list = torch.unique(asym_id).tolist()
163 | for idx1 in asym_ids_list:
164 | chain_iptm = {}
165 | for idx2 in asym_ids_list:
166 | mask_pair_chain = (
167 | maski[:, :, None]
168 | * (asym_id[:, None, :] == idx1)
169 | * (asym_id[:, :, None] == idx2)
170 | * mask_pad[:, None, :]
171 | * mask_pad[:, :, None]
172 | )
173 |
174 | chain_iptm[idx2] = torch.max(
175 | torch.sum(tm_expected_value * mask_pair_chain, dim=-1)
176 | / (torch.sum(mask_pair_chain, dim=-1) + 1e-5),
177 | dim=1,
178 | ).values
179 | chain_pair_iptm[idx1] = chain_iptm
180 |
181 | return ptm, iptm, ligand_iptm, protein_iptm, chain_pair_iptm
182 |
--------------------------------------------------------------------------------
/src/boltz/model/optim/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cddlab/boltz_ext/9d88b09392f04dc2e9cbe05f0e77204b78fbc0c0/src/boltz/model/optim/__init__.py
--------------------------------------------------------------------------------
/src/boltz/model/optim/scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
5 | """Implements the learning rate schedule defined AF3.
6 |
7 | A linear warmup is followed by a plateau at the maximum
8 | learning rate and then exponential decay. Note that the
9 | initial learning rate of the optimizer in question is
10 | ignored; use this class' base_lr parameter to specify
11 | the starting point of the warmup.
12 |
13 | """
14 |
15 | def __init__(
16 | self,
17 | optimizer: torch.optim.Optimizer,
18 | last_epoch: int = -1,
19 | verbose: bool = False,
20 | base_lr: float = 0.0,
21 | max_lr: float = 1.8e-3,
22 | warmup_no_steps: int = 1000,
23 | start_decay_after_n_steps: int = 50000,
24 | decay_every_n_steps: int = 50000,
25 | decay_factor: float = 0.95,
26 | ) -> None:
27 | """Initialize the learning rate scheduler.
28 |
29 | Parameters
30 | ----------
31 | optimizer : torch.optim.Optimizer
32 | The optimizer.
33 | last_epoch : int, optional
34 | The last epoch, by default -1
35 | verbose : bool, optional
36 | Whether to print verbose output, by default False
37 | base_lr : float, optional
38 | The base learning rate, by default 0.0
39 | max_lr : float, optional
40 | The maximum learning rate, by default 1.8e-3
41 | warmup_no_steps : int, optional
42 | The number of warmup steps, by default 1000
43 | start_decay_after_n_steps : int, optional
44 | The number of steps after which to start decay, by default 50000
45 | decay_every_n_steps : int, optional
46 | The number of steps after which to decay, by default 50000
47 | decay_factor : float, optional
48 | The decay factor, by default 0.95
49 |
50 | """
51 | step_counts = {
52 | "warmup_no_steps": warmup_no_steps,
53 | "start_decay_after_n_steps": start_decay_after_n_steps,
54 | }
55 |
56 | for k, v in step_counts.items():
57 | if v < 0:
58 | msg = f"{k} must be nonnegative"
59 | raise ValueError(msg)
60 |
61 | if warmup_no_steps > start_decay_after_n_steps:
62 | msg = "warmup_no_steps must not exceed start_decay_after_n_steps"
63 | raise ValueError(msg)
64 |
65 | self.optimizer = optimizer
66 | self.last_epoch = last_epoch
67 | self.verbose = verbose
68 | self.base_lr = base_lr
69 | self.max_lr = max_lr
70 | self.warmup_no_steps = warmup_no_steps
71 | self.start_decay_after_n_steps = start_decay_after_n_steps
72 | self.decay_every_n_steps = decay_every_n_steps
73 | self.decay_factor = decay_factor
74 |
75 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
76 |
77 | def state_dict(self) -> dict:
78 | state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
79 | return state_dict
80 |
81 | def load_state_dict(self, state_dict):
82 | self.__dict__.update(state_dict)
83 |
84 | def get_lr(self):
85 | if not self._get_lr_called_within_step:
86 | msg = (
87 | "To get the last learning rate computed by the scheduler, use "
88 | "get_last_lr()"
89 | )
90 | raise RuntimeError(msg)
91 |
92 | step_no = self.last_epoch
93 |
94 | if step_no <= self.warmup_no_steps:
95 | lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
96 | elif step_no > self.start_decay_after_n_steps:
97 | steps_since_decay = step_no - self.start_decay_after_n_steps
98 | exp = (steps_since_decay // self.decay_every_n_steps) + 1
99 | lr = self.max_lr * (self.decay_factor**exp)
100 | else: # plateau
101 | lr = self.max_lr
102 |
103 | return [lr for group in self.optimizer.param_groups]
104 |
--------------------------------------------------------------------------------
/tests/model/layers/test_outer_product_mean.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning
2 | import torch
3 | import torch.nn as nn
4 |
5 | import unittest
6 |
7 | from boltz.model.layers.outer_product_mean import OuterProductMean
8 |
9 |
10 | class OuterProductMeanTest(unittest.TestCase):
11 | def setUp(self):
12 | self.c_in = 32
13 | self.c_hidden = 16
14 | self.c_out = 64
15 |
16 | torch.set_grad_enabled(False)
17 | pytorch_lightning.seed_everything(1100)
18 | self.layer = OuterProductMean(self.c_in, self.c_hidden, self.c_out)
19 |
20 | # Initialize layer
21 | for name, param in self.layer.named_parameters():
22 | nn.init.normal_(param, mean=1., std=1.)
23 |
24 |
25 | def test_chunk(self):
26 | chunk_sizes = [16, 33, 64, 83, 100]
27 | B, S, N = 1, 49, 84
28 | m = torch.randn(size=(B, S, N, self.c_in))
29 | mask = torch.randint(low=0, high=1, size=(B, S, N))
30 |
31 | with torch.no_grad():
32 | exp_output = self.layer(m=m, mask=mask)
33 | for chunk_size in chunk_sizes:
34 | with self.subTest(chunk_size=chunk_size):
35 | act_output = self.layer(m=m, mask=mask, chunk_size=chunk_size)
36 | assert torch.allclose(exp_output, act_output, atol=1e-8)
--------------------------------------------------------------------------------
/tests/model/layers/test_triangle_attention.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning
2 | import torch
3 | import torch.nn as nn
4 |
5 | import unittest
6 |
7 | from boltz.model.layers.triangular_attention.attention import TriangleAttention
8 |
9 |
10 | class OuterProductMeanTest(unittest.TestCase):
11 | def setUp(self):
12 | self.c_in = 128
13 | self.c_hidden = 32
14 | self.no_heads = 1
15 |
16 | torch.set_grad_enabled(False)
17 | pytorch_lightning.seed_everything(1100)
18 | self.layer = TriangleAttention(self.c_in, self.c_hidden, self.no_heads)
19 |
20 | # Initialize layer
21 | for name, param in self.layer.named_parameters():
22 | nn.init.normal_(param, mean=1., std=1.)
23 |
24 |
25 | def test_chunk(self):
26 | chunk_sizes = [16, 33, 64, 100]
27 | B, N = 1, 99
28 | m = torch.randn(size=(B, N, N, self.c_in))
29 | mask = torch.randint(low=0, high=1, size=(B, N, N))
30 |
31 | with torch.no_grad():
32 | exp_output = self.layer(x=m, mask=mask)
33 | for chunk_size in chunk_sizes:
34 | with self.subTest(chunk_size=chunk_size):
35 | act_output = self.layer(x=m, mask=mask, chunk_size=chunk_size)
36 | assert torch.allclose(exp_output, act_output, atol=1e-8)
--------------------------------------------------------------------------------
/tests/test_regression.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from dataclasses import asdict
4 | import pprint
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | import pytest
10 | import unittest
11 |
12 | from lightning_fabric import seed_everything
13 |
14 | from boltz.main import MODEL_URL
15 | from boltz.model.model import Boltz1
16 |
17 | import test_utils
18 |
19 | tests_dir = os.path.dirname(os.path.abspath(__file__))
20 | test_data_dir = os.path.join(tests_dir, 'data')
21 |
22 | @pytest.mark.regression
23 | class RegressionTester(unittest.TestCase):
24 |
25 | @classmethod
26 | def setUpClass(cls):
27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28 | cache = os.path.expanduser('~/.boltz')
29 | checkpoint_url = MODEL_URL
30 | model_name = checkpoint_url.split("/")[-1]
31 | checkpoint = os.path.join(cache, model_name)
32 | if not os.path.exists(checkpoint):
33 | test_utils.download_file(checkpoint_url, checkpoint)
34 |
35 | regression_feats_path = os.path.join(test_data_dir, 'ligand_regression_feats.pkl')
36 | if not os.path.exists(regression_feats_path):
37 | regression_feats_url = "https://www.dropbox.com/scl/fi/1avbcvoor5jcnvpt07tp6/ligand_regression_feats.pkl?rlkey=iwtm9gpxgrbp51jbizq937pqf&st=jnbky253&dl=1"
38 | test_utils.download_file(regression_feats_url, regression_feats_path)
39 |
40 | regression_feats = torch.load(regression_feats_path, map_location=device)
41 | model_module: nn.Module = Boltz1.load_from_checkpoint(checkpoint, map_location=device)
42 | model_module.to(device)
43 | model_module.eval()
44 |
45 | coords = regression_feats["feats"]["coords"]
46 | # Coords should be rank 4
47 | if len(coords.shape) == 3:
48 | coords = coords.unsqueeze(0)
49 | regression_feats["feats"]["coords"] = coords
50 | for key, val in regression_feats["feats"].items():
51 | if hasattr(val, "to"):
52 | regression_feats["feats"][key] = val.to(device)
53 |
54 | cls.model_module = model_module.to(device)
55 | cls.regression_feats = regression_feats
56 |
57 | def test_input_embedder(self):
58 | exp_s_inputs = self.regression_feats["s_inputs"]
59 | act_s_inputs = self.model_module.input_embedder(self.regression_feats["feats"])
60 |
61 | assert torch.allclose(exp_s_inputs, act_s_inputs, atol=1e-5)
62 |
63 | def test_rel_pos(self):
64 | exp_rel_pos_encoding = self.regression_feats["relative_position_encoding"]
65 | act_rel_pos_encoding = self.model_module.rel_pos(self.regression_feats["feats"])
66 |
67 | assert torch.allclose(exp_rel_pos_encoding, act_rel_pos_encoding, atol=1e-5)
68 |
69 | @pytest.mark.slow
70 | def test_structure_output(self):
71 | exp_structure_output = self.regression_feats["structure_output"]
72 | s = self.regression_feats["s"]
73 | z = self.regression_feats["z"]
74 | s_inputs = self.regression_feats["s_inputs"]
75 | feats = self.regression_feats["feats"]
76 | relative_position_encoding = self.regression_feats["relative_position_encoding"]
77 | multiplicity_diffusion_train = self.regression_feats["multiplicity_diffusion_train"]
78 |
79 | self.model_module.structure_module.coordinate_augmentation = False
80 | self.model_module.structure_module.sigma_data = 0.0
81 |
82 | seed_everything(self.regression_feats["seed"])
83 | act_structure_output = self.model_module.structure_module(
84 | s_trunk=s,
85 | z_trunk=z,
86 | s_inputs=s_inputs,
87 | feats=feats,
88 | relative_position_encoding=relative_position_encoding,
89 | multiplicity=multiplicity_diffusion_train,
90 | )
91 |
92 | act_keys = act_structure_output.keys()
93 | exp_keys = exp_structure_output.keys()
94 | assert act_keys == exp_keys
95 |
96 | # Other keys have some randomness, so we will only check the keys that
97 | # we can make deterministic with sigma_data = 0.0 (above).
98 | check_keys = ["noised_atom_coords", "aligned_true_atom_coords"]
99 | for key in check_keys:
100 | exp_val = exp_structure_output[key]
101 | act_val = act_structure_output[key]
102 | assert exp_val.shape == act_val.shape, f"Shape mismatch in {key}"
103 | assert torch.allclose(exp_val, act_val, atol=1e-4)
104 |
105 |
106 | if __name__ == '__main__':
107 | unittest.main()
108 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import requests
4 |
5 | def download_file(url, filepath, verbose=True):
6 | if verbose:
7 | print(f"Downloading {url} to {filepath}")
8 | response = requests.get(url)
9 |
10 | target_dir = os.path.dirname(filepath)
11 | if target_dir and not os.path.exists(target_dir):
12 | os.makedirs(target_dir)
13 |
14 | # Check if the request was successful
15 | if response.status_code == 200:
16 | with open(filepath, 'wb') as file:
17 | file.write(response.content)
18 | else:
19 | print(f"Failed to download file. Status code: {response.status_code}")
20 |
21 | return filepath
22 |
--------------------------------------------------------------------------------