├── src
└── boltz
│ ├── data
│ ├── __init__.py
│ ├── crop
│ │ ├── __init__.py
│ │ ├── cropper.py
│ │ └── boltz.py
│ ├── filter
│ │ ├── __init__.py
│ │ ├── dynamic
│ │ │ ├── __init__.py
│ │ │ ├── filter.py
│ │ │ ├── resolution.py
│ │ │ ├── max_residues.py
│ │ │ ├── size.py
│ │ │ ├── subset.py
│ │ │ └── date.py
│ │ └── static
│ │ │ ├── __init__.py
│ │ │ ├── filter.py
│ │ │ ├── ligand.py
│ │ │ └── polymer.py
│ ├── module
│ │ ├── __init__.py
│ │ └── inference.py
│ ├── parse
│ │ ├── __init__.py
│ │ ├── yaml.py
│ │ ├── a3m.py
│ │ └── fasta.py
│ ├── sample
│ │ ├── __init__.py
│ │ ├── random.py
│ │ ├── sampler.py
│ │ ├── distillation.py
│ │ └── cluster.py
│ ├── write
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── pdb.py
│ │ ├── writer.py
│ │ └── mmcif.py
│ ├── feature
│ │ ├── __init__.py
│ │ └── pad.py
│ ├── tokenize
│ │ ├── __init__.py
│ │ ├── tokenizer.py
│ │ └── boltz.py
│ └── const.py
│ ├── model
│ ├── __init__.py
│ ├── loss
│ │ ├── __init__.py
│ │ ├── distogram.py
│ │ └── diffusion.py
│ ├── optim
│ │ ├── __init__.py
│ │ └── scheduler.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── triangular_attention
│ │ │ ├── __init__.py
│ │ │ └── attention.py
│ │ ├── dropout.py
│ │ ├── transition.py
│ │ ├── initialize.py
│ │ ├── outer_product_mean.py
│ │ ├── attention.py
│ │ ├── triangular_mult.py
│ │ └── pair_averaging.py
│ └── modules
│ │ ├── __init__.py
│ │ ├── confidence_utils.py
│ │ └── transformers.py
│ └── __init__.py
├── docs
├── boltz1_pred_figure.png
├── training.md
└── prediction.md
├── examples
├── ligand.yaml
└── ligand.fasta
├── scripts
└── train
│ ├── assets
│ ├── casp15_ids.txt
│ ├── test_ids.txt
│ └── validation_ids.txt
│ ├── configs
│ ├── structure.yaml
│ └── confidence.yaml
│ └── train.py
├── LICENSE
├── README.md
├── pyproject.toml
└── .gitignore
/src/boltz/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/data/crop/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/data/module/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/data/parse/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/data/sample/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/data/write/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/model/loss/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/model/optim/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/data/feature/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/data/tokenize/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/model/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/static/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/triangular_attention/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/boltz1_pred_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vaibhavs10/boltz/main/docs/boltz1_pred_figure.png
--------------------------------------------------------------------------------
/src/boltz/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import PackageNotFoundError, version
2 |
3 | try: # noqa: SIM105
4 | __version__ = version("boltz")
5 | except PackageNotFoundError:
6 | # package is not installed
7 | pass
8 |
--------------------------------------------------------------------------------
/src/boltz/data/tokenize/tokenizer.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from boltz.data.types import Input, Tokenized
4 |
5 |
6 | class Tokenizer(ABC):
7 | """Tokenize an input structure for training."""
8 |
9 | @abstractmethod
10 | def tokenize(self, data: Input) -> Tokenized:
11 | """Tokenize the input data.
12 |
13 | Parameters
14 | ----------
15 | data : Inpput
16 | The input data.
17 |
18 | Returns
19 | -------
20 | Tokenized
21 | The tokenized data.
22 |
23 | """
24 | raise NotImplementedError
25 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/filter.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from boltz.data.types import Record
4 |
5 |
6 | class DynamicFilter(ABC):
7 | """Base class for data filters."""
8 |
9 | @abstractmethod
10 | def filter(self, record: Record) -> bool:
11 | """Filter a data record.
12 |
13 | Parameters
14 | ----------
15 | record : Record
16 | The object to consider filtering in / out.
17 |
18 | Returns
19 | -------
20 | bool
21 | True if the data passes the filter, False otherwise.
22 |
23 | """
24 | raise NotImplementedError
25 |
--------------------------------------------------------------------------------
/src/boltz/data/write/utils.py:
--------------------------------------------------------------------------------
1 | import string
2 | from collections.abc import Iterator
3 |
4 |
5 | def generate_tags() -> Iterator[str]:
6 | """Generate chain tags.
7 |
8 | Yields
9 | ------
10 | str
11 | The next chain tag
12 |
13 | """
14 | for i in range(1, 4):
15 | for j in range(len(string.ascii_uppercase) ** i):
16 | tag = ""
17 | for k in range(i):
18 | tag += string.ascii_uppercase[
19 | j
20 | // (len(string.ascii_uppercase) ** k)
21 | % len(string.ascii_uppercase)
22 | ]
23 | yield tag
24 |
--------------------------------------------------------------------------------
/examples/ligand.yaml:
--------------------------------------------------------------------------------
1 | version: 1 # Optional, defaults to 1
2 | sequences:
3 | - protein:
4 | id: [A, B]
5 | sequence: MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
6 | msa: ./examples/msa/seq1.a3m
7 | - ligand:
8 | id: [C, D]
9 | ccd: SAH
10 | - ligand:
11 | id: [E, F]
12 | smiles: N[C@@H](Cc1ccc(O)cc1)C(=O)O
13 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/static/filter.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 |
5 | from boltz.data.types import Structure
6 |
7 |
8 | class StaticFilter(ABC):
9 | """Base class for structure filters."""
10 |
11 | @abstractmethod
12 | def filter(self, structure: Structure) -> np.ndarray:
13 | """Filter chains in a structure.
14 |
15 | Parameters
16 | ----------
17 | structure : Structure
18 | The structure to filter chains from.
19 |
20 | Returns
21 | -------
22 | np.ndarray
23 | The chains to keep, as a boolean mask.
24 |
25 | """
26 | raise NotImplementedError
27 |
--------------------------------------------------------------------------------
/scripts/train/assets/casp15_ids.txt:
--------------------------------------------------------------------------------
1 | T1112
2 | T1118v1
3 | T1154
4 | T1137s1
5 | T1188
6 | T1157s1
7 | T1137s6
8 | R1117
9 | H1106
10 | T1106s2
11 | R1149
12 | T1158
13 | T1137s2
14 | T1145
15 | T1121
16 | T1123
17 | T1113
18 | R1156
19 | T1114s1
20 | T1183
21 | R1107
22 | T1137s7
23 | T1124
24 | T1178
25 | T1147
26 | R1128
27 | T1161
28 | R1108
29 | T1194
30 | T1185s2
31 | T1176
32 | T1158v3
33 | T1137s4
34 | T1160
35 | T1120
36 | H1185
37 | T1134s1
38 | T1119
39 | H1151
40 | T1137s8
41 | T1133
42 | T1187
43 | H1157
44 | T1122
45 | T1104
46 | T1158v2
47 | T1137s5
48 | T1129s2
49 | T1174
50 | T1157s2
51 | T1155
52 | T1158v4
53 | T1152
54 | T1137s9
55 | T1134s2
56 | T1125
57 | R1116
58 | H1134
59 | R1136
60 | T1159
61 | T1137s3
62 | T1185s1
63 | T1179
64 | T1106s1
65 | T1132
66 | T1185s4
67 | T1114s3
68 | T1114s2
69 | T1151s2
70 | T1158v1
71 | R1117v2
72 | T1173
73 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/dropout.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 |
5 | def get_dropout_mask(
6 | dropout: float,
7 | z: Tensor,
8 | training: bool,
9 | columnwise: bool = False,
10 | ) -> Tensor:
11 | """Get the dropout mask.
12 |
13 | Parameters
14 | ----------
15 | dropout : float
16 | The dropout rate
17 | z : torch.Tensor
18 | The tensor to apply dropout to
19 | training : bool
20 | Whether the model is in training mode
21 | columnwise : bool, optional
22 | Whether to apply dropout columnwise
23 |
24 | Returns
25 | -------
26 | torch.Tensor
27 | The dropout mask
28 |
29 | """
30 | dropout = dropout * training
31 | v = z[:, 0:1, :, 0:1] if columnwise else z[:, :, 0:1, 0:1]
32 | d = torch.rand_like(v) > dropout
33 | d = d * 1.0 / (1.0 - dropout)
34 | return d
35 |
--------------------------------------------------------------------------------
/examples/ligand.fasta:
--------------------------------------------------------------------------------
1 | >A|protein|./examples/msa/seq1.a3m
2 | MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
3 | >B|protein|./examples/msa/seq1.a3m
4 | MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
5 | >C|ccd
6 | SAH
7 | >D|ccd
8 | SAH
9 | >E|smiles
10 | N[C@@H](Cc1ccc(O)cc1)C(=O)O
11 | >F|smiles
12 | N[C@@H](Cc1ccc(O)cc1)C(=O)O
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/resolution.py:
--------------------------------------------------------------------------------
1 | from boltz.data.types import Record
2 | from boltz.data.filter.dynamic.filter import DynamicFilter
3 |
4 |
5 | class ResolutionFilter(DynamicFilter):
6 | """A filter that filters complexes based on their resolution."""
7 |
8 | def __init__(self, resolution: float = 9.0) -> None:
9 | """Initialize the filter.
10 |
11 | Parameters
12 | ----------
13 | resolution : float, optional
14 | The maximum allowed resolution.
15 |
16 | """
17 | self.resolution = resolution
18 |
19 | def filter(self, record: Record) -> bool:
20 | """Filter complexes based on their resolution.
21 |
22 | Parameters
23 | ----------
24 | record : Record
25 | The record to filter.
26 |
27 | Returns
28 | -------
29 | bool
30 | Whether the record should be filtered.
31 |
32 | """
33 | structure = record.structure
34 | return structure.resolution <= self.resolution
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/max_residues.py:
--------------------------------------------------------------------------------
1 | from boltz.data.types import Record
2 | from boltz.data.filter.dynamic.filter import DynamicFilter
3 |
4 |
5 | class MaxResiduesFilter(DynamicFilter):
6 | """A filter that filters structures based on their size."""
7 |
8 | def __init__(self, min_residues: int = 1, max_residues: int = 500) -> None:
9 | """Initialize the filter.
10 |
11 | Parameters
12 | ----------
13 | min_chains : int
14 | The minimum number of chains allowed.
15 | max_chains : int
16 | The maximum number of chains allowed.
17 |
18 | """
19 | self.min_residues = min_residues
20 | self.max_residues = max_residues
21 |
22 | def filter(self, record: Record) -> bool:
23 | """Filter structures based on their resolution.
24 |
25 | Parameters
26 | ----------
27 | record : Record
28 | The record to filter.
29 |
30 | Returns
31 | -------
32 | bool
33 | Whether the record should be filtered.
34 |
35 | """
36 | num_residues = sum(chain.num_residues for chain in record.chains)
37 | return num_residues <= self.max_residues and num_residues >= self.min_residues
38 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/size.py:
--------------------------------------------------------------------------------
1 | from boltz.data.types import Record
2 | from boltz.data.filter.dynamic.filter import DynamicFilter
3 |
4 |
5 | class SizeFilter(DynamicFilter):
6 | """A filter that filters structures based on their size."""
7 |
8 | def __init__(self, min_chains: int = 1, max_chains: int = 300) -> None:
9 | """Initialize the filter.
10 |
11 | Parameters
12 | ----------
13 | min_chains : int
14 | The minimum number of chains allowed.
15 | max_chains : int
16 | The maximum number of chains allowed.
17 |
18 | """
19 | self.min_chains = min_chains
20 | self.max_chains = max_chains
21 |
22 | def filter(self, record: Record) -> bool:
23 | """Filter structures based on their resolution.
24 |
25 | Parameters
26 | ----------
27 | record : Record
28 | The record to filter.
29 |
30 | Returns
31 | -------
32 | bool
33 | Whether the record should be filtered.
34 |
35 | """
36 | num_chains = record.structure.num_chains
37 | num_valid = sum(1 for chain in record.chains if chain.valid)
38 | return num_chains <= self.max_chains and num_valid >= self.min_chains
39 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/subset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from boltz.data.types import Record
4 | from boltz.data.filter.dynamic.filter import DynamicFilter
5 |
6 |
7 | class SubsetFilter(DynamicFilter):
8 | """Filter a data record based on a subset of the data."""
9 |
10 | def __init__(self, subset: str, reverse: bool = False) -> None:
11 | """Initialize the filter.
12 |
13 | Parameters
14 | ----------
15 | subset : str
16 | The subset of data to consider, one per line.
17 |
18 | """
19 | with Path(subset).open("r") as f:
20 | subset = f.read().splitlines()
21 |
22 | self.subset = {s.lower() for s in subset}
23 | self.reverse = reverse
24 |
25 | def filter(self, record: Record) -> bool:
26 | """Filter a data record.
27 |
28 | Parameters
29 | ----------
30 | record : Record
31 | The object to consider filtering in / out.
32 |
33 | Returns
34 | -------
35 | bool
36 | True if the data passes the filter, False otherwise.
37 |
38 | """
39 | if self.reverse:
40 | return record.id.lower() not in self.subset
41 | else: # noqa: RET505
42 | return record.id.lower() in self.subset
43 |
--------------------------------------------------------------------------------
/src/boltz/data/sample/random.py:
--------------------------------------------------------------------------------
1 | from dataclasses import replace
2 | from typing import Iterator, List
3 |
4 | from numpy.random import RandomState
5 |
6 | from boltz.data.types import Record
7 | from boltz.data.sample.sampler import Sample, Sampler
8 |
9 |
10 | class RandomSampler(Sampler):
11 | """A simple random sampler with replacement."""
12 |
13 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
14 | """Sample a structure from the dataset infinitely.
15 |
16 | Parameters
17 | ----------
18 | records : List[Record]
19 | The records to sample from.
20 | random : RandomState
21 | The random state for reproducibility.
22 |
23 | Yields
24 | ------
25 | Sample
26 | A data sample.
27 |
28 | """
29 | while True:
30 | # Sample item from the list
31 | index = random.randint(0, len(records))
32 | record = records[index]
33 |
34 | # Remove invalid chains and interfaces
35 | chains = [c for c in record.chains if c.valid]
36 | interfaces = [i for i in record.interfaces if i.valid]
37 | record = replace(record, chains=chains, interfaces=interfaces)
38 |
39 | yield Sample(record=record)
40 |
--------------------------------------------------------------------------------
/src/boltz/data/sample/sampler.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from dataclasses import dataclass
3 | from typing import Iterator, List, Optional
4 |
5 | from numpy.random import RandomState
6 |
7 | from boltz.data.types import Record
8 |
9 |
10 | @dataclass
11 | class Sample:
12 | """A sample with optional chain and interface IDs.
13 |
14 | Attributes
15 | ----------
16 | record : Record
17 | The record.
18 | chain_id : Optional[int]
19 | The chain ID.
20 | interface_id : Optional[int]
21 | The interface ID.
22 | """
23 |
24 | record: Record
25 | chain_id: Optional[int] = None
26 | interface_id: Optional[int] = None
27 |
28 |
29 | class Sampler(ABC):
30 | """Abstract base class for samplers."""
31 |
32 | @abstractmethod
33 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
34 | """Sample a structure from the dataset infinitely.
35 |
36 | Parameters
37 | ----------
38 | records : List[Record]
39 | The records to sample from.
40 | random : RandomState
41 | The random state for reproducibility.
42 |
43 | Yields
44 | ------
45 | Sample
46 | A data sample.
47 |
48 | """
49 | raise NotImplementedError
50 |
--------------------------------------------------------------------------------
/src/boltz/data/crop/cropper.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Optional
3 |
4 | import numpy as np
5 |
6 | from boltz.data.types import Tokenized
7 |
8 |
9 | class Cropper(ABC):
10 | """Abstract base class for cropper."""
11 |
12 | @abstractmethod
13 | def crop(
14 | self,
15 | data: Tokenized,
16 | max_tokens: int,
17 | random: np.random.RandomState,
18 | max_atoms: Optional[int] = None,
19 | chain_id: Optional[int] = None,
20 | interface_id: Optional[int] = None,
21 | ) -> Tokenized:
22 | """Crop the data to a maximum number of tokens.
23 |
24 | Parameters
25 | ----------
26 | data : Tokenized
27 | The tokenized data.
28 | max_tokens : int
29 | The maximum number of tokens to crop.
30 | random : np.random.RandomState
31 | The random state for reproducibility.
32 | max_atoms : Optional[int]
33 | The maximum number of atoms to consider.
34 | chain_id : Optional[int]
35 | The chain ID to crop.
36 | interface_id : Optional[int]
37 | The interface ID to crop.
38 |
39 | Returns
40 | -------
41 | Tokenized
42 | The cropped data.
43 |
44 | """
45 | raise NotImplementedError
46 |
--------------------------------------------------------------------------------
/src/boltz/model/loss/distogram.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 |
7 | def distogram_loss(
8 | output: Dict[str, Tensor],
9 | feats: Dict[str, Tensor],
10 | ) -> Tuple[Tensor, Tensor]:
11 | """Compute the distogram loss.
12 |
13 | Parameters
14 | ----------
15 | output : Dict[str, Tensor]
16 | Output of the model
17 | feats : Dict[str, Tensor]
18 | Input features
19 |
20 | Returns
21 | -------
22 | Tensor
23 | The globally averaged loss.
24 | Tensor
25 | Per example loss.
26 |
27 | """
28 | # Get predicted distograms
29 | pred = output["pdistogram"]
30 |
31 | # Compute target distogram
32 | target = feats["disto_target"]
33 |
34 | # Combine target mask and padding mask
35 | mask = feats["token_disto_mask"]
36 | mask = mask[:, None, :] * mask[:, :, None]
37 | mask = mask * (1 - torch.eye(mask.shape[1])[None]).to(pred)
38 |
39 | # Compute the distogram loss
40 | errors = -1 * torch.sum(
41 | target * torch.nn.functional.log_softmax(pred, dim=-1),
42 | dim=-1,
43 | )
44 | denom = 1e-5 + torch.sum(mask, dim=(-1, -2))
45 | mean = errors * mask
46 | mean = torch.sum(mean, dim=-1)
47 | mean = mean / denom[..., None]
48 | batch_loss = torch.sum(mean, dim=-1)
49 | global_loss = torch.mean(batch_loss)
50 | return global_loss, batch_loss
51 |
--------------------------------------------------------------------------------
/src/boltz/data/parse/yaml.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import yaml
4 | from rdkit.Chem.rdchem import Mol
5 |
6 | from boltz.data.parse.schema import parse_boltz_schema
7 | from boltz.data.types import Target
8 |
9 |
10 | def parse_yaml(path: Path, ccd: dict[str, Mol]) -> Target:
11 | """Parse a Boltz input yaml / json.
12 |
13 | The input file should be a yaml file with the following format:
14 |
15 | sequences:
16 | - protein:
17 | id: A
18 | sequence: "MADQLTEEQIAEFKEAFSLF"
19 | - protein:
20 | id: [B, C]
21 | sequence: "AKLSILPWGHC"
22 | - rna:
23 | id: D
24 | sequence: "GCAUAGC"
25 | - ligand:
26 | id: E
27 | smiles: "CC1=CC=CC=C1"
28 | - ligand:
29 | id: [F, G]
30 | ccd: []
31 | constraints:
32 | - bond:
33 | atom1: [A, 1, CA]
34 | atom2: [A, 2, N]
35 | - pocket:
36 | binder: E
37 | contacts: [[B, 1], [B, 2]]
38 | version: 1
39 |
40 | Parameters
41 | ----------
42 | path : Path
43 | Path to the YAML input format.
44 | components : Dict
45 | Dictionary of CCD components.
46 |
47 | Returns
48 | -------
49 | Target
50 | The parsed target.
51 |
52 | """
53 | with path.open("r") as file:
54 | data = yaml.safe_load(file)
55 |
56 | name = path.stem
57 | return parse_boltz_schema(name, data, ccd)
58 |
--------------------------------------------------------------------------------
/src/boltz/data/sample/distillation.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator, List
2 |
3 | from numpy.random import RandomState
4 |
5 | from boltz.data.types import Record
6 | from boltz.data.sample.sampler import Sample, Sampler
7 |
8 |
9 | class DistillationSampler(Sampler):
10 | """A sampler for monomer distillation data."""
11 |
12 | def __init__(self, small_size: int = 200, small_prob: float = 0.01) -> None:
13 | """Initialize the sampler.
14 |
15 | Parameters
16 | ----------
17 | small_size : int, optional
18 | The maximum size to be considered small.
19 | small_prob : float, optional
20 | The probability of sampling a small item.
21 |
22 | """
23 | self._size = small_size
24 | self._prob = small_prob
25 |
26 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
27 | """Sample a structure from the dataset infinitely.
28 |
29 | Parameters
30 | ----------
31 | records : List[Record]
32 | The records to sample from.
33 | random : RandomState
34 | The random state for reproducibility.
35 |
36 | Yields
37 | ------
38 | Sample
39 | A data sample.
40 |
41 | """
42 | # Remove records with invalid chains
43 | records = [r for r in records if r.chains[0].valid]
44 |
45 | # Split in small and large proteins. We assume that there is only
46 | # one chain per record, as is the case for monomer distillation
47 | small = [r for r in records if r.chains[0].num_residues <= self._size]
48 | large = [r for r in records if r.chains[0].num_residues > self._size]
49 |
50 | # Sample infinitely
51 | while True:
52 | # Sample small or large
53 | samples = small if random.rand() < self._prob else large
54 |
55 | # Sample item from the list
56 | index = random.randint(0, len(samples))
57 | yield Sample(record=samples[index])
58 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Boltz-1:
2 |
3 | Democratizing Biomolecular Interaction Modeling
4 |
5 |
6 | 
7 |
8 | Boltz-1 is an open-source model which predicts the 3D structure of proteins, rna, dna and small molecules; it handles modified residues, covalent ligands and glycans, as well as condition the generation on pocket residues.
9 |
10 | For more information about the model, see our [technical report](https://gcorso.github.io/assets/boltz1.pdf).
11 |
12 | ## Installation
13 | Install boltz with PyPI (recommended):
14 |
15 | ```
16 | pip install boltz
17 | ```
18 |
19 | or directly from GitHub for daily updates:
20 |
21 | ```
22 | git clone https://github.com/jwohlwend/boltz.git
23 | cd boltz; pip install -e .
24 | ```
25 | > Note: we recommend installing boltz in a fresh python environment
26 |
27 | ## Inference
28 |
29 | You can run inference using Boltz-1 with:
30 |
31 | ```
32 | boltz predict input_path
33 | ```
34 |
35 | Boltz currently accepts three input formats:
36 |
37 | 1. Fasta file, for most use cases
38 |
39 | 2. A comprehensive YAML schema, for more complex use cases
40 |
41 | 3. A directory containing files of the above formats, for batched processing
42 |
43 | To see all available options: `boltz predict --help` and for more informaton on these input formats, see our [prediction instructions](docs/prediction.md).
44 |
45 | ## Training
46 |
47 | If you're interested in retraining the model, see our [training instructions](docs/training.md).
48 |
49 | ## Contributing
50 |
51 | We welcome external contributions and are eager to engage with the community. Connect with us on our [Slack channel](https://boltz-community.slack.com/archives/C0818M6DWH2) to discuss advancements, share insights, and foster collaboration around Boltz-1.
52 |
53 | ## Coming very soon
54 |
55 | - [ ] Pocket conditioning support
56 | - [ ] More examples
57 | - [ ] Full data processing pipeline
58 | - [ ] Colab notebook for inference
59 | - [ ] Confidence model checkpoint
60 | - [ ] Support for custom paired MSA
61 | - [ ] Kernel integration
62 |
63 | ## License
64 |
65 | Our model and code are released under MIT License, and can be freely used for both academic and commercial purposes.
66 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/dynamic/date.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import Literal
3 |
4 | from boltz.data.types import Record
5 | from boltz.data.filter.dynamic.filter import DynamicFilter
6 |
7 |
8 | class DateFilter(DynamicFilter):
9 | """A filter that filters complexes based on their date.
10 |
11 | The date can be the deposition, release, or revision date.
12 | If the date is not available, the previous date is used.
13 |
14 | If no date is available, the complex is rejected.
15 |
16 | """
17 |
18 | def __init__(
19 | self,
20 | date: str,
21 | ref: Literal["deposited", "revised", "released"],
22 | ) -> None:
23 | """Initialize the filter.
24 |
25 | Parameters
26 | ----------
27 | date : str, optional
28 | The maximum date of PDB entries to filter
29 | ref : Literal["deposited", "revised", "released"]
30 | The reference date to use.
31 |
32 | """
33 | self.filter_date = datetime.fromisoformat(date)
34 | self.ref = ref
35 |
36 | if ref not in ["deposited", "revised", "released"]:
37 | msg = (
38 | "Invalid reference date. Must be ",
39 | "deposited, revised, or released",
40 | )
41 | raise ValueError(msg)
42 |
43 | def filter(self, record: Record) -> bool:
44 | """Filter a record based on its date.
45 |
46 | Parameters
47 | ----------
48 | record : Record
49 | The record to filter.
50 |
51 | Returns
52 | -------
53 | bool
54 | Whether the record should be filtered.
55 |
56 | """
57 | structure = record.structure
58 |
59 | if self.ref == "deposited":
60 | date = structure.deposited
61 | elif self.ref == "released":
62 | date = structure.released
63 | if not date:
64 | date = structure.deposited
65 | elif self.ref == "revised":
66 | date = structure.revised
67 | if not date and structure.released:
68 | date = structure.released
69 | elif not date:
70 | date = structure.deposited
71 |
72 | if date is None or date == "":
73 | return False
74 |
75 | date = datetime.fromisoformat(date)
76 | return date <= self.filter_date
77 |
--------------------------------------------------------------------------------
/src/boltz/data/feature/pad.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn.functional import pad
4 |
5 |
6 | def pad_dim(data: Tensor, dim: int, pad_len: float, value: float = 0) -> Tensor:
7 | """Pad a tensor along a given dimension.
8 |
9 | Parameters
10 | ----------
11 | data : Tensor
12 | The input tensor.
13 | dim : int
14 | The dimension to pad.
15 | pad_len : float
16 | The padding length.
17 | value : int, optional
18 | The value to pad with.
19 |
20 | Returns
21 | -------
22 | Tensor
23 | The padded tensor.
24 |
25 | """
26 | if pad_len == 0:
27 | return data
28 |
29 | total_dims = len(data.shape)
30 | padding = [0] * (2 * (total_dims - dim))
31 | padding[2 * (total_dims - 1 - dim) + 1] = pad_len
32 | return pad(data, tuple(padding), value=value)
33 |
34 |
35 | def pad_to_max(data: list[Tensor], value: float = 0) -> tuple[Tensor, Tensor]:
36 | """Pad the data in all dimensions to the maximum found.
37 |
38 | Parameters
39 | ----------
40 | data : List[Tensor]
41 | List of tensors to pad.
42 | value : float
43 | The value to use for padding.
44 |
45 | Returns
46 | -------
47 | Tensor
48 | The padded tensor.
49 | Tensor
50 | The padding mask.
51 |
52 | """
53 | if isinstance(data[0], str):
54 | return data, 0
55 |
56 | # Check if all have the same shape
57 | if all(d.shape == data[0].shape for d in data):
58 | return torch.stack(data, dim=0), 0
59 |
60 | # Get the maximum in each dimension
61 | num_dims = len(data[0].shape)
62 | max_dims = [max(d.shape[i] for d in data) for i in range(num_dims)]
63 |
64 | # Get the padding lengths
65 | pad_lengths = []
66 | for d in data:
67 | dims = []
68 | for i in range(num_dims):
69 | dims.append(0)
70 | dims.append(max_dims[num_dims - i - 1] - d.shape[num_dims - i - 1])
71 | pad_lengths.append(dims)
72 |
73 | # Pad the data
74 | padding = [
75 | pad(torch.ones_like(d), pad_len, value=0)
76 | for d, pad_len in zip(data, pad_lengths)
77 | ]
78 | data = [pad(d, pad_len, value=value) for d, pad_len in zip(data, pad_lengths)]
79 |
80 | # Stack the data
81 | padding = torch.stack(padding, dim=0)
82 | data = torch.stack(data, dim=0)
83 |
84 | return data, padding
85 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools >= 61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "boltz"
7 | version = "0.1.0"
8 | requires-python = ">=3.9"
9 | description = "Boltz-1"
10 | readme = "README.md"
11 | dependencies = [
12 | "torch>=2.2",
13 | "numpy==1.26.3",
14 | "hydra-core==1.3.2",
15 | "pytorch-lightning==2.4.0",
16 | "rdkit==2024.3.6",
17 | "dm-tree==0.1.8",
18 | "requests==2.32.3",
19 | "pandas==2.2.3",
20 | "types-requests",
21 | "einops==0.8.0",
22 | "einx==0.3.0",
23 | "fairscale==0.4.13",
24 | "mashumaro==3.14",
25 | "modelcif==1.2",
26 | "wandb==0.18.7",
27 | "click==8.1.7",
28 | "pyyaml==6.0.2",
29 | "biopython==1.84",
30 | "scipy==1.13.1",
31 | ]
32 |
33 | [project.scripts]
34 | boltz = "boltz.main:cli"
35 |
36 | [project.optional-dependencies]
37 | lint = ["ruff"]
38 |
39 | [tool.ruff]
40 | src = ["src"]
41 | extend-exclude = ["conf.py"]
42 | target-version = "py39"
43 | lint.select = ["ALL"]
44 | lint.ignore = [
45 | "COM812", # Conflicts with the formatter
46 | "ISC001", # Conflicts with the formatter
47 | "ANN101", # "missing-type-self"
48 | "RET504", # Unnecessary assignment to `x` before `return` statementRuff
49 | "S101", # Use of `assert` detected
50 | "D100", # Missing docstring in public module
51 | "D104", # Missing docstring in public package
52 | "PT001", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715
53 | "PT004", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715
54 | "PT005", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715
55 | "PT023", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715
56 | "FBT001",
57 | "FBT002",
58 | "PLR0913", # Too many arguments to init (> 5)
59 | ]
60 |
61 | [tool.ruff.lint.per-file-ignores]
62 | "**/__init__.py" = [
63 | "F401", # Imported but unused
64 | "F403", # Wildcard imports
65 | ]
66 | "docs/**" = [
67 | "INP001", # Requires __init__.py but folder is not a package.
68 | ]
69 | "scripts/**" = [
70 | "INP001", # Requires __init__.py but folder is not a package.
71 | ]
72 |
73 | [tool.ruff.lint.pyupgrade]
74 | # Preserve types, even if a file imports `from __future__ import annotations`(https://github.com/astral-sh/ruff/issues/5434)
75 | keep-runtime-typing = true
76 |
77 | [tool.ruff.lint.pydocstyle]
78 | convention = "numpy"
79 |
--------------------------------------------------------------------------------
/docs/training.md:
--------------------------------------------------------------------------------
1 | # Training
2 |
3 | ## Download processed data
4 |
5 | Instructions on how to download the processed dataset for training are coming soon, we are currently uploading the data to sharable storage and will update this page when ready.
6 |
7 | ## Modify the configuration file
8 |
9 | The training script requires a configuration file to run. This file specifies the paths to the data, the output directory, and other parameters of the data, model and training process.
10 |
11 | We provide under `scripts/train/configs` a template configuration file analogous to the one we used for training the structure model (`structure.yaml`) and the confidence model (`confidence.yaml`).
12 |
13 | The following are the main parameters that you should modify in the configuration file to get the structure model to train:
14 |
15 | ```yaml
16 | trainer:
17 | devices: 1
18 |
19 | output: SET_PATH_HERE # Path to the output directory
20 | resume: PATH_TO_CHECKPOINT_FILE # Path to a checkpoint file to resume training from if any null otherwise
21 |
22 | data:
23 | datasets:
24 | - _target_: boltz.data.module.training.DatasetConfig
25 | target_dir: PATH_TO_TARGETS_DIR # Path to the directory containing the processed structure files
26 | msa_dir: PATH_TO_MSA_DIR # Path to the directory containing the processed MSA files
27 |
28 | symmetries: PATH_TO_SYMMETRY_FILE # Path to the file containing molecule the symmetry information
29 | max_tokens: 512 # Maximum number of tokens in the input sequence
30 | max_atoms: 4608 # Maximum number of atoms in the input structure
31 | ```
32 |
33 | `max_tokens` and `max_atoms` are the maximum number of tokens and atoms in the crop. Depending on the size of the GPUs you are using (as well as the training speed desired), you may want to adjust these values. Other recommended values are 256 and 2304, or 384 and 3456 respectively.
34 |
35 | ## Run the training script
36 |
37 | Before running the full training, we recommend using the debug flag. This turns off DDP (sets single device) and set `num_workers` to 0 so everything is in a single process, as well as disabling wandb:
38 |
39 | python scripts/train/train.py scripts/train/configs/structure.yaml debug=1
40 |
41 | Once that seems to run okay, you can kill it and launch the training run:
42 |
43 | python scripts/train/train.py scripts/train/configs/structure.yaml
44 |
45 | We also provide a different configuration file to train the confidence model:
46 |
47 | python scripts/train/train.py scripts/train/configs/confidence.yaml
--------------------------------------------------------------------------------
/src/boltz/model/layers/transition.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from torch import Tensor, nn
4 |
5 | import boltz.model.layers.initialize as init
6 |
7 |
8 | class Transition(nn.Module):
9 | """Perform a two-layer MLP."""
10 |
11 | def __init__(
12 | self,
13 | dim: int = 128,
14 | hidden: int = 512,
15 | out_dim: Optional[int] = None,
16 | chunk_size: int = None,
17 | ) -> None:
18 | """Initialize the TransitionUpdate module.
19 |
20 | Parameters
21 | ----------
22 | dim: int
23 | The dimension of the input, default 128
24 | hidden: int
25 | The dimension of the hidden, default 512
26 | out_dim: Optional[int]
27 | The dimension of the output, default None
28 | chunk_size: int
29 | The chunk size for inference, default None
30 |
31 | """
32 | super().__init__()
33 | if out_dim is None:
34 | out_dim = dim
35 |
36 | self.norm = nn.LayerNorm(dim, eps=1e-5)
37 | self.fc1 = nn.Linear(dim, hidden, bias=False)
38 | self.fc2 = nn.Linear(dim, hidden, bias=False)
39 | self.fc3 = nn.Linear(hidden, out_dim, bias=False)
40 | self.silu = nn.SiLU()
41 | self.hidden = hidden
42 | self.chunk_size = chunk_size
43 |
44 | init.bias_init_one_(self.norm.weight)
45 | init.bias_init_zero_(self.norm.bias)
46 |
47 | init.lecun_normal_init_(self.fc1.weight)
48 | init.lecun_normal_init_(self.fc2.weight)
49 | init.final_init_(self.fc3.weight)
50 |
51 | def forward(self, x: Tensor) -> Tensor:
52 | """Perform a forward pass.
53 |
54 | Parameters
55 | ----------
56 | x: torch.Tensor
57 | The input data of shape (..., D)
58 |
59 | Returns
60 | -------
61 | x: torch.Tensor
62 | The output data of shape (..., D)
63 |
64 | """
65 | x = self.norm(x)
66 |
67 | if self.chunk_size is None or self.training:
68 | x = self.silu(self.fc1(x)) * self.fc2(x)
69 | x = self.fc3(x)
70 | return x
71 | else:
72 | # Compute in chunks
73 | for i in range(0, self.hidden, self.chunk_size):
74 | fc1_slice = self.fc1.weight[i : i + self.chunk_size, :]
75 | fc2_slice = self.fc2.weight[i : i + self.chunk_size, :]
76 | fc3_slice = self.fc3.weight[:, i : i + self.chunk_size]
77 | x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T)
78 | if i == 0:
79 | x_out = x_chunk @ fc3_slice.T
80 | else:
81 | x_out = x_out + x_chunk @ fc3_slice.T
82 | return x_out
83 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/initialize.py:
--------------------------------------------------------------------------------
1 | """Utility functions for initializing weights and biases."""
2 |
3 | # Copyright 2021 AlQuraishi Laboratory
4 | # Copyright 2021 DeepMind Technologies Limited
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import math
19 | import numpy as np
20 | from scipy.stats import truncnorm
21 | import torch
22 |
23 |
24 | def _prod(nums):
25 | out = 1
26 | for n in nums:
27 | out = out * n
28 | return out
29 |
30 |
31 | def _calculate_fan(linear_weight_shape, fan="fan_in"):
32 | fan_out, fan_in = linear_weight_shape
33 |
34 | if fan == "fan_in":
35 | f = fan_in
36 | elif fan == "fan_out":
37 | f = fan_out
38 | elif fan == "fan_avg":
39 | f = (fan_in + fan_out) / 2
40 | else:
41 | raise ValueError("Invalid fan option")
42 |
43 | return f
44 |
45 |
46 | def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
47 | shape = weights.shape
48 | f = _calculate_fan(shape, fan)
49 | scale = scale / max(1, f)
50 | a = -2
51 | b = 2
52 | std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
53 | size = _prod(shape)
54 | samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
55 | samples = np.reshape(samples, shape)
56 | with torch.no_grad():
57 | weights.copy_(torch.tensor(samples, device=weights.device))
58 |
59 |
60 | def lecun_normal_init_(weights):
61 | trunc_normal_init_(weights, scale=1.0)
62 |
63 |
64 | def he_normal_init_(weights):
65 | trunc_normal_init_(weights, scale=2.0)
66 |
67 |
68 | def glorot_uniform_init_(weights):
69 | torch.nn.init.xavier_uniform_(weights, gain=1)
70 |
71 |
72 | def final_init_(weights):
73 | with torch.no_grad():
74 | weights.fill_(0.0)
75 |
76 |
77 | def gating_init_(weights):
78 | with torch.no_grad():
79 | weights.fill_(0.0)
80 |
81 |
82 | def bias_init_zero_(bias):
83 | with torch.no_grad():
84 | bias.fill_(0.0)
85 |
86 |
87 | def bias_init_one_(bias):
88 | with torch.no_grad():
89 | bias.fill_(1.0)
90 |
91 |
92 | def normal_init_(weights):
93 | torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
94 |
95 |
96 | def ipa_point_weights_init_(weights):
97 | with torch.no_grad():
98 | softplus_inverse_1 = 0.541324854612918
99 | weights.fill_(softplus_inverse_1)
100 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/outer_product_mean.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor, nn
3 |
4 | import boltz.model.layers.initialize as init
5 |
6 |
7 | class OuterProductMean(nn.Module):
8 | """Outer product mean layer."""
9 |
10 | def __init__(
11 | self, c_in: int, c_hidden: int, c_out: int, chunk_size: int = None
12 | ) -> None:
13 | """Initialize the outer product mean layer.
14 |
15 | Parameters
16 | ----------
17 | c_in : int
18 | The input dimension.
19 | c_hidden : int
20 | The hidden dimension.
21 | c_out : int
22 | The output dimension.
23 | chunk_size : int, optional
24 | The inference chunk size, by default None.
25 |
26 | """
27 | super().__init__()
28 | self.chunk_size = chunk_size
29 | self.c_hidden = c_hidden
30 | self.norm = nn.LayerNorm(c_in)
31 | self.proj_a = nn.Linear(c_in, c_hidden, bias=False)
32 | self.proj_b = nn.Linear(c_in, c_hidden, bias=False)
33 | self.proj_o = nn.Linear(c_hidden * c_hidden, c_out)
34 | init.final_init_(self.proj_o.weight)
35 | init.final_init_(self.proj_o.bias)
36 |
37 | def forward(self, m: Tensor, mask: Tensor) -> Tensor:
38 | """Forward pass.
39 |
40 | Parameters
41 | ----------
42 | m : torch.Tensor
43 | The sequence tensor (B, S, N, c_in).
44 | mask : torch.Tensor
45 | The mask tensor (B, S, N).
46 |
47 | Returns
48 | -------
49 | torch.Tensor
50 | The output tensor (B, N, N, c_out).
51 |
52 | """
53 | # Expand mask
54 | mask = mask.unsqueeze(-1).to(m)
55 |
56 | # Compute projections
57 | m = self.norm(m)
58 | a = self.proj_a(m) * mask
59 | b = self.proj_b(m) * mask
60 |
61 | # Compute pairwise mask
62 | mask = mask[:, :, None, :] * mask[:, :, :, None]
63 |
64 | # Compute outer product mean
65 | if self.chunk_size is not None and not self.training:
66 | # Compute squentially in chunks
67 | for i in range(0, self.c_hidden, self.chunk_size):
68 | a_chunk = a[:, :, :, i : i + self.chunk_size]
69 | sliced_weight_proj_o = self.proj_o.weight[
70 | :, i * self.c_hidden : (i + self.chunk_size) * self.c_hidden
71 | ]
72 |
73 | z = torch.einsum("bsic,bsjd->bijcd", a_chunk, b)
74 | z = z.reshape(*z.shape[:3], -1)
75 | z = z / mask.sum(dim=1).clamp(min=1)
76 |
77 | # Project to output
78 | if i == 0:
79 | z_out = z.to(m) @ sliced_weight_proj_o.T
80 | else:
81 | z_out = z_out + z.to(m) @ sliced_weight_proj_o.T
82 | return z_out
83 | else:
84 | z = torch.einsum("bsic,bsjd->bijcd", a.float(), b.float())
85 | z = z.reshape(*z.shape[:3], -1)
86 | z = z / mask.sum(dim=1).clamp(min=1)
87 |
88 | # Project to output
89 | z = self.proj_o(z.to(m))
90 | return z
91 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/static/ligand.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from boltz.data import const
4 | from boltz.data.types import Structure
5 | from boltz.data.filter.static.filter import StaticFilter
6 |
7 | LIGAND_EXCLUSION = {
8 | "144",
9 | "15P",
10 | "1PE",
11 | "2F2",
12 | "2JC",
13 | "3HR",
14 | "3SY",
15 | "7N5",
16 | "7PE",
17 | "9JE",
18 | "AAE",
19 | "ABA",
20 | "ACE",
21 | "ACN",
22 | "ACT",
23 | "ACY",
24 | "AZI",
25 | "BAM",
26 | "BCN",
27 | "BCT",
28 | "BDN",
29 | "BEN",
30 | "BME",
31 | "BO3",
32 | "BTB",
33 | "BTC",
34 | "BU1",
35 | "C8E",
36 | "CAD",
37 | "CAQ",
38 | "CBM",
39 | "CCN",
40 | "CIT",
41 | "CL",
42 | "CLR",
43 | "CM",
44 | "CMO",
45 | "CO3",
46 | "CPT",
47 | "CXS",
48 | "D10",
49 | "DEP",
50 | "DIO",
51 | "DMS",
52 | "DN",
53 | "DOD",
54 | "DOX",
55 | "EDO",
56 | "EEE",
57 | "EGL",
58 | "EOH",
59 | "EOX",
60 | "EPE",
61 | "ETF",
62 | "FCY",
63 | "FJO",
64 | "FLC",
65 | "FMT",
66 | "FW5",
67 | "GOL",
68 | "GSH",
69 | "GTT",
70 | "GYF",
71 | "HED",
72 | "IHP",
73 | "IHS",
74 | "IMD",
75 | "IOD",
76 | "IPA",
77 | "IPH",
78 | "LDA",
79 | "MB3",
80 | "MEG",
81 | "MES",
82 | "MLA",
83 | "MLI",
84 | "MOH",
85 | "MPD",
86 | "MRD",
87 | "MSE",
88 | "MYR",
89 | "N",
90 | "NA",
91 | "NH2",
92 | "NH4",
93 | "NHE",
94 | "NO3",
95 | "O4B",
96 | "OHE",
97 | "OLA",
98 | "OLC",
99 | "OMB",
100 | "OME",
101 | "OXA",
102 | "P6G",
103 | "PE3",
104 | "PE4",
105 | "PEG",
106 | "PEO",
107 | "PEP",
108 | "PG0",
109 | "PG4",
110 | "PGE",
111 | "PGR",
112 | "PLM",
113 | "PO4",
114 | "POL",
115 | "POP",
116 | "PVO",
117 | "SAR",
118 | "SCN",
119 | "SEO",
120 | "SEP",
121 | "SIN",
122 | "SO4",
123 | "SPD",
124 | "SPM",
125 | "SR",
126 | "STE",
127 | "STO",
128 | "STU",
129 | "TAR",
130 | "TBU",
131 | "TME",
132 | "TPO",
133 | "TRS",
134 | "UNK",
135 | "UNL",
136 | "UNX",
137 | "UPL",
138 | "URE",
139 | }
140 |
141 |
142 | class ExcludedLigands(StaticFilter):
143 | """Filter excluded ligands."""
144 |
145 | def filter(self, structure: Structure) -> np.ndarray:
146 | """Filter excluded ligands.
147 |
148 | Parameters
149 | ----------
150 | structure : Structure
151 | The structure to filter chains from.
152 |
153 | Returns
154 | -------
155 | np.ndarray
156 | The chains to keep, as a boolean mask.
157 |
158 | """
159 | valid = np.ones(len(structure.chains), dtype=bool)
160 |
161 | for i, chain in enumerate(structure.chains):
162 | if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]:
163 | continue
164 |
165 | res_start = chain["res_idx"]
166 | res_end = res_start + chain["res_num"]
167 | residues = structure.residues[res_start:res_end]
168 | if any(res["name"] in LIGAND_EXCLUSION for res in residues):
169 | valid[i] = 0
170 |
171 | return valid
172 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
--------------------------------------------------------------------------------
/src/boltz/data/parse/a3m.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | from pathlib import Path
3 | from typing import Optional, TextIO
4 |
5 | import numpy as np
6 |
7 | from boltz.data import const
8 | from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence
9 |
10 |
11 | def _parse_a3m( # noqa: C901
12 | lines: TextIO,
13 | taxonomy: Optional[dict[str, str]],
14 | max_seqs: Optional[int] = None,
15 | ) -> MSA:
16 | """Process an MSA file.
17 |
18 | Parameters
19 | ----------
20 | lines : TextIO
21 | The lines of the MSA file.
22 | taxonomy : dict[str, str]
23 | The taxonomy database, if available.
24 | max_seqs : int, optional
25 | The maximum number of sequences.
26 |
27 | Returns
28 | -------
29 | MSA
30 | The MSA object.
31 |
32 | """
33 | visited = set()
34 | sequences = []
35 | deletions = []
36 | residues = []
37 |
38 | seq_idx = 0
39 | for line in lines:
40 | line: str
41 | line = line.strip() # noqa: PLW2901
42 | if not line or line.startswith("#"):
43 | continue
44 |
45 | # Get taxonomy, if annotated
46 | if line.startswith(">"):
47 | header = line.split()[0]
48 | if taxonomy and header.startswith(">UniRef100"):
49 | uniref_id = header.split("_")[1]
50 | taxonomy_id = taxonomy.get(uniref_id)
51 | if taxonomy_id is None:
52 | taxonomy_id = -1
53 | else:
54 | taxonomy_id = -1
55 | continue
56 |
57 | # Skip if duplicate sequence
58 | str_seq = line.replace("-", "").upper()
59 | if str_seq not in visited:
60 | visited.add(str_seq)
61 | else:
62 | continue
63 |
64 | # Process sequence
65 | residue = []
66 | deletion = []
67 | count = 0
68 | res_idx = 0
69 | for c in line:
70 | if c != "-" and c.islower():
71 | count += 1
72 | continue
73 | token = const.prot_letter_to_token[c]
74 | token = const.token_ids[token]
75 | residue.append(token)
76 | if count > 0:
77 | deletion.append((res_idx, count))
78 | count = 0
79 | res_idx += 1
80 |
81 | res_start = len(residues)
82 | res_end = res_start + len(residue)
83 |
84 | del_start = len(deletions)
85 | del_end = del_start + len(deletion)
86 |
87 | sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end))
88 | residues.extend(residue)
89 | deletions.extend(deletion)
90 |
91 | seq_idx += 1
92 | if (max_seqs is not None) and (seq_idx >= max_seqs):
93 | break
94 |
95 | # Create MSA object
96 | msa = MSA(
97 | residues=np.array(residues, dtype=MSAResidue),
98 | deletions=np.array(deletions, dtype=MSADeletion),
99 | sequences=np.array(sequences, dtype=MSASequence),
100 | )
101 | return msa
102 |
103 |
104 | def parse_a3m(
105 | path: Path,
106 | taxonomy: Optional[dict[str, str]],
107 | max_seqs: Optional[int] = None,
108 | ) -> MSA:
109 | """Process an A3M file.
110 |
111 | Parameters
112 | ----------
113 | path : Path
114 | The path to the a3m(.gz) file.
115 | taxonomy : Redis
116 | The taxonomy database.
117 | max_seqs : int, optional
118 | The maximum number of sequences.
119 |
120 | Returns
121 | -------
122 | MSA
123 | The MSA object.
124 |
125 | """
126 | # Read the file
127 | if path.suffix == ".gz":
128 | with gzip.open(str(path), "rt") as f:
129 | msa = _parse_a3m(f, taxonomy, max_seqs)
130 | else:
131 | with path.open("r") as f:
132 | msa = _parse_a3m(f, taxonomy, max_seqs)
133 |
134 | return msa
135 |
--------------------------------------------------------------------------------
/src/boltz/model/optim/scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
5 | """Implements the learning rate schedule defined AF3.
6 |
7 | A linear warmup is followed by a plateau at the maximum
8 | learning rate and then exponential decay. Note that the
9 | initial learning rate of the optimizer in question is
10 | ignored; use this class' base_lr parameter to specify
11 | the starting point of the warmup.
12 |
13 | """
14 |
15 | def __init__(
16 | self,
17 | optimizer: torch.optim.Optimizer,
18 | last_epoch: int = -1,
19 | verbose: bool = False,
20 | base_lr: float = 0.0,
21 | max_lr: float = 1.8e-3,
22 | warmup_no_steps: int = 1000,
23 | start_decay_after_n_steps: int = 50000,
24 | decay_every_n_steps: int = 50000,
25 | decay_factor: float = 0.95,
26 | ) -> None:
27 | """Initialize the learning rate scheduler.
28 |
29 | Parameters
30 | ----------
31 | optimizer : torch.optim.Optimizer
32 | The optimizer.
33 | last_epoch : int, optional
34 | The last epoch, by default -1
35 | verbose : bool, optional
36 | Whether to print verbose output, by default False
37 | base_lr : float, optional
38 | The base learning rate, by default 0.0
39 | max_lr : float, optional
40 | The maximum learning rate, by default 1.8e-3
41 | warmup_no_steps : int, optional
42 | The number of warmup steps, by default 1000
43 | start_decay_after_n_steps : int, optional
44 | The number of steps after which to start decay, by default 50000
45 | decay_every_n_steps : int, optional
46 | The number of steps after which to decay, by default 50000
47 | decay_factor : float, optional
48 | The decay factor, by default 0.95
49 |
50 | """
51 | step_counts = {
52 | "warmup_no_steps": warmup_no_steps,
53 | "start_decay_after_n_steps": start_decay_after_n_steps,
54 | }
55 |
56 | for k, v in step_counts.items():
57 | if v < 0:
58 | msg = f"{k} must be nonnegative"
59 | raise ValueError(msg)
60 |
61 | if warmup_no_steps > start_decay_after_n_steps:
62 | msg = "warmup_no_steps must not exceed start_decay_after_n_steps"
63 | raise ValueError(msg)
64 |
65 | self.optimizer = optimizer
66 | self.last_epoch = last_epoch
67 | self.verbose = verbose
68 | self.base_lr = base_lr
69 | self.max_lr = max_lr
70 | self.warmup_no_steps = warmup_no_steps
71 | self.start_decay_after_n_steps = start_decay_after_n_steps
72 | self.decay_every_n_steps = decay_every_n_steps
73 | self.decay_factor = decay_factor
74 |
75 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
76 |
77 | def state_dict(self) -> dict:
78 | state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
79 | return state_dict
80 |
81 | def load_state_dict(self, state_dict):
82 | self.__dict__.update(state_dict)
83 |
84 | def get_lr(self):
85 | if not self._get_lr_called_within_step:
86 | msg = (
87 | "To get the last learning rate computed by the scheduler, use "
88 | "get_last_lr()"
89 | )
90 | raise RuntimeError(msg)
91 |
92 | step_no = self.last_epoch
93 |
94 | if step_no <= self.warmup_no_steps:
95 | lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
96 | elif step_no > self.start_decay_after_n_steps:
97 | steps_since_decay = step_no - self.start_decay_after_n_steps
98 | exp = (steps_since_decay // self.decay_every_n_steps) + 1
99 | lr = self.max_lr * (self.decay_factor**exp)
100 | else: # plateau
101 | lr = self.max_lr
102 |
103 | return [lr for group in self.optimizer.param_groups]
104 |
--------------------------------------------------------------------------------
/src/boltz/data/parse/fasta.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Mapping
2 | from pathlib import Path
3 |
4 | from Bio import SeqIO
5 | from rdkit.Chem.rdchem import Mol
6 |
7 | from boltz.data.parse.yaml import parse_boltz_schema
8 | from boltz.data.types import Target
9 |
10 |
11 | def parse_fasta(path: Path, ccd: Mapping[str, Mol]) -> Target: # noqa: C901
12 | """Parse a fasta file.
13 |
14 | The name of the fasta file is used as the name of this job.
15 | We rely on the fasta record id to determine the entity type.
16 |
17 | > CHAIN_ID|ENTITY_TYPE|MSA_ID
18 | SEQUENCE
19 | > CHAIN_ID|ENTITY_TYPE|MSA_ID
20 | ...
21 |
22 | Where ENTITY_TYPE is either protein, rna, dna, ccd or smiles,
23 | and CHAIN_ID is the chain identifier, which should be unique.
24 | The MSA_ID is optional and should only be used on proteins.
25 |
26 | Parameters
27 | ----------
28 | fasta_file : Path
29 | Path to the fasta file.
30 | ccd : Dict
31 | Dictionary of CCD components.
32 |
33 | Returns
34 | -------
35 | Target
36 | The parsed target.
37 |
38 | """
39 | # Read fasta file
40 | with path.open("r") as f:
41 | records = list(SeqIO.parse(f, "fasta"))
42 |
43 | # Make sure all records have a chain id and entity
44 | for seq_record in records:
45 | if "|" not in seq_record.id:
46 | msg = f"Invalid record id: {seq_record.id}"
47 | raise ValueError(msg)
48 |
49 | header = seq_record.id.split("|")
50 | assert len(header) >= 2, f"Invalid record id: {seq_record.id}"
51 |
52 | chain_id, entity_type = header[:2]
53 | if entity_type.lower() not in {"protein", "dna", "rna", "ccd", "smiles"}:
54 | msg = f"Invalid entity type: {entity_type}"
55 | raise ValueError(msg)
56 | if chain_id == "":
57 | msg = "Empty chain id in input fasta!"
58 | raise ValueError(msg)
59 | if entity_type == "":
60 | msg = "Empty entity type in input fasta!"
61 | raise ValueError(msg)
62 |
63 | # Convert to yaml format
64 | sequences = []
65 | for seq_record in records:
66 | # Get chain id, entity type and sequence
67 | header = seq_record.id.split("|")
68 | chain_id, entity_type = header[:2]
69 | if len(header) == 3 and header[2] != "":
70 | assert (
71 | entity_type.lower() == "protein"
72 | ), "MSA_ID is only allowed for proteins"
73 | msa_id = header[2]
74 |
75 | entity_type = entity_type.upper()
76 | seq = str(seq_record.seq)
77 |
78 | if entity_type == "PROTEIN":
79 | molecule = {
80 | "protein": {
81 | "id": chain_id,
82 | "sequence": seq,
83 | "modifications": [],
84 | "msa": msa_id,
85 | },
86 | }
87 | elif entity_type == "RNA":
88 | molecule = {
89 | "rna": {
90 | "id": chain_id,
91 | "sequence": seq,
92 | "modifications": [],
93 | },
94 | }
95 | elif entity_type == "DNA":
96 | molecule = {
97 | "dna": {
98 | "id": chain_id,
99 | "sequence": seq,
100 | "modifications": [],
101 | }
102 | }
103 | elif entity_type.upper() == "CCD":
104 | molecule = {
105 | "ligand": {
106 | "id": chain_id,
107 | "ccd": seq,
108 | }
109 | }
110 | elif entity_type.upper() == "SMILES":
111 | molecule = {
112 | "ligand": {
113 | "id": chain_id,
114 | "smiles": seq,
115 | }
116 | }
117 |
118 | sequences.append(molecule)
119 |
120 | data = {
121 | "sequences": sequences,
122 | "bonds": [],
123 | "version": 1,
124 | }
125 |
126 | name = path.stem
127 | return parse_boltz_schema(name, data, ccd)
128 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/attention.py:
--------------------------------------------------------------------------------
1 | from einops.layers.torch import Rearrange
2 | import torch
3 | from torch import Tensor, nn
4 |
5 | import boltz.model.layers.initialize as init
6 |
7 |
8 | class AttentionPairBias(nn.Module):
9 | """Attention pair bias layer."""
10 |
11 | def __init__(
12 | self,
13 | c_s: int,
14 | c_z: int,
15 | num_heads: int,
16 | inf: float = 1e6,
17 | initial_norm: bool = True,
18 | ) -> None:
19 | """Initialize the attention pair bias layer.
20 |
21 | Parameters
22 | ----------
23 | c_s : int
24 | The input sequence dimension.
25 | c_z : int
26 | The input pairwise dimension.
27 | num_heads : int
28 | The number of heads.
29 | inf : float, optional
30 | The inf value, by default 1e6
31 | initial_norm: bool, optional
32 | Whether to apply layer norm to the input, by default True
33 |
34 | """
35 | super().__init__()
36 |
37 | assert c_s % num_heads == 0
38 |
39 | self.c_s = c_s
40 | self.num_heads = num_heads
41 | self.head_dim = c_s // num_heads
42 | self.inf = inf
43 |
44 | self.initial_norm = initial_norm
45 | if self.initial_norm:
46 | self.norm_s = nn.LayerNorm(c_s)
47 |
48 | self.proj_q = nn.Linear(c_s, c_s)
49 | self.proj_k = nn.Linear(c_s, c_s, bias=False)
50 | self.proj_v = nn.Linear(c_s, c_s, bias=False)
51 | self.proj_g = nn.Linear(c_s, c_s, bias=False)
52 |
53 | self.proj_z = nn.Sequential(
54 | nn.LayerNorm(c_z),
55 | nn.Linear(c_z, num_heads, bias=False),
56 | Rearrange("b ... h -> b h ..."),
57 | )
58 |
59 | self.proj_o = nn.Linear(c_s, c_s, bias=False)
60 | init.final_init_(self.proj_o.weight)
61 |
62 | def forward(
63 | self,
64 | s: Tensor,
65 | z: Tensor,
66 | mask: Tensor,
67 | multiplicity: int = 1,
68 | to_keys=None,
69 | model_cache=None,
70 | ) -> Tensor:
71 | """Forward pass.
72 |
73 | Parameters
74 | ----------
75 | s : torch.Tensor
76 | The input sequence tensor (B, S, D)
77 | z : torch.Tensor
78 | The input pairwise tensor (B, N, N, D)
79 | mask : torch.Tensor
80 | The pairwise mask tensor (B, N, N)
81 | multiplicity : int, optional
82 | The diffusion batch size, by default 1
83 |
84 | Returns
85 | -------
86 | torch.Tensor
87 | The output sequence tensor.
88 |
89 | """
90 | B = s.shape[0]
91 |
92 | # Layer norms
93 | if self.initial_norm:
94 | s = self.norm_s(s)
95 |
96 | if to_keys is not None:
97 | k_in = to_keys(s)
98 | mask = to_keys(mask.unsqueeze(-1)).squeeze(-1)
99 | else:
100 | k_in = s
101 |
102 | # Compute projections
103 | q = self.proj_q(s).view(B, -1, self.num_heads, self.head_dim)
104 | k = self.proj_k(k_in).view(B, -1, self.num_heads, self.head_dim)
105 | v = self.proj_v(k_in).view(B, -1, self.num_heads, self.head_dim)
106 |
107 | # Caching z projection during diffusion roll-out
108 | if model_cache is None or "z" not in model_cache:
109 | z = self.proj_z(z)
110 |
111 | if model_cache is not None:
112 | model_cache["z"] = z
113 | else:
114 | z = model_cache["z"]
115 | z = z.repeat_interleave(multiplicity, 0)
116 |
117 | g = self.proj_g(s).sigmoid()
118 |
119 | with torch.autocast("cuda", enabled=False):
120 | # Compute attention weights
121 | attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float())
122 | attn = attn / (self.head_dim**0.5) + z.float()
123 | attn = attn + (1 - mask[:, None, None].float()) * -self.inf
124 | attn = attn.softmax(dim=-1)
125 |
126 | # Compute output
127 | o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype)
128 | o = o.reshape(B, -1, self.c_s)
129 | o = self.proj_o(g * o)
130 |
131 | return o
132 |
--------------------------------------------------------------------------------
/src/boltz/data/write/pdb.py:
--------------------------------------------------------------------------------
1 | from rdkit import Chem
2 |
3 | from boltz.data import const
4 | from boltz.data.types import Structure
5 | from boltz.data.write.utils import generate_tags
6 |
7 |
8 | def to_pdb(structure: Structure) -> str: # noqa: PLR0915
9 | """Write a structure into a PDB file.
10 |
11 | Parameters
12 | ----------
13 | structure : Structure
14 | The input structure
15 |
16 | Returns
17 | -------
18 | str
19 | the output PDB file
20 |
21 | """
22 | pdb_lines = []
23 |
24 | atom_index = 1
25 | atom_reindex_ter = []
26 | chain_tags = generate_tags()
27 |
28 | # Load periodic table for element mapping
29 | periodic_table = Chem.GetPeriodicTable()
30 |
31 | # Add all atom sites.
32 | for chain in structure.chains:
33 | # We rename the chains in alphabetical order
34 | chain_idx = chain["asym_id"]
35 | chain_tag = next(chain_tags)
36 |
37 | res_start = chain["res_idx"]
38 | res_end = chain["res_idx"] + chain["res_num"]
39 |
40 | residues = structure.residues[res_start:res_end]
41 | for residue in residues:
42 | atom_start = residue["atom_idx"]
43 | atom_end = residue["atom_idx"] + residue["atom_num"]
44 | atoms = structure.atoms[atom_start:atom_end]
45 | atom_coords = atoms["coords"]
46 | for i, atom in enumerate(atoms):
47 | # This should not happen on predictions, but just in case.
48 | if not atom["is_present"]:
49 | continue
50 |
51 | record_type = (
52 | "ATOM"
53 | if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]
54 | else "HETATM"
55 | )
56 | name = atom["name"]
57 | name = [chr(c + 32) for c in name if c != 0]
58 | name = "".join(name)
59 | name = name if len(name) == 4 else f" {name}" # noqa: PLR2004
60 | alt_loc = ""
61 | insertion_code = ""
62 | occupancy = 1.00
63 | element = periodic_table.GetElementSymbol(atom["element"].item())
64 | element = element.upper()
65 | charge = ""
66 | residue_index = residue["res_idx"] + 1
67 | pos = atom_coords[i]
68 | res_name_3 = (
69 | "LIG" if record_type == "HETATM" else str(residue["name"][:3])
70 | )
71 | b_factor = 1.00
72 |
73 | # PDB is a columnar format, every space matters here!
74 | atom_line = (
75 | f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
76 | f"{res_name_3:>3} {chain_tag:>1}"
77 | f"{residue_index:>4}{insertion_code:>1} "
78 | f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
79 | f"{occupancy:>6.2f}{b_factor:>6.2f} "
80 | f"{element:>2}{charge:>2}"
81 | )
82 | pdb_lines.append(atom_line)
83 | atom_reindex_ter.append(atom_index)
84 | atom_index += 1
85 |
86 | should_terminate = chain_idx < (len(structure.chains) - 1)
87 | if should_terminate:
88 | # Close the chain.
89 | chain_end = "TER"
90 | chain_termination_line = (
91 | f"{chain_end:<6}{atom_index:>5} "
92 | f"{res_name_3:>3} "
93 | f"{chain_tag:>1}{residue_index:>4}"
94 | )
95 | pdb_lines.append(chain_termination_line)
96 | atom_index += 1
97 |
98 | # Dump CONECT records.
99 | for bonds in [structure.bonds, structure.connections]:
100 | for bond in bonds:
101 | atom1 = structure.atoms[bond["atom_1"]]
102 | atom2 = structure.atoms[bond["atom_2"]]
103 | if not atom1["is_present"] or not atom2["is_present"]:
104 | continue
105 | atom1_idx = atom_reindex_ter[bond["atom_1"]]
106 | atom2_idx = atom_reindex_ter[bond["atom_2"]]
107 | conect_line = f"CONECT{atom1_idx:>5}{atom2_idx:>5}"
108 | pdb_lines.append(conect_line)
109 |
110 | pdb_lines.append("END")
111 | pdb_lines.append("")
112 | pdb_lines = [line.ljust(80) for line in pdb_lines]
113 | return "\n".join(pdb_lines)
114 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/triangular_mult.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor, nn
3 |
4 | from boltz.model.layers import initialize as init
5 |
6 |
7 | class TriangleMultiplicationOutgoing(nn.Module):
8 | """TriangleMultiplicationOutgoing."""
9 |
10 | def __init__(self, dim: int = 128) -> None:
11 | """Initialize the TriangularUpdate module.
12 |
13 | Parameters
14 | ----------
15 | dim: int
16 | The dimension of the input, default 128
17 |
18 | """
19 | super().__init__()
20 |
21 | self.norm_in = nn.LayerNorm(dim, eps=1e-5)
22 | self.p_in = nn.Linear(dim, 2 * dim, bias=False)
23 | self.g_in = nn.Linear(dim, 2 * dim, bias=False)
24 |
25 | self.norm_out = nn.LayerNorm(dim)
26 | self.p_out = nn.Linear(dim, dim, bias=False)
27 | self.g_out = nn.Linear(dim, dim, bias=False)
28 |
29 | init.bias_init_one_(self.norm_in.weight)
30 | init.bias_init_zero_(self.norm_in.bias)
31 |
32 | init.lecun_normal_init_(self.p_in.weight)
33 | init.gating_init_(self.g_in.weight)
34 |
35 | init.bias_init_one_(self.norm_out.weight)
36 | init.bias_init_zero_(self.norm_out.bias)
37 |
38 | init.final_init_(self.p_out.weight)
39 | init.gating_init_(self.g_out.weight)
40 |
41 | def forward(self, x: Tensor, mask: Tensor) -> Tensor:
42 | """Perform a forward pass.
43 |
44 | Parameters
45 | ----------
46 | x: torch.Tensor
47 | The input data of shape (B, N, N, D)
48 | mask: torch.Tensor
49 | The input mask of shape (B, N, N)
50 |
51 | Returns
52 | -------
53 | x: torch.Tensor
54 | The output data of shape (B, N, N, D)
55 |
56 | """
57 | # Input gating: D -> D
58 | x = self.norm_in(x)
59 | x_in = x
60 | x = self.p_in(x) * self.g_in(x).sigmoid()
61 |
62 | # Apply mask
63 | x = x * mask.unsqueeze(-1)
64 |
65 | # Split input and cast to float
66 | a, b = torch.chunk(x.float(), 2, dim=-1)
67 |
68 | # Triangular projection
69 | x = torch.einsum("bikd,bjkd->bijd", a, b)
70 |
71 | # Output gating
72 | x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
73 |
74 | return x
75 |
76 |
77 | class TriangleMultiplicationIncoming(nn.Module):
78 | """TriangleMultiplicationIncoming."""
79 |
80 | def __init__(self, dim: int = 128) -> None:
81 | """Initialize the TriangularUpdate module.
82 |
83 | Parameters
84 | ----------
85 | dim: int
86 | The dimension of the input, default 128
87 |
88 | """
89 | super().__init__()
90 |
91 | self.norm_in = nn.LayerNorm(dim, eps=1e-5)
92 | self.p_in = nn.Linear(dim, 2 * dim, bias=False)
93 | self.g_in = nn.Linear(dim, 2 * dim, bias=False)
94 |
95 | self.norm_out = nn.LayerNorm(dim)
96 | self.p_out = nn.Linear(dim, dim, bias=False)
97 | self.g_out = nn.Linear(dim, dim, bias=False)
98 |
99 | init.bias_init_one_(self.norm_in.weight)
100 | init.bias_init_zero_(self.norm_in.bias)
101 |
102 | init.lecun_normal_init_(self.p_in.weight)
103 | init.gating_init_(self.g_in.weight)
104 |
105 | init.bias_init_one_(self.norm_out.weight)
106 | init.bias_init_zero_(self.norm_out.bias)
107 |
108 | init.final_init_(self.p_out.weight)
109 | init.gating_init_(self.g_out.weight)
110 |
111 | def forward(self, x: Tensor, mask: Tensor) -> Tensor:
112 | """Perform a forward pass.
113 |
114 | Parameters
115 | ----------
116 | x: torch.Tensor
117 | The input data of shape (B, N, N, D)
118 | mask: torch.Tensor
119 | The input mask of shape (B, N, N)
120 |
121 | Returns
122 | -------
123 | x: torch.Tensor
124 | The output data of shape (B, N, N, D)
125 |
126 | """
127 | # Input gating: D -> D
128 | x = self.norm_in(x)
129 | x_in = x
130 | x = self.p_in(x) * self.g_in(x).sigmoid()
131 |
132 | # Apply mask
133 | x = x * mask.unsqueeze(-1)
134 |
135 | # Split input and cast to float
136 | a, b = torch.chunk(x.float(), 2, dim=-1)
137 |
138 | # Triangular projection
139 | x = torch.einsum("bkid,bkjd->bijd", a, b)
140 |
141 | # Output gating
142 | x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
143 |
144 | return x
145 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/pair_averaging.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor, nn
3 |
4 | import boltz.model.layers.initialize as init
5 |
6 |
7 | class PairWeightedAveraging(nn.Module):
8 | """Pair weighted averaging layer."""
9 |
10 | def __init__(
11 | self,
12 | c_m: int,
13 | c_z: int,
14 | c_h: int,
15 | num_heads: int,
16 | inf: float = 1e6,
17 | chunk_heads: bool = False,
18 | ) -> None:
19 | """Initialize the pair weighted averaging layer.
20 |
21 | Parameters
22 | ----------
23 | c_m: int
24 | The dimension of the input sequence.
25 | c_z: int
26 | The dimension of the input pairwise tensor.
27 | c_h: int
28 | The dimension of the hidden.
29 | num_heads: int
30 | The number of heads.
31 | inf: float
32 | The value to use for masking, default 1e6.
33 | chunk_heads: bool
34 | Whether to sequentially compute heads at inference, default False.
35 |
36 | """
37 | super().__init__()
38 | self.c_m = c_m
39 | self.c_z = c_z
40 | self.c_h = c_h
41 | self.num_heads = num_heads
42 | self.inf = inf
43 | self.chunk_heads = chunk_heads
44 |
45 | self.norm_m = nn.LayerNorm(c_m)
46 | self.norm_z = nn.LayerNorm(c_z)
47 |
48 | self.proj_m = nn.Linear(c_m, c_h * num_heads, bias=False)
49 | self.proj_g = nn.Linear(c_m, c_h * num_heads, bias=False)
50 | self.proj_z = nn.Linear(c_z, num_heads, bias=False)
51 | self.proj_o = nn.Linear(c_h * num_heads, c_m, bias=False)
52 | init.final_init_(self.proj_o.weight)
53 |
54 | def forward(self, m: Tensor, z: Tensor, mask: Tensor) -> Tensor:
55 | """Forward pass.
56 |
57 | Parameters
58 | ----------
59 | m : torch.Tensor
60 | The input sequence tensor (B, S, N, D)
61 | z : torch.Tensor
62 | The input pairwise tensor (B, N, N, D)
63 | mask : torch.Tensor
64 | The pairwise mask tensor (B, N, N)
65 |
66 | Returns
67 | -------
68 | torch.Tensor
69 | The output sequence tensor (B, S, N, D)
70 |
71 | """
72 | # Compute layer norms
73 | m = self.norm_m(m)
74 | z = self.norm_z(z)
75 |
76 | if self.chunk_heads and not self.training:
77 | # Compute heads sequentially
78 | o_chunks = []
79 | for head_idx in range(self.num_heads):
80 | sliced_weight_proj_m = self.proj_m.weight[
81 | head_idx * self.c_h : (head_idx + 1) * self.c_h, :
82 | ]
83 | sliced_weight_proj_g = self.proj_g.weight[
84 | head_idx * self.c_h : (head_idx + 1) * self.c_h, :
85 | ]
86 | sliced_weight_proj_z = self.proj_z.weight[head_idx : (head_idx + 1), :]
87 | sliced_weight_proj_o = self.proj_o.weight[
88 | :, head_idx * self.c_h : (head_idx + 1) * self.c_h
89 | ]
90 |
91 | # Project input tensors
92 | v: Tensor = m @ sliced_weight_proj_m.T
93 | v = v.reshape(*v.shape[:3], 1, self.c_h)
94 | v = v.permute(0, 3, 1, 2, 4)
95 |
96 | # Compute weights
97 | b: Tensor = z @ sliced_weight_proj_z.T
98 | b = b.permute(0, 3, 1, 2)
99 | b = b + (1 - mask[:, None]) * -self.inf
100 | w = torch.softmax(b, dim=-1)
101 |
102 | # Compute gating
103 | g: Tensor = m @ sliced_weight_proj_g.T
104 | g = g.sigmoid()
105 |
106 | # Compute output
107 | o = torch.einsum("bhij,bhsjd->bhsid", w, v)
108 | o = o.permute(0, 2, 3, 1, 4)
109 | o = o.reshape(*o.shape[:3], 1 * self.c_h)
110 | o_chunks = g * o
111 | if head_idx == 0:
112 | o_out = o_chunks @ sliced_weight_proj_o.T
113 | else:
114 | o_out += o_chunks @ sliced_weight_proj_o.T
115 | return o_out
116 | else:
117 | # Project input tensors
118 | v: Tensor = self.proj_m(m)
119 | v = v.reshape(*v.shape[:3], self.num_heads, self.c_h)
120 | v = v.permute(0, 3, 1, 2, 4)
121 |
122 | # Compute weights
123 | b: Tensor = self.proj_z(z)
124 | b = b.permute(0, 3, 1, 2)
125 | b = b + (1 - mask[:, None]) * -self.inf
126 | w = torch.softmax(b, dim=-1)
127 |
128 | # Compute gating
129 | g: Tensor = self.proj_g(m)
130 | g = g.sigmoid()
131 |
132 | # Compute output
133 | o = torch.einsum("bhij,bhsjd->bhsid", w, v)
134 | o = o.permute(0, 2, 3, 1, 4)
135 | o = o.reshape(*o.shape[:3], self.num_heads * self.c_h)
136 | o = self.proj_o(g * o)
137 | return o
138 |
--------------------------------------------------------------------------------
/scripts/train/configs/structure.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | accelerator: gpu
3 | devices: 1
4 | precision: 32
5 | gradient_clip_val: 10.0
6 | max_epochs: -1
7 |
8 | # Optional set wandb here
9 | # wandb:
10 | # name: boltz
11 | # project: boltz
12 | # entity: boltz
13 |
14 | output: SET_PATH_HERE
15 | resume: PATH_TO_CHECKPOINT_FILE
16 | disable_checkpoint: false
17 | matmul_precision: null
18 | save_top_k: -1
19 |
20 | data:
21 | datasets:
22 | - _target_: boltz.data.module.training.DatasetConfig
23 | target_dir: PATH_TO_TARGETS_DIR
24 | msa_dir: PATH_TO_MSA_DIR
25 | prob: 1.0
26 | sampler:
27 | _target_: boltz.data.sample.cluster.ClusterSampler
28 | cropper:
29 | _target_: boltz.data.crop.boltz.BoltzCropper
30 | min_neighborhood: 0
31 | max_neighborhood: 40
32 | split: ./scripts/train/assets/validation_ids.txt
33 |
34 | filters:
35 | - _target_: boltz.data.filter.dynamic.size.SizeFilter
36 | min_chains: 1
37 | max_chains: 300
38 | - _target_: boltz.data.filter.dynamic.date.DateFilter
39 | date: "2021-09-30"
40 | ref: released
41 | - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter
42 | resolution: 9.0
43 |
44 | tokenizer:
45 | _target_: boltz.data.tokenize.boltz.BoltzTokenizer
46 | featurizer:
47 | _target_: boltz.data.feature.featurizer.BoltzFeaturizer
48 |
49 | symmetries: PATH_TO_SYMMETRY_FILE
50 | max_tokens: 512
51 | max_atoms: 4608
52 | max_seqs: 2048
53 | pad_to_max_tokens: true
54 | pad_to_max_atoms: true
55 | pad_to_max_seqs: true
56 | samples_per_epoch: 100000
57 | batch_size: 1
58 | num_workers: 4
59 | random_seed: 42
60 | pin_memory: true
61 | overfit: null
62 | crop_validation: false
63 | return_train_symmetries: false
64 | return_val_symmetries: true
65 | train_binder_pocket_conditioned_prop: 0.3
66 | val_binder_pocket_conditioned_prop: 0.3
67 | binder_pocket_cutoff: 6.0
68 | binder_pocket_sampling_geometric_p: 0.3
69 | min_dist: 2.0
70 | max_dist: 22.0
71 | num_bins: 64
72 | atoms_per_window_queries: 32
73 |
74 | model:
75 | _target_: boltz.model.model.Boltz1
76 | atom_s: 128
77 | atom_z: 16
78 | token_s: 384
79 | token_z: 128
80 | num_bins: 64
81 | atom_feature_dim: 389
82 | atoms_per_window_queries: 32
83 | atoms_per_window_keys: 128
84 | compile_pairformer: false
85 | nucleotide_rmsd_weight: 5.0
86 | ligand_rmsd_weight: 10.0
87 | ema: true
88 | ema_decay: 0.999
89 |
90 | embedder_args:
91 | atom_encoder_depth: 3
92 | atom_encoder_heads: 4
93 |
94 | msa_args:
95 | msa_s: 64
96 | msa_blocks: 4
97 | msa_dropout: 0.15
98 | z_dropout: 0.25
99 | pairwise_head_width: 32
100 | pairwise_num_heads: 4
101 | activation_checkpointing: true
102 | offload_to_cpu: false
103 |
104 | pairformer_args:
105 | num_blocks: 48
106 | num_heads: 16
107 | dropout: 0.25
108 | activation_checkpointing: true
109 | offload_to_cpu: false
110 |
111 | score_model_args:
112 | sigma_data: 16
113 | dim_fourier: 256
114 | atom_encoder_depth: 3
115 | atom_encoder_heads: 4
116 | token_transformer_depth: 24
117 | token_transformer_heads: 16
118 | atom_decoder_depth: 3
119 | atom_decoder_heads: 4
120 | conditioning_transition_layers: 2
121 | activation_checkpointing: true
122 | offload_to_cpu: false
123 |
124 | confidence_prediction: false
125 | confidence_model_args:
126 | use_gaussian: false
127 | num_dist_bins: 64
128 | max_dist: 22
129 | add_s_to_z_prod: true
130 | add_s_input_to_s: true
131 | use_s_diffusion: true
132 | add_z_input_to_z: true
133 |
134 | confidence_args:
135 | num_plddt_bins: 50
136 | num_pde_bins: 64
137 | num_pae_bins: 64
138 | relative_confidence: none
139 |
140 | training_args:
141 | recycling_steps: 3
142 | sampling_steps: 20
143 | diffusion_multiplicity: 16
144 | diffusion_samples: 2
145 | confidence_loss_weight: 1e-4
146 | diffusion_loss_weight: 4.0
147 | distogram_loss_weight: 3e-2
148 | adam_beta_1: 0.9
149 | adam_beta_2: 0.95
150 | adam_eps: 0.00000001
151 | lr_scheduler: af3
152 | base_lr: 0.0
153 | max_lr: 0.0018
154 | lr_warmup_no_steps: 1000
155 | lr_start_decay_after_n_steps: 50000
156 | lr_decay_every_n_steps: 50000
157 | lr_decay_factor: 0.95
158 |
159 | validation_args:
160 | recycling_steps: 3
161 | sampling_steps: 200
162 | diffusion_samples: 5
163 | symmetry_correction: true
164 |
165 | diffusion_process_args:
166 | sigma_min: 0.0004
167 | sigma_max: 160.0
168 | sigma_data: 16.0
169 | rho: 7
170 | P_mean: -1.2
171 | P_std: 1.5
172 | gamma_0: 0.8
173 | gamma_min: 1.0
174 | noise_scale: 1.0
175 | step_scale: 1.0
176 | coordinate_augmentation: true
177 | alignment_reverse_diff: true
178 | synchronize_sigmas: true
179 | use_inference_model_cache: true
180 |
181 | diffusion_loss_args:
182 | add_smooth_lddt_loss: true
183 | nucleotide_loss_weight: 5.0
184 | ligand_loss_weight: 10.0
185 |
--------------------------------------------------------------------------------
/src/boltz/model/modules/confidence_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from boltz.data import const
5 | from boltz.model.loss.confidence import compute_frame_pred
6 |
7 |
8 | def compute_aggregated_metric(logits, end=1.0):
9 | """Compute the metric from the logits.
10 |
11 | Parameters
12 | ----------
13 | logits : torch.Tensor
14 | The logits of the metric
15 | end : float
16 | Max value of the metric, by default 1.0
17 |
18 | Returns
19 | -------
20 | Tensor
21 | The metric value
22 |
23 | """
24 | num_bins = logits.shape[-1]
25 | bin_width = end / num_bins
26 | bounds = torch.arange(
27 | start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
28 | )
29 | probs = nn.functional.softmax(logits, dim=-1)
30 | plddt = torch.sum(
31 | probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
32 | dim=-1,
33 | )
34 | return plddt
35 |
36 |
37 | def tm_function(d, Nres):
38 | """Compute the rescaling function for pTM.
39 |
40 | Parameters
41 | ----------
42 | d : torch.Tensor
43 | The input
44 | Nres : torch.Tensor
45 | The number of residues
46 |
47 | Returns
48 | -------
49 | Tensor
50 | Output of the function
51 |
52 | """
53 | d0 = 1.24 * (torch.clip(Nres, min=19) - 15) ** (1 / 3) - 1.8
54 | return 1 / (1 + (d / d0) ** 2)
55 |
56 |
57 | def compute_ptms(logits, x_preds, feats, multiplicity):
58 | """Compute pTM and ipTM scores.
59 |
60 | Parameters
61 | ----------
62 | logits : torch.Tensor
63 | pae logits
64 | x_preds : torch.Tensor
65 | The predicted coordinates
66 | feats : Dict[str, torch.Tensor]
67 | The input features
68 | multiplicity : int
69 | The batch size of the diffusion roll-out
70 |
71 | Returns
72 | -------
73 | Tensor
74 | pTM score
75 | Tensor
76 | ipTM score
77 | Tensor
78 | ligand ipTM score
79 | Tensor
80 | protein ipTM score
81 |
82 | """
83 | # Compute mask for collinear and overlapping tokens
84 | _, mask_collinear_pred = compute_frame_pred(
85 | x_preds, feats["frames_idx"], feats, multiplicity, inference=True
86 | )
87 | mask_pad = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
88 | maski = mask_collinear_pred.reshape(-1, mask_collinear_pred.shape[-1])
89 | pair_mask_ptm = maski[:, :, None] * mask_pad[:, None, :] * mask_pad[:, :, None]
90 | asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
91 | pair_mask_iptm = (
92 | maski[:, :, None]
93 | * (asym_id[:, None, :] != asym_id[:, :, None])
94 | * mask_pad[:, None, :]
95 | * mask_pad[:, :, None]
96 | )
97 |
98 | # Extract pae values
99 | num_bins = logits.shape[-1]
100 | bin_width = 32.0 / num_bins
101 | end = 32.0
102 | pae_value = torch.arange(
103 | start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
104 | ).unsqueeze(0)
105 | N_res = mask_pad.sum(dim=-1, keepdim=True)
106 |
107 | # compute pTM and ipTM
108 | tm_value = tm_function(pae_value, N_res).unsqueeze(1).unsqueeze(2)
109 | probs = nn.functional.softmax(logits, dim=-1)
110 | tm_expected_value = torch.sum(
111 | probs * tm_value,
112 | dim=-1,
113 | ) # shape (B, N, N)
114 | ptm = torch.max(
115 | torch.sum(tm_expected_value * pair_mask_ptm, dim=-1)
116 | / (torch.sum(pair_mask_ptm, dim=-1) + 1e-5),
117 | dim=1,
118 | ).values
119 | iptm = torch.max(
120 | torch.sum(tm_expected_value * pair_mask_iptm, dim=-1)
121 | / (torch.sum(pair_mask_iptm, dim=-1) + 1e-5),
122 | dim=1,
123 | ).values
124 |
125 | # compute ligand and protein ipTM
126 | token_type = feats["mol_type"]
127 | token_type = token_type.repeat_interleave(multiplicity, 0)
128 | is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
129 | is_protein_token = (token_type == const.chain_type_ids["PROTEIN"]).float()
130 |
131 | ligand_iptm_mask = (
132 | maski[:, :, None]
133 | * (asym_id[:, None, :] != asym_id[:, :, None])
134 | * mask_pad[:, None, :]
135 | * mask_pad[:, :, None]
136 | * (
137 | (is_ligand_token[:, :, None] * is_protein_token[:, None, :])
138 | + (is_protein_token[:, :, None] * is_ligand_token[:, None, :])
139 | )
140 | )
141 | protein_ipmt_mask = (
142 | maski[:, :, None]
143 | * (asym_id[:, None, :] != asym_id[:, :, None])
144 | * mask_pad[:, None, :]
145 | * mask_pad[:, :, None]
146 | * (is_protein_token[:, :, None] * is_protein_token[:, None, :])
147 | )
148 |
149 | ligand_iptm = torch.max(
150 | torch.sum(tm_expected_value * ligand_iptm_mask, dim=-1)
151 | / (torch.sum(ligand_iptm_mask, dim=-1) + 1e-5),
152 | dim=1,
153 | ).values
154 | protein_iptm = torch.max(
155 | torch.sum(tm_expected_value * protein_ipmt_mask, dim=-1)
156 | / (torch.sum(protein_ipmt_mask, dim=-1) + 1e-5),
157 | dim=1,
158 | ).values
159 |
160 | return ptm, iptm, ligand_iptm, protein_iptm
161 |
--------------------------------------------------------------------------------
/scripts/train/configs/confidence.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | accelerator: gpu
3 | devices: 1
4 | precision: 32
5 | gradient_clip_val: 10.0
6 | max_epochs: -1
7 |
8 | # Optional set wandb here
9 | # wandb:
10 | # name: boltz
11 | # project: boltz
12 | # entity: boltz
13 |
14 |
15 | output: SET_PATH_HERE
16 | pretrained: PATH_TO_STRUCTURE_CHECKPOINT_FILE
17 | resume: null
18 | disable_checkpoint: false
19 | matmul_precision: null
20 | save_top_k: -1
21 | load_confidence_from_trunk: true
22 |
23 | data:
24 | datasets:
25 | - _target_: boltz.data.module.training.DatasetConfig
26 | target_dir: PATH_TO_TARGETS_DIR
27 | msa_dir: PATH_TO_MSA_DIR
28 | prob: 1.0
29 | sampler:
30 | _target_: boltz.data.sample.cluster.ClusterSampler
31 | cropper:
32 | _target_: boltz.data.crop.boltz.BoltzCropper
33 | min_neighborhood: 0
34 | max_neighborhood: 40
35 | split: ./scripts/train/assets/validation_ids.txt
36 |
37 | filters:
38 | - _target_: boltz.data.filter.dynamic.size.SizeFilter
39 | min_chains: 1
40 | max_chains: 300
41 | - _target_: boltz.data.filter.dynamic.date.DateFilter
42 | date: "2021-09-30"
43 | ref: released
44 | - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter
45 | resolution: 4.0
46 |
47 | tokenizer:
48 | _target_: boltz.data.tokenize.boltz.BoltzTokenizer
49 | featurizer:
50 | _target_: boltz.data.feature.featurizer.BoltzFeaturizer
51 |
52 | symmetries: PATH_TO_SYMMETRY_FILE
53 | max_tokens: 512
54 | max_atoms: 4608
55 | max_seqs: 2048
56 | pad_to_max_tokens: true
57 | pad_to_max_atoms: true
58 | pad_to_max_seqs: true
59 | samples_per_epoch: 100000
60 | batch_size: 1
61 | num_workers: 4
62 | random_seed: 42
63 | pin_memory: true
64 | overfit: null
65 | crop_validation: true
66 | return_train_symmetries: true
67 | return_val_symmetries: true
68 | train_binder_pocket_conditioned_prop: 0.3
69 | val_binder_pocket_conditioned_prop: 0.3
70 | binder_pocket_cutoff: 6.0
71 | binder_pocket_sampling_geometric_p: 0.3
72 | min_dist: 2.0
73 | max_dist: 22.0
74 | num_bins: 64
75 | atoms_per_window_queries: 32
76 |
77 | model:
78 | _target_: boltz.model.models.boltz.BoltzPreview
79 | atom_s: 128
80 | atom_z: 16
81 | token_s: 384
82 | token_z: 128
83 | num_bins: 64
84 | atom_feature_dim: 389
85 | atoms_per_window_queries: 32
86 | atoms_per_window_keys: 128
87 | compile_pairformer: false
88 | nucleotide_rmsd_weight: 5.0
89 | ligand_rmsd_weight: 10.0
90 | ema: true
91 | ema_decay: 0.999
92 |
93 | embedder_args:
94 | atom_encoder_depth: 3
95 | atom_encoder_heads: 4
96 |
97 | msa_args:
98 | msa_s: 64
99 | msa_blocks: 4
100 | msa_dropout: 0.15
101 | z_dropout: 0.25
102 | pairwise_head_width: 32
103 | pairwise_num_heads: 4
104 | activation_checkpointing: true
105 | offload_to_cpu: false
106 |
107 | pairformer_args:
108 | num_blocks: 48
109 | num_heads: 16
110 | dropout: 0.25
111 | activation_checkpointing: true
112 | offload_to_cpu: false
113 |
114 | score_model_args:
115 | sigma_data: 16
116 | dim_fourier: 256
117 | atom_encoder_depth: 3
118 | atom_encoder_heads: 4
119 | token_transformer_depth: 24
120 | token_transformer_heads: 16
121 | atom_decoder_depth: 3
122 | atom_decoder_heads: 4
123 | conditioning_transition_layers: 2
124 | activation_checkpointing: true
125 | offload_to_cpu: false
126 |
127 | structure_prediction_training: false
128 | run_trunk_and_structure: true
129 | confidence_prediction: true
130 | alpha_pae: 1
131 | confidence_imitate_trunk: true
132 | confidence_model_args:
133 | use_gaussian: false
134 | num_dist_bins: 64
135 | max_dist: 22
136 | add_s_to_z_prod: true
137 | add_s_input_to_s: true
138 | use_s_diffusion: true
139 | add_z_input_to_z: true
140 |
141 | confidence_args:
142 | num_plddt_bins: 50
143 | num_pde_bins: 64
144 | num_pae_bins: 64
145 | relative_confidence: none
146 |
147 | training_args:
148 | recycling_steps: 3
149 | sampling_steps: 200
150 | diffusion_multiplicity: 16
151 | diffusion_samples: 1
152 | confidence_loss_weight: 3e-3
153 | diffusion_loss_weight: 4.0
154 | distogram_loss_weight: 3e-2
155 | adam_beta_1: 0.9
156 | adam_beta_2: 0.95
157 | adam_eps: 0.00000001
158 | lr_scheduler: af3
159 | base_lr: 0.0
160 | max_lr: 0.0018
161 | lr_warmup_no_steps: 1000
162 | lr_start_decay_after_n_steps: 50000
163 | lr_decay_every_n_steps: 50000
164 | lr_decay_factor: 0.95
165 | symmetry_correction: true
166 |
167 | validation_args:
168 | recycling_steps: 3
169 | sampling_steps: 200
170 | diffusion_samples: 5
171 | symmetry_correction: true
172 |
173 | diffusion_process_args:
174 | sigma_min: 0.0004
175 | sigma_max: 160.0
176 | sigma_data: 16.0
177 | rho: 7
178 | P_mean: -1.2
179 | P_std: 1.5
180 | gamma_0: 0.8
181 | gamma_min: 1.0
182 | noise_scale: 1.0
183 | step_scale: 1.0
184 | coordinate_augmentation: true
185 | alignment_reverse_diff: true
186 | synchronize_sigmas: true
187 | use_inference_model_cache: true
188 |
189 | diffusion_loss_args:
190 | add_smooth_lddt_loss: true
191 | nucleotide_loss_weight: 5.0
192 | ligand_loss_weight: 10.0
193 |
--------------------------------------------------------------------------------
/src/boltz/model/layers/triangular_attention/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 AlQuraishi Laboratory
2 | # Copyright 2021 DeepMind Technologies Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from functools import partial, partialmethod
17 | from typing import List, Optional
18 |
19 | import torch
20 | import torch.nn as nn
21 |
22 | from boltz.model.layers.triangular_attention.primitives import (
23 | Attention,
24 | LayerNorm,
25 | Linear,
26 | )
27 | from boltz.model.layers.triangular_attention.utils import (
28 | chunk_layer,
29 | permute_final_dims,
30 | )
31 |
32 |
33 | class TriangleAttention(nn.Module):
34 | def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
35 | """
36 | Args:
37 | c_in:
38 | Input channel dimension
39 | c_hidden:
40 | Overall hidden channel dimension (not per-head)
41 | no_heads:
42 | Number of attention heads
43 | """
44 | super(TriangleAttention, self).__init__()
45 |
46 | self.c_in = c_in
47 | self.c_hidden = c_hidden
48 | self.no_heads = no_heads
49 | self.starting = starting
50 | self.inf = inf
51 |
52 | self.layer_norm = LayerNorm(self.c_in)
53 |
54 | self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
55 |
56 | self.mha = Attention(
57 | self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
58 | )
59 |
60 | @torch.jit.ignore
61 | def _chunk(
62 | self,
63 | x: torch.Tensor,
64 | biases: List[torch.Tensor],
65 | chunk_size: int,
66 | use_memory_efficient_kernel: bool = False,
67 | use_deepspeed_evo_attention: bool = False,
68 | use_lma: bool = False,
69 | inplace_safe: bool = False,
70 | ) -> torch.Tensor:
71 | "triangle! triangle!"
72 | mha_inputs = {
73 | "q_x": x,
74 | "kv_x": x,
75 | "biases": biases,
76 | }
77 |
78 | return chunk_layer(
79 | partial(
80 | self.mha,
81 | use_memory_efficient_kernel=use_memory_efficient_kernel,
82 | use_deepspeed_evo_attention=use_deepspeed_evo_attention,
83 | use_lma=use_lma,
84 | ),
85 | mha_inputs,
86 | chunk_size=chunk_size,
87 | no_batch_dims=len(x.shape[:-2]),
88 | _out=x if inplace_safe else None,
89 | )
90 |
91 | def forward(
92 | self,
93 | x: torch.Tensor,
94 | mask: Optional[torch.Tensor] = None,
95 | chunk_size: Optional[int] = None,
96 | use_memory_efficient_kernel: bool = False,
97 | use_deepspeed_evo_attention: bool = False,
98 | use_lma: bool = False,
99 | inplace_safe: bool = False,
100 | ) -> torch.Tensor:
101 | """
102 | Args:
103 | x:
104 | [*, I, J, C_in] input tensor (e.g. the pair representation)
105 | Returns:
106 | [*, I, J, C_in] output tensor
107 | """
108 | if mask is None:
109 | # [*, I, J]
110 | mask = x.new_ones(
111 | x.shape[:-1],
112 | )
113 |
114 | if not self.starting:
115 | x = x.transpose(-2, -3)
116 | mask = mask.transpose(-1, -2)
117 |
118 | # [*, I, J, C_in]
119 | x = self.layer_norm(x)
120 |
121 | # [*, I, 1, 1, J]
122 | mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
123 |
124 | # [*, H, I, J]
125 | triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
126 |
127 | # [*, 1, H, I, J]
128 | triangle_bias = triangle_bias.unsqueeze(-4)
129 |
130 | biases = [mask_bias, triangle_bias]
131 |
132 | if chunk_size is not None:
133 | x = self._chunk(
134 | x,
135 | biases,
136 | chunk_size,
137 | use_memory_efficient_kernel=use_memory_efficient_kernel,
138 | use_deepspeed_evo_attention=use_deepspeed_evo_attention,
139 | use_lma=use_lma,
140 | inplace_safe=inplace_safe,
141 | )
142 | else:
143 | x = self.mha(
144 | q_x=x,
145 | kv_x=x,
146 | biases=biases,
147 | use_memory_efficient_kernel=use_memory_efficient_kernel,
148 | use_deepspeed_evo_attention=use_deepspeed_evo_attention,
149 | use_lma=use_lma,
150 | )
151 |
152 | if not self.starting:
153 | x = x.transpose(-2, -3)
154 |
155 | return x
156 |
157 |
158 | # Implements Algorithm 13
159 | TriangleAttentionStartingNode = TriangleAttention
160 |
161 |
162 | class TriangleAttentionEndingNode(TriangleAttention):
163 | """Implement Algorithm 14."""
164 |
165 | __init__ = partialmethod(TriangleAttention.__init__, starting=False)
166 |
--------------------------------------------------------------------------------
/scripts/train/assets/test_ids.txt:
--------------------------------------------------------------------------------
1 | 8BZ4
2 | 8URN
3 | 7U71
4 | 7Z64
5 | 7Y3Z
6 | 8SOT
7 | 8GH8
8 | 8IIB
9 | 7U08
10 | 8EB5
11 | 8G49
12 | 8K7Y
13 | 7QQD
14 | 8EIL
15 | 8JQE
16 | 8V1K
17 | 7ZRZ
18 | 7YN2
19 | 8D40
20 | 8RXO
21 | 8SXS
22 | 7UDL
23 | 8ADD
24 | 7Z3I
25 | 7YUK
26 | 7XWY
27 | 8F9Y
28 | 8WO7
29 | 8C27
30 | 8I3J
31 | 8HVC
32 | 8SXU
33 | 8K1I
34 | 8FTV
35 | 8ERC
36 | 8DVQ
37 | 8DTQ
38 | 8J12
39 | 8D0P
40 | 8POG
41 | 8HN0
42 | 7QPK
43 | 8AGR
44 | 8GXR
45 | 8K7X
46 | 8BL6
47 | 8HAW
48 | 8SRO
49 | 8HHM
50 | 8C26
51 | 7SPQ
52 | 8SME
53 | 7XGV
54 | 8GTY
55 | 8Q42
56 | 8BRY
57 | 8HDV
58 | 8B3Z
59 | 7XNJ
60 | 8EEL
61 | 8IOI
62 | 8Q70
63 | 8Y4U
64 | 8ANT
65 | 8IUB
66 | 8D49
67 | 8CPQ
68 | 8BAT
69 | 8E2B
70 | 8IWP
71 | 8IJT
72 | 7Y01
73 | 8CJG
74 | 8HML
75 | 8WU2
76 | 8VRM
77 | 8J1J
78 | 8DAJ
79 | 8SUT
80 | 8PTJ
81 | 8IVZ
82 | 8SDZ
83 | 7YDQ
84 | 8JU7
85 | 8K34
86 | 8B6Q
87 | 8F7N
88 | 8IBZ
89 | 7WOI
90 | 8R7D
91 | 8T65
92 | 8IQC
93 | 8SIU
94 | 8QK8
95 | 8HIG
96 | 7Y43
97 | 8IN8
98 | 8IBW
99 | 8GOY
100 | 7ZAO
101 | 8J9G
102 | 7ZCA
103 | 8HIO
104 | 8EFZ
105 | 8IQ8
106 | 8OQ0
107 | 8HHL
108 | 7XMW
109 | 8GI1
110 | 8AYR
111 | 7ZCB
112 | 8BRD
113 | 8IN6
114 | 8I3F
115 | 8HIU
116 | 8ER5
117 | 8WIL
118 | 7YPR
119 | 8UA2
120 | 8BW6
121 | 8IL8
122 | 8J3R
123 | 8K1F
124 | 8OHI
125 | 8WCT
126 | 8AN0
127 | 8BDQ
128 | 7FCT
129 | 8J69
130 | 8HTX
131 | 8PE3
132 | 8K5U
133 | 8AXT
134 | 8PSO
135 | 8JHR
136 | 8GY0
137 | 8QCW
138 | 8K3D
139 | 8P6J
140 | 8J0Q
141 | 7XS3
142 | 8DHJ
143 | 8EIN
144 | 7WKP
145 | 8GAQ
146 | 7WRN
147 | 8AHD
148 | 7SC4
149 | 8B3E
150 | 8AAS
151 | 8UZ8
152 | 8Q1K
153 | 8K5K
154 | 8B45
155 | 8PT7
156 | 7ZPN
157 | 8UQ9
158 | 8TJG
159 | 8TN8
160 | 8B2E
161 | 7XFZ
162 | 8FW7
163 | 8B3W
164 | 7T4W
165 | 8SVA
166 | 7YL4
167 | 8GLD
168 | 8OEI
169 | 8GMX
170 | 8OWF
171 | 8FNR
172 | 8IRQ
173 | 8JDG
174 | 7UXA
175 | 8TKA
176 | 7YH1
177 | 8HUZ
178 | 8TA2
179 | 8E5D
180 | 7YUN
181 | 7UOI
182 | 7WMY
183 | 8AA9
184 | 8ISZ
185 | 8EXA
186 | 8E7F
187 | 8B2S
188 | 8TP8
189 | 8GSY
190 | 7XRX
191 | 8SY3
192 | 8CIL
193 | 8WBR
194 | 7XF1
195 | 7YPO
196 | 8AXF
197 | 7QNL
198 | 8OYY
199 | 7R1N
200 | 8H5S
201 | 8B6U
202 | 8IBX
203 | 8Q43
204 | 8OW8
205 | 7XSG
206 | 8U0M
207 | 8IOO
208 | 8HR5
209 | 8BVK
210 | 8P0C
211 | 7TL6
212 | 8J48
213 | 8S0U
214 | 8K8A
215 | 8G53
216 | 7XYO
217 | 8POF
218 | 8U1K
219 | 8HF2
220 | 8K4L
221 | 8JAH
222 | 8KGZ
223 | 8BNB
224 | 7UG2
225 | 8A0A
226 | 8Q3Z
227 | 8XBI
228 | 8JNM
229 | 8GPS
230 | 8K1R
231 | 8Q66
232 | 7YLQ
233 | 7YNX
234 | 8IMD
235 | 7Y8H
236 | 8OXU
237 | 8BVE
238 | 8B4E
239 | 8V14
240 | 7R5I
241 | 8IR2
242 | 8UK7
243 | 8EBB
244 | 7XCC
245 | 8AEP
246 | 7YDW
247 | 8XX9
248 | 7VS6
249 | 8K3F
250 | 8CQM
251 | 7XH4
252 | 8BH9
253 | 7VXT
254 | 8SM9
255 | 8HGU
256 | 8PSQ
257 | 8SSU
258 | 8VXA
259 | 8GSX
260 | 8GHZ
261 | 8BJ3
262 | 8C9V
263 | 8T66
264 | 7XPC
265 | 8RH3
266 | 8CMQ
267 | 8AGG
268 | 8ERM
269 | 8P6M
270 | 8BUX
271 | 7S2J
272 | 8G32
273 | 8AXJ
274 | 8CID
275 | 8CPK
276 | 8P5Q
277 | 8HP8
278 | 7YUJ
279 | 8PT2
280 | 7YK3
281 | 7YYG
282 | 8ABV
283 | 7XL7
284 | 7YLZ
285 | 8JWS
286 | 8IW5
287 | 8SM6
288 | 8BBZ
289 | 8EOV
290 | 8PXC
291 | 7UWV
292 | 8A9N
293 | 7YH5
294 | 8DEO
295 | 7X2X
296 | 8W7P
297 | 8B5W
298 | 8CIH
299 | 8RB4
300 | 8HLG
301 | 8J8H
302 | 8UA5
303 | 7YKM
304 | 8S9W
305 | 7YPD
306 | 8GA6
307 | 7YPQ
308 | 8X7X
309 | 8HI8
310 | 8H7A
311 | 8C4D
312 | 8XAT
313 | 8W8S
314 | 8HM4
315 | 8H3Z
316 | 7W91
317 | 8GPP
318 | 8TNM
319 | 7YSI
320 | 8OML
321 | 8BBR
322 | 7YOJ
323 | 8JZX
324 | 8I3X
325 | 8AU6
326 | 8ITO
327 | 7SFY
328 | 8B6P
329 | 7Y8S
330 | 8ESL
331 | 8DSP
332 | 8CLZ
333 | 8F72
334 | 8QLD
335 | 8K86
336 | 8G8E
337 | 8QDO
338 | 8ANU
339 | 8PT6
340 | 8F5D
341 | 8DQ6
342 | 8IFK
343 | 8OJN
344 | 8SSC
345 | 7QRR
346 | 8E55
347 | 7TPU
348 | 7UQU
349 | 8HFP
350 | 7XGT
351 | 8A39
352 | 8CB2
353 | 8ACR
354 | 8G5S
355 | 7TZL
356 | 8T4R
357 | 8H18
358 | 7UI4
359 | 8Q41
360 | 8K76
361 | 7WUY
362 | 8VXC
363 | 8GYG
364 | 8IMS
365 | 8IKS
366 | 8X51
367 | 7Y7O
368 | 8PX4
369 | 8BF8
370 | 7XMJ
371 | 8GDW
372 | 7YTU
373 | 8CH4
374 | 7XHZ
375 | 7YH4
376 | 8PSN
377 | 8A16
378 | 8FBJ
379 | 7Y9G
380 | 8JI2
381 | 7YR9
382 | 8SW0
383 | 8A90
384 | 8X6V
385 | 8H8P
386 | 7WJU
387 | 8PSS
388 | 8HL8
389 | 8FJD
390 | 8PM4
391 | 7UK8
392 | 8DX0
393 | 8PHB
394 | 8FBN
395 | 8FXF
396 | 8GKH
397 | 8ENR
398 | 8PTH
399 | 8CBV
400 | 8GKV
401 | 8CQO
402 | 8OK3
403 | 8GSR
404 | 8TPK
405 | 8H1J
406 | 8QFL
407 | 8CHW
408 | 7V34
409 | 8HE2
410 | 7ZIE
411 | 8A50
412 | 7Z8E
413 | 8ILL
414 | 7WWC
415 | 7XVI
416 | 8Q2A
417 | 8HNO
418 | 8PR6
419 | 7XCA
420 | 7XGS
421 | 8H55
422 | 8FJE
423 | 7UNH
424 | 8AY2
425 | 8ARD
426 | 8HBR
427 | 8EWG
428 | 8D4A
429 | 8FIT
430 | 8E5E
431 | 8PMU
432 | 8F5G
433 | 8AMU
434 | 8CPN
435 | 7QPL
436 | 8EHN
437 | 8SQU
438 | 8F70
439 | 8FX9
440 | 7UR2
441 | 8T1M
442 | 7ZDS
443 | 7YH2
444 | 8B6A
445 | 8CHX
446 | 8G0N
447 | 8GY4
448 | 7YKG
449 | 8BH8
450 | 8BVI
451 | 7XF2
452 | 8BFY
453 | 8IA3
454 | 8JW3
455 | 8OQJ
456 | 8TFS
457 | 7Y1S
458 | 8HBB
459 | 8AF9
460 | 8IP1
461 | 7XZ3
462 | 8T0P
463 | 7Y16
464 | 8BRP
465 | 8JNX
466 | 8JP0
467 | 8EC3
468 | 8PZH
469 | 7URP
470 | 8B4D
471 | 8JFR
472 | 8GYR
473 | 7XFS
474 | 8SMQ
475 | 7WNH
476 | 8H0L
477 | 8OWI
478 | 8HFC
479 | 7X6G
480 | 8FKL
481 | 8PAG
482 | 8UPI
483 | 8D4B
484 | 8BCK
485 | 8JFU
486 | 8FUQ
487 | 8IF8
488 | 8PAQ
489 | 8HDU
490 | 8W9O
491 | 8ACA
492 | 7YIA
493 | 7ZFR
494 | 7Y9A
495 | 8TTO
496 | 7YFX
497 | 8B2H
498 | 8PSU
499 | 8ACC
500 | 8JMR
501 | 8IHA
502 | 7UYX
503 | 8DWJ
504 | 8BY5
505 | 8EZW
506 | 8A82
507 | 8TVL
508 | 8R79
509 | 8R8A
510 | 8AHZ
511 | 8AYV
512 | 8JHU
513 | 8Q44
514 | 8ARE
515 | 8OLJ
516 | 7Y95
517 | 7XP0
518 | 8EX9
519 | 8BID
520 | 8Q40
521 | 7QSJ
522 | 7UBA
523 | 7XFU
524 | 8OU1
525 | 8G2V
526 | 8YA7
527 | 8GMZ
528 | 8T8L
529 | 8CK0
530 | 7Y4H
531 | 8IOM
532 | 7ZLQ
533 | 8BZ2
534 | 8B4C
535 | 8DZJ
536 | 8CEG
537 | 8IBY
538 | 8T3J
539 | 8IVI
540 | 8ITN
541 | 8CR7
542 | 8TGH
543 | 8OKH
544 | 7UI8
545 | 8EHT
546 | 8ADC
547 | 8T4C
548 | 7XBJ
549 | 8CLU
550 | 7QA1
551 |
--------------------------------------------------------------------------------
/src/boltz/model/loss/diffusion.py:
--------------------------------------------------------------------------------
1 | # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2 |
3 | from einops import einsum
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | def weighted_rigid_align(
9 | true_coords,
10 | pred_coords,
11 | weights,
12 | mask,
13 | ):
14 | """Compute weighted alignment.
15 |
16 | Parameters
17 | ----------
18 | true_coords: torch.Tensor
19 | The ground truth atom coordinates
20 | pred_coords: torch.Tensor
21 | The predicted atom coordinates
22 | weights: torch.Tensor
23 | The weights for alignment
24 | mask: torch.Tensor
25 | The atoms mask
26 |
27 | Returns
28 | -------
29 | torch.Tensor
30 | Aligned coordinates
31 |
32 | """
33 |
34 | batch_size, num_points, dim = true_coords.shape
35 | weights = (mask * weights).unsqueeze(-1)
36 |
37 | # Compute weighted centroids
38 | true_centroid = (true_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
39 | dim=1, keepdim=True
40 | )
41 | pred_centroid = (pred_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
42 | dim=1, keepdim=True
43 | )
44 |
45 | # Center the coordinates
46 | true_coords_centered = true_coords - true_centroid
47 | pred_coords_centered = pred_coords - pred_centroid
48 |
49 | if num_points < (dim + 1):
50 | print(
51 | "Warning: The size of one of the point clouds is <= dim+1. "
52 | + "`WeightedRigidAlign` cannot return a unique rotation."
53 | )
54 |
55 | # Compute the weighted covariance matrix
56 | cov_matrix = einsum(
57 | weights * pred_coords_centered, true_coords_centered, "b n i, b n j -> b i j"
58 | )
59 |
60 | # Compute the SVD of the covariance matrix, required float32 for svd and determinant
61 | original_dtype = cov_matrix.dtype
62 | cov_matrix_32 = cov_matrix.to(dtype=torch.float32)
63 | U, S, V = torch.linalg.svd(
64 | cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None
65 | )
66 | V = V.mH
67 |
68 | # Catch ambiguous rotation by checking the magnitude of singular values
69 | if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)):
70 | print(
71 | "Warning: Excessively low rank of "
72 | + "cross-correlation between aligned point clouds. "
73 | + "`WeightedRigidAlign` cannot return a unique rotation."
74 | )
75 |
76 | # Compute the rotation matrix
77 | rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V).to(dtype=torch.float32)
78 |
79 | # Ensure proper rotation matrix with determinant 1
80 | F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[
81 | None
82 | ].repeat(batch_size, 1, 1)
83 | F[:, -1, -1] = torch.det(rot_matrix)
84 | rot_matrix = einsum(U, F, V, "b i j, b j k, b l k -> b i l")
85 | rot_matrix = rot_matrix.to(dtype=original_dtype)
86 |
87 | # Apply the rotation and translation
88 | aligned_coords = (
89 | einsum(true_coords_centered, rot_matrix, "b n i, b j i -> b n j")
90 | + pred_centroid
91 | )
92 | aligned_coords.detach_()
93 |
94 | return aligned_coords
95 |
96 |
97 | def smooth_lddt_loss(
98 | pred_coords,
99 | true_coords,
100 | is_nucleotide,
101 | coords_mask,
102 | nucleic_acid_cutoff: float = 30.0,
103 | other_cutoff: float = 15.0,
104 | multiplicity: int = 1,
105 | ):
106 | """Compute weighted alignment.
107 |
108 | Parameters
109 | ----------
110 | pred_coords: torch.Tensor
111 | The predicted atom coordinates
112 | true_coords: torch.Tensor
113 | The ground truth atom coordinates
114 | is_nucleotide: torch.Tensor
115 | The weights for alignment
116 | coords_mask: torch.Tensor
117 | The atoms mask
118 | nucleic_acid_cutoff: float
119 | The nucleic acid cutoff
120 | other_cutoff: float
121 | The non nucleic acid cutoff
122 | multiplicity: int
123 | The multiplicity
124 | Returns
125 | -------
126 | torch.Tensor
127 | Aligned coordinates
128 |
129 | """
130 | B, N, _ = true_coords.shape
131 | true_dists = torch.cdist(true_coords, true_coords)
132 | is_nucleotide = is_nucleotide.repeat_interleave(multiplicity, 0)
133 |
134 | coords_mask = coords_mask.repeat_interleave(multiplicity, 0)
135 | is_nucleotide_pair = is_nucleotide.unsqueeze(-1).expand(
136 | -1, -1, is_nucleotide.shape[-1]
137 | )
138 |
139 | mask = (
140 | is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float()
141 | + (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float()
142 | )
143 | mask = mask * (1 - torch.eye(pred_coords.shape[1], device=pred_coords.device))
144 | mask = mask * (coords_mask.unsqueeze(-1) * coords_mask.unsqueeze(-2))
145 |
146 | # Compute distances between all pairs of atoms
147 | pred_dists = torch.cdist(pred_coords, pred_coords)
148 | dist_diff = torch.abs(true_dists - pred_dists)
149 |
150 | # Compute epsilon values
151 | eps = (
152 | (
153 | (
154 | F.sigmoid(0.5 - dist_diff)
155 | + F.sigmoid(1.0 - dist_diff)
156 | + F.sigmoid(2.0 - dist_diff)
157 | + F.sigmoid(4.0 - dist_diff)
158 | )
159 | / 4.0
160 | )
161 | .view(multiplicity, B // multiplicity, N, N)
162 | .mean(dim=0)
163 | )
164 |
165 | # Calculate masked averaging
166 | eps = eps.repeat_interleave(multiplicity, 0)
167 | num = (eps * mask).sum(dim=(-1, -2))
168 | den = mask.sum(dim=(-1, -2)).clamp(min=1)
169 | lddt = num / den
170 |
171 | return 1.0 - lddt.mean()
172 |
--------------------------------------------------------------------------------
/scripts/train/assets/validation_ids.txt:
--------------------------------------------------------------------------------
1 | 7UTN
2 | 7F9H
3 | 7TZV
4 | 7ZHH
5 | 7SOV
6 | 7EOF
7 | 7R8H
8 | 8AW3
9 | 7F2F
10 | 8BAO
11 | 7BCB
12 | 7D8T
13 | 7D3T
14 | 7BHY
15 | 7YZ7
16 | 8DC2
17 | 7SOW
18 | 8CTL
19 | 7SOS
20 | 7V6W
21 | 7Z55
22 | 7NQF
23 | 7VTN
24 | 7KSP
25 | 7BJQ
26 | 7YZC
27 | 7Y3L
28 | 7TDX
29 | 7R8I
30 | 7OYK
31 | 7TZ1
32 | 7KIJ
33 | 7T8K
34 | 7KII
35 | 7YZA
36 | 7VP4
37 | 7KIK
38 | 7M5W
39 | 7Q94
40 | 7BCA
41 | 7YZB
42 | 7OG0
43 | 7VTI
44 | 7SOP
45 | 7S03
46 | 7YZG
47 | 7TXC
48 | 7VP5
49 | 7Y3I
50 | 7TDW
51 | 8B0R
52 | 7R8G
53 | 7FEF
54 | 7VP1
55 | 7VP3
56 | 7RGU
57 | 7DV2
58 | 7YZD
59 | 7OFZ
60 | 7Y3K
61 | 7TEC
62 | 7WQ5
63 | 7VP2
64 | 7EDB
65 | 7VP7
66 | 7PDV
67 | 7XHT
68 | 7R6R
69 | 8CSH
70 | 8CSZ
71 | 7V9O
72 | 7Q1C
73 | 8EDC
74 | 7PWI
75 | 7FI1
76 | 7ESI
77 | 7F0Y
78 | 7EYR
79 | 7ZVA
80 | 7WEG
81 | 7E4N
82 | 7U5Q
83 | 7FAV
84 | 7LJ2
85 | 7S6F
86 | 7B3N
87 | 7V4P
88 | 7AJO
89 | 7WH1
90 | 8DQP
91 | 7STT
92 | 7VQ7
93 | 7E4J
94 | 7RIS
95 | 7FH8
96 | 7BMW
97 | 7RD0
98 | 7V54
99 | 7LKC
100 | 7OU1
101 | 7QOD
102 | 7PX1
103 | 7EBY
104 | 7U1V
105 | 7PLP
106 | 7T8N
107 | 7SJK
108 | 7RGB
109 | 7TEM
110 | 7UG9
111 | 7B7A
112 | 7TM2
113 | 7Z74
114 | 7PCM
115 | 7V8G
116 | 7EUU
117 | 7VTL
118 | 7ZEI
119 | 7ZC0
120 | 7DZ9
121 | 8B2M
122 | 7NE9
123 | 7ALV
124 | 7M96
125 | 7O6T
126 | 7SKO
127 | 7Z2V
128 | 7OWX
129 | 7SHW
130 | 7TNI
131 | 7ZQY
132 | 7MDF
133 | 7EXR
134 | 7W6B
135 | 7EQF
136 | 7WWO
137 | 7FBW
138 | 8EHE
139 | 7CLE
140 | 7T80
141 | 7WMV
142 | 7SMG
143 | 7WSJ
144 | 7DBU
145 | 7VHY
146 | 7W5F
147 | 7SHG
148 | 7VU3
149 | 7ATH
150 | 7FGZ
151 | 7ADS
152 | 7REO
153 | 7T7H
154 | 7X0N
155 | 7TCU
156 | 7SKH
157 | 7EF6
158 | 7TBV
159 | 7B29
160 | 7VO5
161 | 7TM1
162 | 7QLD
163 | 7BB9
164 | 7SZ8
165 | 7RLM
166 | 7WWP
167 | 7NBV
168 | 7PLD
169 | 7DNM
170 | 7SFZ
171 | 7EAW
172 | 7QNQ
173 | 7SZX
174 | 7U2S
175 | 7WZX
176 | 7TYG
177 | 7QCE
178 | 7DCN
179 | 7WJL
180 | 7VV6
181 | 7TJ4
182 | 7VI8
183 | 8AKP
184 | 7WAO
185 | 7N7V
186 | 7EYO
187 | 7VTD
188 | 7VEG
189 | 7QY5
190 | 7ELV
191 | 7P0J
192 | 7YX8
193 | 7U4H
194 | 7TBD
195 | 7WME
196 | 7RI3
197 | 7TOH
198 | 7ZVM
199 | 7PUL
200 | 7VBO
201 | 7DM0
202 | 7XN9
203 | 7ALY
204 | 7LTB
205 | 8A28
206 | 7UBZ
207 | 8DTE
208 | 7TA2
209 | 7QST
210 | 7AN1
211 | 7FIB
212 | 8BAL
213 | 7TMJ
214 | 7REV
215 | 7PZJ
216 | 7T9X
217 | 7SUU
218 | 7KJQ
219 | 7V6P
220 | 7QA3
221 | 7ULC
222 | 7Y3X
223 | 7TMU
224 | 7OA7
225 | 7PO9
226 | 7Q20
227 | 8H2C
228 | 7VW1
229 | 7VLJ
230 | 8EP4
231 | 7P57
232 | 7QUL
233 | 7ZQE
234 | 7UJU
235 | 7WG1
236 | 7DMK
237 | 7Y8X
238 | 7EHG
239 | 7W13
240 | 7NL4
241 | 7R4J
242 | 7AOV
243 | 7RFT
244 | 7VUF
245 | 7F72
246 | 8DSR
247 | 7MK3
248 | 7MQQ
249 | 7R55
250 | 7T85
251 | 7NCY
252 | 7ZHL
253 | 7E1N
254 | 7W8F
255 | 7PGK
256 | 8GUN
257 | 7P8D
258 | 7PUK
259 | 7N9D
260 | 7XWN
261 | 7ZHA
262 | 7TVP
263 | 7VI6
264 | 7PW6
265 | 7YM0
266 | 7RWK
267 | 8DKR
268 | 7WGU
269 | 7LJI
270 | 7THW
271 | 7OB6
272 | 7N3Z
273 | 7T3S
274 | 7PAB
275 | 7F9F
276 | 7PPP
277 | 7AD5
278 | 7VGM
279 | 7WBO
280 | 7RWM
281 | 7QFI
282 | 7T91
283 | 7ANU
284 | 7UX0
285 | 7USR
286 | 7RDN
287 | 7VW5
288 | 7Q4T
289 | 7W3R
290 | 8DKQ
291 | 7RCX
292 | 7UOF
293 | 7OKR
294 | 7NX1
295 | 6ZBS
296 | 7VEV
297 | 8E8U
298 | 7WJ6
299 | 7MP4
300 | 7RPY
301 | 7R5Z
302 | 7VLM
303 | 7SNE
304 | 7WDW
305 | 8E19
306 | 7PP2
307 | 7Z5H
308 | 7P7I
309 | 7LJJ
310 | 7QPC
311 | 7VJS
312 | 7QOE
313 | 7KZH
314 | 7F6N
315 | 7TMI
316 | 7POH
317 | 8DKS
318 | 7YMO
319 | 6S5I
320 | 7N6O
321 | 7LYU
322 | 7POK
323 | 7BLK
324 | 7TCY
325 | 7W19
326 | 8B55
327 | 7SMU
328 | 7QFK
329 | 7T5T
330 | 7EPQ
331 | 7DCK
332 | 7S69
333 | 6ZSV
334 | 7ZGT
335 | 7TJ1
336 | 7V09
337 | 7ZHD
338 | 7ALL
339 | 7P1Y
340 | 7T71
341 | 7MNK
342 | 7W5Q
343 | 7PZ2
344 | 7QSQ
345 | 7QI3
346 | 7NZZ
347 | 7Q47
348 | 8D08
349 | 7QH5
350 | 7RXQ
351 | 7F45
352 | 8D07
353 | 8EHC
354 | 7PZT
355 | 7K3C
356 | 7ZGI
357 | 7MC4
358 | 7NPQ
359 | 7VD7
360 | 7XAN
361 | 7FDP
362 | 8A0K
363 | 7TXO
364 | 7ZB1
365 | 7V5V
366 | 7WWS
367 | 7PBK
368 | 8EBG
369 | 7N0J
370 | 7UMA
371 | 7T1S
372 | 8EHB
373 | 7DWC
374 | 7K6W
375 | 7WEJ
376 | 7LRH
377 | 7ZCV
378 | 7RKC
379 | 7X8C
380 | 7PV1
381 | 7UGK
382 | 7ULN
383 | 7A66
384 | 7R7M
385 | 7M0Q
386 | 7BGS
387 | 7UPP
388 | 7O62
389 | 7VKK
390 | 7L6Y
391 | 7VG4
392 | 7V2V
393 | 7ETN
394 | 7ZTB
395 | 7AOO
396 | 7OH2
397 | 7E0M
398 | 7PEG
399 | 8CUK
400 | 7ZP0
401 | 7T6A
402 | 7BTM
403 | 7DOV
404 | 7VVV
405 | 7P22
406 | 7RUO
407 | 7E40
408 | 7O5Y
409 | 7XPK
410 | 7R0K
411 | 8D04
412 | 7TYD
413 | 7LSV
414 | 7XSI
415 | 7RTZ
416 | 7UXR
417 | 7QH3
418 | 8END
419 | 8CYK
420 | 7MRJ
421 | 7DJL
422 | 7S5B
423 | 7XUX
424 | 7EV8
425 | 7R6S
426 | 7UH4
427 | 7R9X
428 | 7F7P
429 | 7ACW
430 | 7SPN
431 | 7W70
432 | 7Q5G
433 | 7DXN
434 | 7DK9
435 | 8DT0
436 | 7FDN
437 | 7DGX
438 | 7UJB
439 | 7X4O
440 | 7F4O
441 | 7T9W
442 | 8AID
443 | 7ERQ
444 | 7EQB
445 | 7YDG
446 | 7ETR
447 | 8D27
448 | 7OUU
449 | 7R5Y
450 | 7T8I
451 | 7UZT
452 | 7X8V
453 | 7QLH
454 | 7SAF
455 | 7EN6
456 | 8D4Y
457 | 7ESJ
458 | 7VWO
459 | 7SBE
460 | 7VYU
461 | 7RVJ
462 | 7FCL
463 | 7WUO
464 | 7WWF
465 | 7VMT
466 | 7SHJ
467 | 7SKP
468 | 7KOU
469 | 6ZSU
470 | 7VGW
471 | 7X45
472 | 8GYZ
473 | 8BFE
474 | 8DGL
475 | 7Z3H
476 | 8BD1
477 | 8A0J
478 | 7JRK
479 | 7QII
480 | 7X39
481 | 7Y6B
482 | 7OIY
483 | 7SBI
484 | 8A3I
485 | 7NLI
486 | 7F4U
487 | 7TVY
488 | 7X0O
489 | 7VMH
490 | 7EPN
491 | 7WBK
492 | 8BFJ
493 | 7XFP
494 | 7LXQ
495 | 7TIL
496 | 7O61
497 | 8B8B
498 | 7W2Q
499 | 8APR
500 | 7WZE
501 | 7NYQ
502 | 7RMX
503 | 7PGE
504 | 8F43
505 | 7N2K
506 | 7UXG
507 | 7SXN
508 | 7T5U
509 | 7R22
510 | 7E3T
511 | 7PTB
512 | 7OA8
513 | 7X5T
514 | 7PL7
515 | 7SQ5
516 | 7VBS
517 | 8D03
518 | 7TAE
519 | 7T69
520 | 7WF6
521 | 7LBU
522 | 8A06
523 | 8DA2
524 | 7QFL
525 | 7KUW
526 | 7X9R
527 | 7XT3
528 | 7RB4
529 | 7PT5
530 | 7RPS
531 | 7RXU
532 | 7TDY
533 | 7W89
534 | 7N9I
535 | 7T1M
536 | 7OBM
537 | 7K3X
538 | 7ZJC
539 | 8BDP
540 | 7V8W
541 | 7DJK
542 | 7W1K
543 | 7QFG
544 | 7DGY
545 | 7ZTQ
546 | 7F8A
547 | 7NEK
548 | 7CG9
549 | 7KOB
550 | 7TN7
551 | 8DYS
552 | 7WVR
553 |
--------------------------------------------------------------------------------
/src/boltz/data/write/writer.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict, replace
2 | from pathlib import Path
3 | from typing import Literal
4 |
5 | import numpy as np
6 | from pytorch_lightning import LightningModule, Trainer
7 | from pytorch_lightning.callbacks import BasePredictionWriter
8 | from torch import Tensor
9 |
10 | from boltz.data.types import (
11 | Interface,
12 | Record,
13 | Structure,
14 | )
15 | from boltz.data.write.mmcif import to_mmcif
16 | from boltz.data.write.pdb import to_pdb
17 |
18 |
19 | class BoltzWriter(BasePredictionWriter):
20 | """Custom writer for predictions."""
21 |
22 | def __init__(
23 | self,
24 | data_dir: str,
25 | output_dir: str,
26 | output_format: Literal["pdb", "mmcif"] = "mmcif",
27 | ) -> None:
28 | """Initialize the writer.
29 |
30 | Parameters
31 | ----------
32 | output_dir : str
33 | The directory to save the predictions.
34 |
35 | """
36 | super().__init__(write_interval="batch")
37 | if output_format not in ["pdb", "mmcif"]:
38 | msg = f"Invalid output format: {output_format}"
39 | raise ValueError(msg)
40 |
41 | self.data_dir = Path(data_dir)
42 | self.output_dir = Path(output_dir)
43 | self.output_format = output_format
44 | self.failed = 0
45 |
46 | # Create the output directories
47 | self.output_dir.mkdir(parents=True, exist_ok=True)
48 |
49 | def write_on_batch_end(
50 | self,
51 | trainer: Trainer, # noqa: ARG002
52 | pl_module: LightningModule, # noqa: ARG002
53 | prediction: dict[str, Tensor],
54 | batch_indices: list[int], # noqa: ARG002
55 | batch: dict[str, Tensor],
56 | batch_idx: int, # noqa: ARG002
57 | dataloader_idx: int, # noqa: ARG002
58 | ) -> None:
59 | """Write the predictions to disk."""
60 | if prediction["exception"]:
61 | self.failed += 1
62 | return
63 |
64 | # Get the records
65 | records: list[Record] = batch["record"]
66 |
67 | # Get the predictions
68 | coords = prediction["coords"]
69 | coords = coords.unsqueeze(0)
70 |
71 | pad_masks = prediction["masks"]
72 | if prediction.get("confidence") is not None:
73 | confidences = prediction["confidence"]
74 | confidences = confidences.reshape(len(records), -1).tolist()
75 | else:
76 | confidences = [0.0 for _ in range(len(records))]
77 |
78 | # Iterate over the records
79 | for record, coord, pad_mask, _confidence in zip(
80 | records, coords, pad_masks, confidences
81 | ):
82 | # Load the structure
83 | path = self.data_dir / f"{record.id}.npz"
84 | structure: Structure = Structure.load(path)
85 |
86 | # Compute chain map with masked removed, to be used later
87 | chain_map = {}
88 | for i, mask in enumerate(structure.mask):
89 | if mask:
90 | chain_map[len(chain_map)] = i
91 |
92 | # Remove masked chains completely
93 | structure = structure.remove_invalid_chains()
94 |
95 | for model_idx in range(coord.shape[0]):
96 | # Get model coord
97 | model_coord = coord[model_idx]
98 | # Unpad
99 | coord_unpad = model_coord[pad_mask.bool()]
100 | coord_unpad = coord_unpad.cpu().numpy()
101 |
102 | # New atom table
103 | atoms = structure.atoms
104 | atoms["coords"] = coord_unpad
105 | atoms["is_present"] = True
106 |
107 | # Mew residue table
108 | residues = structure.residues
109 | residues["is_present"] = True
110 |
111 | # Update the structure
112 | interfaces = np.array([], dtype=Interface)
113 | new_structure: Structure = replace(
114 | structure,
115 | atoms=atoms,
116 | residues=residues,
117 | interfaces=interfaces,
118 | )
119 |
120 | # Update chain info
121 | chain_info = []
122 | for chain in new_structure.chains:
123 | old_chain_idx = chain_map[chain["asym_id"]]
124 | old_chain_info = record.chains[old_chain_idx]
125 | new_chain_info = replace(
126 | old_chain_info,
127 | chain_id=int(chain["asym_id"]),
128 | valid=True,
129 | )
130 | chain_info.append(new_chain_info)
131 |
132 | # Save the structure
133 | struct_dir = self.output_dir / record.id
134 | struct_dir.mkdir(exist_ok=True)
135 |
136 | if self.output_format == "pdb":
137 | path = struct_dir / f"{record.id}_model_{model_idx}.pdb"
138 | with path.open("w") as f:
139 | f.write(to_pdb(new_structure))
140 | elif self.output_format == "mmcif":
141 | path = struct_dir / f"{record.id}_model_{model_idx}.cif"
142 | with path.open("w") as f:
143 | f.write(to_mmcif(new_structure))
144 | else:
145 | path = struct_dir / f"{record.id}_model_{model_idx}.npz"
146 | np.savez_compressed(path, **asdict(new_structure))
147 |
148 | def on_predict_epoch_end(
149 | self,
150 | trainer: Trainer, # noqa: ARG00s2
151 | pl_module: LightningModule, # noqa: ARG002
152 | ) -> None:
153 | """Print the number of failed examples."""
154 | # Print number of failed examples
155 | print(f"Number of failed examples: {self.failed}") # noqa: T201
156 |
--------------------------------------------------------------------------------
/src/boltz/data/tokenize/boltz.py:
--------------------------------------------------------------------------------
1 | from dataclasses import astuple, dataclass
2 |
3 | import numpy as np
4 |
5 | from boltz.data import const
6 | from boltz.data.tokenize.tokenizer import Tokenizer
7 | from boltz.data.types import Input, Token, TokenBond, Tokenized
8 |
9 |
10 | @dataclass
11 | class TokenData:
12 | """TokenData datatype."""
13 |
14 | token_idx: int
15 | atom_idx: int
16 | atom_num: int
17 | res_idx: int
18 | res_type: int
19 | sym_id: int
20 | asym_id: int
21 | entity_id: int
22 | mol_type: int
23 | center_idx: int
24 | disto_idx: int
25 | center_coords: np.ndarray
26 | disto_coords: np.ndarray
27 | resolved_mask: bool
28 | disto_mask: bool
29 |
30 |
31 | class BoltzTokenizer(Tokenizer):
32 | """Tokenize an input structure for training."""
33 |
34 | def tokenize(self, data: Input) -> Tokenized:
35 | """Tokenize the input data.
36 |
37 | Parameters
38 | ----------
39 | data : Inpput
40 | The input data.
41 |
42 | Returns
43 | -------
44 | Tokenized
45 | The tokenized data.
46 |
47 | """
48 | # Get structure data
49 | struct = data.structure
50 |
51 | # Create token data
52 | token_data = []
53 |
54 | # Keep track of atom_idx to token_idx
55 | token_idx = 0
56 | atom_to_token = {}
57 |
58 | # Filter to valid chains only
59 | chains = struct.chains[struct.mask]
60 |
61 | for chain in chains:
62 | # Get residue indices
63 | res_start = chain["res_idx"]
64 | res_end = chain["res_idx"] + chain["res_num"]
65 |
66 | for res in struct.residues[res_start:res_end]:
67 | # Get atom indices
68 | atom_start = res["atom_idx"]
69 | atom_end = res["atom_idx"] + res["atom_num"]
70 |
71 | # Standard residues are tokens
72 | if res["is_standard"]:
73 | # Get center and disto atoms
74 | center = struct.atoms[res["atom_center"]]
75 | disto = struct.atoms[res["atom_disto"]]
76 |
77 | # Token is present if centers are
78 | is_present = res["is_present"] & center["is_present"]
79 | is_disto_present = res["is_present"] & disto["is_present"]
80 |
81 | # Apply chain transformation
82 | c_coords = center["coords"]
83 | d_coords = disto["coords"]
84 |
85 | # Create token
86 | token = TokenData(
87 | token_idx=token_idx,
88 | atom_idx=res["atom_idx"],
89 | atom_num=res["atom_num"],
90 | res_idx=res["res_idx"],
91 | res_type=res["res_type"],
92 | sym_id=chain["sym_id"],
93 | asym_id=chain["asym_id"],
94 | entity_id=chain["entity_id"],
95 | mol_type=chain["mol_type"],
96 | center_idx=res["atom_center"],
97 | disto_idx=res["atom_disto"],
98 | center_coords=c_coords,
99 | disto_coords=d_coords,
100 | resolved_mask=is_present,
101 | disto_mask=is_disto_present,
102 | )
103 | token_data.append(astuple(token))
104 |
105 | # Update atom_idx to token_idx
106 | for atom_idx in range(atom_start, atom_end):
107 | atom_to_token[atom_idx] = token_idx
108 |
109 | token_idx += 1
110 |
111 | # Non-standard are tokenized per atom
112 | else:
113 | # We use the unk protein token as res_type
114 | unk_token = const.unk_token["PROTEIN"]
115 | unk_id = const.token_ids[unk_token]
116 |
117 | # Get atom coordinates
118 | atom_data = struct.atoms[atom_start:atom_end]
119 | atom_coords = atom_data["coords"]
120 |
121 | # Tokenize each atom
122 | for i, atom in enumerate(atom_data):
123 | # Token is present if atom is
124 | is_present = res["is_present"] & atom["is_present"]
125 | index = atom_start + i
126 |
127 | # Create token
128 | token = TokenData(
129 | token_idx=token_idx,
130 | atom_idx=index,
131 | atom_num=1,
132 | res_idx=res["res_idx"],
133 | res_type=unk_id,
134 | sym_id=chain["sym_id"],
135 | asym_id=chain["asym_id"],
136 | entity_id=chain["entity_id"],
137 | mol_type=chain["mol_type"],
138 | center_idx=index,
139 | disto_idx=index,
140 | center_coords=atom_coords[i],
141 | disto_coords=atom_coords[i],
142 | resolved_mask=is_present,
143 | disto_mask=is_present,
144 | )
145 | token_data.append(astuple(token))
146 |
147 | # Update atom_idx to token_idx
148 | atom_to_token[index] = token_idx
149 | token_idx += 1
150 |
151 | # Create token bonds
152 | token_bonds = []
153 |
154 | # Add atom-atom bonds from ligands
155 | for bond in struct.bonds:
156 | if (
157 | bond["atom_1"] not in atom_to_token
158 | or bond["atom_2"] not in atom_to_token
159 | ):
160 | continue
161 | token_bond = (
162 | atom_to_token[bond["atom_1"]],
163 | atom_to_token[bond["atom_2"]],
164 | )
165 | token_bonds.append(token_bond)
166 |
167 | # Add connection bonds (covalent)
168 | for conn in struct.connections:
169 | if (
170 | conn["atom_1"] not in atom_to_token
171 | or conn["atom_2"] not in atom_to_token
172 | ):
173 | continue
174 | token_bond = (
175 | atom_to_token[conn["atom_1"]],
176 | atom_to_token[conn["atom_2"]],
177 | )
178 | token_bonds.append(token_bond)
179 |
180 | # Consider adding missing bond for modified residues to standard?
181 | # I'm not sure it's necessary because the bond is probably always
182 | # the same and the model can use the residue indices to infer it
183 | token_data = np.array(token_data, dtype=Token)
184 | token_bonds = np.array(token_bonds, dtype=TokenBond)
185 | tokenized = Tokenized(
186 | token_data,
187 | token_bonds,
188 | data.structure,
189 | data.msa,
190 | )
191 | return tokenized
192 |
--------------------------------------------------------------------------------
/scripts/train/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from dataclasses import dataclass
4 | from pathlib import Path
5 | from typing import Optional
6 |
7 | import hydra
8 | import omegaconf
9 | import pytorch_lightning as pl
10 | import torch
11 | import torch.multiprocessing
12 | from omegaconf import OmegaConf, listconfig
13 | from pytorch_lightning import LightningModule
14 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
15 | from pytorch_lightning.loggers import WandbLogger
16 | from pytorch_lightning.strategies import DDPStrategy
17 | from pytorch_lightning.utilities import rank_zero_only
18 |
19 | from boltz.data.module.training import BoltzTrainingDataModule, DataConfig
20 |
21 |
22 | @dataclass
23 | class TrainConfig:
24 | """Train configuration.
25 |
26 | Attributes
27 | ----------
28 | data : DataConfig
29 | The data configuration.
30 | model : ModelConfig
31 | The model configuration.
32 | output : str
33 | The output directory.
34 | trainer : Optional[dict]
35 | The trainer configuration.
36 | resume : Optional[str]
37 | The resume checkpoint.
38 | pretrained : Optional[str]
39 | The pretrained model.
40 | wandb : Optional[dict]
41 | The wandb configuration.
42 | disable_checkpoint : bool
43 | Disable checkpoint.
44 | matmul_precision : Optional[str]
45 | The matmul precision.
46 | find_unused_parameters : Optional[bool]
47 | Find unused parameters.
48 | save_top_k : Optional[int]
49 | Save top k checkpoints.
50 | validation_only : bool
51 | Run validation only.
52 | debug : bool
53 | Debug mode.
54 | strict_loading : bool
55 | Fail on mismatched checkpoint weights.
56 | load_confidence_from_trunk: Optional[bool]
57 | Load pre-trained confidence weights from trunk.
58 |
59 | """
60 |
61 | data: DataConfig
62 | model: LightningModule
63 | output: str
64 | trainer: Optional[dict] = None
65 | resume: Optional[str] = None
66 | pretrained: Optional[str] = None
67 | wandb: Optional[dict] = None
68 | disable_checkpoint: bool = False
69 | matmul_precision: Optional[str] = None
70 | find_unused_parameters: Optional[bool] = False
71 | save_top_k: Optional[int] = 1
72 | validation_only: bool = False
73 | debug: bool = False
74 | strict_loading: bool = True
75 | load_confidence_from_trunk: Optional[bool] = False
76 |
77 |
78 | def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR0915
79 | """Run training.
80 |
81 | Parameters
82 | ----------
83 | raw_config : str
84 | The input yaml configuration.
85 | args : list[str]
86 | Any command line overrides.
87 |
88 | """
89 | # Load the configuration
90 | raw_config = omegaconf.OmegaConf.load(raw_config)
91 |
92 | # Apply input arguments
93 | args = omegaconf.OmegaConf.from_dotlist(args)
94 | raw_config = omegaconf.OmegaConf.merge(raw_config, args)
95 |
96 | # Instantiate the task
97 | cfg = hydra.utils.instantiate(raw_config)
98 | cfg = TrainConfig(**cfg)
99 |
100 | # Set matmul precision
101 | if cfg.matmul_precision is not None:
102 | torch.set_float32_matmul_precision(cfg.matmul_precision)
103 |
104 | # Create trainer dict
105 | trainer = cfg.trainer
106 | if trainer is None:
107 | trainer = {}
108 |
109 | # Flip some arguments in debug mode
110 | devices = trainer.get("devices", 1)
111 |
112 | wandb = cfg.wandb
113 | if cfg.debug:
114 | if isinstance(devices, int):
115 | devices = 1
116 | elif isinstance(devices, (list, listconfig.ListConfig)):
117 | devices = [devices[0]]
118 | trainer["devices"] = devices
119 | cfg.data.num_workers = 0
120 | if wandb:
121 | wandb = None
122 |
123 | # Create objects
124 | data_config = DataConfig(**cfg.data)
125 | data_module = BoltzTrainingDataModule(data_config)
126 | model_module = cfg.model
127 |
128 | if cfg.pretrained and not cfg.resume:
129 | # Load the pretrained weights into the confidence module
130 | if cfg.load_confidence_from_trunk:
131 | checkpoint = torch.load(cfg.pretrained, map_location="cpu")
132 |
133 | # Modify parameter names in the state_dict
134 | new_state_dict = {}
135 | for key, value in checkpoint["state_dict"].items():
136 | if not key.startswith("structure_module") and not key.startswith(
137 | "distogram_module"
138 | ):
139 | new_key = "confidence_module." + key
140 | new_state_dict[new_key] = value
141 | new_state_dict.update(checkpoint["state_dict"])
142 |
143 | # Update the checkpoint with the new state_dict
144 | checkpoint["state_dict"] = new_state_dict
145 | else:
146 | file_path = cfg.pretrained
147 |
148 | print(f"Loading model from {file_path}")
149 | model_module = type(model_module).load_from_checkpoint(
150 | file_path, strict=False, **(model_module.hparams)
151 | )
152 |
153 | if cfg.load_confidence_from_trunk:
154 | os.remove(file_path)
155 |
156 | # Create checkpoint callback
157 | callbacks = []
158 | dirpath = cfg.output
159 | if not cfg.disable_checkpoint:
160 | mc = ModelCheckpoint(
161 | monitor="val/lddt",
162 | save_top_k=cfg.save_top_k,
163 | save_last=True,
164 | mode="max",
165 | every_n_epochs=1,
166 | )
167 | callbacks = [mc]
168 |
169 | # Create wandb logger
170 | loggers = []
171 | if wandb:
172 | wdb_logger = WandbLogger(
173 | group=wandb["name"],
174 | save_dir=cfg.output,
175 | project=wandb["project"],
176 | entity=wandb["entity"],
177 | log_model=False,
178 | )
179 | loggers.append(wdb_logger)
180 | # Save the config to wandb
181 |
182 | @rank_zero_only
183 | def save_config_to_wandb() -> None:
184 | config_out = Path(wdb_logger.experiment.dir) / "run.yaml"
185 | with Path.open(config_out, "w") as f:
186 | OmegaConf.save(raw_config, f)
187 | wdb_logger.experiment.save(str(config_out))
188 |
189 | save_config_to_wandb()
190 |
191 | # Set up trainer
192 | strategy = "auto"
193 | if (isinstance(devices, int) and devices > 1) or (
194 | isinstance(devices, (list, listconfig.ListConfig)) and len(devices) > 1
195 | ):
196 | strategy = DDPStrategy(find_unused_parameters=cfg.find_unused_parameters)
197 |
198 | trainer = pl.Trainer(
199 | default_root_dir=str(dirpath),
200 | strategy=strategy,
201 | callbacks=callbacks,
202 | logger=loggers,
203 | enable_checkpointing=not cfg.disable_checkpoint,
204 | reload_dataloaders_every_n_epochs=1,
205 | **trainer,
206 | )
207 |
208 | if not cfg.strict_loading:
209 | model_module.strict_loading = False
210 |
211 | if cfg.validation_only:
212 | trainer.validate(
213 | model_module,
214 | datamodule=data_module,
215 | ckpt_path=cfg.resume,
216 | )
217 | else:
218 | trainer.fit(
219 | model_module,
220 | datamodule=data_module,
221 | ckpt_path=cfg.resume,
222 | )
223 |
224 |
225 | if __name__ == "__main__":
226 | arg1 = sys.argv[1]
227 | arg2 = sys.argv[2:]
228 | train(arg1, arg2)
229 |
--------------------------------------------------------------------------------
/src/boltz/data/write/mmcif.py:
--------------------------------------------------------------------------------
1 | import io
2 | from typing import Iterator
3 |
4 | import ihm
5 | from modelcif import Assembly, AsymUnit, Entity, System, dumper
6 | from modelcif.model import AbInitioModel, Atom, ModelGroup
7 | from rdkit import Chem
8 |
9 | from boltz.data import const
10 | from boltz.data.types import Structure
11 | from boltz.data.write.utils import generate_tags
12 |
13 |
14 | def to_mmcif(structure: Structure) -> str: # noqa: C901
15 | """Write a structure into an MMCIF file.
16 |
17 | Parameters
18 | ----------
19 | structure : Structure
20 | The input structure
21 |
22 | Returns
23 | -------
24 | str
25 | the output MMCIF file
26 |
27 | """
28 | system = System()
29 |
30 | # Load periodic table for element mapping
31 | periodic_table = Chem.GetPeriodicTable()
32 |
33 | # Map entities to chain_ids
34 | entity_to_chains = {}
35 | entity_to_moltype = {}
36 |
37 | for chain in structure.chains:
38 | entity_id = chain["entity_id"]
39 | mol_type = chain["mol_type"]
40 | entity_to_chains.setdefault(entity_id, []).append(chain)
41 | entity_to_moltype[entity_id] = mol_type
42 |
43 | # Map entities to sequences
44 | sequences = {}
45 | for entity in entity_to_chains:
46 | # Get the first chain
47 | chain = entity_to_chains[entity][0]
48 |
49 | # Get the sequence
50 | res_start = chain["res_idx"]
51 | res_end = chain["res_idx"] + chain["res_num"]
52 | residues = structure.residues[res_start:res_end]
53 | sequence = [str(res["name"]) for res in residues]
54 | sequences[entity] = sequence
55 |
56 | # Create entity objects
57 | entities_map = {}
58 | for entity, sequence in sequences.items():
59 | mol_type = entity_to_moltype[entity]
60 |
61 | if mol_type == const.chain_type_ids["PROTEIN"]:
62 | alphabet = ihm.LPeptideAlphabet()
63 | chem_comp = lambda x: ihm.LPeptideChemComp(id=x, code=x, code_canonical="X") # noqa: E731
64 | elif mol_type == const.chain_type_ids["DNA"]:
65 | alphabet = ihm.DNAAlphabet()
66 | chem_comp = lambda x: ihm.DNAChemComp(id=x, code=x, code_canonical="N") # noqa: E731
67 | elif mol_type == const.chain_type_ids["RNA"]:
68 | alphabet = ihm.RNAAlphabet()
69 | chem_comp = lambda x: ihm.RNAChemComp(id=x, code=x, code_canonical="N") # noqa: E731
70 | elif len(sequence) > 1:
71 | alphabet = {}
72 | chem_comp = lambda x: ihm.SaccharideChemComp(id=x) # noqa: E731
73 | else:
74 | alphabet = {}
75 | chem_comp = lambda x: ihm.NonPolymerChemComp(id=x) # noqa: E731
76 |
77 | seq = [
78 | alphabet[item] if item in alphabet else chem_comp(item) for item in sequence
79 | ]
80 | model_e = Entity(seq)
81 | for chain in entity_to_chains[entity]:
82 | chain_idx = chain["asym_id"]
83 | entities_map[chain_idx] = model_e
84 |
85 | # We don't assume that symmetry is perfect, so we dump everything
86 | # into the asymmetric unit, and produce just a single assembly
87 | chain_tags = generate_tags()
88 | asym_unit_map = {}
89 | for chain in structure.chains:
90 | # Define the model assembly
91 | chain_idx = chain["asym_id"]
92 | chain_tag = next(chain_tags)
93 | asym = AsymUnit(
94 | entities_map[chain_idx],
95 | details="Model subunit %s" % chain_tag,
96 | id=chain_tag,
97 | )
98 | asym_unit_map[chain_idx] = asym
99 | modeled_assembly = Assembly(asym_unit_map.values(), name="Modeled assembly")
100 |
101 | # class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
102 | # name = "pLDDT"
103 | # software = None
104 | # description = "Predicted lddt"
105 |
106 | # class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT):
107 | # name = "pLDDT"
108 | # software = None
109 | # description = "Global pLDDT, mean of per-residue pLDDTs"
110 |
111 | class _MyModel(AbInitioModel):
112 | def get_atoms(self) -> Iterator[Atom]:
113 | # Add all atom sites.
114 | for chain in structure.chains:
115 | # We rename the chains in alphabetical order
116 | het = chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]
117 | chain_idx = chain["asym_id"]
118 | res_start = chain["res_idx"]
119 | res_end = chain["res_idx"] + chain["res_num"]
120 |
121 | residues = structure.residues[res_start:res_end]
122 | for residue in residues:
123 | atom_start = residue["atom_idx"]
124 | atom_end = residue["atom_idx"] + residue["atom_num"]
125 | atoms = structure.atoms[atom_start:atom_end]
126 | atom_coords = atoms["coords"]
127 | for i, atom in enumerate(atoms):
128 | # This should not happen on predictions, but just in case.
129 | if not atom["is_present"]:
130 | continue
131 |
132 | name = atom["name"]
133 | name = [chr(c + 32) for c in name if c != 0]
134 | name = "".join(name)
135 | element = periodic_table.GetElementSymbol(
136 | atom["element"].item()
137 | )
138 | element = element.upper()
139 | residue_index = residue["res_idx"] + 1
140 | pos = atom_coords[i]
141 | yield Atom(
142 | asym_unit=asym_unit_map[chain_idx],
143 | type_symbol=element,
144 | seq_id=residue_index,
145 | atom_id=name,
146 | x=pos[0],
147 | y=pos[1],
148 | z=pos[2],
149 | het=het,
150 | biso=1.00,
151 | occupancy=1.00,
152 | )
153 |
154 | def add_scores(self):
155 | return
156 | # local scores
157 | # plddt_per_residue = {}
158 | # for i in range(n):
159 | # for mask, b_factor in zip(atom_mask[i], b_factors[i]):
160 | # if mask < 0.5:
161 | # continue
162 | # # add 1 per residue, not 1 per atom
163 | # if chain_index[i] not in plddt_per_residue:
164 | # # first time a chain index is seen: add the key and start the residue dict
165 | # plddt_per_residue[chain_index[i]] = {residue_index[i]: b_factor}
166 | # if residue_index[i] not in plddt_per_residue[chain_index[i]]:
167 | # plddt_per_residue[chain_index[i]][residue_index[i]] = b_factor
168 | # plddts = []
169 | # for chain_idx in plddt_per_residue:
170 | # for residue_idx in plddt_per_residue[chain_idx]:
171 | # plddt = plddt_per_residue[chain_idx][residue_idx]
172 | # plddts.append(plddt)
173 | # self.qa_metrics.append(
174 | # _LocalPLDDT(
175 | # asym_unit_map[chain_idx].residue(residue_idx), plddt
176 | # )
177 | # )
178 | # # global score
179 | # self.qa_metrics.append((_GlobalPLDDT(np.mean(plddts))))
180 |
181 | # Add the model and modeling protocol to the file and write them out:
182 | model = _MyModel(assembly=modeled_assembly, name="Model")
183 | # model.add_scores()
184 |
185 | model_group = ModelGroup([model], name="All models")
186 | system.model_groups.append(model_group)
187 |
188 | fh = io.StringIO()
189 | dumper.write(fh, [system])
190 | return fh.getvalue()
191 |
--------------------------------------------------------------------------------
/src/boltz/data/module/inference.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | import pytorch_lightning as pl
5 | import torch
6 | from torch import Tensor
7 | from torch.utils.data import DataLoader
8 |
9 | from boltz.data import const
10 | from boltz.data.feature.featurizer import BoltzFeaturizer
11 | from boltz.data.feature.pad import pad_to_max
12 | from boltz.data.tokenize.boltz import BoltzTokenizer
13 | from boltz.data.types import MSA, Input, Manifest, Record, Structure
14 |
15 |
16 | def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input:
17 | """Load the given input data.
18 |
19 | Parameters
20 | ----------
21 | record : Record
22 | The record to load.
23 | target_dir : Path
24 | The path to the data directory.
25 | msa_dir : Path
26 | The path to msa directory.
27 |
28 | Returns
29 | -------
30 | Input
31 | The loaded input.
32 |
33 | """
34 | # Load the structure
35 | structure = np.load(target_dir / f"{record.id}.npz")
36 | structure = Structure(
37 | atoms=structure["atoms"],
38 | bonds=structure["bonds"],
39 | residues=structure["residues"],
40 | chains=structure["chains"],
41 | connections=structure["connections"],
42 | interfaces=structure["interfaces"],
43 | mask=structure["mask"],
44 | )
45 |
46 | msas = {}
47 | for chain in record.chains:
48 | msa_id = chain.msa_id
49 | # Load the MSA for this chain, if any
50 | if msa_id != -1:
51 | msa = np.load(msa_dir / f"{msa_id}.npz")
52 | msas[chain.chain_id] = MSA(**msa)
53 |
54 | return Input(structure, msas)
55 |
56 |
57 | def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
58 | """Collate the data.
59 |
60 | Parameters
61 | ----------
62 | data : List[Dict[str, Tensor]]
63 | The data to collate.
64 |
65 | Returns
66 | -------
67 | Dict[str, Tensor]
68 | The collated data.
69 |
70 | """
71 | # Get the keys
72 | keys = data[0].keys()
73 |
74 | # Collate the data
75 | collated = {}
76 | for key in keys:
77 | values = [d[key] for d in data]
78 |
79 | if key not in [
80 | "all_coords",
81 | "all_resolved_mask",
82 | "crop_to_all_atom_map",
83 | "chain_symmetries",
84 | "amino_acids_symmetries",
85 | "ligand_symmetries",
86 | "record",
87 | ]:
88 | # Check if all have the same shape
89 | shape = values[0].shape
90 | if not all(v.shape == shape for v in values):
91 | values, _ = pad_to_max(values, 0)
92 | else:
93 | values = torch.stack(values, dim=0)
94 |
95 | # Stack the values
96 | collated[key] = values
97 |
98 | return collated
99 |
100 |
101 | class PredictionDataset(torch.utils.data.Dataset):
102 | """Base iterable dataset."""
103 |
104 | def __init__(
105 | self,
106 | manifest: Manifest,
107 | target_dir: Path,
108 | msa_dir: Path,
109 | ) -> None:
110 | """Initialize the training dataset.
111 |
112 | Parameters
113 | ----------
114 | manifest : Manifest
115 | The manifest to load data from.
116 | target_dir : Path
117 | The path to the target directory.
118 | msa_dir : Path
119 | The path to the msa directory.
120 |
121 | """
122 | super().__init__()
123 | self.manifest = manifest
124 | self.target_dir = target_dir
125 | self.msa_dir = msa_dir
126 | self.tokenizer = BoltzTokenizer()
127 | self.featurizer = BoltzFeaturizer()
128 |
129 | def __getitem__(self, idx: int) -> dict:
130 | """Get an item from the dataset.
131 |
132 | Returns
133 | -------
134 | Dict[str, Tensor]
135 | The sampled data features.
136 |
137 | """
138 | # Get a sample from the dataset
139 | record = self.manifest.records[idx]
140 |
141 | # Get the structure
142 | try:
143 | input_data = load_input(record, self.target_dir, self.msa_dir)
144 | except Exception as e: # noqa: BLE001
145 | print(f"Failed to load input for {record.id} with error {e}. Skipping.") # noqa: T201
146 | return self.__getitem__(0)
147 |
148 | # Tokenize structure
149 | try:
150 | tokenized = self.tokenizer.tokenize(input_data)
151 | except Exception as e: # noqa: BLE001
152 | print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
153 | return self.__getitem__(0)
154 |
155 | # Compute features
156 | try:
157 | features = self.featurizer.process(
158 | tokenized,
159 | training=False,
160 | max_atoms=None,
161 | max_tokens=None,
162 | max_seqs=const.max_msa_seqs,
163 | pad_to_max_seqs=False,
164 | symmetries={},
165 | compute_symmetries=False,
166 | )
167 | except Exception as e: # noqa: BLE001
168 | print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
169 | return self.__getitem__(0)
170 |
171 | features["record"] = record
172 | return features
173 |
174 | def __len__(self) -> int:
175 | """Get the length of the dataset.
176 |
177 | Returns
178 | -------
179 | int
180 | The length of the dataset.
181 |
182 | """
183 | return len(self.manifest.records)
184 |
185 |
186 | class BoltzInferenceDataModule(pl.LightningDataModule):
187 | """DataModule for Boltz inference."""
188 |
189 | def __init__(
190 | self,
191 | manifest: Manifest,
192 | target_dir: Path,
193 | msa_dir: Path,
194 | num_workers: int,
195 | ) -> None:
196 | """Initialize the DataModule.
197 |
198 | Parameters
199 | ----------
200 | config : DataConfig
201 | The data configuration.
202 |
203 | """
204 | super().__init__()
205 | self.num_workers = num_workers
206 | self.manifest = manifest
207 | self.target_dir = target_dir
208 | self.msa_dir = msa_dir
209 |
210 | def predict_dataloader(self) -> DataLoader:
211 | """Get the training dataloader.
212 |
213 | Returns
214 | -------
215 | DataLoader
216 | The training dataloader.
217 |
218 | """
219 | dataset = PredictionDataset(
220 | manifest=self.manifest,
221 | target_dir=self.target_dir,
222 | msa_dir=self.msa_dir,
223 | )
224 | return DataLoader(
225 | dataset,
226 | batch_size=1,
227 | num_workers=self.num_workers,
228 | pin_memory=True,
229 | shuffle=False,
230 | collate_fn=collate,
231 | )
232 |
233 | def transfer_batch_to_device(
234 | self,
235 | batch: dict,
236 | device: torch.device,
237 | dataloader_idx: int, # noqa: ARG002
238 | ) -> dict:
239 | """Transfer a batch to the given device.
240 |
241 | Parameters
242 | ----------
243 | batch : Dict
244 | The batch to transfer.
245 | device : torch.device
246 | The device to transfer to.
247 | dataloader_idx : int
248 | The dataloader index.
249 |
250 | Returns
251 | -------
252 | np.Any
253 | The transferred batch.
254 |
255 | """
256 | for key in batch:
257 | if key not in [
258 | "all_coords",
259 | "all_resolved_mask",
260 | "crop_to_all_atom_map",
261 | "chain_symmetries",
262 | "amino_acids_symmetries",
263 | "ligand_symmetries",
264 | "record",
265 | ]:
266 | batch[key] = batch[key].to(device)
267 | return batch
268 |
--------------------------------------------------------------------------------
/docs/prediction.md:
--------------------------------------------------------------------------------
1 | # Prediction
2 |
3 | Once you have installed `boltz`, you can start making predictions by simply running:
4 |
5 | `boltz predict `
6 |
7 | where `` is a path to the input file or a directory. The input file can either be in fasta (enough for most use cases) or YAML format (for more complex inputs). If you specify a directory, `boltz` will run predictions on each `.yaml` or `.fasta` file in the directory.
8 |
9 | Before diving into more details about the input formats, here are the key differences in what they each support:
10 |
11 | | Feature | Fasta | YAML |
12 | | -------- |--------------------| ------- |
13 | | Polymers | :white_check_mark: | :white_check_mark: |
14 | | Smiles | :white_check_mark: | :white_check_mark: |
15 | | CCD code | :white_check_mark: | :white_check_mark: |
16 | | Custom MSA | :white_check_mark: | :white_check_mark: |
17 | | Modified Residues | :x: | :white_check_mark: |
18 | | Covalent bonds | :x: | :white_check_mark: |
19 | | Pocket conditioning | :x: | :white_check_mark: |
20 |
21 |
22 |
23 | ## Fasta format
24 |
25 | The fasta format should contain entries as follows:
26 |
27 | ```
28 | >CHAIN_ID|ENTITY_TYPE|MSA_PATH
29 | SEQUENCE
30 | ```
31 |
32 | Where `CHAIN_ID` is a unique identifier for each input chain, `ENTITY_TYPE` can be one of `protein`, `dna`, `rna`, `smiles`, `ccd` and `MSA_PATH` is only specified for protein entities and is the path to the `.a3m` file containing a computed MSA for the sequence of the protein. Note that we support both smiles and CCD code for ligands.
33 |
34 | For each of these cases, the corresponding `SEQUENCE` will contain an amino acid sequence (e.g. `EFKEAFSLF`), a sequence of nucleotide bases (e.g. `ATCG`), a smiles string (e.g. `CC1=CC=CC=C1`), or a CCD code (e.g. `ATP`), depending on the entity.
35 |
36 | As an example:
37 |
38 | ```yaml
39 | >A|protein|./examples/msa/seq1.a3m
40 | MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
41 | >B|protein|./examples/msa/seq1.a3m
42 | MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
43 | >C|ccd
44 | SAH
45 | >D|ccd
46 | SAH
47 | >E|smiles
48 | N[C@@H](Cc1ccc(O)cc1)C(=O)O
49 | >F|smiles
50 | N[C@@H](Cc1ccc(O)cc1)C(=O)O
51 | ```
52 |
53 |
54 | ## YAML format
55 |
56 | The YAML format is more flexible and allows for more complex inputs, particularly around covalent bonds. The schema of the YAML is the following:
57 |
58 | ```yaml
59 | sequences:
60 | - ENTITY_TYPE:
61 | id: CHAIN_ID
62 | sequence: SEQUENCE # only for protein, dna, rna
63 | smiles: SMILES # only for ligand, exclusive with ccd
64 | ccd: CCD # only for ligand, exclusive with smiles
65 | msa: MSA_PATH # only for protein
66 | modifications:
67 | - position: RES_IDX # index of residue, starting from 1
68 | ccd: CCD # CCD code of the modified residue
69 |
70 | - ENTITY_TYPE:
71 | id: [CHAIN_ID, CHAIN_ID] # multiple ids in case of multiple identical entities
72 | ...
73 | constraints:
74 | - bond:
75 | atom1: [CHAIN_ID, RES_IDX, ATOM_NAME]
76 | atom2: [CHAIN_ID, RES_IDX, ATOM_NAME]
77 | - pocket:
78 | binder: CHAIN_ID
79 | contacts: [[CHAIN_ID, RES_IDX], [CHAIN_ID, RES_IDX]]
80 | ```
81 | `sequences` has one entry for every unique chain/molecule in the input. Each polymer entity as a `ENTITY_TYPE` either `protein`, `dna` or`rna` and have a `sequence` attribute. Non-polymer entities are indicated by `ENTITY_TYPE` equal to `ligand` and have a `smiles` or `ccd` attribute. `CHAIN_ID` is the unique identifier for each chain/molecule, and it should be set as a list in case of multiple identical entities in the structure. Protein entities should also contain an `msa` attribute with `MSA_PATH` indicating the path to the `.a3m` file containing a computed MSA for the sequence of the protein.
82 |
83 | The `modifications` field is an optional field that allows you to specify modified residues in the polymer (`protein`, `dna` or`rna`). The `position` field specifies the index (starting from 1) of the residue, and `ccd` is the CCD code of the modified residue. This field is currently only supported for CCD ligands.
84 |
85 | `constraints` is an optional field that allows you to specify additional information about the input structure. Currently, we support just `bond`. The `bond` constraint specifies a covalent bonds between two atoms (`atom1` and `atom2`). It is currently only supported for CCD ligands and canonical residues, `CHAIN_ID` refers to the id of the residue set above, `RES_IDX` is the index (starting from 1) of the residue (1 for ligands), and `ATOM_NAME` is the standardized atom name (can be verified in CIF file of that component on the RCSB website).
86 |
87 | As an example:
88 |
89 | ```yaml
90 | version: 1
91 | sequences:
92 | - protein:
93 | id: [A, B]
94 | sequence: MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
95 | msa: ./examples/msa/seq1.a3m
96 | - ligand:
97 | id: [C, D]
98 | ccd: SAH
99 | - ligand:
100 | id: [E, F]
101 | smiles: N[C@@H](Cc1ccc(O)cc1)C(=O)O
102 | ```
103 |
104 |
105 | ## Options
106 |
107 | The following options are available for the `predict` command:
108 |
109 | boltz predict [OPTIONS] input_path
110 |
111 | | **Option** | **Type** | **Default** | **Description** |
112 | |-----------------------------|-----------------|--------------------|---------------------------------------------------------------------------------|
113 | | `--out_dir PATH` | `PATH` | `./` | The path where to save the predictions. |
114 | | `--cache PATH` | `PATH` | `~/.boltz` | The directory where to download the data and model. |
115 | | `--checkpoint PATH` | `PATH` | None | An optional checkpoint. Uses the provided Boltz-1 model by default. |
116 | | `--devices INTEGER` | `INTEGER` | `1` | The number of devices to use for prediction. |
117 | | `--accelerator` | `[gpu,cpu,tpu]` | `gpu` | The accelerator to use for prediction. |
118 | | `--recycling_steps INTEGER` | `INTEGER` | `3` | The number of recycling steps to use for prediction. |
119 | | `--sampling_steps INTEGER` | `INTEGER` | `200` | The number of sampling steps to use for prediction. |
120 | | `--diffusion_samples INTEGER` | `INTEGER` | `1` | The number of diffusion samples to use for prediction. |
121 | | `--output_format` | `[pdb,mmcif]` | `mmcif` | The output format to use for the predictions. |
122 | | `--num_workers INTEGER` | `INTEGER` | `2` | The number of dataloader workers to use for prediction. |
123 | | `--override` | `FLAG` | `False` | Whether to override existing predictions if found. |
124 |
125 | ## Output
126 |
127 | After running the model, the generated outputs are organized into the output directory following the structure below:
128 | ```
129 | out_dir/
130 | ├── lightning_logs/ # Logs generated during training or evaluation
131 | ├── predictions/ # Contains the model's predictions
132 | ├── [input_file1]/
133 | ├── [input_file1]_model_0.cif # The predicted structure in CIF format
134 | ...
135 | └── [input_file1]_model_[diffusion_samples-1].cif # The predicted structure in CIF format
136 | └── [input_file2]/
137 | ...
138 | └── processed/ # Processed data used during execution
139 | ```
140 | The `predictions` folder contains a unique folder for each input file. The input folders contain diffusion_samples predictions saved in the output_format. The `processed` folder contains the processed input files that are used by the model during inference.
141 |
--------------------------------------------------------------------------------
/src/boltz/data/sample/cluster.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Iterator, List
2 |
3 | import numpy as np
4 | from numpy.random import RandomState
5 |
6 | from boltz.data import const
7 | from boltz.data.types import ChainInfo, InterfaceInfo, Record
8 | from boltz.data.sample.sampler import Sample, Sampler
9 |
10 |
11 | def get_chain_cluster(chain: ChainInfo, record: Record) -> str: # noqa: ARG001
12 | """Get the cluster id for a chain.
13 |
14 | Parameters
15 | ----------
16 | chain : ChainInfo
17 | The chain id to get the cluster id for.
18 | record : Record
19 | The record the interface is part of.
20 |
21 | Returns
22 | -------
23 | str
24 | The cluster id of the chain.
25 |
26 | """
27 | return chain.cluster_id
28 |
29 |
30 | def get_interface_cluster(interface: InterfaceInfo, record: Record) -> str:
31 | """Get the cluster id for an interface.
32 |
33 | Parameters
34 | ----------
35 | interface : InterfaceInfo
36 | The interface to get the cluster id for.
37 | record : Record
38 | The record the interface is part of.
39 |
40 | Returns
41 | -------
42 | str
43 | The cluster id of the interface.
44 |
45 | """
46 | chain1 = record.chains[interface.chain_1]
47 | chain2 = record.chains[interface.chain_2]
48 |
49 | cluster_1 = str(chain1.cluster_id)
50 | cluster_2 = str(chain2.cluster_id)
51 |
52 | cluster_id = (cluster_1, cluster_2)
53 | cluster_id = tuple(sorted(cluster_id))
54 |
55 | return cluster_id
56 |
57 |
58 | def get_chain_weight(
59 | chain: ChainInfo,
60 | record: Record, # noqa: ARG001
61 | clusters: Dict[str, int],
62 | beta_chain: float,
63 | alpha_prot: float,
64 | alpha_nucl: float,
65 | alpha_ligand: float,
66 | ) -> float:
67 | """Get the weight of a chain.
68 |
69 | Parameters
70 | ----------
71 | chain : ChainInfo
72 | The chain to get the weight for.
73 | record : Record
74 | The record the chain is part of.
75 | clusters : Dict[str, int]
76 | The cluster sizes.
77 | beta_chain : float
78 | The beta value for chains.
79 | alpha_prot : float
80 | The alpha value for proteins.
81 | alpha_nucl : float
82 | The alpha value for nucleic acids.
83 | alpha_ligand : float
84 | The alpha value for ligands.
85 |
86 | Returns
87 | -------
88 | float
89 | The weight of the chain.
90 |
91 | """
92 | prot_id = const.chain_type_ids["PROTEIN"]
93 | rna_id = const.chain_type_ids["RNA"]
94 | dna_id = const.chain_type_ids["DNA"]
95 | ligand_id = const.chain_type_ids["NONPOLYMER"]
96 |
97 | weight = beta_chain / clusters[chain.cluster_id]
98 | if chain.mol_type == prot_id:
99 | weight *= alpha_prot
100 | elif chain.mol_type in [rna_id, dna_id]:
101 | weight *= alpha_nucl
102 | elif chain.mol_type == ligand_id:
103 | weight *= alpha_ligand
104 |
105 | return weight
106 |
107 |
108 | def get_interface_weight(
109 | interface: InterfaceInfo,
110 | record: Record,
111 | clusters: Dict[str, int],
112 | beta_interface: float,
113 | alpha_prot: float,
114 | alpha_nucl: float,
115 | alpha_ligand: float,
116 | ) -> float:
117 | """Get the weight of an interface.
118 |
119 | Parameters
120 | ----------
121 | interface : InterfaceInfo
122 | The interface to get the weight for.
123 | record : Record
124 | The record the interface is part of.
125 | clusters : Dict[str, int]
126 | The cluster sizes.
127 | beta_interface : float
128 | The beta value for interfaces.
129 | alpha_prot : float
130 | The alpha value for proteins.
131 | alpha_nucl : float
132 | The alpha value for nucleic acids.
133 | alpha_ligand : float
134 | The alpha value for ligands.
135 |
136 | Returns
137 | -------
138 | float
139 | The weight of the interface.
140 |
141 | """
142 | prot_id = const.chain_type_ids["PROTEIN"]
143 | rna_id = const.chain_type_ids["RNA"]
144 | dna_id = const.chain_type_ids["DNA"]
145 | ligand_id = const.chain_type_ids["NONPOLYMER"]
146 |
147 | chain1 = record.chains[interface.chain_1]
148 | chain2 = record.chains[interface.chain_2]
149 |
150 | n_prot = (chain1.mol_type) == prot_id
151 | n_nuc = chain1.mol_type in [rna_id, dna_id]
152 | n_ligand = chain1.mol_type == ligand_id
153 |
154 | n_prot += chain2.mol_type == prot_id
155 | n_nuc += chain2.mol_type in [rna_id, dna_id]
156 | n_ligand += chain2.mol_type == ligand_id
157 |
158 | weight = beta_interface / clusters[get_interface_cluster(interface, record)]
159 | weight *= alpha_prot * n_prot + alpha_nucl * n_nuc + alpha_ligand * n_ligand
160 | return weight
161 |
162 |
163 | class ClusterSampler(Sampler):
164 | """The weighted sampling approach, as described in AF3.
165 |
166 | Each chain / interface is given a weight according
167 | to the following formula, and sampled accordingly:
168 |
169 | w = b / n_clust *(a_prot * n_prot + a_nuc * n_nuc
170 | + a_ligand * n_ligand)
171 |
172 | """
173 |
174 | def __init__(
175 | self,
176 | alpha_prot: float = 3.0,
177 | alpha_nucl: float = 3.0,
178 | alpha_ligand: float = 1.0,
179 | beta_chain: float = 0.5,
180 | beta_interface: float = 1.0,
181 | ) -> None:
182 | """Initialize the sampler.
183 |
184 | Parameters
185 | ----------
186 | alpha_prot : float, optional
187 | The alpha value for proteins.
188 | alpha_nucl : float, optional
189 | The alpha value for nucleic acids.
190 | alpha_ligand : float, optional
191 | The alpha value for ligands.
192 | beta_chain : float, optional
193 | The beta value for chains.
194 | beta_interface : float, optional
195 | The beta value for interfaces.
196 |
197 | """
198 | self.alpha_prot = alpha_prot
199 | self.alpha_nucl = alpha_nucl
200 | self.alpha_ligand = alpha_ligand
201 | self.beta_chain = beta_chain
202 | self.beta_interface = beta_interface
203 |
204 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: # noqa: C901, PLR0912
205 | """Sample a structure from the dataset infinitely.
206 |
207 | Parameters
208 | ----------
209 | records : List[Record]
210 | The records to sample from.
211 | random : RandomState
212 | The random state for reproducibility.
213 |
214 | Yields
215 | ------
216 | Sample
217 | A data sample.
218 |
219 | """
220 | # Compute chain cluster sizes
221 | chain_clusters: Dict[str, int] = {}
222 | for record in records:
223 | for chain in record.chains:
224 | if not chain.valid:
225 | continue
226 | cluster_id = get_chain_cluster(chain, record)
227 | if cluster_id not in chain_clusters:
228 | chain_clusters[cluster_id] = 0
229 | chain_clusters[cluster_id] += 1
230 |
231 | # Compute interface clusters sizes
232 | interface_clusters: Dict[str, int] = {}
233 | for record in records:
234 | for interface in record.interfaces:
235 | if not interface.valid:
236 | continue
237 | cluster_id = get_interface_cluster(interface, record)
238 | if cluster_id not in interface_clusters:
239 | interface_clusters[cluster_id] = 0
240 | interface_clusters[cluster_id] += 1
241 |
242 | # Compute weights
243 | items, weights = [], []
244 | for record in records:
245 | for chain_id, chain in enumerate(record.chains):
246 | if not chain.valid:
247 | continue
248 | weight = get_chain_weight(
249 | chain,
250 | record,
251 | chain_clusters,
252 | self.beta_chain,
253 | self.alpha_prot,
254 | self.alpha_nucl,
255 | self.alpha_ligand,
256 | )
257 | items.append((record, 0, chain_id))
258 | weights.append(weight)
259 |
260 | for int_id, interface in enumerate(record.interfaces):
261 | if not interface.valid:
262 | continue
263 | weight = get_interface_weight(
264 | interface,
265 | record,
266 | interface_clusters,
267 | self.beta_interface,
268 | self.alpha_prot,
269 | self.alpha_nucl,
270 | self.alpha_ligand,
271 | )
272 | items.append((record, 1, int_id))
273 | weights.append(weight)
274 |
275 | # Sample infinitely
276 | weights = np.array(weights) / np.sum(weights)
277 | while True:
278 | item_idx = random.choice(len(items), p=weights)
279 | record, kind, index = items[item_idx]
280 | if kind == 0:
281 | yield Sample(record=record, chain_id=index)
282 | else:
283 | yield Sample(record=record, interface_id=index)
284 |
--------------------------------------------------------------------------------
/src/boltz/model/modules/transformers.py:
--------------------------------------------------------------------------------
1 | # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2 |
3 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
4 | from torch import nn, sigmoid
5 | from torch.nn import (
6 | LayerNorm,
7 | Linear,
8 | Module,
9 | ModuleList,
10 | Sequential,
11 | )
12 |
13 | from boltz.model.layers.attention import AttentionPairBias
14 | from boltz.model.modules.utils import LinearNoBias, SwiGLU, default
15 |
16 |
17 | class AdaLN(Module):
18 | """Adaptive Layer Normalization"""
19 |
20 | def __init__(self, dim, dim_single_cond):
21 | """Initialize the adaptive layer normalization.
22 |
23 | Parameters
24 | ----------
25 | dim : int
26 | The input dimension.
27 | dim_single_cond : int
28 | The single condition dimension.
29 |
30 | """
31 | super().__init__()
32 | self.a_norm = LayerNorm(dim, elementwise_affine=False, bias=False)
33 | self.s_norm = LayerNorm(dim_single_cond, bias=False)
34 | self.s_scale = Linear(dim_single_cond, dim)
35 | self.s_bias = LinearNoBias(dim_single_cond, dim)
36 |
37 | def forward(self, a, s):
38 | a = self.a_norm(a)
39 | s = self.s_norm(s)
40 | a = sigmoid(self.s_scale(s)) * a + self.s_bias(s)
41 | return a
42 |
43 |
44 | class ConditionedTransitionBlock(Module):
45 | """Conditioned Transition Block"""
46 |
47 | def __init__(self, dim_single, dim_single_cond, expansion_factor=2):
48 | """Initialize the conditioned transition block.
49 |
50 | Parameters
51 | ----------
52 | dim_single : int
53 | The single dimension.
54 | dim_single_cond : int
55 | The single condition dimension.
56 | expansion_factor : int, optional
57 | The expansion factor, by default 2
58 |
59 | """
60 | super().__init__()
61 |
62 | self.adaln = AdaLN(dim_single, dim_single_cond)
63 |
64 | dim_inner = int(dim_single * expansion_factor)
65 | self.swish_gate = Sequential(
66 | LinearNoBias(dim_single, dim_inner * 2),
67 | SwiGLU(),
68 | )
69 | self.a_to_b = LinearNoBias(dim_single, dim_inner)
70 | self.b_to_a = LinearNoBias(dim_inner, dim_single)
71 |
72 | output_projection_linear = Linear(dim_single_cond, dim_single)
73 | nn.init.zeros_(output_projection_linear.weight)
74 | nn.init.constant_(output_projection_linear.bias, -2.0)
75 |
76 | self.output_projection = nn.Sequential(output_projection_linear, nn.Sigmoid())
77 |
78 | def forward(
79 | self,
80 | a,
81 | s,
82 | ):
83 | a = self.adaln(a, s)
84 | b = self.swish_gate(a) * self.a_to_b(a)
85 | a = self.output_projection(s) * self.b_to_a(b)
86 |
87 | return a
88 |
89 |
90 | class DiffusionTransformer(Module):
91 | """Diffusion Transformer"""
92 |
93 | def __init__(
94 | self,
95 | depth,
96 | heads,
97 | dim=384,
98 | dim_single_cond=None,
99 | dim_pairwise=128,
100 | activation_checkpointing=False,
101 | offload_to_cpu=False,
102 | ):
103 | """Initialize the diffusion transformer.
104 |
105 | Parameters
106 | ----------
107 | depth : int
108 | The depth.
109 | heads : int
110 | The number of heads.
111 | dim : int, optional
112 | The dimension, by default 384
113 | dim_single_cond : int, optional
114 | The single condition dimension, by default None
115 | dim_pairwise : int, optional
116 | The pairwise dimension, by default 128
117 | activation_checkpointing : bool, optional
118 | Whether to use activation checkpointing, by default False
119 | offload_to_cpu : bool, optional
120 | Whether to offload to CPU, by default False
121 |
122 | """
123 | super().__init__()
124 | self.activation_checkpointing = activation_checkpointing
125 | dim_single_cond = default(dim_single_cond, dim)
126 |
127 | self.layers = ModuleList()
128 | for _ in range(depth):
129 | if activation_checkpointing:
130 | self.layers.append(
131 | checkpoint_wrapper(
132 | DiffusionTransformerLayer(
133 | heads,
134 | dim,
135 | dim_single_cond,
136 | dim_pairwise,
137 | ),
138 | offload_to_cpu=offload_to_cpu,
139 | )
140 | )
141 | else:
142 | self.layers.append(
143 | DiffusionTransformerLayer(
144 | heads,
145 | dim,
146 | dim_single_cond,
147 | dim_pairwise,
148 | )
149 | )
150 |
151 | def forward(
152 | self,
153 | a,
154 | s,
155 | z,
156 | mask=None,
157 | to_keys=None,
158 | multiplicity=1,
159 | model_cache=None,
160 | ):
161 | for i, layer in enumerate(self.layers):
162 | layer_cache = None
163 | if model_cache is not None:
164 | prefix_cache = "layer_" + str(i)
165 | if prefix_cache not in model_cache:
166 | model_cache[prefix_cache] = {}
167 | layer_cache = model_cache[prefix_cache]
168 | a = layer(
169 | a,
170 | s,
171 | z,
172 | mask=mask,
173 | to_keys=to_keys,
174 | multiplicity=multiplicity,
175 | layer_cache=layer_cache,
176 | )
177 | return a
178 |
179 |
180 | class DiffusionTransformerLayer(Module):
181 | """Diffusion Transformer Layer"""
182 |
183 | def __init__(
184 | self,
185 | heads,
186 | dim=384,
187 | dim_single_cond=None,
188 | dim_pairwise=128,
189 | ):
190 | """Initialize the diffusion transformer layer.
191 |
192 | Parameters
193 | ----------
194 | heads : int
195 | The number of heads.
196 | dim : int, optional
197 | The dimension, by default 384
198 | dim_single_cond : int, optional
199 | The single condition dimension, by default None
200 | dim_pairwise : int, optional
201 | The pairwise dimension, by default 128
202 |
203 | """
204 | super().__init__()
205 |
206 | dim_single_cond = default(dim_single_cond, dim)
207 |
208 | self.adaln = AdaLN(dim, dim_single_cond)
209 |
210 | self.pair_bias_attn = AttentionPairBias(
211 | c_s=dim, c_z=dim_pairwise, num_heads=heads, initial_norm=False
212 | )
213 |
214 | self.output_projection_linear = Linear(dim_single_cond, dim)
215 | nn.init.zeros_(self.output_projection_linear.weight)
216 | nn.init.constant_(self.output_projection_linear.bias, -2.0)
217 |
218 | self.output_projection = nn.Sequential(
219 | self.output_projection_linear, nn.Sigmoid()
220 | )
221 | self.transition = ConditionedTransitionBlock(
222 | dim_single=dim, dim_single_cond=dim_single_cond
223 | )
224 |
225 | def forward(
226 | self,
227 | a,
228 | s,
229 | z,
230 | mask=None,
231 | to_keys=None,
232 | multiplicity=1,
233 | layer_cache=None,
234 | ):
235 | b = self.adaln(a, s)
236 | b = self.pair_bias_attn(
237 | s=b,
238 | z=z,
239 | mask=mask,
240 | multiplicity=multiplicity,
241 | to_keys=to_keys,
242 | model_cache=layer_cache,
243 | )
244 | b = self.output_projection(s) * b
245 |
246 | # NOTE: Added residual connection!
247 | a = a + b
248 | a = a + self.transition(a, s)
249 | return a
250 |
251 |
252 | class AtomTransformer(Module):
253 | """Atom Transformer"""
254 |
255 | def __init__(
256 | self,
257 | attn_window_queries=None,
258 | attn_window_keys=None,
259 | **diffusion_transformer_kwargs,
260 | ):
261 | """Initialize the atom transformer.
262 |
263 | Parameters
264 | ----------
265 | attn_window_queries : int, optional
266 | The attention window queries, by default None
267 | attn_window_keys : int, optional
268 | The attention window keys, by default None
269 | diffusion_transformer_kwargs : dict
270 | The diffusion transformer keyword arguments
271 |
272 | """
273 | super().__init__()
274 | self.attn_window_queries = attn_window_queries
275 | self.attn_window_keys = attn_window_keys
276 | self.diffusion_transformer = DiffusionTransformer(
277 | **diffusion_transformer_kwargs
278 | )
279 |
280 | def forward(
281 | self,
282 | q,
283 | c,
284 | p,
285 | to_keys=None,
286 | mask=None,
287 | multiplicity=1,
288 | model_cache=None,
289 | ):
290 | W = self.attn_window_queries
291 | H = self.attn_window_keys
292 |
293 | if W is not None:
294 | B, N, D = q.shape
295 | NW = N // W
296 |
297 | # reshape tokens
298 | q = q.view((B * NW, W, -1))
299 | c = c.view((B * NW, W, -1))
300 | if mask is not None:
301 | mask = mask.view(B * NW, W)
302 | p = p.view((p.shape[0] * NW, W, H, -1))
303 |
304 | to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1)
305 | else:
306 | to_keys_new = None
307 |
308 | # main transformer
309 | q = self.diffusion_transformer(
310 | a=q,
311 | s=c,
312 | z=p,
313 | mask=mask.float(),
314 | multiplicity=multiplicity,
315 | to_keys=to_keys_new,
316 | model_cache=model_cache,
317 | )
318 |
319 | if W is not None:
320 | q = q.view((B, NW * W, D))
321 |
322 | return q
323 |
--------------------------------------------------------------------------------
/src/boltz/data/filter/static/polymer.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List
3 |
4 | import numpy as np
5 | from scipy.spatial.distance import cdist
6 |
7 | from boltz.data import const
8 | from boltz.data.types import Structure
9 | from boltz.data.filter.static.filter import StaticFilter
10 |
11 |
12 | class MinimumLengthFilter(StaticFilter):
13 | """Filter polymers based on their length.
14 |
15 | We use the number of resolved residues when considering
16 | the minimum, and the sequence length for the maximum.
17 |
18 | """
19 |
20 | def __init__(self, min_len: int = 4, max_len: int = 5000) -> None:
21 | """Initialize the filter.
22 |
23 | Parameters
24 | ----------
25 | min_len : float, optional
26 | The minimum allowed length.
27 | max_len : float, optional
28 | The maximum allowed length.
29 |
30 | """
31 | self._min = min_len
32 | self._max = max_len
33 |
34 | def filter(self, structure: Structure) -> np.ndarray:
35 | """Filter a chains based on their length.
36 |
37 | Parameters
38 | ----------
39 | structure : Structure
40 | The structure to filter chains from.
41 |
42 | Returns
43 | -------
44 | np.ndarray
45 | The chains to keep, as a boolean mask.
46 |
47 | """
48 | valid = np.ones(len(structure.chains), dtype=bool)
49 |
50 | for i, chain in enumerate(structure.chains):
51 | if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
52 | continue
53 |
54 | res_start = chain["res_idx"]
55 | res_end = res_start + chain["res_num"]
56 | residues = structure.residues[res_start:res_end]
57 | resolved = residues["is_present"].sum()
58 |
59 | if (resolved < self._min) or (resolved > self._max):
60 | valid[i] = 0
61 |
62 | return valid
63 |
64 |
65 | class UnknownFilter(StaticFilter):
66 | """Filter proteins with all unknown residues."""
67 |
68 | def filter(self, structure: Structure) -> np.ndarray:
69 | """Filter proteins with all unknown residues.
70 |
71 | Parameters
72 | ----------
73 | structure : Structure
74 | The structure to filter chains from.
75 |
76 | Returns
77 | -------
78 | np.ndarray
79 | The chains to keep, as a boolean mask.
80 |
81 | """
82 | valid = np.ones(len(structure.chains), dtype=bool)
83 | unk_toks = {
84 | const.chain_type_ids["PROTEIN"]: const.unk_token_ids["PROTEIN"],
85 | const.chain_type_ids["DNA"]: const.unk_token_ids["DNA"],
86 | const.chain_type_ids["RNA"]: const.unk_token_ids["RNA"],
87 | }
88 |
89 | for i, chain in enumerate(structure.chains):
90 | if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
91 | continue
92 |
93 | res_start = chain["res_idx"]
94 | res_end = res_start + chain["res_num"]
95 | residues = structure.residues[res_start:res_end]
96 |
97 | unk_id = unk_toks[chain["mol_type"]]
98 | if np.all(residues["res_type"] == unk_id):
99 | valid[i] = 0
100 |
101 | return valid
102 |
103 |
104 | class ConsecutiveCA(StaticFilter):
105 | """Filter proteins with consecutive CA atoms above a threshold."""
106 |
107 | def __init__(self, max_dist: int = 10.0) -> None:
108 | """Initialize the filter.
109 |
110 | Parameters
111 | ----------
112 | max_dist : float, optional
113 | The maximum allowed distance.
114 |
115 | """
116 | self._max_dist = max_dist
117 |
118 | def filter(self, structure: Structure) -> np.ndarray:
119 | """Filter protein if consecutive CA atoms above a threshold.
120 |
121 | Parameters
122 | ----------
123 | structure : Structure
124 | The structure to filter chains from.
125 |
126 | Returns
127 | -------
128 | np.ndarray
129 | The chains to keep, as a boolean mask.
130 |
131 | """
132 | valid = np.ones(len(structure.chains), dtype=bool)
133 |
134 | # Remove chain if consecutive CA atoms are above threshold
135 | for i, chain in enumerate(structure.chains):
136 | # Skip non-protein chains
137 | if chain["mol_type"] != const.chain_type_ids["PROTEIN"]:
138 | continue
139 |
140 | # Get residues
141 | res_start = chain["res_idx"]
142 | res_end = res_start + chain["res_num"]
143 | residues = structure.residues[res_start:res_end]
144 |
145 | # Get c-alphas
146 | ca_ids = residues["atom_center"]
147 | ca_atoms = structure.atoms[ca_ids]
148 |
149 | res_valid = residues["is_present"]
150 | ca_valid = ca_atoms["is_present"] & res_valid
151 | ca_coords = ca_atoms["coords"]
152 |
153 | # Compute distances between consecutive atoms
154 | dist = np.linalg.norm(ca_coords[1:] - ca_coords[:-1], axis=1)
155 | dist = dist > self._max_dist
156 | dist = dist[ca_valid[1:] & ca_valid[:-1]]
157 |
158 | # Remove the chain if any valid pair is above threshold
159 | if np.any(dist):
160 | valid[i] = 0
161 |
162 | return valid
163 |
164 |
165 | @dataclass(frozen=True)
166 | class Clash:
167 | """A clash between two chains."""
168 |
169 | chain: int
170 | other: int
171 | num_atoms: int
172 | num_clashes: int
173 |
174 |
175 | class ClashingChainsFilter(StaticFilter):
176 | """A filter that filters clashing chains.
177 |
178 | Clashing chains are defined as those with >30% of atoms
179 | within 1.7 Å of an atom in another chain. If two chains
180 | are clashing with each other, the chain with the greater
181 | percentage of clashing atoms will be removed. If the same
182 | fraction of atoms are clashing, the chain with fewer total
183 | atoms is removed. If the chains have the same number of
184 | atoms, then the chain with the larger chain id is removed.
185 |
186 | """
187 |
188 | def __init__(self, dist: float = 1.7, freq: float = 0.3) -> None:
189 | """Initialize the filter.
190 |
191 | Parameters
192 | ----------
193 | dist : float, optional
194 | The maximum distance for a clash.
195 | freq : float, optional
196 | The maximum allowed frequency of clashes.
197 |
198 | """
199 | self._dist = dist
200 | self._freq = freq
201 |
202 | def filter(self, structure: Structure) -> np.ndarray: # noqa: PLR0912, C901
203 | """Filter out clashing chains.
204 |
205 | Parameters
206 | ----------
207 | structure : Structure
208 | The structure to filter chains from.
209 |
210 | Returns
211 | -------
212 | np.ndarray
213 | The chains to keep, as a boolean mask.
214 |
215 | """
216 | num_chains = len(structure.chains)
217 | if num_chains < 2: # noqa: PLR2004
218 | return np.ones(num_chains, dtype=bool)
219 |
220 | # Get unique chain pairs
221 | pairs = zip(range(num_chains), range(num_chains))
222 | pairs = [(i, j) for i, j in pairs if i < j]
223 |
224 | # Compute clashes
225 | clashes: List[Clash] = []
226 | for i, j in pairs:
227 | # Get the chains
228 | c1 = structure.chains[i]
229 | c2 = structure.chains[j]
230 |
231 | # Get the atoms from each chain
232 | c1_start = c1["atom_idx"]
233 | c2_start = c2["atom_idx"]
234 | c1_end = c1_start + c1["atom_num"]
235 | c2_end = c2_start + c2["atom_num"]
236 |
237 | atoms1 = structure.atoms[c1_start:c1_end]
238 | atoms2 = structure.atoms[c2_start:c2_end]
239 | atoms1 = atoms1[atoms1["is_present"]]
240 | atoms2 = atoms2[atoms2["is_present"]]
241 |
242 | # Compute the number of clashes
243 | dists = cdist(atoms1["coords"], atoms2["coords"])
244 | clashes = dists < self._dist
245 | c1_clashes = np.any(clashes, axis=1).sum().item()
246 | c2_clashes = np.any(clashes, axis=0).sum().item()
247 |
248 | # Save results
249 | if (c1_clashes / len(atoms1)) > self._freq:
250 | clashes.append(Clash(i, j, len(atoms1), c1_clashes))
251 | if (c2_clashes / len(atoms2)) > self._freq:
252 | clashes.append(Clash(j, i, len(atoms2), c2_clashes))
253 |
254 | # Compute indices to clash map
255 | removed = set()
256 | ids_to_clash = {(c.chain, c.other): c for c in clashes}
257 |
258 | # Filter out chains according to ruleset
259 | for clash in clashes:
260 | # If either is already removed, skip
261 | if clash.chain in removed or clash.other in removed:
262 | continue
263 |
264 | # Check if the two chains clash with each other
265 | other_clash = ids_to_clash.get((clash.other, clash.chain))
266 | if other_clash is not None:
267 | # Remove the chain with the most clashes
268 | clash1_freq = clash.num_clashes / clash.num_atoms
269 | clash2_freq = other_clash.num_clashes / other_clash.num_atoms
270 | if clash1_freq > clash2_freq:
271 | removed.add(clash.chain)
272 | elif clash1_freq < clash2_freq:
273 | removed.add(clash.other)
274 |
275 | # If same, remove the chain with fewer atoms
276 | elif clash.num_atoms < other_clash.num_atoms:
277 | removed.add(clash.chain)
278 | elif clash.num_atoms > other_clash.num_atoms:
279 | removed.add(clash.other)
280 |
281 | # If same, remove the chain with the larger chain id
282 | else:
283 | removed.add(max(clash.chain, clash.other))
284 |
285 | # Otherwise, just remove the chain directly
286 | else:
287 | removed.add(clash.chain)
288 |
289 | # Remove the chains
290 | valid = np.ones(len(structure.chains), dtype=bool)
291 | for i in removed:
292 | valid[i] = 0
293 |
294 | return valid
295 |
--------------------------------------------------------------------------------
/src/boltz/data/const.py:
--------------------------------------------------------------------------------
1 | ####################################################################################################
2 | # CHAINS
3 | ####################################################################################################
4 |
5 | chain_types = [
6 | "PROTEIN",
7 | "DNA",
8 | "RNA",
9 | "NONPOLYMER",
10 | ]
11 | chain_type_ids = {chain: i for i, chain in enumerate(chain_types)}
12 |
13 | out_types = [
14 | "dna_protein",
15 | "rna_protein",
16 | "ligand_protein",
17 | "dna_ligand",
18 | "rna_ligand",
19 | "intra_ligand",
20 | "intra_dna",
21 | "intra_rna",
22 | "intra_protein",
23 | "protein_protein",
24 | ]
25 |
26 | out_types_weights_af3 = {
27 | "dna_protein": 10.0,
28 | "rna_protein": 10.0,
29 | "ligand_protein": 10.0,
30 | "dna_ligand": 5.0,
31 | "rna_ligand": 5.0,
32 | "intra_ligand": 20.0,
33 | "intra_dna": 4.0,
34 | "intra_rna": 16.0,
35 | "intra_protein": 20.0,
36 | "protein_protein": 20.0,
37 | }
38 |
39 | out_types_weights = {
40 | "dna_protein": 5.0,
41 | "rna_protein": 5.0,
42 | "ligand_protein": 20.0,
43 | "dna_ligand": 2.0,
44 | "rna_ligand": 2.0,
45 | "intra_ligand": 20.0,
46 | "intra_dna": 2.0,
47 | "intra_rna": 8.0,
48 | "intra_protein": 20.0,
49 | "protein_protein": 20.0,
50 | }
51 |
52 |
53 | out_single_types = ["protein", "ligand", "dna", "rna"]
54 |
55 | ####################################################################################################
56 | # RESIDUES & TOKENS
57 | ####################################################################################################
58 |
59 | tokens = [
60 | "",
61 | "-",
62 | "ALA",
63 | "ARG",
64 | "ASN",
65 | "ASP",
66 | "CYS",
67 | "GLN",
68 | "GLU",
69 | "GLY",
70 | "HIS",
71 | "ILE",
72 | "LEU",
73 | "LYS",
74 | "MET",
75 | "PHE",
76 | "PRO",
77 | "SER",
78 | "THR",
79 | "TRP",
80 | "TYR",
81 | "VAL",
82 | "UNK", # unknown protein token
83 | "A",
84 | "G",
85 | "C",
86 | "U",
87 | "N", # unknown rna token
88 | "DA",
89 | "DG",
90 | "DC",
91 | "DT",
92 | "DN", # unknown dna token
93 | ]
94 |
95 | token_ids = {token: i for i, token in enumerate(tokens)}
96 | num_tokens = len(tokens)
97 | unk_token = {"PROTEIN": "UNK", "DNA": "DN", "RNA": "N"}
98 | unk_token_ids = {m: token_ids[t] for m, t in unk_token.items()}
99 |
100 | prot_letter_to_token = {
101 | "A": "ALA",
102 | "R": "ARG",
103 | "N": "ASN",
104 | "D": "ASP",
105 | "C": "CYS",
106 | "E": "GLU",
107 | "Q": "GLN",
108 | "G": "GLY",
109 | "H": "HIS",
110 | "I": "ILE",
111 | "L": "LEU",
112 | "K": "LYS",
113 | "M": "MET",
114 | "F": "PHE",
115 | "P": "PRO",
116 | "S": "SER",
117 | "T": "THR",
118 | "W": "TRP",
119 | "Y": "TYR",
120 | "V": "VAL",
121 | "X": "UNK",
122 | "J": "UNK",
123 | "B": "UNK",
124 | "Z": "UNK",
125 | "O": "UNK",
126 | "U": "UNK",
127 | "-": "-",
128 | }
129 |
130 | prot_token_to_letter = {v: k for k, v in prot_letter_to_token.items()}
131 | prot_token_to_letter["UNK"] = "X"
132 |
133 | rna_letter_to_token = {
134 | "A": "A",
135 | "G": "G",
136 | "C": "C",
137 | "U": "U",
138 | "N": "N",
139 | }
140 | rna_token_to_letter = {v: k for k, v in rna_letter_to_token.items()}
141 |
142 | dna_letter_to_token = {
143 | "A": "DA",
144 | "G": "DG",
145 | "C": "DC",
146 | "T": "DT",
147 | "N": "DN",
148 | }
149 | dna_token_to_letter = {v: k for k, v in dna_letter_to_token.items()}
150 |
151 | ####################################################################################################
152 | # ATOMS
153 | ####################################################################################################
154 |
155 | num_elements = 128
156 |
157 | chirality_types = [
158 | "CHI_UNSPECIFIED",
159 | "CHI_TETRAHEDRAL_CW",
160 | "CHI_TETRAHEDRAL_CCW",
161 | "CHI_OTHER",
162 | ]
163 | chirality_type_ids = {chirality: i for i, chirality in enumerate(chirality_types)}
164 | unk_chirality_type = "CHI_UNSPECIFIED"
165 |
166 | # fmt: off
167 | ref_atoms = {
168 | "PAD": [],
169 | "UNK": ["N", "CA", "C", "O", "CB"],
170 | "-": [],
171 | "ALA": ["N", "CA", "C", "O", "CB"],
172 | "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"],
173 | "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"],
174 | "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"],
175 | "CYS": ["N", "CA", "C", "O", "CB", "SG"],
176 | "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"],
177 | "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"],
178 | "GLY": ["N", "CA", "C", "O"],
179 | "HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"],
180 | "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"],
181 | "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"],
182 | "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"],
183 | "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"],
184 | "PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"],
185 | "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD"],
186 | "SER": ["N", "CA", "C", "O", "CB", "OG"],
187 | "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2"],
188 | "TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"], # noqa: E501
189 | "TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"],
190 | "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2"],
191 | "A": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N9", "C8", "N7", "C5", "C6", "N6", "N1", "C2", "N3", "C4"], # noqa: E501
192 | "G": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N9", "C8", "N7", "C5", "C6", "O6", "N1", "C2", "N2", "N3", "C4"], # noqa: E501
193 | "C": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"], # noqa: E501
194 | "U": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N1", "C2", "O2", "N3", "C4", "O4", "C5", "C6"], # noqa: E501
195 | "N": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'"], # noqa: E501
196 | "DA": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N9", "C8", "N7", "C5", "C6", "N6", "N1", "C2", "N3", "C4"], # noqa: E501
197 | "DG": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N9", "C8", "N7", "C5", "C6", "O6", "N1", "C2", "N2", "N3", "C4"], # noqa: E501
198 | "DC": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"], # noqa: E501
199 | "DT": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N1", "C2", "O2", "N3", "C4", "O4", "C5", "C7", "C6"], # noqa: E501
200 | "DN": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'"]
201 | }
202 |
203 | ref_symmetries = {
204 | "PAD": [],
205 | "ALA": [],
206 | "ARG": [],
207 | "ASN": [],
208 | "ASP": [[(6, 7), (7, 6)]],
209 | "CYS": [],
210 | "GLN": [],
211 | "GLU": [[(7, 8), (8, 7)]],
212 | "GLY": [],
213 | "HIS": [],
214 | "ILE": [],
215 | "LEU": [],
216 | "LYS": [],
217 | "MET": [],
218 | "PHE": [[(6, 7), (7, 6), (8, 9), (9, 8)]],
219 | "PRO": [],
220 | "SER": [],
221 | "THR": [],
222 | "TRP": [],
223 | "TYR": [[(6, 7), (7, 6), (8, 9), (9, 8)]],
224 | "VAL": [],
225 | "A": [[(1, 2), (2, 1)]],
226 | "G": [[(1, 2), (2, 1)]],
227 | "C": [[(1, 2), (2, 1)]],
228 | "U": [[(1, 2), (2, 1)]],
229 | "N": [[(1, 2), (2, 1)]],
230 | "DA": [[(1, 2), (2, 1)]],
231 | "DG": [[(1, 2), (2, 1)]],
232 | "DC": [[(1, 2), (2, 1)]],
233 | "DT": [[(1, 2), (2, 1)]],
234 | "DN": [[(1, 2), (2, 1)]]
235 | }
236 |
237 |
238 | res_to_center_atom = {
239 | "UNK": "CA",
240 | "ALA": "CA",
241 | "ARG": "CA",
242 | "ASN": "CA",
243 | "ASP": "CA",
244 | "CYS": "CA",
245 | "GLN": "CA",
246 | "GLU": "CA",
247 | "GLY": "CA",
248 | "HIS": "CA",
249 | "ILE": "CA",
250 | "LEU": "CA",
251 | "LYS": "CA",
252 | "MET": "CA",
253 | "PHE": "CA",
254 | "PRO": "CA",
255 | "SER": "CA",
256 | "THR": "CA",
257 | "TRP": "CA",
258 | "TYR": "CA",
259 | "VAL": "CA",
260 | "A": "C1'",
261 | "G": "C1'",
262 | "C": "C1'",
263 | "U": "C1'",
264 | "N": "C1'",
265 | "DA": "C1'",
266 | "DG": "C1'",
267 | "DC": "C1'",
268 | "DT": "C1'",
269 | "DN": "C1'"
270 | }
271 |
272 | res_to_disto_atom = {
273 | "UNK": "CB",
274 | "ALA": "CB",
275 | "ARG": "CB",
276 | "ASN": "CB",
277 | "ASP": "CB",
278 | "CYS": "CB",
279 | "GLN": "CB",
280 | "GLU": "CB",
281 | "GLY": "CA",
282 | "HIS": "CB",
283 | "ILE": "CB",
284 | "LEU": "CB",
285 | "LYS": "CB",
286 | "MET": "CB",
287 | "PHE": "CB",
288 | "PRO": "CB",
289 | "SER": "CB",
290 | "THR": "CB",
291 | "TRP": "CB",
292 | "TYR": "CB",
293 | "VAL": "CB",
294 | "A": "C4",
295 | "G": "C4",
296 | "C": "C2",
297 | "U": "C2",
298 | "N": "C1'",
299 | "DA": "C4",
300 | "DG": "C4",
301 | "DC": "C2",
302 | "DT": "C2",
303 | "DN": "C1'"
304 | }
305 |
306 | res_to_center_atom_id = {
307 | res: ref_atoms[res].index(atom)
308 | for res, atom in res_to_center_atom.items()
309 | }
310 |
311 | res_to_disto_atom_id = {
312 | res: ref_atoms[res].index(atom)
313 | for res, atom in res_to_disto_atom.items()
314 | }
315 |
316 | # fmt: on
317 |
318 | ####################################################################################################
319 | # BONDS
320 | ####################################################################################################
321 |
322 | atom_interface_cutoff = 5.0
323 | interface_cutoff = 15.0
324 |
325 | bond_types = [
326 | "OTHER",
327 | "SINGLE",
328 | "DOUBLE",
329 | "TRIPLE",
330 | "AROMATIC",
331 | ]
332 | bond_type_ids = {bond: i for i, bond in enumerate(bond_types)}
333 | unk_bond_type = "OTHER"
334 |
335 |
336 | ####################################################################################################
337 | # Contacts
338 | ####################################################################################################
339 |
340 |
341 | pocket_contact_info = {
342 | "UNSPECIFIED": 0,
343 | "UNSELECTED": 1,
344 | "POCKET": 2,
345 | "BINDER": 3,
346 | }
347 |
348 |
349 | ####################################################################################################
350 | # MSA
351 | ####################################################################################################
352 |
353 | max_msa_seqs = 16384
354 |
--------------------------------------------------------------------------------
/src/boltz/data/crop/boltz.py:
--------------------------------------------------------------------------------
1 | from dataclasses import replace
2 | from typing import Optional
3 |
4 | import numpy as np
5 | from scipy.spatial.distance import cdist
6 |
7 | from boltz.data import const
8 | from boltz.data.crop.cropper import Cropper
9 | from boltz.data.types import Tokenized
10 |
11 |
12 | def pick_random_token(
13 | tokens: np.ndarray,
14 | random: np.random.RandomState,
15 | ) -> np.ndarray:
16 | """Pick a random token from the data.
17 |
18 | Parameters
19 | ----------
20 | tokens : np.ndarray
21 | The token data.
22 | random : np.ndarray
23 | The random state for reproducibility.
24 |
25 | Returns
26 | -------
27 | np.ndarray
28 | The selected token.
29 |
30 | """
31 | return tokens[random.randint(len(tokens))]
32 |
33 |
34 | def pick_chain_token(
35 | tokens: np.ndarray,
36 | chain_id: int,
37 | random: np.random.RandomState,
38 | ) -> np.ndarray:
39 | """Pick a random token from a chain.
40 |
41 | Parameters
42 | ----------
43 | tokens : np.ndarray
44 | The token data.
45 | chain_id : int
46 | The chain ID.
47 | random : np.ndarray
48 | The random state for reproducibility.
49 |
50 | Returns
51 | -------
52 | np.ndarray
53 | The selected token.
54 |
55 | """
56 | # Filter to chain
57 | chain_tokens = tokens[tokens["asym_id"] == chain_id]
58 |
59 | # Pick from chain, fallback to all tokens
60 | if chain_tokens.size:
61 | query = pick_random_token(chain_tokens, random)
62 | else:
63 | query = pick_random_token(tokens, random)
64 |
65 | return query
66 |
67 |
68 | def pick_interface_token(
69 | tokens: np.ndarray,
70 | interface: np.ndarray,
71 | random: np.random.RandomState,
72 | ) -> np.ndarray:
73 | """Pick a random token from an interface.
74 |
75 | Parameters
76 | ----------
77 | tokens : np.ndarray
78 | The token data.
79 | interface : int
80 | The interface ID.
81 | random : np.ndarray
82 | The random state for reproducibility.
83 |
84 | Returns
85 | -------
86 | np.ndarray
87 | The selected token.
88 |
89 | """
90 | # Sample random interface
91 | chain_1 = int(interface["chain_1"])
92 | chain_2 = int(interface["chain_2"])
93 |
94 | tokens_1 = tokens[tokens["asym_id"] == chain_1]
95 | tokens_2 = tokens[tokens["asym_id"] == chain_2]
96 |
97 | # If no interface, pick from the chains
98 | if tokens_1.size and (not tokens_2.size):
99 | query = pick_random_token(tokens_1, random)
100 | elif tokens_2.size and (not tokens_1.size):
101 | query = pick_random_token(tokens_2, random)
102 | elif (not tokens_1.size) and (not tokens_2.size):
103 | query = pick_random_token(tokens, random)
104 | else:
105 | # If we have tokens, compute distances
106 | tokens_1_coords = tokens_1["center_coords"]
107 | tokens_2_coords = tokens_2["center_coords"]
108 |
109 | dists = cdist(tokens_1_coords, tokens_2_coords)
110 | cuttoff = dists < const.interface_cutoff
111 |
112 | # In rare cases, the interface cuttoff is slightly
113 | # too small, then we slightly expand it if it happens
114 | if not np.any(cuttoff):
115 | cuttoff = dists < (const.interface_cutoff + 5.0)
116 |
117 | tokens_1 = tokens_1[np.any(cuttoff, axis=1)]
118 | tokens_2 = tokens_2[np.any(cuttoff, axis=0)]
119 |
120 | # Select random token
121 | candidates = np.concatenate([tokens_1, tokens_2])
122 | query = pick_random_token(candidates, random)
123 |
124 | return query
125 |
126 |
127 | class BoltzCropper(Cropper):
128 | """Interpolate between contiguous and spatial crops."""
129 |
130 | def __init__(self, min_neighborhood: int = 0, max_neighborhood: int = 40) -> None:
131 | """Initialize the cropper.
132 |
133 | Modulates the type of cropping to be performed.
134 | Smaller neighborhoods result in more spatial
135 | cropping. Larger neighborhoods result in more
136 | continuous cropping. A mix can be achieved by
137 | providing a range over which to sample.
138 |
139 | Parameters
140 | ----------
141 | min_neighborhood : int
142 | The minimum neighborhood size, by default 0.
143 | max_neighborhood : int
144 | The maximum neighborhood size, by default 40.
145 |
146 | """
147 | sizes = list(range(min_neighborhood, max_neighborhood + 1, 2))
148 | self.neighborhood_sizes = sizes
149 |
150 | def crop( # noqa: PLR0915
151 | self,
152 | data: Tokenized,
153 | max_tokens: int,
154 | random: np.random.RandomState,
155 | max_atoms: Optional[int] = None,
156 | chain_id: Optional[int] = None,
157 | interface_id: Optional[int] = None,
158 | ) -> Tokenized:
159 | """Crop the data to a maximum number of tokens.
160 |
161 | Parameters
162 | ----------
163 | data : Tokenized
164 | The tokenized data.
165 | max_tokens : int
166 | The maximum number of tokens to crop.
167 | random : np.random.RandomState
168 | The random state for reproducibility.
169 | max_atoms : int, optional
170 | The maximum number of atoms to consider.
171 | chain_id : int, optional
172 | The chain ID to crop.
173 | interface_id : int, optional
174 | The interface ID to crop.
175 |
176 | Returns
177 | -------
178 | Tokenized
179 | The cropped data.
180 |
181 | """
182 | # Check inputs
183 | if chain_id is not None and interface_id is not None:
184 | msg = "Only one of chain_id or interface_id can be provided."
185 | raise ValueError(msg)
186 |
187 | # Randomly select a neighborhood size
188 | neighborhood_size = random.choice(self.neighborhood_sizes)
189 |
190 | # Get token data
191 | token_data = data.tokens
192 | token_bonds = data.bonds
193 | mask = data.structure.mask
194 | chains = data.structure.chains
195 | interfaces = data.structure.interfaces
196 |
197 | # Filter to valid chains
198 | valid_chains = chains[mask]
199 |
200 | # Filter to valid interfaces
201 | valid_interfaces = interfaces
202 | valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_1"]]]
203 | valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_2"]]]
204 |
205 | # Filter to resolved tokens
206 | valid_tokens = token_data[token_data["resolved_mask"]]
207 |
208 | # Check if we have any valid tokens
209 | if not valid_tokens.size:
210 | msg = "No valid tokens in structure"
211 | raise ValueError(msg)
212 |
213 | # Pick a random token, chain, or interface
214 | if chain_id is not None:
215 | query = pick_chain_token(valid_tokens, chain_id, random)
216 | elif interface_id is not None:
217 | interface = interfaces[interface_id]
218 | query = pick_interface_token(valid_tokens, interface, random)
219 | elif valid_interfaces.size:
220 | idx = random.randint(len(valid_interfaces))
221 | interface = valid_interfaces[idx]
222 | query = pick_interface_token(valid_tokens, interface, random)
223 | else:
224 | idx = random.randint(len(valid_chains))
225 | chain_id = valid_chains[idx]["asym_id"]
226 | query = pick_chain_token(valid_tokens, chain_id, random)
227 |
228 | # Sort all tokens by distance to query_coords
229 | dists = valid_tokens["center_coords"] - query["center_coords"]
230 | indices = np.argsort(np.linalg.norm(dists, axis=1))
231 |
232 | # Select cropped indices
233 | cropped: set[int] = set()
234 | total_atoms = 0
235 | for idx in indices:
236 | # Get the token
237 | token = valid_tokens[idx]
238 |
239 | # Get all tokens from this chain
240 | chain_tokens = token_data[token_data["asym_id"] == token["asym_id"]]
241 |
242 | # Pick the whole chain if possible, otherwise select
243 | # a contiguous subset centered at the query token
244 | if len(chain_tokens) <= neighborhood_size:
245 | new_tokens = chain_tokens
246 | else:
247 | # First limit to the maximum set of tokens, with the
248 | # neighboorhood on both sides to handle edges. This
249 | # is mostly for efficiency with the while loop below.
250 | min_idx = token["res_idx"] - neighborhood_size
251 | max_idx = token["res_idx"] + neighborhood_size
252 |
253 | max_token_set = chain_tokens
254 | max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx]
255 | max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx]
256 |
257 | # Start by adding just the query token
258 | new_tokens = max_token_set[max_token_set["res_idx"] == token["res_idx"]]
259 |
260 | # Expand the neighborhood until we have enough tokens, one
261 | # by one to handle some edge cases with non-standard chains.
262 | # We switch to the res_idx instead of the token_idx to always
263 | # include all tokens from modified residues or from ligands.
264 | min_idx = max_idx = token["res_idx"]
265 | while new_tokens.size < neighborhood_size:
266 | min_idx = min_idx - 1
267 | max_idx = max_idx + 1
268 | new_tokens = max_token_set
269 | new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx]
270 | new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx]
271 |
272 | # Compute new tokens and new atoms
273 | new_indices = set(new_tokens["token_idx"]) - cropped
274 | new_tokens = token_data[list(new_indices)]
275 | new_atoms = np.sum(new_tokens["atom_num"])
276 |
277 | # Stop if we exceed the max number of tokens or atoms
278 | if (len(new_indices) > (max_tokens - len(cropped))) or (
279 | (max_atoms is not None) and ((total_atoms + new_atoms) > max_atoms)
280 | ):
281 | break
282 |
283 | # Add new indices
284 | cropped.update(new_indices)
285 | total_atoms += new_atoms
286 |
287 | # Get the cropped tokens sorted by index
288 | token_data = token_data[sorted(cropped)]
289 |
290 | # Only keep bonds within the cropped tokens
291 | indices = token_data["token_idx"]
292 | token_bonds = token_bonds[np.isin(token_bonds["token_1"], indices)]
293 | token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)]
294 |
295 | # Return the cropped tokens
296 | return replace(data, tokens=token_data, bonds=token_bonds)
297 |
--------------------------------------------------------------------------------