├── src └── boltz │ ├── data │ ├── __init__.py │ ├── crop │ │ ├── __init__.py │ │ ├── cropper.py │ │ └── boltz.py │ ├── filter │ │ ├── __init__.py │ │ ├── dynamic │ │ │ ├── __init__.py │ │ │ ├── filter.py │ │ │ ├── resolution.py │ │ │ ├── max_residues.py │ │ │ ├── size.py │ │ │ ├── subset.py │ │ │ └── date.py │ │ └── static │ │ │ ├── __init__.py │ │ │ ├── filter.py │ │ │ ├── ligand.py │ │ │ └── polymer.py │ ├── module │ │ ├── __init__.py │ │ └── inference.py │ ├── parse │ │ ├── __init__.py │ │ ├── yaml.py │ │ ├── a3m.py │ │ └── fasta.py │ ├── sample │ │ ├── __init__.py │ │ ├── random.py │ │ ├── sampler.py │ │ ├── distillation.py │ │ └── cluster.py │ ├── write │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── pdb.py │ │ ├── writer.py │ │ └── mmcif.py │ ├── feature │ │ ├── __init__.py │ │ └── pad.py │ ├── tokenize │ │ ├── __init__.py │ │ ├── tokenizer.py │ │ └── boltz.py │ └── const.py │ ├── model │ ├── __init__.py │ ├── loss │ │ ├── __init__.py │ │ ├── distogram.py │ │ └── diffusion.py │ ├── optim │ │ ├── __init__.py │ │ └── scheduler.py │ ├── layers │ │ ├── __init__.py │ │ ├── triangular_attention │ │ │ ├── __init__.py │ │ │ └── attention.py │ │ ├── dropout.py │ │ ├── transition.py │ │ ├── initialize.py │ │ ├── outer_product_mean.py │ │ ├── attention.py │ │ ├── triangular_mult.py │ │ └── pair_averaging.py │ └── modules │ │ ├── __init__.py │ │ ├── confidence_utils.py │ │ └── transformers.py │ └── __init__.py ├── docs ├── boltz1_pred_figure.png ├── training.md └── prediction.md ├── examples ├── ligand.yaml └── ligand.fasta ├── scripts └── train │ ├── assets │ ├── casp15_ids.txt │ ├── test_ids.txt │ └── validation_ids.txt │ ├── configs │ ├── structure.yaml │ └── confidence.yaml │ └── train.py ├── LICENSE ├── README.md ├── pyproject.toml └── .gitignore /src/boltz/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/data/crop/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/data/filter/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/data/module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/data/parse/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/data/sample/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/data/write/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/model/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/model/optim/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/data/feature/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/data/tokenize/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/model/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/model/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/data/filter/static/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/boltz/model/layers/triangular_attention/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/boltz1_pred_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/boltz/main/docs/boltz1_pred_figure.png -------------------------------------------------------------------------------- /src/boltz/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import PackageNotFoundError, version 2 | 3 | try: # noqa: SIM105 4 | __version__ = version("boltz") 5 | except PackageNotFoundError: 6 | # package is not installed 7 | pass 8 | -------------------------------------------------------------------------------- /src/boltz/data/tokenize/tokenizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from boltz.data.types import Input, Tokenized 4 | 5 | 6 | class Tokenizer(ABC): 7 | """Tokenize an input structure for training.""" 8 | 9 | @abstractmethod 10 | def tokenize(self, data: Input) -> Tokenized: 11 | """Tokenize the input data. 12 | 13 | Parameters 14 | ---------- 15 | data : Inpput 16 | The input data. 17 | 18 | Returns 19 | ------- 20 | Tokenized 21 | The tokenized data. 22 | 23 | """ 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/filter.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from boltz.data.types import Record 4 | 5 | 6 | class DynamicFilter(ABC): 7 | """Base class for data filters.""" 8 | 9 | @abstractmethod 10 | def filter(self, record: Record) -> bool: 11 | """Filter a data record. 12 | 13 | Parameters 14 | ---------- 15 | record : Record 16 | The object to consider filtering in / out. 17 | 18 | Returns 19 | ------- 20 | bool 21 | True if the data passes the filter, False otherwise. 22 | 23 | """ 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /src/boltz/data/write/utils.py: -------------------------------------------------------------------------------- 1 | import string 2 | from collections.abc import Iterator 3 | 4 | 5 | def generate_tags() -> Iterator[str]: 6 | """Generate chain tags. 7 | 8 | Yields 9 | ------ 10 | str 11 | The next chain tag 12 | 13 | """ 14 | for i in range(1, 4): 15 | for j in range(len(string.ascii_uppercase) ** i): 16 | tag = "" 17 | for k in range(i): 18 | tag += string.ascii_uppercase[ 19 | j 20 | // (len(string.ascii_uppercase) ** k) 21 | % len(string.ascii_uppercase) 22 | ] 23 | yield tag 24 | -------------------------------------------------------------------------------- /examples/ligand.yaml: -------------------------------------------------------------------------------- 1 | version: 1 # Optional, defaults to 1 2 | sequences: 3 | - protein: 4 | id: [A, B] 5 | sequence: MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ 6 | msa: ./examples/msa/seq1.a3m 7 | - ligand: 8 | id: [C, D] 9 | ccd: SAH 10 | - ligand: 11 | id: [E, F] 12 | smiles: N[C@@H](Cc1ccc(O)cc1)C(=O)O 13 | -------------------------------------------------------------------------------- /src/boltz/data/filter/static/filter.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | 5 | from boltz.data.types import Structure 6 | 7 | 8 | class StaticFilter(ABC): 9 | """Base class for structure filters.""" 10 | 11 | @abstractmethod 12 | def filter(self, structure: Structure) -> np.ndarray: 13 | """Filter chains in a structure. 14 | 15 | Parameters 16 | ---------- 17 | structure : Structure 18 | The structure to filter chains from. 19 | 20 | Returns 21 | ------- 22 | np.ndarray 23 | The chains to keep, as a boolean mask. 24 | 25 | """ 26 | raise NotImplementedError 27 | -------------------------------------------------------------------------------- /scripts/train/assets/casp15_ids.txt: -------------------------------------------------------------------------------- 1 | T1112 2 | T1118v1 3 | T1154 4 | T1137s1 5 | T1188 6 | T1157s1 7 | T1137s6 8 | R1117 9 | H1106 10 | T1106s2 11 | R1149 12 | T1158 13 | T1137s2 14 | T1145 15 | T1121 16 | T1123 17 | T1113 18 | R1156 19 | T1114s1 20 | T1183 21 | R1107 22 | T1137s7 23 | T1124 24 | T1178 25 | T1147 26 | R1128 27 | T1161 28 | R1108 29 | T1194 30 | T1185s2 31 | T1176 32 | T1158v3 33 | T1137s4 34 | T1160 35 | T1120 36 | H1185 37 | T1134s1 38 | T1119 39 | H1151 40 | T1137s8 41 | T1133 42 | T1187 43 | H1157 44 | T1122 45 | T1104 46 | T1158v2 47 | T1137s5 48 | T1129s2 49 | T1174 50 | T1157s2 51 | T1155 52 | T1158v4 53 | T1152 54 | T1137s9 55 | T1134s2 56 | T1125 57 | R1116 58 | H1134 59 | R1136 60 | T1159 61 | T1137s3 62 | T1185s1 63 | T1179 64 | T1106s1 65 | T1132 66 | T1185s4 67 | T1114s3 68 | T1114s2 69 | T1151s2 70 | T1158v1 71 | R1117v2 72 | T1173 73 | -------------------------------------------------------------------------------- /src/boltz/model/layers/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def get_dropout_mask( 6 | dropout: float, 7 | z: Tensor, 8 | training: bool, 9 | columnwise: bool = False, 10 | ) -> Tensor: 11 | """Get the dropout mask. 12 | 13 | Parameters 14 | ---------- 15 | dropout : float 16 | The dropout rate 17 | z : torch.Tensor 18 | The tensor to apply dropout to 19 | training : bool 20 | Whether the model is in training mode 21 | columnwise : bool, optional 22 | Whether to apply dropout columnwise 23 | 24 | Returns 25 | ------- 26 | torch.Tensor 27 | The dropout mask 28 | 29 | """ 30 | dropout = dropout * training 31 | v = z[:, 0:1, :, 0:1] if columnwise else z[:, :, 0:1, 0:1] 32 | d = torch.rand_like(v) > dropout 33 | d = d * 1.0 / (1.0 - dropout) 34 | return d 35 | -------------------------------------------------------------------------------- /examples/ligand.fasta: -------------------------------------------------------------------------------- 1 | >A|protein|./examples/msa/seq1.a3m 2 | MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ 3 | >B|protein|./examples/msa/seq1.a3m 4 | MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ 5 | >C|ccd 6 | SAH 7 | >D|ccd 8 | SAH 9 | >E|smiles 10 | N[C@@H](Cc1ccc(O)cc1)C(=O)O 11 | >F|smiles 12 | N[C@@H](Cc1ccc(O)cc1)C(=O)O -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/resolution.py: -------------------------------------------------------------------------------- 1 | from boltz.data.types import Record 2 | from boltz.data.filter.dynamic.filter import DynamicFilter 3 | 4 | 5 | class ResolutionFilter(DynamicFilter): 6 | """A filter that filters complexes based on their resolution.""" 7 | 8 | def __init__(self, resolution: float = 9.0) -> None: 9 | """Initialize the filter. 10 | 11 | Parameters 12 | ---------- 13 | resolution : float, optional 14 | The maximum allowed resolution. 15 | 16 | """ 17 | self.resolution = resolution 18 | 19 | def filter(self, record: Record) -> bool: 20 | """Filter complexes based on their resolution. 21 | 22 | Parameters 23 | ---------- 24 | record : Record 25 | The record to filter. 26 | 27 | Returns 28 | ------- 29 | bool 30 | Whether the record should be filtered. 31 | 32 | """ 33 | structure = record.structure 34 | return structure.resolution <= self.resolution 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/max_residues.py: -------------------------------------------------------------------------------- 1 | from boltz.data.types import Record 2 | from boltz.data.filter.dynamic.filter import DynamicFilter 3 | 4 | 5 | class MaxResiduesFilter(DynamicFilter): 6 | """A filter that filters structures based on their size.""" 7 | 8 | def __init__(self, min_residues: int = 1, max_residues: int = 500) -> None: 9 | """Initialize the filter. 10 | 11 | Parameters 12 | ---------- 13 | min_chains : int 14 | The minimum number of chains allowed. 15 | max_chains : int 16 | The maximum number of chains allowed. 17 | 18 | """ 19 | self.min_residues = min_residues 20 | self.max_residues = max_residues 21 | 22 | def filter(self, record: Record) -> bool: 23 | """Filter structures based on their resolution. 24 | 25 | Parameters 26 | ---------- 27 | record : Record 28 | The record to filter. 29 | 30 | Returns 31 | ------- 32 | bool 33 | Whether the record should be filtered. 34 | 35 | """ 36 | num_residues = sum(chain.num_residues for chain in record.chains) 37 | return num_residues <= self.max_residues and num_residues >= self.min_residues 38 | -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/size.py: -------------------------------------------------------------------------------- 1 | from boltz.data.types import Record 2 | from boltz.data.filter.dynamic.filter import DynamicFilter 3 | 4 | 5 | class SizeFilter(DynamicFilter): 6 | """A filter that filters structures based on their size.""" 7 | 8 | def __init__(self, min_chains: int = 1, max_chains: int = 300) -> None: 9 | """Initialize the filter. 10 | 11 | Parameters 12 | ---------- 13 | min_chains : int 14 | The minimum number of chains allowed. 15 | max_chains : int 16 | The maximum number of chains allowed. 17 | 18 | """ 19 | self.min_chains = min_chains 20 | self.max_chains = max_chains 21 | 22 | def filter(self, record: Record) -> bool: 23 | """Filter structures based on their resolution. 24 | 25 | Parameters 26 | ---------- 27 | record : Record 28 | The record to filter. 29 | 30 | Returns 31 | ------- 32 | bool 33 | Whether the record should be filtered. 34 | 35 | """ 36 | num_chains = record.structure.num_chains 37 | num_valid = sum(1 for chain in record.chains if chain.valid) 38 | return num_chains <= self.max_chains and num_valid >= self.min_chains 39 | -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/subset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from boltz.data.types import Record 4 | from boltz.data.filter.dynamic.filter import DynamicFilter 5 | 6 | 7 | class SubsetFilter(DynamicFilter): 8 | """Filter a data record based on a subset of the data.""" 9 | 10 | def __init__(self, subset: str, reverse: bool = False) -> None: 11 | """Initialize the filter. 12 | 13 | Parameters 14 | ---------- 15 | subset : str 16 | The subset of data to consider, one per line. 17 | 18 | """ 19 | with Path(subset).open("r") as f: 20 | subset = f.read().splitlines() 21 | 22 | self.subset = {s.lower() for s in subset} 23 | self.reverse = reverse 24 | 25 | def filter(self, record: Record) -> bool: 26 | """Filter a data record. 27 | 28 | Parameters 29 | ---------- 30 | record : Record 31 | The object to consider filtering in / out. 32 | 33 | Returns 34 | ------- 35 | bool 36 | True if the data passes the filter, False otherwise. 37 | 38 | """ 39 | if self.reverse: 40 | return record.id.lower() not in self.subset 41 | else: # noqa: RET505 42 | return record.id.lower() in self.subset 43 | -------------------------------------------------------------------------------- /src/boltz/data/sample/random.py: -------------------------------------------------------------------------------- 1 | from dataclasses import replace 2 | from typing import Iterator, List 3 | 4 | from numpy.random import RandomState 5 | 6 | from boltz.data.types import Record 7 | from boltz.data.sample.sampler import Sample, Sampler 8 | 9 | 10 | class RandomSampler(Sampler): 11 | """A simple random sampler with replacement.""" 12 | 13 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: 14 | """Sample a structure from the dataset infinitely. 15 | 16 | Parameters 17 | ---------- 18 | records : List[Record] 19 | The records to sample from. 20 | random : RandomState 21 | The random state for reproducibility. 22 | 23 | Yields 24 | ------ 25 | Sample 26 | A data sample. 27 | 28 | """ 29 | while True: 30 | # Sample item from the list 31 | index = random.randint(0, len(records)) 32 | record = records[index] 33 | 34 | # Remove invalid chains and interfaces 35 | chains = [c for c in record.chains if c.valid] 36 | interfaces = [i for i in record.interfaces if i.valid] 37 | record = replace(record, chains=chains, interfaces=interfaces) 38 | 39 | yield Sample(record=record) 40 | -------------------------------------------------------------------------------- /src/boltz/data/sample/sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Iterator, List, Optional 4 | 5 | from numpy.random import RandomState 6 | 7 | from boltz.data.types import Record 8 | 9 | 10 | @dataclass 11 | class Sample: 12 | """A sample with optional chain and interface IDs. 13 | 14 | Attributes 15 | ---------- 16 | record : Record 17 | The record. 18 | chain_id : Optional[int] 19 | The chain ID. 20 | interface_id : Optional[int] 21 | The interface ID. 22 | """ 23 | 24 | record: Record 25 | chain_id: Optional[int] = None 26 | interface_id: Optional[int] = None 27 | 28 | 29 | class Sampler(ABC): 30 | """Abstract base class for samplers.""" 31 | 32 | @abstractmethod 33 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: 34 | """Sample a structure from the dataset infinitely. 35 | 36 | Parameters 37 | ---------- 38 | records : List[Record] 39 | The records to sample from. 40 | random : RandomState 41 | The random state for reproducibility. 42 | 43 | Yields 44 | ------ 45 | Sample 46 | A data sample. 47 | 48 | """ 49 | raise NotImplementedError 50 | -------------------------------------------------------------------------------- /src/boltz/data/crop/cropper.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional 3 | 4 | import numpy as np 5 | 6 | from boltz.data.types import Tokenized 7 | 8 | 9 | class Cropper(ABC): 10 | """Abstract base class for cropper.""" 11 | 12 | @abstractmethod 13 | def crop( 14 | self, 15 | data: Tokenized, 16 | max_tokens: int, 17 | random: np.random.RandomState, 18 | max_atoms: Optional[int] = None, 19 | chain_id: Optional[int] = None, 20 | interface_id: Optional[int] = None, 21 | ) -> Tokenized: 22 | """Crop the data to a maximum number of tokens. 23 | 24 | Parameters 25 | ---------- 26 | data : Tokenized 27 | The tokenized data. 28 | max_tokens : int 29 | The maximum number of tokens to crop. 30 | random : np.random.RandomState 31 | The random state for reproducibility. 32 | max_atoms : Optional[int] 33 | The maximum number of atoms to consider. 34 | chain_id : Optional[int] 35 | The chain ID to crop. 36 | interface_id : Optional[int] 37 | The interface ID to crop. 38 | 39 | Returns 40 | ------- 41 | Tokenized 42 | The cropped data. 43 | 44 | """ 45 | raise NotImplementedError 46 | -------------------------------------------------------------------------------- /src/boltz/model/loss/distogram.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def distogram_loss( 8 | output: Dict[str, Tensor], 9 | feats: Dict[str, Tensor], 10 | ) -> Tuple[Tensor, Tensor]: 11 | """Compute the distogram loss. 12 | 13 | Parameters 14 | ---------- 15 | output : Dict[str, Tensor] 16 | Output of the model 17 | feats : Dict[str, Tensor] 18 | Input features 19 | 20 | Returns 21 | ------- 22 | Tensor 23 | The globally averaged loss. 24 | Tensor 25 | Per example loss. 26 | 27 | """ 28 | # Get predicted distograms 29 | pred = output["pdistogram"] 30 | 31 | # Compute target distogram 32 | target = feats["disto_target"] 33 | 34 | # Combine target mask and padding mask 35 | mask = feats["token_disto_mask"] 36 | mask = mask[:, None, :] * mask[:, :, None] 37 | mask = mask * (1 - torch.eye(mask.shape[1])[None]).to(pred) 38 | 39 | # Compute the distogram loss 40 | errors = -1 * torch.sum( 41 | target * torch.nn.functional.log_softmax(pred, dim=-1), 42 | dim=-1, 43 | ) 44 | denom = 1e-5 + torch.sum(mask, dim=(-1, -2)) 45 | mean = errors * mask 46 | mean = torch.sum(mean, dim=-1) 47 | mean = mean / denom[..., None] 48 | batch_loss = torch.sum(mean, dim=-1) 49 | global_loss = torch.mean(batch_loss) 50 | return global_loss, batch_loss 51 | -------------------------------------------------------------------------------- /src/boltz/data/parse/yaml.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yaml 4 | from rdkit.Chem.rdchem import Mol 5 | 6 | from boltz.data.parse.schema import parse_boltz_schema 7 | from boltz.data.types import Target 8 | 9 | 10 | def parse_yaml(path: Path, ccd: dict[str, Mol]) -> Target: 11 | """Parse a Boltz input yaml / json. 12 | 13 | The input file should be a yaml file with the following format: 14 | 15 | sequences: 16 | - protein: 17 | id: A 18 | sequence: "MADQLTEEQIAEFKEAFSLF" 19 | - protein: 20 | id: [B, C] 21 | sequence: "AKLSILPWGHC" 22 | - rna: 23 | id: D 24 | sequence: "GCAUAGC" 25 | - ligand: 26 | id: E 27 | smiles: "CC1=CC=CC=C1" 28 | - ligand: 29 | id: [F, G] 30 | ccd: [] 31 | constraints: 32 | - bond: 33 | atom1: [A, 1, CA] 34 | atom2: [A, 2, N] 35 | - pocket: 36 | binder: E 37 | contacts: [[B, 1], [B, 2]] 38 | version: 1 39 | 40 | Parameters 41 | ---------- 42 | path : Path 43 | Path to the YAML input format. 44 | components : Dict 45 | Dictionary of CCD components. 46 | 47 | Returns 48 | ------- 49 | Target 50 | The parsed target. 51 | 52 | """ 53 | with path.open("r") as file: 54 | data = yaml.safe_load(file) 55 | 56 | name = path.stem 57 | return parse_boltz_schema(name, data, ccd) 58 | -------------------------------------------------------------------------------- /src/boltz/data/sample/distillation.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List 2 | 3 | from numpy.random import RandomState 4 | 5 | from boltz.data.types import Record 6 | from boltz.data.sample.sampler import Sample, Sampler 7 | 8 | 9 | class DistillationSampler(Sampler): 10 | """A sampler for monomer distillation data.""" 11 | 12 | def __init__(self, small_size: int = 200, small_prob: float = 0.01) -> None: 13 | """Initialize the sampler. 14 | 15 | Parameters 16 | ---------- 17 | small_size : int, optional 18 | The maximum size to be considered small. 19 | small_prob : float, optional 20 | The probability of sampling a small item. 21 | 22 | """ 23 | self._size = small_size 24 | self._prob = small_prob 25 | 26 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: 27 | """Sample a structure from the dataset infinitely. 28 | 29 | Parameters 30 | ---------- 31 | records : List[Record] 32 | The records to sample from. 33 | random : RandomState 34 | The random state for reproducibility. 35 | 36 | Yields 37 | ------ 38 | Sample 39 | A data sample. 40 | 41 | """ 42 | # Remove records with invalid chains 43 | records = [r for r in records if r.chains[0].valid] 44 | 45 | # Split in small and large proteins. We assume that there is only 46 | # one chain per record, as is the case for monomer distillation 47 | small = [r for r in records if r.chains[0].num_residues <= self._size] 48 | large = [r for r in records if r.chains[0].num_residues > self._size] 49 | 50 | # Sample infinitely 51 | while True: 52 | # Sample small or large 53 | samples = small if random.rand() < self._prob else large 54 | 55 | # Sample item from the list 56 | index = random.randint(0, len(samples)) 57 | yield Sample(record=samples[index]) 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Boltz-1: 2 | 3 | Democratizing Biomolecular Interaction Modeling 4 |

5 | 6 | ![](docs/boltz1_pred_figure.png) 7 | 8 | Boltz-1 is an open-source model which predicts the 3D structure of proteins, rna, dna and small molecules; it handles modified residues, covalent ligands and glycans, as well as condition the generation on pocket residues. 9 | 10 | For more information about the model, see our [technical report](https://gcorso.github.io/assets/boltz1.pdf). 11 | 12 | ## Installation 13 | Install boltz with PyPI (recommended): 14 | 15 | ``` 16 | pip install boltz 17 | ``` 18 | 19 | or directly from GitHub for daily updates: 20 | 21 | ``` 22 | git clone https://github.com/jwohlwend/boltz.git 23 | cd boltz; pip install -e . 24 | ``` 25 | > Note: we recommend installing boltz in a fresh python environment 26 | 27 | ## Inference 28 | 29 | You can run inference using Boltz-1 with: 30 | 31 | ``` 32 | boltz predict input_path 33 | ``` 34 | 35 | Boltz currently accepts three input formats: 36 | 37 | 1. Fasta file, for most use cases 38 | 39 | 2. A comprehensive YAML schema, for more complex use cases 40 | 41 | 3. A directory containing files of the above formats, for batched processing 42 | 43 | To see all available options: `boltz predict --help` and for more informaton on these input formats, see our [prediction instructions](docs/prediction.md). 44 | 45 | ## Training 46 | 47 | If you're interested in retraining the model, see our [training instructions](docs/training.md). 48 | 49 | ## Contributing 50 | 51 | We welcome external contributions and are eager to engage with the community. Connect with us on our [Slack channel](https://boltz-community.slack.com/archives/C0818M6DWH2) to discuss advancements, share insights, and foster collaboration around Boltz-1. 52 | 53 | ## Coming very soon 54 | 55 | - [ ] Pocket conditioning support 56 | - [ ] More examples 57 | - [ ] Full data processing pipeline 58 | - [ ] Colab notebook for inference 59 | - [ ] Confidence model checkpoint 60 | - [ ] Support for custom paired MSA 61 | - [ ] Kernel integration 62 | 63 | ## License 64 | 65 | Our model and code are released under MIT License, and can be freely used for both academic and commercial purposes. 66 | -------------------------------------------------------------------------------- /src/boltz/data/filter/dynamic/date.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Literal 3 | 4 | from boltz.data.types import Record 5 | from boltz.data.filter.dynamic.filter import DynamicFilter 6 | 7 | 8 | class DateFilter(DynamicFilter): 9 | """A filter that filters complexes based on their date. 10 | 11 | The date can be the deposition, release, or revision date. 12 | If the date is not available, the previous date is used. 13 | 14 | If no date is available, the complex is rejected. 15 | 16 | """ 17 | 18 | def __init__( 19 | self, 20 | date: str, 21 | ref: Literal["deposited", "revised", "released"], 22 | ) -> None: 23 | """Initialize the filter. 24 | 25 | Parameters 26 | ---------- 27 | date : str, optional 28 | The maximum date of PDB entries to filter 29 | ref : Literal["deposited", "revised", "released"] 30 | The reference date to use. 31 | 32 | """ 33 | self.filter_date = datetime.fromisoformat(date) 34 | self.ref = ref 35 | 36 | if ref not in ["deposited", "revised", "released"]: 37 | msg = ( 38 | "Invalid reference date. Must be ", 39 | "deposited, revised, or released", 40 | ) 41 | raise ValueError(msg) 42 | 43 | def filter(self, record: Record) -> bool: 44 | """Filter a record based on its date. 45 | 46 | Parameters 47 | ---------- 48 | record : Record 49 | The record to filter. 50 | 51 | Returns 52 | ------- 53 | bool 54 | Whether the record should be filtered. 55 | 56 | """ 57 | structure = record.structure 58 | 59 | if self.ref == "deposited": 60 | date = structure.deposited 61 | elif self.ref == "released": 62 | date = structure.released 63 | if not date: 64 | date = structure.deposited 65 | elif self.ref == "revised": 66 | date = structure.revised 67 | if not date and structure.released: 68 | date = structure.released 69 | elif not date: 70 | date = structure.deposited 71 | 72 | if date is None or date == "": 73 | return False 74 | 75 | date = datetime.fromisoformat(date) 76 | return date <= self.filter_date 77 | -------------------------------------------------------------------------------- /src/boltz/data/feature/pad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn.functional import pad 4 | 5 | 6 | def pad_dim(data: Tensor, dim: int, pad_len: float, value: float = 0) -> Tensor: 7 | """Pad a tensor along a given dimension. 8 | 9 | Parameters 10 | ---------- 11 | data : Tensor 12 | The input tensor. 13 | dim : int 14 | The dimension to pad. 15 | pad_len : float 16 | The padding length. 17 | value : int, optional 18 | The value to pad with. 19 | 20 | Returns 21 | ------- 22 | Tensor 23 | The padded tensor. 24 | 25 | """ 26 | if pad_len == 0: 27 | return data 28 | 29 | total_dims = len(data.shape) 30 | padding = [0] * (2 * (total_dims - dim)) 31 | padding[2 * (total_dims - 1 - dim) + 1] = pad_len 32 | return pad(data, tuple(padding), value=value) 33 | 34 | 35 | def pad_to_max(data: list[Tensor], value: float = 0) -> tuple[Tensor, Tensor]: 36 | """Pad the data in all dimensions to the maximum found. 37 | 38 | Parameters 39 | ---------- 40 | data : List[Tensor] 41 | List of tensors to pad. 42 | value : float 43 | The value to use for padding. 44 | 45 | Returns 46 | ------- 47 | Tensor 48 | The padded tensor. 49 | Tensor 50 | The padding mask. 51 | 52 | """ 53 | if isinstance(data[0], str): 54 | return data, 0 55 | 56 | # Check if all have the same shape 57 | if all(d.shape == data[0].shape for d in data): 58 | return torch.stack(data, dim=0), 0 59 | 60 | # Get the maximum in each dimension 61 | num_dims = len(data[0].shape) 62 | max_dims = [max(d.shape[i] for d in data) for i in range(num_dims)] 63 | 64 | # Get the padding lengths 65 | pad_lengths = [] 66 | for d in data: 67 | dims = [] 68 | for i in range(num_dims): 69 | dims.append(0) 70 | dims.append(max_dims[num_dims - i - 1] - d.shape[num_dims - i - 1]) 71 | pad_lengths.append(dims) 72 | 73 | # Pad the data 74 | padding = [ 75 | pad(torch.ones_like(d), pad_len, value=0) 76 | for d, pad_len in zip(data, pad_lengths) 77 | ] 78 | data = [pad(d, pad_len, value=value) for d, pad_len in zip(data, pad_lengths)] 79 | 80 | # Stack the data 81 | padding = torch.stack(padding, dim=0) 82 | data = torch.stack(data, dim=0) 83 | 84 | return data, padding 85 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "boltz" 7 | version = "0.1.0" 8 | requires-python = ">=3.9" 9 | description = "Boltz-1" 10 | readme = "README.md" 11 | dependencies = [ 12 | "torch>=2.2", 13 | "numpy==1.26.3", 14 | "hydra-core==1.3.2", 15 | "pytorch-lightning==2.4.0", 16 | "rdkit==2024.3.6", 17 | "dm-tree==0.1.8", 18 | "requests==2.32.3", 19 | "pandas==2.2.3", 20 | "types-requests", 21 | "einops==0.8.0", 22 | "einx==0.3.0", 23 | "fairscale==0.4.13", 24 | "mashumaro==3.14", 25 | "modelcif==1.2", 26 | "wandb==0.18.7", 27 | "click==8.1.7", 28 | "pyyaml==6.0.2", 29 | "biopython==1.84", 30 | "scipy==1.13.1", 31 | ] 32 | 33 | [project.scripts] 34 | boltz = "boltz.main:cli" 35 | 36 | [project.optional-dependencies] 37 | lint = ["ruff"] 38 | 39 | [tool.ruff] 40 | src = ["src"] 41 | extend-exclude = ["conf.py"] 42 | target-version = "py39" 43 | lint.select = ["ALL"] 44 | lint.ignore = [ 45 | "COM812", # Conflicts with the formatter 46 | "ISC001", # Conflicts with the formatter 47 | "ANN101", # "missing-type-self" 48 | "RET504", # Unnecessary assignment to `x` before `return` statementRuff 49 | "S101", # Use of `assert` detected 50 | "D100", # Missing docstring in public module 51 | "D104", # Missing docstring in public package 52 | "PT001", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715 53 | "PT004", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715 54 | "PT005", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715 55 | "PT023", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715 56 | "FBT001", 57 | "FBT002", 58 | "PLR0913", # Too many arguments to init (> 5) 59 | ] 60 | 61 | [tool.ruff.lint.per-file-ignores] 62 | "**/__init__.py" = [ 63 | "F401", # Imported but unused 64 | "F403", # Wildcard imports 65 | ] 66 | "docs/**" = [ 67 | "INP001", # Requires __init__.py but folder is not a package. 68 | ] 69 | "scripts/**" = [ 70 | "INP001", # Requires __init__.py but folder is not a package. 71 | ] 72 | 73 | [tool.ruff.lint.pyupgrade] 74 | # Preserve types, even if a file imports `from __future__ import annotations`(https://github.com/astral-sh/ruff/issues/5434) 75 | keep-runtime-typing = true 76 | 77 | [tool.ruff.lint.pydocstyle] 78 | convention = "numpy" 79 | -------------------------------------------------------------------------------- /docs/training.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | ## Download processed data 4 | 5 | Instructions on how to download the processed dataset for training are coming soon, we are currently uploading the data to sharable storage and will update this page when ready. 6 | 7 | ## Modify the configuration file 8 | 9 | The training script requires a configuration file to run. This file specifies the paths to the data, the output directory, and other parameters of the data, model and training process. 10 | 11 | We provide under `scripts/train/configs` a template configuration file analogous to the one we used for training the structure model (`structure.yaml`) and the confidence model (`confidence.yaml`). 12 | 13 | The following are the main parameters that you should modify in the configuration file to get the structure model to train: 14 | 15 | ```yaml 16 | trainer: 17 | devices: 1 18 | 19 | output: SET_PATH_HERE # Path to the output directory 20 | resume: PATH_TO_CHECKPOINT_FILE # Path to a checkpoint file to resume training from if any null otherwise 21 | 22 | data: 23 | datasets: 24 | - _target_: boltz.data.module.training.DatasetConfig 25 | target_dir: PATH_TO_TARGETS_DIR # Path to the directory containing the processed structure files 26 | msa_dir: PATH_TO_MSA_DIR # Path to the directory containing the processed MSA files 27 | 28 | symmetries: PATH_TO_SYMMETRY_FILE # Path to the file containing molecule the symmetry information 29 | max_tokens: 512 # Maximum number of tokens in the input sequence 30 | max_atoms: 4608 # Maximum number of atoms in the input structure 31 | ``` 32 | 33 | `max_tokens` and `max_atoms` are the maximum number of tokens and atoms in the crop. Depending on the size of the GPUs you are using (as well as the training speed desired), you may want to adjust these values. Other recommended values are 256 and 2304, or 384 and 3456 respectively. 34 | 35 | ## Run the training script 36 | 37 | Before running the full training, we recommend using the debug flag. This turns off DDP (sets single device) and set `num_workers` to 0 so everything is in a single process, as well as disabling wandb: 38 | 39 | python scripts/train/train.py scripts/train/configs/structure.yaml debug=1 40 | 41 | Once that seems to run okay, you can kill it and launch the training run: 42 | 43 | python scripts/train/train.py scripts/train/configs/structure.yaml 44 | 45 | We also provide a different configuration file to train the confidence model: 46 | 47 | python scripts/train/train.py scripts/train/configs/confidence.yaml -------------------------------------------------------------------------------- /src/boltz/model/layers/transition.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torch import Tensor, nn 4 | 5 | import boltz.model.layers.initialize as init 6 | 7 | 8 | class Transition(nn.Module): 9 | """Perform a two-layer MLP.""" 10 | 11 | def __init__( 12 | self, 13 | dim: int = 128, 14 | hidden: int = 512, 15 | out_dim: Optional[int] = None, 16 | chunk_size: int = None, 17 | ) -> None: 18 | """Initialize the TransitionUpdate module. 19 | 20 | Parameters 21 | ---------- 22 | dim: int 23 | The dimension of the input, default 128 24 | hidden: int 25 | The dimension of the hidden, default 512 26 | out_dim: Optional[int] 27 | The dimension of the output, default None 28 | chunk_size: int 29 | The chunk size for inference, default None 30 | 31 | """ 32 | super().__init__() 33 | if out_dim is None: 34 | out_dim = dim 35 | 36 | self.norm = nn.LayerNorm(dim, eps=1e-5) 37 | self.fc1 = nn.Linear(dim, hidden, bias=False) 38 | self.fc2 = nn.Linear(dim, hidden, bias=False) 39 | self.fc3 = nn.Linear(hidden, out_dim, bias=False) 40 | self.silu = nn.SiLU() 41 | self.hidden = hidden 42 | self.chunk_size = chunk_size 43 | 44 | init.bias_init_one_(self.norm.weight) 45 | init.bias_init_zero_(self.norm.bias) 46 | 47 | init.lecun_normal_init_(self.fc1.weight) 48 | init.lecun_normal_init_(self.fc2.weight) 49 | init.final_init_(self.fc3.weight) 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | """Perform a forward pass. 53 | 54 | Parameters 55 | ---------- 56 | x: torch.Tensor 57 | The input data of shape (..., D) 58 | 59 | Returns 60 | ------- 61 | x: torch.Tensor 62 | The output data of shape (..., D) 63 | 64 | """ 65 | x = self.norm(x) 66 | 67 | if self.chunk_size is None or self.training: 68 | x = self.silu(self.fc1(x)) * self.fc2(x) 69 | x = self.fc3(x) 70 | return x 71 | else: 72 | # Compute in chunks 73 | for i in range(0, self.hidden, self.chunk_size): 74 | fc1_slice = self.fc1.weight[i : i + self.chunk_size, :] 75 | fc2_slice = self.fc2.weight[i : i + self.chunk_size, :] 76 | fc3_slice = self.fc3.weight[:, i : i + self.chunk_size] 77 | x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T) 78 | if i == 0: 79 | x_out = x_chunk @ fc3_slice.T 80 | else: 81 | x_out = x_out + x_chunk @ fc3_slice.T 82 | return x_out 83 | -------------------------------------------------------------------------------- /src/boltz/model/layers/initialize.py: -------------------------------------------------------------------------------- 1 | """Utility functions for initializing weights and biases.""" 2 | 3 | # Copyright 2021 AlQuraishi Laboratory 4 | # Copyright 2021 DeepMind Technologies Limited 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import math 19 | import numpy as np 20 | from scipy.stats import truncnorm 21 | import torch 22 | 23 | 24 | def _prod(nums): 25 | out = 1 26 | for n in nums: 27 | out = out * n 28 | return out 29 | 30 | 31 | def _calculate_fan(linear_weight_shape, fan="fan_in"): 32 | fan_out, fan_in = linear_weight_shape 33 | 34 | if fan == "fan_in": 35 | f = fan_in 36 | elif fan == "fan_out": 37 | f = fan_out 38 | elif fan == "fan_avg": 39 | f = (fan_in + fan_out) / 2 40 | else: 41 | raise ValueError("Invalid fan option") 42 | 43 | return f 44 | 45 | 46 | def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): 47 | shape = weights.shape 48 | f = _calculate_fan(shape, fan) 49 | scale = scale / max(1, f) 50 | a = -2 51 | b = 2 52 | std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) 53 | size = _prod(shape) 54 | samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) 55 | samples = np.reshape(samples, shape) 56 | with torch.no_grad(): 57 | weights.copy_(torch.tensor(samples, device=weights.device)) 58 | 59 | 60 | def lecun_normal_init_(weights): 61 | trunc_normal_init_(weights, scale=1.0) 62 | 63 | 64 | def he_normal_init_(weights): 65 | trunc_normal_init_(weights, scale=2.0) 66 | 67 | 68 | def glorot_uniform_init_(weights): 69 | torch.nn.init.xavier_uniform_(weights, gain=1) 70 | 71 | 72 | def final_init_(weights): 73 | with torch.no_grad(): 74 | weights.fill_(0.0) 75 | 76 | 77 | def gating_init_(weights): 78 | with torch.no_grad(): 79 | weights.fill_(0.0) 80 | 81 | 82 | def bias_init_zero_(bias): 83 | with torch.no_grad(): 84 | bias.fill_(0.0) 85 | 86 | 87 | def bias_init_one_(bias): 88 | with torch.no_grad(): 89 | bias.fill_(1.0) 90 | 91 | 92 | def normal_init_(weights): 93 | torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") 94 | 95 | 96 | def ipa_point_weights_init_(weights): 97 | with torch.no_grad(): 98 | softplus_inverse_1 = 0.541324854612918 99 | weights.fill_(softplus_inverse_1) 100 | -------------------------------------------------------------------------------- /src/boltz/model/layers/outer_product_mean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | import boltz.model.layers.initialize as init 5 | 6 | 7 | class OuterProductMean(nn.Module): 8 | """Outer product mean layer.""" 9 | 10 | def __init__( 11 | self, c_in: int, c_hidden: int, c_out: int, chunk_size: int = None 12 | ) -> None: 13 | """Initialize the outer product mean layer. 14 | 15 | Parameters 16 | ---------- 17 | c_in : int 18 | The input dimension. 19 | c_hidden : int 20 | The hidden dimension. 21 | c_out : int 22 | The output dimension. 23 | chunk_size : int, optional 24 | The inference chunk size, by default None. 25 | 26 | """ 27 | super().__init__() 28 | self.chunk_size = chunk_size 29 | self.c_hidden = c_hidden 30 | self.norm = nn.LayerNorm(c_in) 31 | self.proj_a = nn.Linear(c_in, c_hidden, bias=False) 32 | self.proj_b = nn.Linear(c_in, c_hidden, bias=False) 33 | self.proj_o = nn.Linear(c_hidden * c_hidden, c_out) 34 | init.final_init_(self.proj_o.weight) 35 | init.final_init_(self.proj_o.bias) 36 | 37 | def forward(self, m: Tensor, mask: Tensor) -> Tensor: 38 | """Forward pass. 39 | 40 | Parameters 41 | ---------- 42 | m : torch.Tensor 43 | The sequence tensor (B, S, N, c_in). 44 | mask : torch.Tensor 45 | The mask tensor (B, S, N). 46 | 47 | Returns 48 | ------- 49 | torch.Tensor 50 | The output tensor (B, N, N, c_out). 51 | 52 | """ 53 | # Expand mask 54 | mask = mask.unsqueeze(-1).to(m) 55 | 56 | # Compute projections 57 | m = self.norm(m) 58 | a = self.proj_a(m) * mask 59 | b = self.proj_b(m) * mask 60 | 61 | # Compute pairwise mask 62 | mask = mask[:, :, None, :] * mask[:, :, :, None] 63 | 64 | # Compute outer product mean 65 | if self.chunk_size is not None and not self.training: 66 | # Compute squentially in chunks 67 | for i in range(0, self.c_hidden, self.chunk_size): 68 | a_chunk = a[:, :, :, i : i + self.chunk_size] 69 | sliced_weight_proj_o = self.proj_o.weight[ 70 | :, i * self.c_hidden : (i + self.chunk_size) * self.c_hidden 71 | ] 72 | 73 | z = torch.einsum("bsic,bsjd->bijcd", a_chunk, b) 74 | z = z.reshape(*z.shape[:3], -1) 75 | z = z / mask.sum(dim=1).clamp(min=1) 76 | 77 | # Project to output 78 | if i == 0: 79 | z_out = z.to(m) @ sliced_weight_proj_o.T 80 | else: 81 | z_out = z_out + z.to(m) @ sliced_weight_proj_o.T 82 | return z_out 83 | else: 84 | z = torch.einsum("bsic,bsjd->bijcd", a.float(), b.float()) 85 | z = z.reshape(*z.shape[:3], -1) 86 | z = z / mask.sum(dim=1).clamp(min=1) 87 | 88 | # Project to output 89 | z = self.proj_o(z.to(m)) 90 | return z 91 | -------------------------------------------------------------------------------- /src/boltz/data/filter/static/ligand.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from boltz.data import const 4 | from boltz.data.types import Structure 5 | from boltz.data.filter.static.filter import StaticFilter 6 | 7 | LIGAND_EXCLUSION = { 8 | "144", 9 | "15P", 10 | "1PE", 11 | "2F2", 12 | "2JC", 13 | "3HR", 14 | "3SY", 15 | "7N5", 16 | "7PE", 17 | "9JE", 18 | "AAE", 19 | "ABA", 20 | "ACE", 21 | "ACN", 22 | "ACT", 23 | "ACY", 24 | "AZI", 25 | "BAM", 26 | "BCN", 27 | "BCT", 28 | "BDN", 29 | "BEN", 30 | "BME", 31 | "BO3", 32 | "BTB", 33 | "BTC", 34 | "BU1", 35 | "C8E", 36 | "CAD", 37 | "CAQ", 38 | "CBM", 39 | "CCN", 40 | "CIT", 41 | "CL", 42 | "CLR", 43 | "CM", 44 | "CMO", 45 | "CO3", 46 | "CPT", 47 | "CXS", 48 | "D10", 49 | "DEP", 50 | "DIO", 51 | "DMS", 52 | "DN", 53 | "DOD", 54 | "DOX", 55 | "EDO", 56 | "EEE", 57 | "EGL", 58 | "EOH", 59 | "EOX", 60 | "EPE", 61 | "ETF", 62 | "FCY", 63 | "FJO", 64 | "FLC", 65 | "FMT", 66 | "FW5", 67 | "GOL", 68 | "GSH", 69 | "GTT", 70 | "GYF", 71 | "HED", 72 | "IHP", 73 | "IHS", 74 | "IMD", 75 | "IOD", 76 | "IPA", 77 | "IPH", 78 | "LDA", 79 | "MB3", 80 | "MEG", 81 | "MES", 82 | "MLA", 83 | "MLI", 84 | "MOH", 85 | "MPD", 86 | "MRD", 87 | "MSE", 88 | "MYR", 89 | "N", 90 | "NA", 91 | "NH2", 92 | "NH4", 93 | "NHE", 94 | "NO3", 95 | "O4B", 96 | "OHE", 97 | "OLA", 98 | "OLC", 99 | "OMB", 100 | "OME", 101 | "OXA", 102 | "P6G", 103 | "PE3", 104 | "PE4", 105 | "PEG", 106 | "PEO", 107 | "PEP", 108 | "PG0", 109 | "PG4", 110 | "PGE", 111 | "PGR", 112 | "PLM", 113 | "PO4", 114 | "POL", 115 | "POP", 116 | "PVO", 117 | "SAR", 118 | "SCN", 119 | "SEO", 120 | "SEP", 121 | "SIN", 122 | "SO4", 123 | "SPD", 124 | "SPM", 125 | "SR", 126 | "STE", 127 | "STO", 128 | "STU", 129 | "TAR", 130 | "TBU", 131 | "TME", 132 | "TPO", 133 | "TRS", 134 | "UNK", 135 | "UNL", 136 | "UNX", 137 | "UPL", 138 | "URE", 139 | } 140 | 141 | 142 | class ExcludedLigands(StaticFilter): 143 | """Filter excluded ligands.""" 144 | 145 | def filter(self, structure: Structure) -> np.ndarray: 146 | """Filter excluded ligands. 147 | 148 | Parameters 149 | ---------- 150 | structure : Structure 151 | The structure to filter chains from. 152 | 153 | Returns 154 | ------- 155 | np.ndarray 156 | The chains to keep, as a boolean mask. 157 | 158 | """ 159 | valid = np.ones(len(structure.chains), dtype=bool) 160 | 161 | for i, chain in enumerate(structure.chains): 162 | if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]: 163 | continue 164 | 165 | res_start = chain["res_idx"] 166 | res_end = res_start + chain["res_num"] 167 | residues = structure.residues[res_start:res_end] 168 | if any(res["name"] in LIGAND_EXCLUSION for res in residues): 169 | valid[i] = 0 170 | 171 | return valid 172 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ -------------------------------------------------------------------------------- /src/boltz/data/parse/a3m.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | from pathlib import Path 3 | from typing import Optional, TextIO 4 | 5 | import numpy as np 6 | 7 | from boltz.data import const 8 | from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence 9 | 10 | 11 | def _parse_a3m( # noqa: C901 12 | lines: TextIO, 13 | taxonomy: Optional[dict[str, str]], 14 | max_seqs: Optional[int] = None, 15 | ) -> MSA: 16 | """Process an MSA file. 17 | 18 | Parameters 19 | ---------- 20 | lines : TextIO 21 | The lines of the MSA file. 22 | taxonomy : dict[str, str] 23 | The taxonomy database, if available. 24 | max_seqs : int, optional 25 | The maximum number of sequences. 26 | 27 | Returns 28 | ------- 29 | MSA 30 | The MSA object. 31 | 32 | """ 33 | visited = set() 34 | sequences = [] 35 | deletions = [] 36 | residues = [] 37 | 38 | seq_idx = 0 39 | for line in lines: 40 | line: str 41 | line = line.strip() # noqa: PLW2901 42 | if not line or line.startswith("#"): 43 | continue 44 | 45 | # Get taxonomy, if annotated 46 | if line.startswith(">"): 47 | header = line.split()[0] 48 | if taxonomy and header.startswith(">UniRef100"): 49 | uniref_id = header.split("_")[1] 50 | taxonomy_id = taxonomy.get(uniref_id) 51 | if taxonomy_id is None: 52 | taxonomy_id = -1 53 | else: 54 | taxonomy_id = -1 55 | continue 56 | 57 | # Skip if duplicate sequence 58 | str_seq = line.replace("-", "").upper() 59 | if str_seq not in visited: 60 | visited.add(str_seq) 61 | else: 62 | continue 63 | 64 | # Process sequence 65 | residue = [] 66 | deletion = [] 67 | count = 0 68 | res_idx = 0 69 | for c in line: 70 | if c != "-" and c.islower(): 71 | count += 1 72 | continue 73 | token = const.prot_letter_to_token[c] 74 | token = const.token_ids[token] 75 | residue.append(token) 76 | if count > 0: 77 | deletion.append((res_idx, count)) 78 | count = 0 79 | res_idx += 1 80 | 81 | res_start = len(residues) 82 | res_end = res_start + len(residue) 83 | 84 | del_start = len(deletions) 85 | del_end = del_start + len(deletion) 86 | 87 | sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end)) 88 | residues.extend(residue) 89 | deletions.extend(deletion) 90 | 91 | seq_idx += 1 92 | if (max_seqs is not None) and (seq_idx >= max_seqs): 93 | break 94 | 95 | # Create MSA object 96 | msa = MSA( 97 | residues=np.array(residues, dtype=MSAResidue), 98 | deletions=np.array(deletions, dtype=MSADeletion), 99 | sequences=np.array(sequences, dtype=MSASequence), 100 | ) 101 | return msa 102 | 103 | 104 | def parse_a3m( 105 | path: Path, 106 | taxonomy: Optional[dict[str, str]], 107 | max_seqs: Optional[int] = None, 108 | ) -> MSA: 109 | """Process an A3M file. 110 | 111 | Parameters 112 | ---------- 113 | path : Path 114 | The path to the a3m(.gz) file. 115 | taxonomy : Redis 116 | The taxonomy database. 117 | max_seqs : int, optional 118 | The maximum number of sequences. 119 | 120 | Returns 121 | ------- 122 | MSA 123 | The MSA object. 124 | 125 | """ 126 | # Read the file 127 | if path.suffix == ".gz": 128 | with gzip.open(str(path), "rt") as f: 129 | msa = _parse_a3m(f, taxonomy, max_seqs) 130 | else: 131 | with path.open("r") as f: 132 | msa = _parse_a3m(f, taxonomy, max_seqs) 133 | 134 | return msa 135 | -------------------------------------------------------------------------------- /src/boltz/model/optim/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler): 5 | """Implements the learning rate schedule defined AF3. 6 | 7 | A linear warmup is followed by a plateau at the maximum 8 | learning rate and then exponential decay. Note that the 9 | initial learning rate of the optimizer in question is 10 | ignored; use this class' base_lr parameter to specify 11 | the starting point of the warmup. 12 | 13 | """ 14 | 15 | def __init__( 16 | self, 17 | optimizer: torch.optim.Optimizer, 18 | last_epoch: int = -1, 19 | verbose: bool = False, 20 | base_lr: float = 0.0, 21 | max_lr: float = 1.8e-3, 22 | warmup_no_steps: int = 1000, 23 | start_decay_after_n_steps: int = 50000, 24 | decay_every_n_steps: int = 50000, 25 | decay_factor: float = 0.95, 26 | ) -> None: 27 | """Initialize the learning rate scheduler. 28 | 29 | Parameters 30 | ---------- 31 | optimizer : torch.optim.Optimizer 32 | The optimizer. 33 | last_epoch : int, optional 34 | The last epoch, by default -1 35 | verbose : bool, optional 36 | Whether to print verbose output, by default False 37 | base_lr : float, optional 38 | The base learning rate, by default 0.0 39 | max_lr : float, optional 40 | The maximum learning rate, by default 1.8e-3 41 | warmup_no_steps : int, optional 42 | The number of warmup steps, by default 1000 43 | start_decay_after_n_steps : int, optional 44 | The number of steps after which to start decay, by default 50000 45 | decay_every_n_steps : int, optional 46 | The number of steps after which to decay, by default 50000 47 | decay_factor : float, optional 48 | The decay factor, by default 0.95 49 | 50 | """ 51 | step_counts = { 52 | "warmup_no_steps": warmup_no_steps, 53 | "start_decay_after_n_steps": start_decay_after_n_steps, 54 | } 55 | 56 | for k, v in step_counts.items(): 57 | if v < 0: 58 | msg = f"{k} must be nonnegative" 59 | raise ValueError(msg) 60 | 61 | if warmup_no_steps > start_decay_after_n_steps: 62 | msg = "warmup_no_steps must not exceed start_decay_after_n_steps" 63 | raise ValueError(msg) 64 | 65 | self.optimizer = optimizer 66 | self.last_epoch = last_epoch 67 | self.verbose = verbose 68 | self.base_lr = base_lr 69 | self.max_lr = max_lr 70 | self.warmup_no_steps = warmup_no_steps 71 | self.start_decay_after_n_steps = start_decay_after_n_steps 72 | self.decay_every_n_steps = decay_every_n_steps 73 | self.decay_factor = decay_factor 74 | 75 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) 76 | 77 | def state_dict(self) -> dict: 78 | state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]} 79 | return state_dict 80 | 81 | def load_state_dict(self, state_dict): 82 | self.__dict__.update(state_dict) 83 | 84 | def get_lr(self): 85 | if not self._get_lr_called_within_step: 86 | msg = ( 87 | "To get the last learning rate computed by the scheduler, use " 88 | "get_last_lr()" 89 | ) 90 | raise RuntimeError(msg) 91 | 92 | step_no = self.last_epoch 93 | 94 | if step_no <= self.warmup_no_steps: 95 | lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr 96 | elif step_no > self.start_decay_after_n_steps: 97 | steps_since_decay = step_no - self.start_decay_after_n_steps 98 | exp = (steps_since_decay // self.decay_every_n_steps) + 1 99 | lr = self.max_lr * (self.decay_factor**exp) 100 | else: # plateau 101 | lr = self.max_lr 102 | 103 | return [lr for group in self.optimizer.param_groups] 104 | -------------------------------------------------------------------------------- /src/boltz/data/parse/fasta.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | from pathlib import Path 3 | 4 | from Bio import SeqIO 5 | from rdkit.Chem.rdchem import Mol 6 | 7 | from boltz.data.parse.yaml import parse_boltz_schema 8 | from boltz.data.types import Target 9 | 10 | 11 | def parse_fasta(path: Path, ccd: Mapping[str, Mol]) -> Target: # noqa: C901 12 | """Parse a fasta file. 13 | 14 | The name of the fasta file is used as the name of this job. 15 | We rely on the fasta record id to determine the entity type. 16 | 17 | > CHAIN_ID|ENTITY_TYPE|MSA_ID 18 | SEQUENCE 19 | > CHAIN_ID|ENTITY_TYPE|MSA_ID 20 | ... 21 | 22 | Where ENTITY_TYPE is either protein, rna, dna, ccd or smiles, 23 | and CHAIN_ID is the chain identifier, which should be unique. 24 | The MSA_ID is optional and should only be used on proteins. 25 | 26 | Parameters 27 | ---------- 28 | fasta_file : Path 29 | Path to the fasta file. 30 | ccd : Dict 31 | Dictionary of CCD components. 32 | 33 | Returns 34 | ------- 35 | Target 36 | The parsed target. 37 | 38 | """ 39 | # Read fasta file 40 | with path.open("r") as f: 41 | records = list(SeqIO.parse(f, "fasta")) 42 | 43 | # Make sure all records have a chain id and entity 44 | for seq_record in records: 45 | if "|" not in seq_record.id: 46 | msg = f"Invalid record id: {seq_record.id}" 47 | raise ValueError(msg) 48 | 49 | header = seq_record.id.split("|") 50 | assert len(header) >= 2, f"Invalid record id: {seq_record.id}" 51 | 52 | chain_id, entity_type = header[:2] 53 | if entity_type.lower() not in {"protein", "dna", "rna", "ccd", "smiles"}: 54 | msg = f"Invalid entity type: {entity_type}" 55 | raise ValueError(msg) 56 | if chain_id == "": 57 | msg = "Empty chain id in input fasta!" 58 | raise ValueError(msg) 59 | if entity_type == "": 60 | msg = "Empty entity type in input fasta!" 61 | raise ValueError(msg) 62 | 63 | # Convert to yaml format 64 | sequences = [] 65 | for seq_record in records: 66 | # Get chain id, entity type and sequence 67 | header = seq_record.id.split("|") 68 | chain_id, entity_type = header[:2] 69 | if len(header) == 3 and header[2] != "": 70 | assert ( 71 | entity_type.lower() == "protein" 72 | ), "MSA_ID is only allowed for proteins" 73 | msa_id = header[2] 74 | 75 | entity_type = entity_type.upper() 76 | seq = str(seq_record.seq) 77 | 78 | if entity_type == "PROTEIN": 79 | molecule = { 80 | "protein": { 81 | "id": chain_id, 82 | "sequence": seq, 83 | "modifications": [], 84 | "msa": msa_id, 85 | }, 86 | } 87 | elif entity_type == "RNA": 88 | molecule = { 89 | "rna": { 90 | "id": chain_id, 91 | "sequence": seq, 92 | "modifications": [], 93 | }, 94 | } 95 | elif entity_type == "DNA": 96 | molecule = { 97 | "dna": { 98 | "id": chain_id, 99 | "sequence": seq, 100 | "modifications": [], 101 | } 102 | } 103 | elif entity_type.upper() == "CCD": 104 | molecule = { 105 | "ligand": { 106 | "id": chain_id, 107 | "ccd": seq, 108 | } 109 | } 110 | elif entity_type.upper() == "SMILES": 111 | molecule = { 112 | "ligand": { 113 | "id": chain_id, 114 | "smiles": seq, 115 | } 116 | } 117 | 118 | sequences.append(molecule) 119 | 120 | data = { 121 | "sequences": sequences, 122 | "bonds": [], 123 | "version": 1, 124 | } 125 | 126 | name = path.stem 127 | return parse_boltz_schema(name, data, ccd) 128 | -------------------------------------------------------------------------------- /src/boltz/model/layers/attention.py: -------------------------------------------------------------------------------- 1 | from einops.layers.torch import Rearrange 2 | import torch 3 | from torch import Tensor, nn 4 | 5 | import boltz.model.layers.initialize as init 6 | 7 | 8 | class AttentionPairBias(nn.Module): 9 | """Attention pair bias layer.""" 10 | 11 | def __init__( 12 | self, 13 | c_s: int, 14 | c_z: int, 15 | num_heads: int, 16 | inf: float = 1e6, 17 | initial_norm: bool = True, 18 | ) -> None: 19 | """Initialize the attention pair bias layer. 20 | 21 | Parameters 22 | ---------- 23 | c_s : int 24 | The input sequence dimension. 25 | c_z : int 26 | The input pairwise dimension. 27 | num_heads : int 28 | The number of heads. 29 | inf : float, optional 30 | The inf value, by default 1e6 31 | initial_norm: bool, optional 32 | Whether to apply layer norm to the input, by default True 33 | 34 | """ 35 | super().__init__() 36 | 37 | assert c_s % num_heads == 0 38 | 39 | self.c_s = c_s 40 | self.num_heads = num_heads 41 | self.head_dim = c_s // num_heads 42 | self.inf = inf 43 | 44 | self.initial_norm = initial_norm 45 | if self.initial_norm: 46 | self.norm_s = nn.LayerNorm(c_s) 47 | 48 | self.proj_q = nn.Linear(c_s, c_s) 49 | self.proj_k = nn.Linear(c_s, c_s, bias=False) 50 | self.proj_v = nn.Linear(c_s, c_s, bias=False) 51 | self.proj_g = nn.Linear(c_s, c_s, bias=False) 52 | 53 | self.proj_z = nn.Sequential( 54 | nn.LayerNorm(c_z), 55 | nn.Linear(c_z, num_heads, bias=False), 56 | Rearrange("b ... h -> b h ..."), 57 | ) 58 | 59 | self.proj_o = nn.Linear(c_s, c_s, bias=False) 60 | init.final_init_(self.proj_o.weight) 61 | 62 | def forward( 63 | self, 64 | s: Tensor, 65 | z: Tensor, 66 | mask: Tensor, 67 | multiplicity: int = 1, 68 | to_keys=None, 69 | model_cache=None, 70 | ) -> Tensor: 71 | """Forward pass. 72 | 73 | Parameters 74 | ---------- 75 | s : torch.Tensor 76 | The input sequence tensor (B, S, D) 77 | z : torch.Tensor 78 | The input pairwise tensor (B, N, N, D) 79 | mask : torch.Tensor 80 | The pairwise mask tensor (B, N, N) 81 | multiplicity : int, optional 82 | The diffusion batch size, by default 1 83 | 84 | Returns 85 | ------- 86 | torch.Tensor 87 | The output sequence tensor. 88 | 89 | """ 90 | B = s.shape[0] 91 | 92 | # Layer norms 93 | if self.initial_norm: 94 | s = self.norm_s(s) 95 | 96 | if to_keys is not None: 97 | k_in = to_keys(s) 98 | mask = to_keys(mask.unsqueeze(-1)).squeeze(-1) 99 | else: 100 | k_in = s 101 | 102 | # Compute projections 103 | q = self.proj_q(s).view(B, -1, self.num_heads, self.head_dim) 104 | k = self.proj_k(k_in).view(B, -1, self.num_heads, self.head_dim) 105 | v = self.proj_v(k_in).view(B, -1, self.num_heads, self.head_dim) 106 | 107 | # Caching z projection during diffusion roll-out 108 | if model_cache is None or "z" not in model_cache: 109 | z = self.proj_z(z) 110 | 111 | if model_cache is not None: 112 | model_cache["z"] = z 113 | else: 114 | z = model_cache["z"] 115 | z = z.repeat_interleave(multiplicity, 0) 116 | 117 | g = self.proj_g(s).sigmoid() 118 | 119 | with torch.autocast("cuda", enabled=False): 120 | # Compute attention weights 121 | attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float()) 122 | attn = attn / (self.head_dim**0.5) + z.float() 123 | attn = attn + (1 - mask[:, None, None].float()) * -self.inf 124 | attn = attn.softmax(dim=-1) 125 | 126 | # Compute output 127 | o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype) 128 | o = o.reshape(B, -1, self.c_s) 129 | o = self.proj_o(g * o) 130 | 131 | return o 132 | -------------------------------------------------------------------------------- /src/boltz/data/write/pdb.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | 3 | from boltz.data import const 4 | from boltz.data.types import Structure 5 | from boltz.data.write.utils import generate_tags 6 | 7 | 8 | def to_pdb(structure: Structure) -> str: # noqa: PLR0915 9 | """Write a structure into a PDB file. 10 | 11 | Parameters 12 | ---------- 13 | structure : Structure 14 | The input structure 15 | 16 | Returns 17 | ------- 18 | str 19 | the output PDB file 20 | 21 | """ 22 | pdb_lines = [] 23 | 24 | atom_index = 1 25 | atom_reindex_ter = [] 26 | chain_tags = generate_tags() 27 | 28 | # Load periodic table for element mapping 29 | periodic_table = Chem.GetPeriodicTable() 30 | 31 | # Add all atom sites. 32 | for chain in structure.chains: 33 | # We rename the chains in alphabetical order 34 | chain_idx = chain["asym_id"] 35 | chain_tag = next(chain_tags) 36 | 37 | res_start = chain["res_idx"] 38 | res_end = chain["res_idx"] + chain["res_num"] 39 | 40 | residues = structure.residues[res_start:res_end] 41 | for residue in residues: 42 | atom_start = residue["atom_idx"] 43 | atom_end = residue["atom_idx"] + residue["atom_num"] 44 | atoms = structure.atoms[atom_start:atom_end] 45 | atom_coords = atoms["coords"] 46 | for i, atom in enumerate(atoms): 47 | # This should not happen on predictions, but just in case. 48 | if not atom["is_present"]: 49 | continue 50 | 51 | record_type = ( 52 | "ATOM" 53 | if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"] 54 | else "HETATM" 55 | ) 56 | name = atom["name"] 57 | name = [chr(c + 32) for c in name if c != 0] 58 | name = "".join(name) 59 | name = name if len(name) == 4 else f" {name}" # noqa: PLR2004 60 | alt_loc = "" 61 | insertion_code = "" 62 | occupancy = 1.00 63 | element = periodic_table.GetElementSymbol(atom["element"].item()) 64 | element = element.upper() 65 | charge = "" 66 | residue_index = residue["res_idx"] + 1 67 | pos = atom_coords[i] 68 | res_name_3 = ( 69 | "LIG" if record_type == "HETATM" else str(residue["name"][:3]) 70 | ) 71 | b_factor = 1.00 72 | 73 | # PDB is a columnar format, every space matters here! 74 | atom_line = ( 75 | f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" 76 | f"{res_name_3:>3} {chain_tag:>1}" 77 | f"{residue_index:>4}{insertion_code:>1} " 78 | f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" 79 | f"{occupancy:>6.2f}{b_factor:>6.2f} " 80 | f"{element:>2}{charge:>2}" 81 | ) 82 | pdb_lines.append(atom_line) 83 | atom_reindex_ter.append(atom_index) 84 | atom_index += 1 85 | 86 | should_terminate = chain_idx < (len(structure.chains) - 1) 87 | if should_terminate: 88 | # Close the chain. 89 | chain_end = "TER" 90 | chain_termination_line = ( 91 | f"{chain_end:<6}{atom_index:>5} " 92 | f"{res_name_3:>3} " 93 | f"{chain_tag:>1}{residue_index:>4}" 94 | ) 95 | pdb_lines.append(chain_termination_line) 96 | atom_index += 1 97 | 98 | # Dump CONECT records. 99 | for bonds in [structure.bonds, structure.connections]: 100 | for bond in bonds: 101 | atom1 = structure.atoms[bond["atom_1"]] 102 | atom2 = structure.atoms[bond["atom_2"]] 103 | if not atom1["is_present"] or not atom2["is_present"]: 104 | continue 105 | atom1_idx = atom_reindex_ter[bond["atom_1"]] 106 | atom2_idx = atom_reindex_ter[bond["atom_2"]] 107 | conect_line = f"CONECT{atom1_idx:>5}{atom2_idx:>5}" 108 | pdb_lines.append(conect_line) 109 | 110 | pdb_lines.append("END") 111 | pdb_lines.append("") 112 | pdb_lines = [line.ljust(80) for line in pdb_lines] 113 | return "\n".join(pdb_lines) 114 | -------------------------------------------------------------------------------- /src/boltz/model/layers/triangular_mult.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from boltz.model.layers import initialize as init 5 | 6 | 7 | class TriangleMultiplicationOutgoing(nn.Module): 8 | """TriangleMultiplicationOutgoing.""" 9 | 10 | def __init__(self, dim: int = 128) -> None: 11 | """Initialize the TriangularUpdate module. 12 | 13 | Parameters 14 | ---------- 15 | dim: int 16 | The dimension of the input, default 128 17 | 18 | """ 19 | super().__init__() 20 | 21 | self.norm_in = nn.LayerNorm(dim, eps=1e-5) 22 | self.p_in = nn.Linear(dim, 2 * dim, bias=False) 23 | self.g_in = nn.Linear(dim, 2 * dim, bias=False) 24 | 25 | self.norm_out = nn.LayerNorm(dim) 26 | self.p_out = nn.Linear(dim, dim, bias=False) 27 | self.g_out = nn.Linear(dim, dim, bias=False) 28 | 29 | init.bias_init_one_(self.norm_in.weight) 30 | init.bias_init_zero_(self.norm_in.bias) 31 | 32 | init.lecun_normal_init_(self.p_in.weight) 33 | init.gating_init_(self.g_in.weight) 34 | 35 | init.bias_init_one_(self.norm_out.weight) 36 | init.bias_init_zero_(self.norm_out.bias) 37 | 38 | init.final_init_(self.p_out.weight) 39 | init.gating_init_(self.g_out.weight) 40 | 41 | def forward(self, x: Tensor, mask: Tensor) -> Tensor: 42 | """Perform a forward pass. 43 | 44 | Parameters 45 | ---------- 46 | x: torch.Tensor 47 | The input data of shape (B, N, N, D) 48 | mask: torch.Tensor 49 | The input mask of shape (B, N, N) 50 | 51 | Returns 52 | ------- 53 | x: torch.Tensor 54 | The output data of shape (B, N, N, D) 55 | 56 | """ 57 | # Input gating: D -> D 58 | x = self.norm_in(x) 59 | x_in = x 60 | x = self.p_in(x) * self.g_in(x).sigmoid() 61 | 62 | # Apply mask 63 | x = x * mask.unsqueeze(-1) 64 | 65 | # Split input and cast to float 66 | a, b = torch.chunk(x.float(), 2, dim=-1) 67 | 68 | # Triangular projection 69 | x = torch.einsum("bikd,bjkd->bijd", a, b) 70 | 71 | # Output gating 72 | x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid() 73 | 74 | return x 75 | 76 | 77 | class TriangleMultiplicationIncoming(nn.Module): 78 | """TriangleMultiplicationIncoming.""" 79 | 80 | def __init__(self, dim: int = 128) -> None: 81 | """Initialize the TriangularUpdate module. 82 | 83 | Parameters 84 | ---------- 85 | dim: int 86 | The dimension of the input, default 128 87 | 88 | """ 89 | super().__init__() 90 | 91 | self.norm_in = nn.LayerNorm(dim, eps=1e-5) 92 | self.p_in = nn.Linear(dim, 2 * dim, bias=False) 93 | self.g_in = nn.Linear(dim, 2 * dim, bias=False) 94 | 95 | self.norm_out = nn.LayerNorm(dim) 96 | self.p_out = nn.Linear(dim, dim, bias=False) 97 | self.g_out = nn.Linear(dim, dim, bias=False) 98 | 99 | init.bias_init_one_(self.norm_in.weight) 100 | init.bias_init_zero_(self.norm_in.bias) 101 | 102 | init.lecun_normal_init_(self.p_in.weight) 103 | init.gating_init_(self.g_in.weight) 104 | 105 | init.bias_init_one_(self.norm_out.weight) 106 | init.bias_init_zero_(self.norm_out.bias) 107 | 108 | init.final_init_(self.p_out.weight) 109 | init.gating_init_(self.g_out.weight) 110 | 111 | def forward(self, x: Tensor, mask: Tensor) -> Tensor: 112 | """Perform a forward pass. 113 | 114 | Parameters 115 | ---------- 116 | x: torch.Tensor 117 | The input data of shape (B, N, N, D) 118 | mask: torch.Tensor 119 | The input mask of shape (B, N, N) 120 | 121 | Returns 122 | ------- 123 | x: torch.Tensor 124 | The output data of shape (B, N, N, D) 125 | 126 | """ 127 | # Input gating: D -> D 128 | x = self.norm_in(x) 129 | x_in = x 130 | x = self.p_in(x) * self.g_in(x).sigmoid() 131 | 132 | # Apply mask 133 | x = x * mask.unsqueeze(-1) 134 | 135 | # Split input and cast to float 136 | a, b = torch.chunk(x.float(), 2, dim=-1) 137 | 138 | # Triangular projection 139 | x = torch.einsum("bkid,bkjd->bijd", a, b) 140 | 141 | # Output gating 142 | x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid() 143 | 144 | return x 145 | -------------------------------------------------------------------------------- /src/boltz/model/layers/pair_averaging.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | import boltz.model.layers.initialize as init 5 | 6 | 7 | class PairWeightedAveraging(nn.Module): 8 | """Pair weighted averaging layer.""" 9 | 10 | def __init__( 11 | self, 12 | c_m: int, 13 | c_z: int, 14 | c_h: int, 15 | num_heads: int, 16 | inf: float = 1e6, 17 | chunk_heads: bool = False, 18 | ) -> None: 19 | """Initialize the pair weighted averaging layer. 20 | 21 | Parameters 22 | ---------- 23 | c_m: int 24 | The dimension of the input sequence. 25 | c_z: int 26 | The dimension of the input pairwise tensor. 27 | c_h: int 28 | The dimension of the hidden. 29 | num_heads: int 30 | The number of heads. 31 | inf: float 32 | The value to use for masking, default 1e6. 33 | chunk_heads: bool 34 | Whether to sequentially compute heads at inference, default False. 35 | 36 | """ 37 | super().__init__() 38 | self.c_m = c_m 39 | self.c_z = c_z 40 | self.c_h = c_h 41 | self.num_heads = num_heads 42 | self.inf = inf 43 | self.chunk_heads = chunk_heads 44 | 45 | self.norm_m = nn.LayerNorm(c_m) 46 | self.norm_z = nn.LayerNorm(c_z) 47 | 48 | self.proj_m = nn.Linear(c_m, c_h * num_heads, bias=False) 49 | self.proj_g = nn.Linear(c_m, c_h * num_heads, bias=False) 50 | self.proj_z = nn.Linear(c_z, num_heads, bias=False) 51 | self.proj_o = nn.Linear(c_h * num_heads, c_m, bias=False) 52 | init.final_init_(self.proj_o.weight) 53 | 54 | def forward(self, m: Tensor, z: Tensor, mask: Tensor) -> Tensor: 55 | """Forward pass. 56 | 57 | Parameters 58 | ---------- 59 | m : torch.Tensor 60 | The input sequence tensor (B, S, N, D) 61 | z : torch.Tensor 62 | The input pairwise tensor (B, N, N, D) 63 | mask : torch.Tensor 64 | The pairwise mask tensor (B, N, N) 65 | 66 | Returns 67 | ------- 68 | torch.Tensor 69 | The output sequence tensor (B, S, N, D) 70 | 71 | """ 72 | # Compute layer norms 73 | m = self.norm_m(m) 74 | z = self.norm_z(z) 75 | 76 | if self.chunk_heads and not self.training: 77 | # Compute heads sequentially 78 | o_chunks = [] 79 | for head_idx in range(self.num_heads): 80 | sliced_weight_proj_m = self.proj_m.weight[ 81 | head_idx * self.c_h : (head_idx + 1) * self.c_h, : 82 | ] 83 | sliced_weight_proj_g = self.proj_g.weight[ 84 | head_idx * self.c_h : (head_idx + 1) * self.c_h, : 85 | ] 86 | sliced_weight_proj_z = self.proj_z.weight[head_idx : (head_idx + 1), :] 87 | sliced_weight_proj_o = self.proj_o.weight[ 88 | :, head_idx * self.c_h : (head_idx + 1) * self.c_h 89 | ] 90 | 91 | # Project input tensors 92 | v: Tensor = m @ sliced_weight_proj_m.T 93 | v = v.reshape(*v.shape[:3], 1, self.c_h) 94 | v = v.permute(0, 3, 1, 2, 4) 95 | 96 | # Compute weights 97 | b: Tensor = z @ sliced_weight_proj_z.T 98 | b = b.permute(0, 3, 1, 2) 99 | b = b + (1 - mask[:, None]) * -self.inf 100 | w = torch.softmax(b, dim=-1) 101 | 102 | # Compute gating 103 | g: Tensor = m @ sliced_weight_proj_g.T 104 | g = g.sigmoid() 105 | 106 | # Compute output 107 | o = torch.einsum("bhij,bhsjd->bhsid", w, v) 108 | o = o.permute(0, 2, 3, 1, 4) 109 | o = o.reshape(*o.shape[:3], 1 * self.c_h) 110 | o_chunks = g * o 111 | if head_idx == 0: 112 | o_out = o_chunks @ sliced_weight_proj_o.T 113 | else: 114 | o_out += o_chunks @ sliced_weight_proj_o.T 115 | return o_out 116 | else: 117 | # Project input tensors 118 | v: Tensor = self.proj_m(m) 119 | v = v.reshape(*v.shape[:3], self.num_heads, self.c_h) 120 | v = v.permute(0, 3, 1, 2, 4) 121 | 122 | # Compute weights 123 | b: Tensor = self.proj_z(z) 124 | b = b.permute(0, 3, 1, 2) 125 | b = b + (1 - mask[:, None]) * -self.inf 126 | w = torch.softmax(b, dim=-1) 127 | 128 | # Compute gating 129 | g: Tensor = self.proj_g(m) 130 | g = g.sigmoid() 131 | 132 | # Compute output 133 | o = torch.einsum("bhij,bhsjd->bhsid", w, v) 134 | o = o.permute(0, 2, 3, 1, 4) 135 | o = o.reshape(*o.shape[:3], self.num_heads * self.c_h) 136 | o = self.proj_o(g * o) 137 | return o 138 | -------------------------------------------------------------------------------- /scripts/train/configs/structure.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | accelerator: gpu 3 | devices: 1 4 | precision: 32 5 | gradient_clip_val: 10.0 6 | max_epochs: -1 7 | 8 | # Optional set wandb here 9 | # wandb: 10 | # name: boltz 11 | # project: boltz 12 | # entity: boltz 13 | 14 | output: SET_PATH_HERE 15 | resume: PATH_TO_CHECKPOINT_FILE 16 | disable_checkpoint: false 17 | matmul_precision: null 18 | save_top_k: -1 19 | 20 | data: 21 | datasets: 22 | - _target_: boltz.data.module.training.DatasetConfig 23 | target_dir: PATH_TO_TARGETS_DIR 24 | msa_dir: PATH_TO_MSA_DIR 25 | prob: 1.0 26 | sampler: 27 | _target_: boltz.data.sample.cluster.ClusterSampler 28 | cropper: 29 | _target_: boltz.data.crop.boltz.BoltzCropper 30 | min_neighborhood: 0 31 | max_neighborhood: 40 32 | split: ./scripts/train/assets/validation_ids.txt 33 | 34 | filters: 35 | - _target_: boltz.data.filter.dynamic.size.SizeFilter 36 | min_chains: 1 37 | max_chains: 300 38 | - _target_: boltz.data.filter.dynamic.date.DateFilter 39 | date: "2021-09-30" 40 | ref: released 41 | - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter 42 | resolution: 9.0 43 | 44 | tokenizer: 45 | _target_: boltz.data.tokenize.boltz.BoltzTokenizer 46 | featurizer: 47 | _target_: boltz.data.feature.featurizer.BoltzFeaturizer 48 | 49 | symmetries: PATH_TO_SYMMETRY_FILE 50 | max_tokens: 512 51 | max_atoms: 4608 52 | max_seqs: 2048 53 | pad_to_max_tokens: true 54 | pad_to_max_atoms: true 55 | pad_to_max_seqs: true 56 | samples_per_epoch: 100000 57 | batch_size: 1 58 | num_workers: 4 59 | random_seed: 42 60 | pin_memory: true 61 | overfit: null 62 | crop_validation: false 63 | return_train_symmetries: false 64 | return_val_symmetries: true 65 | train_binder_pocket_conditioned_prop: 0.3 66 | val_binder_pocket_conditioned_prop: 0.3 67 | binder_pocket_cutoff: 6.0 68 | binder_pocket_sampling_geometric_p: 0.3 69 | min_dist: 2.0 70 | max_dist: 22.0 71 | num_bins: 64 72 | atoms_per_window_queries: 32 73 | 74 | model: 75 | _target_: boltz.model.model.Boltz1 76 | atom_s: 128 77 | atom_z: 16 78 | token_s: 384 79 | token_z: 128 80 | num_bins: 64 81 | atom_feature_dim: 389 82 | atoms_per_window_queries: 32 83 | atoms_per_window_keys: 128 84 | compile_pairformer: false 85 | nucleotide_rmsd_weight: 5.0 86 | ligand_rmsd_weight: 10.0 87 | ema: true 88 | ema_decay: 0.999 89 | 90 | embedder_args: 91 | atom_encoder_depth: 3 92 | atom_encoder_heads: 4 93 | 94 | msa_args: 95 | msa_s: 64 96 | msa_blocks: 4 97 | msa_dropout: 0.15 98 | z_dropout: 0.25 99 | pairwise_head_width: 32 100 | pairwise_num_heads: 4 101 | activation_checkpointing: true 102 | offload_to_cpu: false 103 | 104 | pairformer_args: 105 | num_blocks: 48 106 | num_heads: 16 107 | dropout: 0.25 108 | activation_checkpointing: true 109 | offload_to_cpu: false 110 | 111 | score_model_args: 112 | sigma_data: 16 113 | dim_fourier: 256 114 | atom_encoder_depth: 3 115 | atom_encoder_heads: 4 116 | token_transformer_depth: 24 117 | token_transformer_heads: 16 118 | atom_decoder_depth: 3 119 | atom_decoder_heads: 4 120 | conditioning_transition_layers: 2 121 | activation_checkpointing: true 122 | offload_to_cpu: false 123 | 124 | confidence_prediction: false 125 | confidence_model_args: 126 | use_gaussian: false 127 | num_dist_bins: 64 128 | max_dist: 22 129 | add_s_to_z_prod: true 130 | add_s_input_to_s: true 131 | use_s_diffusion: true 132 | add_z_input_to_z: true 133 | 134 | confidence_args: 135 | num_plddt_bins: 50 136 | num_pde_bins: 64 137 | num_pae_bins: 64 138 | relative_confidence: none 139 | 140 | training_args: 141 | recycling_steps: 3 142 | sampling_steps: 20 143 | diffusion_multiplicity: 16 144 | diffusion_samples: 2 145 | confidence_loss_weight: 1e-4 146 | diffusion_loss_weight: 4.0 147 | distogram_loss_weight: 3e-2 148 | adam_beta_1: 0.9 149 | adam_beta_2: 0.95 150 | adam_eps: 0.00000001 151 | lr_scheduler: af3 152 | base_lr: 0.0 153 | max_lr: 0.0018 154 | lr_warmup_no_steps: 1000 155 | lr_start_decay_after_n_steps: 50000 156 | lr_decay_every_n_steps: 50000 157 | lr_decay_factor: 0.95 158 | 159 | validation_args: 160 | recycling_steps: 3 161 | sampling_steps: 200 162 | diffusion_samples: 5 163 | symmetry_correction: true 164 | 165 | diffusion_process_args: 166 | sigma_min: 0.0004 167 | sigma_max: 160.0 168 | sigma_data: 16.0 169 | rho: 7 170 | P_mean: -1.2 171 | P_std: 1.5 172 | gamma_0: 0.8 173 | gamma_min: 1.0 174 | noise_scale: 1.0 175 | step_scale: 1.0 176 | coordinate_augmentation: true 177 | alignment_reverse_diff: true 178 | synchronize_sigmas: true 179 | use_inference_model_cache: true 180 | 181 | diffusion_loss_args: 182 | add_smooth_lddt_loss: true 183 | nucleotide_loss_weight: 5.0 184 | ligand_loss_weight: 10.0 185 | -------------------------------------------------------------------------------- /src/boltz/model/modules/confidence_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from boltz.data import const 5 | from boltz.model.loss.confidence import compute_frame_pred 6 | 7 | 8 | def compute_aggregated_metric(logits, end=1.0): 9 | """Compute the metric from the logits. 10 | 11 | Parameters 12 | ---------- 13 | logits : torch.Tensor 14 | The logits of the metric 15 | end : float 16 | Max value of the metric, by default 1.0 17 | 18 | Returns 19 | ------- 20 | Tensor 21 | The metric value 22 | 23 | """ 24 | num_bins = logits.shape[-1] 25 | bin_width = end / num_bins 26 | bounds = torch.arange( 27 | start=0.5 * bin_width, end=end, step=bin_width, device=logits.device 28 | ) 29 | probs = nn.functional.softmax(logits, dim=-1) 30 | plddt = torch.sum( 31 | probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape), 32 | dim=-1, 33 | ) 34 | return plddt 35 | 36 | 37 | def tm_function(d, Nres): 38 | """Compute the rescaling function for pTM. 39 | 40 | Parameters 41 | ---------- 42 | d : torch.Tensor 43 | The input 44 | Nres : torch.Tensor 45 | The number of residues 46 | 47 | Returns 48 | ------- 49 | Tensor 50 | Output of the function 51 | 52 | """ 53 | d0 = 1.24 * (torch.clip(Nres, min=19) - 15) ** (1 / 3) - 1.8 54 | return 1 / (1 + (d / d0) ** 2) 55 | 56 | 57 | def compute_ptms(logits, x_preds, feats, multiplicity): 58 | """Compute pTM and ipTM scores. 59 | 60 | Parameters 61 | ---------- 62 | logits : torch.Tensor 63 | pae logits 64 | x_preds : torch.Tensor 65 | The predicted coordinates 66 | feats : Dict[str, torch.Tensor] 67 | The input features 68 | multiplicity : int 69 | The batch size of the diffusion roll-out 70 | 71 | Returns 72 | ------- 73 | Tensor 74 | pTM score 75 | Tensor 76 | ipTM score 77 | Tensor 78 | ligand ipTM score 79 | Tensor 80 | protein ipTM score 81 | 82 | """ 83 | # Compute mask for collinear and overlapping tokens 84 | _, mask_collinear_pred = compute_frame_pred( 85 | x_preds, feats["frames_idx"], feats, multiplicity, inference=True 86 | ) 87 | mask_pad = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) 88 | maski = mask_collinear_pred.reshape(-1, mask_collinear_pred.shape[-1]) 89 | pair_mask_ptm = maski[:, :, None] * mask_pad[:, None, :] * mask_pad[:, :, None] 90 | asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0) 91 | pair_mask_iptm = ( 92 | maski[:, :, None] 93 | * (asym_id[:, None, :] != asym_id[:, :, None]) 94 | * mask_pad[:, None, :] 95 | * mask_pad[:, :, None] 96 | ) 97 | 98 | # Extract pae values 99 | num_bins = logits.shape[-1] 100 | bin_width = 32.0 / num_bins 101 | end = 32.0 102 | pae_value = torch.arange( 103 | start=0.5 * bin_width, end=end, step=bin_width, device=logits.device 104 | ).unsqueeze(0) 105 | N_res = mask_pad.sum(dim=-1, keepdim=True) 106 | 107 | # compute pTM and ipTM 108 | tm_value = tm_function(pae_value, N_res).unsqueeze(1).unsqueeze(2) 109 | probs = nn.functional.softmax(logits, dim=-1) 110 | tm_expected_value = torch.sum( 111 | probs * tm_value, 112 | dim=-1, 113 | ) # shape (B, N, N) 114 | ptm = torch.max( 115 | torch.sum(tm_expected_value * pair_mask_ptm, dim=-1) 116 | / (torch.sum(pair_mask_ptm, dim=-1) + 1e-5), 117 | dim=1, 118 | ).values 119 | iptm = torch.max( 120 | torch.sum(tm_expected_value * pair_mask_iptm, dim=-1) 121 | / (torch.sum(pair_mask_iptm, dim=-1) + 1e-5), 122 | dim=1, 123 | ).values 124 | 125 | # compute ligand and protein ipTM 126 | token_type = feats["mol_type"] 127 | token_type = token_type.repeat_interleave(multiplicity, 0) 128 | is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float() 129 | is_protein_token = (token_type == const.chain_type_ids["PROTEIN"]).float() 130 | 131 | ligand_iptm_mask = ( 132 | maski[:, :, None] 133 | * (asym_id[:, None, :] != asym_id[:, :, None]) 134 | * mask_pad[:, None, :] 135 | * mask_pad[:, :, None] 136 | * ( 137 | (is_ligand_token[:, :, None] * is_protein_token[:, None, :]) 138 | + (is_protein_token[:, :, None] * is_ligand_token[:, None, :]) 139 | ) 140 | ) 141 | protein_ipmt_mask = ( 142 | maski[:, :, None] 143 | * (asym_id[:, None, :] != asym_id[:, :, None]) 144 | * mask_pad[:, None, :] 145 | * mask_pad[:, :, None] 146 | * (is_protein_token[:, :, None] * is_protein_token[:, None, :]) 147 | ) 148 | 149 | ligand_iptm = torch.max( 150 | torch.sum(tm_expected_value * ligand_iptm_mask, dim=-1) 151 | / (torch.sum(ligand_iptm_mask, dim=-1) + 1e-5), 152 | dim=1, 153 | ).values 154 | protein_iptm = torch.max( 155 | torch.sum(tm_expected_value * protein_ipmt_mask, dim=-1) 156 | / (torch.sum(protein_ipmt_mask, dim=-1) + 1e-5), 157 | dim=1, 158 | ).values 159 | 160 | return ptm, iptm, ligand_iptm, protein_iptm 161 | -------------------------------------------------------------------------------- /scripts/train/configs/confidence.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | accelerator: gpu 3 | devices: 1 4 | precision: 32 5 | gradient_clip_val: 10.0 6 | max_epochs: -1 7 | 8 | # Optional set wandb here 9 | # wandb: 10 | # name: boltz 11 | # project: boltz 12 | # entity: boltz 13 | 14 | 15 | output: SET_PATH_HERE 16 | pretrained: PATH_TO_STRUCTURE_CHECKPOINT_FILE 17 | resume: null 18 | disable_checkpoint: false 19 | matmul_precision: null 20 | save_top_k: -1 21 | load_confidence_from_trunk: true 22 | 23 | data: 24 | datasets: 25 | - _target_: boltz.data.module.training.DatasetConfig 26 | target_dir: PATH_TO_TARGETS_DIR 27 | msa_dir: PATH_TO_MSA_DIR 28 | prob: 1.0 29 | sampler: 30 | _target_: boltz.data.sample.cluster.ClusterSampler 31 | cropper: 32 | _target_: boltz.data.crop.boltz.BoltzCropper 33 | min_neighborhood: 0 34 | max_neighborhood: 40 35 | split: ./scripts/train/assets/validation_ids.txt 36 | 37 | filters: 38 | - _target_: boltz.data.filter.dynamic.size.SizeFilter 39 | min_chains: 1 40 | max_chains: 300 41 | - _target_: boltz.data.filter.dynamic.date.DateFilter 42 | date: "2021-09-30" 43 | ref: released 44 | - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter 45 | resolution: 4.0 46 | 47 | tokenizer: 48 | _target_: boltz.data.tokenize.boltz.BoltzTokenizer 49 | featurizer: 50 | _target_: boltz.data.feature.featurizer.BoltzFeaturizer 51 | 52 | symmetries: PATH_TO_SYMMETRY_FILE 53 | max_tokens: 512 54 | max_atoms: 4608 55 | max_seqs: 2048 56 | pad_to_max_tokens: true 57 | pad_to_max_atoms: true 58 | pad_to_max_seqs: true 59 | samples_per_epoch: 100000 60 | batch_size: 1 61 | num_workers: 4 62 | random_seed: 42 63 | pin_memory: true 64 | overfit: null 65 | crop_validation: true 66 | return_train_symmetries: true 67 | return_val_symmetries: true 68 | train_binder_pocket_conditioned_prop: 0.3 69 | val_binder_pocket_conditioned_prop: 0.3 70 | binder_pocket_cutoff: 6.0 71 | binder_pocket_sampling_geometric_p: 0.3 72 | min_dist: 2.0 73 | max_dist: 22.0 74 | num_bins: 64 75 | atoms_per_window_queries: 32 76 | 77 | model: 78 | _target_: boltz.model.models.boltz.BoltzPreview 79 | atom_s: 128 80 | atom_z: 16 81 | token_s: 384 82 | token_z: 128 83 | num_bins: 64 84 | atom_feature_dim: 389 85 | atoms_per_window_queries: 32 86 | atoms_per_window_keys: 128 87 | compile_pairformer: false 88 | nucleotide_rmsd_weight: 5.0 89 | ligand_rmsd_weight: 10.0 90 | ema: true 91 | ema_decay: 0.999 92 | 93 | embedder_args: 94 | atom_encoder_depth: 3 95 | atom_encoder_heads: 4 96 | 97 | msa_args: 98 | msa_s: 64 99 | msa_blocks: 4 100 | msa_dropout: 0.15 101 | z_dropout: 0.25 102 | pairwise_head_width: 32 103 | pairwise_num_heads: 4 104 | activation_checkpointing: true 105 | offload_to_cpu: false 106 | 107 | pairformer_args: 108 | num_blocks: 48 109 | num_heads: 16 110 | dropout: 0.25 111 | activation_checkpointing: true 112 | offload_to_cpu: false 113 | 114 | score_model_args: 115 | sigma_data: 16 116 | dim_fourier: 256 117 | atom_encoder_depth: 3 118 | atom_encoder_heads: 4 119 | token_transformer_depth: 24 120 | token_transformer_heads: 16 121 | atom_decoder_depth: 3 122 | atom_decoder_heads: 4 123 | conditioning_transition_layers: 2 124 | activation_checkpointing: true 125 | offload_to_cpu: false 126 | 127 | structure_prediction_training: false 128 | run_trunk_and_structure: true 129 | confidence_prediction: true 130 | alpha_pae: 1 131 | confidence_imitate_trunk: true 132 | confidence_model_args: 133 | use_gaussian: false 134 | num_dist_bins: 64 135 | max_dist: 22 136 | add_s_to_z_prod: true 137 | add_s_input_to_s: true 138 | use_s_diffusion: true 139 | add_z_input_to_z: true 140 | 141 | confidence_args: 142 | num_plddt_bins: 50 143 | num_pde_bins: 64 144 | num_pae_bins: 64 145 | relative_confidence: none 146 | 147 | training_args: 148 | recycling_steps: 3 149 | sampling_steps: 200 150 | diffusion_multiplicity: 16 151 | diffusion_samples: 1 152 | confidence_loss_weight: 3e-3 153 | diffusion_loss_weight: 4.0 154 | distogram_loss_weight: 3e-2 155 | adam_beta_1: 0.9 156 | adam_beta_2: 0.95 157 | adam_eps: 0.00000001 158 | lr_scheduler: af3 159 | base_lr: 0.0 160 | max_lr: 0.0018 161 | lr_warmup_no_steps: 1000 162 | lr_start_decay_after_n_steps: 50000 163 | lr_decay_every_n_steps: 50000 164 | lr_decay_factor: 0.95 165 | symmetry_correction: true 166 | 167 | validation_args: 168 | recycling_steps: 3 169 | sampling_steps: 200 170 | diffusion_samples: 5 171 | symmetry_correction: true 172 | 173 | diffusion_process_args: 174 | sigma_min: 0.0004 175 | sigma_max: 160.0 176 | sigma_data: 16.0 177 | rho: 7 178 | P_mean: -1.2 179 | P_std: 1.5 180 | gamma_0: 0.8 181 | gamma_min: 1.0 182 | noise_scale: 1.0 183 | step_scale: 1.0 184 | coordinate_augmentation: true 185 | alignment_reverse_diff: true 186 | synchronize_sigmas: true 187 | use_inference_model_cache: true 188 | 189 | diffusion_loss_args: 190 | add_smooth_lddt_loss: true 191 | nucleotide_loss_weight: 5.0 192 | ligand_loss_weight: 10.0 193 | -------------------------------------------------------------------------------- /src/boltz/model/layers/triangular_attention/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial, partialmethod 17 | from typing import List, Optional 18 | 19 | import torch 20 | import torch.nn as nn 21 | 22 | from boltz.model.layers.triangular_attention.primitives import ( 23 | Attention, 24 | LayerNorm, 25 | Linear, 26 | ) 27 | from boltz.model.layers.triangular_attention.utils import ( 28 | chunk_layer, 29 | permute_final_dims, 30 | ) 31 | 32 | 33 | class TriangleAttention(nn.Module): 34 | def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9): 35 | """ 36 | Args: 37 | c_in: 38 | Input channel dimension 39 | c_hidden: 40 | Overall hidden channel dimension (not per-head) 41 | no_heads: 42 | Number of attention heads 43 | """ 44 | super(TriangleAttention, self).__init__() 45 | 46 | self.c_in = c_in 47 | self.c_hidden = c_hidden 48 | self.no_heads = no_heads 49 | self.starting = starting 50 | self.inf = inf 51 | 52 | self.layer_norm = LayerNorm(self.c_in) 53 | 54 | self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") 55 | 56 | self.mha = Attention( 57 | self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads 58 | ) 59 | 60 | @torch.jit.ignore 61 | def _chunk( 62 | self, 63 | x: torch.Tensor, 64 | biases: List[torch.Tensor], 65 | chunk_size: int, 66 | use_memory_efficient_kernel: bool = False, 67 | use_deepspeed_evo_attention: bool = False, 68 | use_lma: bool = False, 69 | inplace_safe: bool = False, 70 | ) -> torch.Tensor: 71 | "triangle! triangle!" 72 | mha_inputs = { 73 | "q_x": x, 74 | "kv_x": x, 75 | "biases": biases, 76 | } 77 | 78 | return chunk_layer( 79 | partial( 80 | self.mha, 81 | use_memory_efficient_kernel=use_memory_efficient_kernel, 82 | use_deepspeed_evo_attention=use_deepspeed_evo_attention, 83 | use_lma=use_lma, 84 | ), 85 | mha_inputs, 86 | chunk_size=chunk_size, 87 | no_batch_dims=len(x.shape[:-2]), 88 | _out=x if inplace_safe else None, 89 | ) 90 | 91 | def forward( 92 | self, 93 | x: torch.Tensor, 94 | mask: Optional[torch.Tensor] = None, 95 | chunk_size: Optional[int] = None, 96 | use_memory_efficient_kernel: bool = False, 97 | use_deepspeed_evo_attention: bool = False, 98 | use_lma: bool = False, 99 | inplace_safe: bool = False, 100 | ) -> torch.Tensor: 101 | """ 102 | Args: 103 | x: 104 | [*, I, J, C_in] input tensor (e.g. the pair representation) 105 | Returns: 106 | [*, I, J, C_in] output tensor 107 | """ 108 | if mask is None: 109 | # [*, I, J] 110 | mask = x.new_ones( 111 | x.shape[:-1], 112 | ) 113 | 114 | if not self.starting: 115 | x = x.transpose(-2, -3) 116 | mask = mask.transpose(-1, -2) 117 | 118 | # [*, I, J, C_in] 119 | x = self.layer_norm(x) 120 | 121 | # [*, I, 1, 1, J] 122 | mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] 123 | 124 | # [*, H, I, J] 125 | triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) 126 | 127 | # [*, 1, H, I, J] 128 | triangle_bias = triangle_bias.unsqueeze(-4) 129 | 130 | biases = [mask_bias, triangle_bias] 131 | 132 | if chunk_size is not None: 133 | x = self._chunk( 134 | x, 135 | biases, 136 | chunk_size, 137 | use_memory_efficient_kernel=use_memory_efficient_kernel, 138 | use_deepspeed_evo_attention=use_deepspeed_evo_attention, 139 | use_lma=use_lma, 140 | inplace_safe=inplace_safe, 141 | ) 142 | else: 143 | x = self.mha( 144 | q_x=x, 145 | kv_x=x, 146 | biases=biases, 147 | use_memory_efficient_kernel=use_memory_efficient_kernel, 148 | use_deepspeed_evo_attention=use_deepspeed_evo_attention, 149 | use_lma=use_lma, 150 | ) 151 | 152 | if not self.starting: 153 | x = x.transpose(-2, -3) 154 | 155 | return x 156 | 157 | 158 | # Implements Algorithm 13 159 | TriangleAttentionStartingNode = TriangleAttention 160 | 161 | 162 | class TriangleAttentionEndingNode(TriangleAttention): 163 | """Implement Algorithm 14.""" 164 | 165 | __init__ = partialmethod(TriangleAttention.__init__, starting=False) 166 | -------------------------------------------------------------------------------- /scripts/train/assets/test_ids.txt: -------------------------------------------------------------------------------- 1 | 8BZ4 2 | 8URN 3 | 7U71 4 | 7Z64 5 | 7Y3Z 6 | 8SOT 7 | 8GH8 8 | 8IIB 9 | 7U08 10 | 8EB5 11 | 8G49 12 | 8K7Y 13 | 7QQD 14 | 8EIL 15 | 8JQE 16 | 8V1K 17 | 7ZRZ 18 | 7YN2 19 | 8D40 20 | 8RXO 21 | 8SXS 22 | 7UDL 23 | 8ADD 24 | 7Z3I 25 | 7YUK 26 | 7XWY 27 | 8F9Y 28 | 8WO7 29 | 8C27 30 | 8I3J 31 | 8HVC 32 | 8SXU 33 | 8K1I 34 | 8FTV 35 | 8ERC 36 | 8DVQ 37 | 8DTQ 38 | 8J12 39 | 8D0P 40 | 8POG 41 | 8HN0 42 | 7QPK 43 | 8AGR 44 | 8GXR 45 | 8K7X 46 | 8BL6 47 | 8HAW 48 | 8SRO 49 | 8HHM 50 | 8C26 51 | 7SPQ 52 | 8SME 53 | 7XGV 54 | 8GTY 55 | 8Q42 56 | 8BRY 57 | 8HDV 58 | 8B3Z 59 | 7XNJ 60 | 8EEL 61 | 8IOI 62 | 8Q70 63 | 8Y4U 64 | 8ANT 65 | 8IUB 66 | 8D49 67 | 8CPQ 68 | 8BAT 69 | 8E2B 70 | 8IWP 71 | 8IJT 72 | 7Y01 73 | 8CJG 74 | 8HML 75 | 8WU2 76 | 8VRM 77 | 8J1J 78 | 8DAJ 79 | 8SUT 80 | 8PTJ 81 | 8IVZ 82 | 8SDZ 83 | 7YDQ 84 | 8JU7 85 | 8K34 86 | 8B6Q 87 | 8F7N 88 | 8IBZ 89 | 7WOI 90 | 8R7D 91 | 8T65 92 | 8IQC 93 | 8SIU 94 | 8QK8 95 | 8HIG 96 | 7Y43 97 | 8IN8 98 | 8IBW 99 | 8GOY 100 | 7ZAO 101 | 8J9G 102 | 7ZCA 103 | 8HIO 104 | 8EFZ 105 | 8IQ8 106 | 8OQ0 107 | 8HHL 108 | 7XMW 109 | 8GI1 110 | 8AYR 111 | 7ZCB 112 | 8BRD 113 | 8IN6 114 | 8I3F 115 | 8HIU 116 | 8ER5 117 | 8WIL 118 | 7YPR 119 | 8UA2 120 | 8BW6 121 | 8IL8 122 | 8J3R 123 | 8K1F 124 | 8OHI 125 | 8WCT 126 | 8AN0 127 | 8BDQ 128 | 7FCT 129 | 8J69 130 | 8HTX 131 | 8PE3 132 | 8K5U 133 | 8AXT 134 | 8PSO 135 | 8JHR 136 | 8GY0 137 | 8QCW 138 | 8K3D 139 | 8P6J 140 | 8J0Q 141 | 7XS3 142 | 8DHJ 143 | 8EIN 144 | 7WKP 145 | 8GAQ 146 | 7WRN 147 | 8AHD 148 | 7SC4 149 | 8B3E 150 | 8AAS 151 | 8UZ8 152 | 8Q1K 153 | 8K5K 154 | 8B45 155 | 8PT7 156 | 7ZPN 157 | 8UQ9 158 | 8TJG 159 | 8TN8 160 | 8B2E 161 | 7XFZ 162 | 8FW7 163 | 8B3W 164 | 7T4W 165 | 8SVA 166 | 7YL4 167 | 8GLD 168 | 8OEI 169 | 8GMX 170 | 8OWF 171 | 8FNR 172 | 8IRQ 173 | 8JDG 174 | 7UXA 175 | 8TKA 176 | 7YH1 177 | 8HUZ 178 | 8TA2 179 | 8E5D 180 | 7YUN 181 | 7UOI 182 | 7WMY 183 | 8AA9 184 | 8ISZ 185 | 8EXA 186 | 8E7F 187 | 8B2S 188 | 8TP8 189 | 8GSY 190 | 7XRX 191 | 8SY3 192 | 8CIL 193 | 8WBR 194 | 7XF1 195 | 7YPO 196 | 8AXF 197 | 7QNL 198 | 8OYY 199 | 7R1N 200 | 8H5S 201 | 8B6U 202 | 8IBX 203 | 8Q43 204 | 8OW8 205 | 7XSG 206 | 8U0M 207 | 8IOO 208 | 8HR5 209 | 8BVK 210 | 8P0C 211 | 7TL6 212 | 8J48 213 | 8S0U 214 | 8K8A 215 | 8G53 216 | 7XYO 217 | 8POF 218 | 8U1K 219 | 8HF2 220 | 8K4L 221 | 8JAH 222 | 8KGZ 223 | 8BNB 224 | 7UG2 225 | 8A0A 226 | 8Q3Z 227 | 8XBI 228 | 8JNM 229 | 8GPS 230 | 8K1R 231 | 8Q66 232 | 7YLQ 233 | 7YNX 234 | 8IMD 235 | 7Y8H 236 | 8OXU 237 | 8BVE 238 | 8B4E 239 | 8V14 240 | 7R5I 241 | 8IR2 242 | 8UK7 243 | 8EBB 244 | 7XCC 245 | 8AEP 246 | 7YDW 247 | 8XX9 248 | 7VS6 249 | 8K3F 250 | 8CQM 251 | 7XH4 252 | 8BH9 253 | 7VXT 254 | 8SM9 255 | 8HGU 256 | 8PSQ 257 | 8SSU 258 | 8VXA 259 | 8GSX 260 | 8GHZ 261 | 8BJ3 262 | 8C9V 263 | 8T66 264 | 7XPC 265 | 8RH3 266 | 8CMQ 267 | 8AGG 268 | 8ERM 269 | 8P6M 270 | 8BUX 271 | 7S2J 272 | 8G32 273 | 8AXJ 274 | 8CID 275 | 8CPK 276 | 8P5Q 277 | 8HP8 278 | 7YUJ 279 | 8PT2 280 | 7YK3 281 | 7YYG 282 | 8ABV 283 | 7XL7 284 | 7YLZ 285 | 8JWS 286 | 8IW5 287 | 8SM6 288 | 8BBZ 289 | 8EOV 290 | 8PXC 291 | 7UWV 292 | 8A9N 293 | 7YH5 294 | 8DEO 295 | 7X2X 296 | 8W7P 297 | 8B5W 298 | 8CIH 299 | 8RB4 300 | 8HLG 301 | 8J8H 302 | 8UA5 303 | 7YKM 304 | 8S9W 305 | 7YPD 306 | 8GA6 307 | 7YPQ 308 | 8X7X 309 | 8HI8 310 | 8H7A 311 | 8C4D 312 | 8XAT 313 | 8W8S 314 | 8HM4 315 | 8H3Z 316 | 7W91 317 | 8GPP 318 | 8TNM 319 | 7YSI 320 | 8OML 321 | 8BBR 322 | 7YOJ 323 | 8JZX 324 | 8I3X 325 | 8AU6 326 | 8ITO 327 | 7SFY 328 | 8B6P 329 | 7Y8S 330 | 8ESL 331 | 8DSP 332 | 8CLZ 333 | 8F72 334 | 8QLD 335 | 8K86 336 | 8G8E 337 | 8QDO 338 | 8ANU 339 | 8PT6 340 | 8F5D 341 | 8DQ6 342 | 8IFK 343 | 8OJN 344 | 8SSC 345 | 7QRR 346 | 8E55 347 | 7TPU 348 | 7UQU 349 | 8HFP 350 | 7XGT 351 | 8A39 352 | 8CB2 353 | 8ACR 354 | 8G5S 355 | 7TZL 356 | 8T4R 357 | 8H18 358 | 7UI4 359 | 8Q41 360 | 8K76 361 | 7WUY 362 | 8VXC 363 | 8GYG 364 | 8IMS 365 | 8IKS 366 | 8X51 367 | 7Y7O 368 | 8PX4 369 | 8BF8 370 | 7XMJ 371 | 8GDW 372 | 7YTU 373 | 8CH4 374 | 7XHZ 375 | 7YH4 376 | 8PSN 377 | 8A16 378 | 8FBJ 379 | 7Y9G 380 | 8JI2 381 | 7YR9 382 | 8SW0 383 | 8A90 384 | 8X6V 385 | 8H8P 386 | 7WJU 387 | 8PSS 388 | 8HL8 389 | 8FJD 390 | 8PM4 391 | 7UK8 392 | 8DX0 393 | 8PHB 394 | 8FBN 395 | 8FXF 396 | 8GKH 397 | 8ENR 398 | 8PTH 399 | 8CBV 400 | 8GKV 401 | 8CQO 402 | 8OK3 403 | 8GSR 404 | 8TPK 405 | 8H1J 406 | 8QFL 407 | 8CHW 408 | 7V34 409 | 8HE2 410 | 7ZIE 411 | 8A50 412 | 7Z8E 413 | 8ILL 414 | 7WWC 415 | 7XVI 416 | 8Q2A 417 | 8HNO 418 | 8PR6 419 | 7XCA 420 | 7XGS 421 | 8H55 422 | 8FJE 423 | 7UNH 424 | 8AY2 425 | 8ARD 426 | 8HBR 427 | 8EWG 428 | 8D4A 429 | 8FIT 430 | 8E5E 431 | 8PMU 432 | 8F5G 433 | 8AMU 434 | 8CPN 435 | 7QPL 436 | 8EHN 437 | 8SQU 438 | 8F70 439 | 8FX9 440 | 7UR2 441 | 8T1M 442 | 7ZDS 443 | 7YH2 444 | 8B6A 445 | 8CHX 446 | 8G0N 447 | 8GY4 448 | 7YKG 449 | 8BH8 450 | 8BVI 451 | 7XF2 452 | 8BFY 453 | 8IA3 454 | 8JW3 455 | 8OQJ 456 | 8TFS 457 | 7Y1S 458 | 8HBB 459 | 8AF9 460 | 8IP1 461 | 7XZ3 462 | 8T0P 463 | 7Y16 464 | 8BRP 465 | 8JNX 466 | 8JP0 467 | 8EC3 468 | 8PZH 469 | 7URP 470 | 8B4D 471 | 8JFR 472 | 8GYR 473 | 7XFS 474 | 8SMQ 475 | 7WNH 476 | 8H0L 477 | 8OWI 478 | 8HFC 479 | 7X6G 480 | 8FKL 481 | 8PAG 482 | 8UPI 483 | 8D4B 484 | 8BCK 485 | 8JFU 486 | 8FUQ 487 | 8IF8 488 | 8PAQ 489 | 8HDU 490 | 8W9O 491 | 8ACA 492 | 7YIA 493 | 7ZFR 494 | 7Y9A 495 | 8TTO 496 | 7YFX 497 | 8B2H 498 | 8PSU 499 | 8ACC 500 | 8JMR 501 | 8IHA 502 | 7UYX 503 | 8DWJ 504 | 8BY5 505 | 8EZW 506 | 8A82 507 | 8TVL 508 | 8R79 509 | 8R8A 510 | 8AHZ 511 | 8AYV 512 | 8JHU 513 | 8Q44 514 | 8ARE 515 | 8OLJ 516 | 7Y95 517 | 7XP0 518 | 8EX9 519 | 8BID 520 | 8Q40 521 | 7QSJ 522 | 7UBA 523 | 7XFU 524 | 8OU1 525 | 8G2V 526 | 8YA7 527 | 8GMZ 528 | 8T8L 529 | 8CK0 530 | 7Y4H 531 | 8IOM 532 | 7ZLQ 533 | 8BZ2 534 | 8B4C 535 | 8DZJ 536 | 8CEG 537 | 8IBY 538 | 8T3J 539 | 8IVI 540 | 8ITN 541 | 8CR7 542 | 8TGH 543 | 8OKH 544 | 7UI8 545 | 8EHT 546 | 8ADC 547 | 8T4C 548 | 7XBJ 549 | 8CLU 550 | 7QA1 551 | -------------------------------------------------------------------------------- /src/boltz/model/loss/diffusion.py: -------------------------------------------------------------------------------- 1 | # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang 2 | 3 | from einops import einsum 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def weighted_rigid_align( 9 | true_coords, 10 | pred_coords, 11 | weights, 12 | mask, 13 | ): 14 | """Compute weighted alignment. 15 | 16 | Parameters 17 | ---------- 18 | true_coords: torch.Tensor 19 | The ground truth atom coordinates 20 | pred_coords: torch.Tensor 21 | The predicted atom coordinates 22 | weights: torch.Tensor 23 | The weights for alignment 24 | mask: torch.Tensor 25 | The atoms mask 26 | 27 | Returns 28 | ------- 29 | torch.Tensor 30 | Aligned coordinates 31 | 32 | """ 33 | 34 | batch_size, num_points, dim = true_coords.shape 35 | weights = (mask * weights).unsqueeze(-1) 36 | 37 | # Compute weighted centroids 38 | true_centroid = (true_coords * weights).sum(dim=1, keepdim=True) / weights.sum( 39 | dim=1, keepdim=True 40 | ) 41 | pred_centroid = (pred_coords * weights).sum(dim=1, keepdim=True) / weights.sum( 42 | dim=1, keepdim=True 43 | ) 44 | 45 | # Center the coordinates 46 | true_coords_centered = true_coords - true_centroid 47 | pred_coords_centered = pred_coords - pred_centroid 48 | 49 | if num_points < (dim + 1): 50 | print( 51 | "Warning: The size of one of the point clouds is <= dim+1. " 52 | + "`WeightedRigidAlign` cannot return a unique rotation." 53 | ) 54 | 55 | # Compute the weighted covariance matrix 56 | cov_matrix = einsum( 57 | weights * pred_coords_centered, true_coords_centered, "b n i, b n j -> b i j" 58 | ) 59 | 60 | # Compute the SVD of the covariance matrix, required float32 for svd and determinant 61 | original_dtype = cov_matrix.dtype 62 | cov_matrix_32 = cov_matrix.to(dtype=torch.float32) 63 | U, S, V = torch.linalg.svd( 64 | cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None 65 | ) 66 | V = V.mH 67 | 68 | # Catch ambiguous rotation by checking the magnitude of singular values 69 | if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)): 70 | print( 71 | "Warning: Excessively low rank of " 72 | + "cross-correlation between aligned point clouds. " 73 | + "`WeightedRigidAlign` cannot return a unique rotation." 74 | ) 75 | 76 | # Compute the rotation matrix 77 | rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V).to(dtype=torch.float32) 78 | 79 | # Ensure proper rotation matrix with determinant 1 80 | F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[ 81 | None 82 | ].repeat(batch_size, 1, 1) 83 | F[:, -1, -1] = torch.det(rot_matrix) 84 | rot_matrix = einsum(U, F, V, "b i j, b j k, b l k -> b i l") 85 | rot_matrix = rot_matrix.to(dtype=original_dtype) 86 | 87 | # Apply the rotation and translation 88 | aligned_coords = ( 89 | einsum(true_coords_centered, rot_matrix, "b n i, b j i -> b n j") 90 | + pred_centroid 91 | ) 92 | aligned_coords.detach_() 93 | 94 | return aligned_coords 95 | 96 | 97 | def smooth_lddt_loss( 98 | pred_coords, 99 | true_coords, 100 | is_nucleotide, 101 | coords_mask, 102 | nucleic_acid_cutoff: float = 30.0, 103 | other_cutoff: float = 15.0, 104 | multiplicity: int = 1, 105 | ): 106 | """Compute weighted alignment. 107 | 108 | Parameters 109 | ---------- 110 | pred_coords: torch.Tensor 111 | The predicted atom coordinates 112 | true_coords: torch.Tensor 113 | The ground truth atom coordinates 114 | is_nucleotide: torch.Tensor 115 | The weights for alignment 116 | coords_mask: torch.Tensor 117 | The atoms mask 118 | nucleic_acid_cutoff: float 119 | The nucleic acid cutoff 120 | other_cutoff: float 121 | The non nucleic acid cutoff 122 | multiplicity: int 123 | The multiplicity 124 | Returns 125 | ------- 126 | torch.Tensor 127 | Aligned coordinates 128 | 129 | """ 130 | B, N, _ = true_coords.shape 131 | true_dists = torch.cdist(true_coords, true_coords) 132 | is_nucleotide = is_nucleotide.repeat_interleave(multiplicity, 0) 133 | 134 | coords_mask = coords_mask.repeat_interleave(multiplicity, 0) 135 | is_nucleotide_pair = is_nucleotide.unsqueeze(-1).expand( 136 | -1, -1, is_nucleotide.shape[-1] 137 | ) 138 | 139 | mask = ( 140 | is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float() 141 | + (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float() 142 | ) 143 | mask = mask * (1 - torch.eye(pred_coords.shape[1], device=pred_coords.device)) 144 | mask = mask * (coords_mask.unsqueeze(-1) * coords_mask.unsqueeze(-2)) 145 | 146 | # Compute distances between all pairs of atoms 147 | pred_dists = torch.cdist(pred_coords, pred_coords) 148 | dist_diff = torch.abs(true_dists - pred_dists) 149 | 150 | # Compute epsilon values 151 | eps = ( 152 | ( 153 | ( 154 | F.sigmoid(0.5 - dist_diff) 155 | + F.sigmoid(1.0 - dist_diff) 156 | + F.sigmoid(2.0 - dist_diff) 157 | + F.sigmoid(4.0 - dist_diff) 158 | ) 159 | / 4.0 160 | ) 161 | .view(multiplicity, B // multiplicity, N, N) 162 | .mean(dim=0) 163 | ) 164 | 165 | # Calculate masked averaging 166 | eps = eps.repeat_interleave(multiplicity, 0) 167 | num = (eps * mask).sum(dim=(-1, -2)) 168 | den = mask.sum(dim=(-1, -2)).clamp(min=1) 169 | lddt = num / den 170 | 171 | return 1.0 - lddt.mean() 172 | -------------------------------------------------------------------------------- /scripts/train/assets/validation_ids.txt: -------------------------------------------------------------------------------- 1 | 7UTN 2 | 7F9H 3 | 7TZV 4 | 7ZHH 5 | 7SOV 6 | 7EOF 7 | 7R8H 8 | 8AW3 9 | 7F2F 10 | 8BAO 11 | 7BCB 12 | 7D8T 13 | 7D3T 14 | 7BHY 15 | 7YZ7 16 | 8DC2 17 | 7SOW 18 | 8CTL 19 | 7SOS 20 | 7V6W 21 | 7Z55 22 | 7NQF 23 | 7VTN 24 | 7KSP 25 | 7BJQ 26 | 7YZC 27 | 7Y3L 28 | 7TDX 29 | 7R8I 30 | 7OYK 31 | 7TZ1 32 | 7KIJ 33 | 7T8K 34 | 7KII 35 | 7YZA 36 | 7VP4 37 | 7KIK 38 | 7M5W 39 | 7Q94 40 | 7BCA 41 | 7YZB 42 | 7OG0 43 | 7VTI 44 | 7SOP 45 | 7S03 46 | 7YZG 47 | 7TXC 48 | 7VP5 49 | 7Y3I 50 | 7TDW 51 | 8B0R 52 | 7R8G 53 | 7FEF 54 | 7VP1 55 | 7VP3 56 | 7RGU 57 | 7DV2 58 | 7YZD 59 | 7OFZ 60 | 7Y3K 61 | 7TEC 62 | 7WQ5 63 | 7VP2 64 | 7EDB 65 | 7VP7 66 | 7PDV 67 | 7XHT 68 | 7R6R 69 | 8CSH 70 | 8CSZ 71 | 7V9O 72 | 7Q1C 73 | 8EDC 74 | 7PWI 75 | 7FI1 76 | 7ESI 77 | 7F0Y 78 | 7EYR 79 | 7ZVA 80 | 7WEG 81 | 7E4N 82 | 7U5Q 83 | 7FAV 84 | 7LJ2 85 | 7S6F 86 | 7B3N 87 | 7V4P 88 | 7AJO 89 | 7WH1 90 | 8DQP 91 | 7STT 92 | 7VQ7 93 | 7E4J 94 | 7RIS 95 | 7FH8 96 | 7BMW 97 | 7RD0 98 | 7V54 99 | 7LKC 100 | 7OU1 101 | 7QOD 102 | 7PX1 103 | 7EBY 104 | 7U1V 105 | 7PLP 106 | 7T8N 107 | 7SJK 108 | 7RGB 109 | 7TEM 110 | 7UG9 111 | 7B7A 112 | 7TM2 113 | 7Z74 114 | 7PCM 115 | 7V8G 116 | 7EUU 117 | 7VTL 118 | 7ZEI 119 | 7ZC0 120 | 7DZ9 121 | 8B2M 122 | 7NE9 123 | 7ALV 124 | 7M96 125 | 7O6T 126 | 7SKO 127 | 7Z2V 128 | 7OWX 129 | 7SHW 130 | 7TNI 131 | 7ZQY 132 | 7MDF 133 | 7EXR 134 | 7W6B 135 | 7EQF 136 | 7WWO 137 | 7FBW 138 | 8EHE 139 | 7CLE 140 | 7T80 141 | 7WMV 142 | 7SMG 143 | 7WSJ 144 | 7DBU 145 | 7VHY 146 | 7W5F 147 | 7SHG 148 | 7VU3 149 | 7ATH 150 | 7FGZ 151 | 7ADS 152 | 7REO 153 | 7T7H 154 | 7X0N 155 | 7TCU 156 | 7SKH 157 | 7EF6 158 | 7TBV 159 | 7B29 160 | 7VO5 161 | 7TM1 162 | 7QLD 163 | 7BB9 164 | 7SZ8 165 | 7RLM 166 | 7WWP 167 | 7NBV 168 | 7PLD 169 | 7DNM 170 | 7SFZ 171 | 7EAW 172 | 7QNQ 173 | 7SZX 174 | 7U2S 175 | 7WZX 176 | 7TYG 177 | 7QCE 178 | 7DCN 179 | 7WJL 180 | 7VV6 181 | 7TJ4 182 | 7VI8 183 | 8AKP 184 | 7WAO 185 | 7N7V 186 | 7EYO 187 | 7VTD 188 | 7VEG 189 | 7QY5 190 | 7ELV 191 | 7P0J 192 | 7YX8 193 | 7U4H 194 | 7TBD 195 | 7WME 196 | 7RI3 197 | 7TOH 198 | 7ZVM 199 | 7PUL 200 | 7VBO 201 | 7DM0 202 | 7XN9 203 | 7ALY 204 | 7LTB 205 | 8A28 206 | 7UBZ 207 | 8DTE 208 | 7TA2 209 | 7QST 210 | 7AN1 211 | 7FIB 212 | 8BAL 213 | 7TMJ 214 | 7REV 215 | 7PZJ 216 | 7T9X 217 | 7SUU 218 | 7KJQ 219 | 7V6P 220 | 7QA3 221 | 7ULC 222 | 7Y3X 223 | 7TMU 224 | 7OA7 225 | 7PO9 226 | 7Q20 227 | 8H2C 228 | 7VW1 229 | 7VLJ 230 | 8EP4 231 | 7P57 232 | 7QUL 233 | 7ZQE 234 | 7UJU 235 | 7WG1 236 | 7DMK 237 | 7Y8X 238 | 7EHG 239 | 7W13 240 | 7NL4 241 | 7R4J 242 | 7AOV 243 | 7RFT 244 | 7VUF 245 | 7F72 246 | 8DSR 247 | 7MK3 248 | 7MQQ 249 | 7R55 250 | 7T85 251 | 7NCY 252 | 7ZHL 253 | 7E1N 254 | 7W8F 255 | 7PGK 256 | 8GUN 257 | 7P8D 258 | 7PUK 259 | 7N9D 260 | 7XWN 261 | 7ZHA 262 | 7TVP 263 | 7VI6 264 | 7PW6 265 | 7YM0 266 | 7RWK 267 | 8DKR 268 | 7WGU 269 | 7LJI 270 | 7THW 271 | 7OB6 272 | 7N3Z 273 | 7T3S 274 | 7PAB 275 | 7F9F 276 | 7PPP 277 | 7AD5 278 | 7VGM 279 | 7WBO 280 | 7RWM 281 | 7QFI 282 | 7T91 283 | 7ANU 284 | 7UX0 285 | 7USR 286 | 7RDN 287 | 7VW5 288 | 7Q4T 289 | 7W3R 290 | 8DKQ 291 | 7RCX 292 | 7UOF 293 | 7OKR 294 | 7NX1 295 | 6ZBS 296 | 7VEV 297 | 8E8U 298 | 7WJ6 299 | 7MP4 300 | 7RPY 301 | 7R5Z 302 | 7VLM 303 | 7SNE 304 | 7WDW 305 | 8E19 306 | 7PP2 307 | 7Z5H 308 | 7P7I 309 | 7LJJ 310 | 7QPC 311 | 7VJS 312 | 7QOE 313 | 7KZH 314 | 7F6N 315 | 7TMI 316 | 7POH 317 | 8DKS 318 | 7YMO 319 | 6S5I 320 | 7N6O 321 | 7LYU 322 | 7POK 323 | 7BLK 324 | 7TCY 325 | 7W19 326 | 8B55 327 | 7SMU 328 | 7QFK 329 | 7T5T 330 | 7EPQ 331 | 7DCK 332 | 7S69 333 | 6ZSV 334 | 7ZGT 335 | 7TJ1 336 | 7V09 337 | 7ZHD 338 | 7ALL 339 | 7P1Y 340 | 7T71 341 | 7MNK 342 | 7W5Q 343 | 7PZ2 344 | 7QSQ 345 | 7QI3 346 | 7NZZ 347 | 7Q47 348 | 8D08 349 | 7QH5 350 | 7RXQ 351 | 7F45 352 | 8D07 353 | 8EHC 354 | 7PZT 355 | 7K3C 356 | 7ZGI 357 | 7MC4 358 | 7NPQ 359 | 7VD7 360 | 7XAN 361 | 7FDP 362 | 8A0K 363 | 7TXO 364 | 7ZB1 365 | 7V5V 366 | 7WWS 367 | 7PBK 368 | 8EBG 369 | 7N0J 370 | 7UMA 371 | 7T1S 372 | 8EHB 373 | 7DWC 374 | 7K6W 375 | 7WEJ 376 | 7LRH 377 | 7ZCV 378 | 7RKC 379 | 7X8C 380 | 7PV1 381 | 7UGK 382 | 7ULN 383 | 7A66 384 | 7R7M 385 | 7M0Q 386 | 7BGS 387 | 7UPP 388 | 7O62 389 | 7VKK 390 | 7L6Y 391 | 7VG4 392 | 7V2V 393 | 7ETN 394 | 7ZTB 395 | 7AOO 396 | 7OH2 397 | 7E0M 398 | 7PEG 399 | 8CUK 400 | 7ZP0 401 | 7T6A 402 | 7BTM 403 | 7DOV 404 | 7VVV 405 | 7P22 406 | 7RUO 407 | 7E40 408 | 7O5Y 409 | 7XPK 410 | 7R0K 411 | 8D04 412 | 7TYD 413 | 7LSV 414 | 7XSI 415 | 7RTZ 416 | 7UXR 417 | 7QH3 418 | 8END 419 | 8CYK 420 | 7MRJ 421 | 7DJL 422 | 7S5B 423 | 7XUX 424 | 7EV8 425 | 7R6S 426 | 7UH4 427 | 7R9X 428 | 7F7P 429 | 7ACW 430 | 7SPN 431 | 7W70 432 | 7Q5G 433 | 7DXN 434 | 7DK9 435 | 8DT0 436 | 7FDN 437 | 7DGX 438 | 7UJB 439 | 7X4O 440 | 7F4O 441 | 7T9W 442 | 8AID 443 | 7ERQ 444 | 7EQB 445 | 7YDG 446 | 7ETR 447 | 8D27 448 | 7OUU 449 | 7R5Y 450 | 7T8I 451 | 7UZT 452 | 7X8V 453 | 7QLH 454 | 7SAF 455 | 7EN6 456 | 8D4Y 457 | 7ESJ 458 | 7VWO 459 | 7SBE 460 | 7VYU 461 | 7RVJ 462 | 7FCL 463 | 7WUO 464 | 7WWF 465 | 7VMT 466 | 7SHJ 467 | 7SKP 468 | 7KOU 469 | 6ZSU 470 | 7VGW 471 | 7X45 472 | 8GYZ 473 | 8BFE 474 | 8DGL 475 | 7Z3H 476 | 8BD1 477 | 8A0J 478 | 7JRK 479 | 7QII 480 | 7X39 481 | 7Y6B 482 | 7OIY 483 | 7SBI 484 | 8A3I 485 | 7NLI 486 | 7F4U 487 | 7TVY 488 | 7X0O 489 | 7VMH 490 | 7EPN 491 | 7WBK 492 | 8BFJ 493 | 7XFP 494 | 7LXQ 495 | 7TIL 496 | 7O61 497 | 8B8B 498 | 7W2Q 499 | 8APR 500 | 7WZE 501 | 7NYQ 502 | 7RMX 503 | 7PGE 504 | 8F43 505 | 7N2K 506 | 7UXG 507 | 7SXN 508 | 7T5U 509 | 7R22 510 | 7E3T 511 | 7PTB 512 | 7OA8 513 | 7X5T 514 | 7PL7 515 | 7SQ5 516 | 7VBS 517 | 8D03 518 | 7TAE 519 | 7T69 520 | 7WF6 521 | 7LBU 522 | 8A06 523 | 8DA2 524 | 7QFL 525 | 7KUW 526 | 7X9R 527 | 7XT3 528 | 7RB4 529 | 7PT5 530 | 7RPS 531 | 7RXU 532 | 7TDY 533 | 7W89 534 | 7N9I 535 | 7T1M 536 | 7OBM 537 | 7K3X 538 | 7ZJC 539 | 8BDP 540 | 7V8W 541 | 7DJK 542 | 7W1K 543 | 7QFG 544 | 7DGY 545 | 7ZTQ 546 | 7F8A 547 | 7NEK 548 | 7CG9 549 | 7KOB 550 | 7TN7 551 | 8DYS 552 | 7WVR 553 | -------------------------------------------------------------------------------- /src/boltz/data/write/writer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, replace 2 | from pathlib import Path 3 | from typing import Literal 4 | 5 | import numpy as np 6 | from pytorch_lightning import LightningModule, Trainer 7 | from pytorch_lightning.callbacks import BasePredictionWriter 8 | from torch import Tensor 9 | 10 | from boltz.data.types import ( 11 | Interface, 12 | Record, 13 | Structure, 14 | ) 15 | from boltz.data.write.mmcif import to_mmcif 16 | from boltz.data.write.pdb import to_pdb 17 | 18 | 19 | class BoltzWriter(BasePredictionWriter): 20 | """Custom writer for predictions.""" 21 | 22 | def __init__( 23 | self, 24 | data_dir: str, 25 | output_dir: str, 26 | output_format: Literal["pdb", "mmcif"] = "mmcif", 27 | ) -> None: 28 | """Initialize the writer. 29 | 30 | Parameters 31 | ---------- 32 | output_dir : str 33 | The directory to save the predictions. 34 | 35 | """ 36 | super().__init__(write_interval="batch") 37 | if output_format not in ["pdb", "mmcif"]: 38 | msg = f"Invalid output format: {output_format}" 39 | raise ValueError(msg) 40 | 41 | self.data_dir = Path(data_dir) 42 | self.output_dir = Path(output_dir) 43 | self.output_format = output_format 44 | self.failed = 0 45 | 46 | # Create the output directories 47 | self.output_dir.mkdir(parents=True, exist_ok=True) 48 | 49 | def write_on_batch_end( 50 | self, 51 | trainer: Trainer, # noqa: ARG002 52 | pl_module: LightningModule, # noqa: ARG002 53 | prediction: dict[str, Tensor], 54 | batch_indices: list[int], # noqa: ARG002 55 | batch: dict[str, Tensor], 56 | batch_idx: int, # noqa: ARG002 57 | dataloader_idx: int, # noqa: ARG002 58 | ) -> None: 59 | """Write the predictions to disk.""" 60 | if prediction["exception"]: 61 | self.failed += 1 62 | return 63 | 64 | # Get the records 65 | records: list[Record] = batch["record"] 66 | 67 | # Get the predictions 68 | coords = prediction["coords"] 69 | coords = coords.unsqueeze(0) 70 | 71 | pad_masks = prediction["masks"] 72 | if prediction.get("confidence") is not None: 73 | confidences = prediction["confidence"] 74 | confidences = confidences.reshape(len(records), -1).tolist() 75 | else: 76 | confidences = [0.0 for _ in range(len(records))] 77 | 78 | # Iterate over the records 79 | for record, coord, pad_mask, _confidence in zip( 80 | records, coords, pad_masks, confidences 81 | ): 82 | # Load the structure 83 | path = self.data_dir / f"{record.id}.npz" 84 | structure: Structure = Structure.load(path) 85 | 86 | # Compute chain map with masked removed, to be used later 87 | chain_map = {} 88 | for i, mask in enumerate(structure.mask): 89 | if mask: 90 | chain_map[len(chain_map)] = i 91 | 92 | # Remove masked chains completely 93 | structure = structure.remove_invalid_chains() 94 | 95 | for model_idx in range(coord.shape[0]): 96 | # Get model coord 97 | model_coord = coord[model_idx] 98 | # Unpad 99 | coord_unpad = model_coord[pad_mask.bool()] 100 | coord_unpad = coord_unpad.cpu().numpy() 101 | 102 | # New atom table 103 | atoms = structure.atoms 104 | atoms["coords"] = coord_unpad 105 | atoms["is_present"] = True 106 | 107 | # Mew residue table 108 | residues = structure.residues 109 | residues["is_present"] = True 110 | 111 | # Update the structure 112 | interfaces = np.array([], dtype=Interface) 113 | new_structure: Structure = replace( 114 | structure, 115 | atoms=atoms, 116 | residues=residues, 117 | interfaces=interfaces, 118 | ) 119 | 120 | # Update chain info 121 | chain_info = [] 122 | for chain in new_structure.chains: 123 | old_chain_idx = chain_map[chain["asym_id"]] 124 | old_chain_info = record.chains[old_chain_idx] 125 | new_chain_info = replace( 126 | old_chain_info, 127 | chain_id=int(chain["asym_id"]), 128 | valid=True, 129 | ) 130 | chain_info.append(new_chain_info) 131 | 132 | # Save the structure 133 | struct_dir = self.output_dir / record.id 134 | struct_dir.mkdir(exist_ok=True) 135 | 136 | if self.output_format == "pdb": 137 | path = struct_dir / f"{record.id}_model_{model_idx}.pdb" 138 | with path.open("w") as f: 139 | f.write(to_pdb(new_structure)) 140 | elif self.output_format == "mmcif": 141 | path = struct_dir / f"{record.id}_model_{model_idx}.cif" 142 | with path.open("w") as f: 143 | f.write(to_mmcif(new_structure)) 144 | else: 145 | path = struct_dir / f"{record.id}_model_{model_idx}.npz" 146 | np.savez_compressed(path, **asdict(new_structure)) 147 | 148 | def on_predict_epoch_end( 149 | self, 150 | trainer: Trainer, # noqa: ARG00s2 151 | pl_module: LightningModule, # noqa: ARG002 152 | ) -> None: 153 | """Print the number of failed examples.""" 154 | # Print number of failed examples 155 | print(f"Number of failed examples: {self.failed}") # noqa: T201 156 | -------------------------------------------------------------------------------- /src/boltz/data/tokenize/boltz.py: -------------------------------------------------------------------------------- 1 | from dataclasses import astuple, dataclass 2 | 3 | import numpy as np 4 | 5 | from boltz.data import const 6 | from boltz.data.tokenize.tokenizer import Tokenizer 7 | from boltz.data.types import Input, Token, TokenBond, Tokenized 8 | 9 | 10 | @dataclass 11 | class TokenData: 12 | """TokenData datatype.""" 13 | 14 | token_idx: int 15 | atom_idx: int 16 | atom_num: int 17 | res_idx: int 18 | res_type: int 19 | sym_id: int 20 | asym_id: int 21 | entity_id: int 22 | mol_type: int 23 | center_idx: int 24 | disto_idx: int 25 | center_coords: np.ndarray 26 | disto_coords: np.ndarray 27 | resolved_mask: bool 28 | disto_mask: bool 29 | 30 | 31 | class BoltzTokenizer(Tokenizer): 32 | """Tokenize an input structure for training.""" 33 | 34 | def tokenize(self, data: Input) -> Tokenized: 35 | """Tokenize the input data. 36 | 37 | Parameters 38 | ---------- 39 | data : Inpput 40 | The input data. 41 | 42 | Returns 43 | ------- 44 | Tokenized 45 | The tokenized data. 46 | 47 | """ 48 | # Get structure data 49 | struct = data.structure 50 | 51 | # Create token data 52 | token_data = [] 53 | 54 | # Keep track of atom_idx to token_idx 55 | token_idx = 0 56 | atom_to_token = {} 57 | 58 | # Filter to valid chains only 59 | chains = struct.chains[struct.mask] 60 | 61 | for chain in chains: 62 | # Get residue indices 63 | res_start = chain["res_idx"] 64 | res_end = chain["res_idx"] + chain["res_num"] 65 | 66 | for res in struct.residues[res_start:res_end]: 67 | # Get atom indices 68 | atom_start = res["atom_idx"] 69 | atom_end = res["atom_idx"] + res["atom_num"] 70 | 71 | # Standard residues are tokens 72 | if res["is_standard"]: 73 | # Get center and disto atoms 74 | center = struct.atoms[res["atom_center"]] 75 | disto = struct.atoms[res["atom_disto"]] 76 | 77 | # Token is present if centers are 78 | is_present = res["is_present"] & center["is_present"] 79 | is_disto_present = res["is_present"] & disto["is_present"] 80 | 81 | # Apply chain transformation 82 | c_coords = center["coords"] 83 | d_coords = disto["coords"] 84 | 85 | # Create token 86 | token = TokenData( 87 | token_idx=token_idx, 88 | atom_idx=res["atom_idx"], 89 | atom_num=res["atom_num"], 90 | res_idx=res["res_idx"], 91 | res_type=res["res_type"], 92 | sym_id=chain["sym_id"], 93 | asym_id=chain["asym_id"], 94 | entity_id=chain["entity_id"], 95 | mol_type=chain["mol_type"], 96 | center_idx=res["atom_center"], 97 | disto_idx=res["atom_disto"], 98 | center_coords=c_coords, 99 | disto_coords=d_coords, 100 | resolved_mask=is_present, 101 | disto_mask=is_disto_present, 102 | ) 103 | token_data.append(astuple(token)) 104 | 105 | # Update atom_idx to token_idx 106 | for atom_idx in range(atom_start, atom_end): 107 | atom_to_token[atom_idx] = token_idx 108 | 109 | token_idx += 1 110 | 111 | # Non-standard are tokenized per atom 112 | else: 113 | # We use the unk protein token as res_type 114 | unk_token = const.unk_token["PROTEIN"] 115 | unk_id = const.token_ids[unk_token] 116 | 117 | # Get atom coordinates 118 | atom_data = struct.atoms[atom_start:atom_end] 119 | atom_coords = atom_data["coords"] 120 | 121 | # Tokenize each atom 122 | for i, atom in enumerate(atom_data): 123 | # Token is present if atom is 124 | is_present = res["is_present"] & atom["is_present"] 125 | index = atom_start + i 126 | 127 | # Create token 128 | token = TokenData( 129 | token_idx=token_idx, 130 | atom_idx=index, 131 | atom_num=1, 132 | res_idx=res["res_idx"], 133 | res_type=unk_id, 134 | sym_id=chain["sym_id"], 135 | asym_id=chain["asym_id"], 136 | entity_id=chain["entity_id"], 137 | mol_type=chain["mol_type"], 138 | center_idx=index, 139 | disto_idx=index, 140 | center_coords=atom_coords[i], 141 | disto_coords=atom_coords[i], 142 | resolved_mask=is_present, 143 | disto_mask=is_present, 144 | ) 145 | token_data.append(astuple(token)) 146 | 147 | # Update atom_idx to token_idx 148 | atom_to_token[index] = token_idx 149 | token_idx += 1 150 | 151 | # Create token bonds 152 | token_bonds = [] 153 | 154 | # Add atom-atom bonds from ligands 155 | for bond in struct.bonds: 156 | if ( 157 | bond["atom_1"] not in atom_to_token 158 | or bond["atom_2"] not in atom_to_token 159 | ): 160 | continue 161 | token_bond = ( 162 | atom_to_token[bond["atom_1"]], 163 | atom_to_token[bond["atom_2"]], 164 | ) 165 | token_bonds.append(token_bond) 166 | 167 | # Add connection bonds (covalent) 168 | for conn in struct.connections: 169 | if ( 170 | conn["atom_1"] not in atom_to_token 171 | or conn["atom_2"] not in atom_to_token 172 | ): 173 | continue 174 | token_bond = ( 175 | atom_to_token[conn["atom_1"]], 176 | atom_to_token[conn["atom_2"]], 177 | ) 178 | token_bonds.append(token_bond) 179 | 180 | # Consider adding missing bond for modified residues to standard? 181 | # I'm not sure it's necessary because the bond is probably always 182 | # the same and the model can use the residue indices to infer it 183 | token_data = np.array(token_data, dtype=Token) 184 | token_bonds = np.array(token_bonds, dtype=TokenBond) 185 | tokenized = Tokenized( 186 | token_data, 187 | token_bonds, 188 | data.structure, 189 | data.msa, 190 | ) 191 | return tokenized 192 | -------------------------------------------------------------------------------- /scripts/train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Optional 6 | 7 | import hydra 8 | import omegaconf 9 | import pytorch_lightning as pl 10 | import torch 11 | import torch.multiprocessing 12 | from omegaconf import OmegaConf, listconfig 13 | from pytorch_lightning import LightningModule 14 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 15 | from pytorch_lightning.loggers import WandbLogger 16 | from pytorch_lightning.strategies import DDPStrategy 17 | from pytorch_lightning.utilities import rank_zero_only 18 | 19 | from boltz.data.module.training import BoltzTrainingDataModule, DataConfig 20 | 21 | 22 | @dataclass 23 | class TrainConfig: 24 | """Train configuration. 25 | 26 | Attributes 27 | ---------- 28 | data : DataConfig 29 | The data configuration. 30 | model : ModelConfig 31 | The model configuration. 32 | output : str 33 | The output directory. 34 | trainer : Optional[dict] 35 | The trainer configuration. 36 | resume : Optional[str] 37 | The resume checkpoint. 38 | pretrained : Optional[str] 39 | The pretrained model. 40 | wandb : Optional[dict] 41 | The wandb configuration. 42 | disable_checkpoint : bool 43 | Disable checkpoint. 44 | matmul_precision : Optional[str] 45 | The matmul precision. 46 | find_unused_parameters : Optional[bool] 47 | Find unused parameters. 48 | save_top_k : Optional[int] 49 | Save top k checkpoints. 50 | validation_only : bool 51 | Run validation only. 52 | debug : bool 53 | Debug mode. 54 | strict_loading : bool 55 | Fail on mismatched checkpoint weights. 56 | load_confidence_from_trunk: Optional[bool] 57 | Load pre-trained confidence weights from trunk. 58 | 59 | """ 60 | 61 | data: DataConfig 62 | model: LightningModule 63 | output: str 64 | trainer: Optional[dict] = None 65 | resume: Optional[str] = None 66 | pretrained: Optional[str] = None 67 | wandb: Optional[dict] = None 68 | disable_checkpoint: bool = False 69 | matmul_precision: Optional[str] = None 70 | find_unused_parameters: Optional[bool] = False 71 | save_top_k: Optional[int] = 1 72 | validation_only: bool = False 73 | debug: bool = False 74 | strict_loading: bool = True 75 | load_confidence_from_trunk: Optional[bool] = False 76 | 77 | 78 | def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR0915 79 | """Run training. 80 | 81 | Parameters 82 | ---------- 83 | raw_config : str 84 | The input yaml configuration. 85 | args : list[str] 86 | Any command line overrides. 87 | 88 | """ 89 | # Load the configuration 90 | raw_config = omegaconf.OmegaConf.load(raw_config) 91 | 92 | # Apply input arguments 93 | args = omegaconf.OmegaConf.from_dotlist(args) 94 | raw_config = omegaconf.OmegaConf.merge(raw_config, args) 95 | 96 | # Instantiate the task 97 | cfg = hydra.utils.instantiate(raw_config) 98 | cfg = TrainConfig(**cfg) 99 | 100 | # Set matmul precision 101 | if cfg.matmul_precision is not None: 102 | torch.set_float32_matmul_precision(cfg.matmul_precision) 103 | 104 | # Create trainer dict 105 | trainer = cfg.trainer 106 | if trainer is None: 107 | trainer = {} 108 | 109 | # Flip some arguments in debug mode 110 | devices = trainer.get("devices", 1) 111 | 112 | wandb = cfg.wandb 113 | if cfg.debug: 114 | if isinstance(devices, int): 115 | devices = 1 116 | elif isinstance(devices, (list, listconfig.ListConfig)): 117 | devices = [devices[0]] 118 | trainer["devices"] = devices 119 | cfg.data.num_workers = 0 120 | if wandb: 121 | wandb = None 122 | 123 | # Create objects 124 | data_config = DataConfig(**cfg.data) 125 | data_module = BoltzTrainingDataModule(data_config) 126 | model_module = cfg.model 127 | 128 | if cfg.pretrained and not cfg.resume: 129 | # Load the pretrained weights into the confidence module 130 | if cfg.load_confidence_from_trunk: 131 | checkpoint = torch.load(cfg.pretrained, map_location="cpu") 132 | 133 | # Modify parameter names in the state_dict 134 | new_state_dict = {} 135 | for key, value in checkpoint["state_dict"].items(): 136 | if not key.startswith("structure_module") and not key.startswith( 137 | "distogram_module" 138 | ): 139 | new_key = "confidence_module." + key 140 | new_state_dict[new_key] = value 141 | new_state_dict.update(checkpoint["state_dict"]) 142 | 143 | # Update the checkpoint with the new state_dict 144 | checkpoint["state_dict"] = new_state_dict 145 | else: 146 | file_path = cfg.pretrained 147 | 148 | print(f"Loading model from {file_path}") 149 | model_module = type(model_module).load_from_checkpoint( 150 | file_path, strict=False, **(model_module.hparams) 151 | ) 152 | 153 | if cfg.load_confidence_from_trunk: 154 | os.remove(file_path) 155 | 156 | # Create checkpoint callback 157 | callbacks = [] 158 | dirpath = cfg.output 159 | if not cfg.disable_checkpoint: 160 | mc = ModelCheckpoint( 161 | monitor="val/lddt", 162 | save_top_k=cfg.save_top_k, 163 | save_last=True, 164 | mode="max", 165 | every_n_epochs=1, 166 | ) 167 | callbacks = [mc] 168 | 169 | # Create wandb logger 170 | loggers = [] 171 | if wandb: 172 | wdb_logger = WandbLogger( 173 | group=wandb["name"], 174 | save_dir=cfg.output, 175 | project=wandb["project"], 176 | entity=wandb["entity"], 177 | log_model=False, 178 | ) 179 | loggers.append(wdb_logger) 180 | # Save the config to wandb 181 | 182 | @rank_zero_only 183 | def save_config_to_wandb() -> None: 184 | config_out = Path(wdb_logger.experiment.dir) / "run.yaml" 185 | with Path.open(config_out, "w") as f: 186 | OmegaConf.save(raw_config, f) 187 | wdb_logger.experiment.save(str(config_out)) 188 | 189 | save_config_to_wandb() 190 | 191 | # Set up trainer 192 | strategy = "auto" 193 | if (isinstance(devices, int) and devices > 1) or ( 194 | isinstance(devices, (list, listconfig.ListConfig)) and len(devices) > 1 195 | ): 196 | strategy = DDPStrategy(find_unused_parameters=cfg.find_unused_parameters) 197 | 198 | trainer = pl.Trainer( 199 | default_root_dir=str(dirpath), 200 | strategy=strategy, 201 | callbacks=callbacks, 202 | logger=loggers, 203 | enable_checkpointing=not cfg.disable_checkpoint, 204 | reload_dataloaders_every_n_epochs=1, 205 | **trainer, 206 | ) 207 | 208 | if not cfg.strict_loading: 209 | model_module.strict_loading = False 210 | 211 | if cfg.validation_only: 212 | trainer.validate( 213 | model_module, 214 | datamodule=data_module, 215 | ckpt_path=cfg.resume, 216 | ) 217 | else: 218 | trainer.fit( 219 | model_module, 220 | datamodule=data_module, 221 | ckpt_path=cfg.resume, 222 | ) 223 | 224 | 225 | if __name__ == "__main__": 226 | arg1 = sys.argv[1] 227 | arg2 = sys.argv[2:] 228 | train(arg1, arg2) 229 | -------------------------------------------------------------------------------- /src/boltz/data/write/mmcif.py: -------------------------------------------------------------------------------- 1 | import io 2 | from typing import Iterator 3 | 4 | import ihm 5 | from modelcif import Assembly, AsymUnit, Entity, System, dumper 6 | from modelcif.model import AbInitioModel, Atom, ModelGroup 7 | from rdkit import Chem 8 | 9 | from boltz.data import const 10 | from boltz.data.types import Structure 11 | from boltz.data.write.utils import generate_tags 12 | 13 | 14 | def to_mmcif(structure: Structure) -> str: # noqa: C901 15 | """Write a structure into an MMCIF file. 16 | 17 | Parameters 18 | ---------- 19 | structure : Structure 20 | The input structure 21 | 22 | Returns 23 | ------- 24 | str 25 | the output MMCIF file 26 | 27 | """ 28 | system = System() 29 | 30 | # Load periodic table for element mapping 31 | periodic_table = Chem.GetPeriodicTable() 32 | 33 | # Map entities to chain_ids 34 | entity_to_chains = {} 35 | entity_to_moltype = {} 36 | 37 | for chain in structure.chains: 38 | entity_id = chain["entity_id"] 39 | mol_type = chain["mol_type"] 40 | entity_to_chains.setdefault(entity_id, []).append(chain) 41 | entity_to_moltype[entity_id] = mol_type 42 | 43 | # Map entities to sequences 44 | sequences = {} 45 | for entity in entity_to_chains: 46 | # Get the first chain 47 | chain = entity_to_chains[entity][0] 48 | 49 | # Get the sequence 50 | res_start = chain["res_idx"] 51 | res_end = chain["res_idx"] + chain["res_num"] 52 | residues = structure.residues[res_start:res_end] 53 | sequence = [str(res["name"]) for res in residues] 54 | sequences[entity] = sequence 55 | 56 | # Create entity objects 57 | entities_map = {} 58 | for entity, sequence in sequences.items(): 59 | mol_type = entity_to_moltype[entity] 60 | 61 | if mol_type == const.chain_type_ids["PROTEIN"]: 62 | alphabet = ihm.LPeptideAlphabet() 63 | chem_comp = lambda x: ihm.LPeptideChemComp(id=x, code=x, code_canonical="X") # noqa: E731 64 | elif mol_type == const.chain_type_ids["DNA"]: 65 | alphabet = ihm.DNAAlphabet() 66 | chem_comp = lambda x: ihm.DNAChemComp(id=x, code=x, code_canonical="N") # noqa: E731 67 | elif mol_type == const.chain_type_ids["RNA"]: 68 | alphabet = ihm.RNAAlphabet() 69 | chem_comp = lambda x: ihm.RNAChemComp(id=x, code=x, code_canonical="N") # noqa: E731 70 | elif len(sequence) > 1: 71 | alphabet = {} 72 | chem_comp = lambda x: ihm.SaccharideChemComp(id=x) # noqa: E731 73 | else: 74 | alphabet = {} 75 | chem_comp = lambda x: ihm.NonPolymerChemComp(id=x) # noqa: E731 76 | 77 | seq = [ 78 | alphabet[item] if item in alphabet else chem_comp(item) for item in sequence 79 | ] 80 | model_e = Entity(seq) 81 | for chain in entity_to_chains[entity]: 82 | chain_idx = chain["asym_id"] 83 | entities_map[chain_idx] = model_e 84 | 85 | # We don't assume that symmetry is perfect, so we dump everything 86 | # into the asymmetric unit, and produce just a single assembly 87 | chain_tags = generate_tags() 88 | asym_unit_map = {} 89 | for chain in structure.chains: 90 | # Define the model assembly 91 | chain_idx = chain["asym_id"] 92 | chain_tag = next(chain_tags) 93 | asym = AsymUnit( 94 | entities_map[chain_idx], 95 | details="Model subunit %s" % chain_tag, 96 | id=chain_tag, 97 | ) 98 | asym_unit_map[chain_idx] = asym 99 | modeled_assembly = Assembly(asym_unit_map.values(), name="Modeled assembly") 100 | 101 | # class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT): 102 | # name = "pLDDT" 103 | # software = None 104 | # description = "Predicted lddt" 105 | 106 | # class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT): 107 | # name = "pLDDT" 108 | # software = None 109 | # description = "Global pLDDT, mean of per-residue pLDDTs" 110 | 111 | class _MyModel(AbInitioModel): 112 | def get_atoms(self) -> Iterator[Atom]: 113 | # Add all atom sites. 114 | for chain in structure.chains: 115 | # We rename the chains in alphabetical order 116 | het = chain["mol_type"] == const.chain_type_ids["NONPOLYMER"] 117 | chain_idx = chain["asym_id"] 118 | res_start = chain["res_idx"] 119 | res_end = chain["res_idx"] + chain["res_num"] 120 | 121 | residues = structure.residues[res_start:res_end] 122 | for residue in residues: 123 | atom_start = residue["atom_idx"] 124 | atom_end = residue["atom_idx"] + residue["atom_num"] 125 | atoms = structure.atoms[atom_start:atom_end] 126 | atom_coords = atoms["coords"] 127 | for i, atom in enumerate(atoms): 128 | # This should not happen on predictions, but just in case. 129 | if not atom["is_present"]: 130 | continue 131 | 132 | name = atom["name"] 133 | name = [chr(c + 32) for c in name if c != 0] 134 | name = "".join(name) 135 | element = periodic_table.GetElementSymbol( 136 | atom["element"].item() 137 | ) 138 | element = element.upper() 139 | residue_index = residue["res_idx"] + 1 140 | pos = atom_coords[i] 141 | yield Atom( 142 | asym_unit=asym_unit_map[chain_idx], 143 | type_symbol=element, 144 | seq_id=residue_index, 145 | atom_id=name, 146 | x=pos[0], 147 | y=pos[1], 148 | z=pos[2], 149 | het=het, 150 | biso=1.00, 151 | occupancy=1.00, 152 | ) 153 | 154 | def add_scores(self): 155 | return 156 | # local scores 157 | # plddt_per_residue = {} 158 | # for i in range(n): 159 | # for mask, b_factor in zip(atom_mask[i], b_factors[i]): 160 | # if mask < 0.5: 161 | # continue 162 | # # add 1 per residue, not 1 per atom 163 | # if chain_index[i] not in plddt_per_residue: 164 | # # first time a chain index is seen: add the key and start the residue dict 165 | # plddt_per_residue[chain_index[i]] = {residue_index[i]: b_factor} 166 | # if residue_index[i] not in plddt_per_residue[chain_index[i]]: 167 | # plddt_per_residue[chain_index[i]][residue_index[i]] = b_factor 168 | # plddts = [] 169 | # for chain_idx in plddt_per_residue: 170 | # for residue_idx in plddt_per_residue[chain_idx]: 171 | # plddt = plddt_per_residue[chain_idx][residue_idx] 172 | # plddts.append(plddt) 173 | # self.qa_metrics.append( 174 | # _LocalPLDDT( 175 | # asym_unit_map[chain_idx].residue(residue_idx), plddt 176 | # ) 177 | # ) 178 | # # global score 179 | # self.qa_metrics.append((_GlobalPLDDT(np.mean(plddts)))) 180 | 181 | # Add the model and modeling protocol to the file and write them out: 182 | model = _MyModel(assembly=modeled_assembly, name="Model") 183 | # model.add_scores() 184 | 185 | model_group = ModelGroup([model], name="All models") 186 | system.model_groups.append(model_group) 187 | 188 | fh = io.StringIO() 189 | dumper.write(fh, [system]) 190 | return fh.getvalue() 191 | -------------------------------------------------------------------------------- /src/boltz/data/module/inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | from torch import Tensor 7 | from torch.utils.data import DataLoader 8 | 9 | from boltz.data import const 10 | from boltz.data.feature.featurizer import BoltzFeaturizer 11 | from boltz.data.feature.pad import pad_to_max 12 | from boltz.data.tokenize.boltz import BoltzTokenizer 13 | from boltz.data.types import MSA, Input, Manifest, Record, Structure 14 | 15 | 16 | def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input: 17 | """Load the given input data. 18 | 19 | Parameters 20 | ---------- 21 | record : Record 22 | The record to load. 23 | target_dir : Path 24 | The path to the data directory. 25 | msa_dir : Path 26 | The path to msa directory. 27 | 28 | Returns 29 | ------- 30 | Input 31 | The loaded input. 32 | 33 | """ 34 | # Load the structure 35 | structure = np.load(target_dir / f"{record.id}.npz") 36 | structure = Structure( 37 | atoms=structure["atoms"], 38 | bonds=structure["bonds"], 39 | residues=structure["residues"], 40 | chains=structure["chains"], 41 | connections=structure["connections"], 42 | interfaces=structure["interfaces"], 43 | mask=structure["mask"], 44 | ) 45 | 46 | msas = {} 47 | for chain in record.chains: 48 | msa_id = chain.msa_id 49 | # Load the MSA for this chain, if any 50 | if msa_id != -1: 51 | msa = np.load(msa_dir / f"{msa_id}.npz") 52 | msas[chain.chain_id] = MSA(**msa) 53 | 54 | return Input(structure, msas) 55 | 56 | 57 | def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]: 58 | """Collate the data. 59 | 60 | Parameters 61 | ---------- 62 | data : List[Dict[str, Tensor]] 63 | The data to collate. 64 | 65 | Returns 66 | ------- 67 | Dict[str, Tensor] 68 | The collated data. 69 | 70 | """ 71 | # Get the keys 72 | keys = data[0].keys() 73 | 74 | # Collate the data 75 | collated = {} 76 | for key in keys: 77 | values = [d[key] for d in data] 78 | 79 | if key not in [ 80 | "all_coords", 81 | "all_resolved_mask", 82 | "crop_to_all_atom_map", 83 | "chain_symmetries", 84 | "amino_acids_symmetries", 85 | "ligand_symmetries", 86 | "record", 87 | ]: 88 | # Check if all have the same shape 89 | shape = values[0].shape 90 | if not all(v.shape == shape for v in values): 91 | values, _ = pad_to_max(values, 0) 92 | else: 93 | values = torch.stack(values, dim=0) 94 | 95 | # Stack the values 96 | collated[key] = values 97 | 98 | return collated 99 | 100 | 101 | class PredictionDataset(torch.utils.data.Dataset): 102 | """Base iterable dataset.""" 103 | 104 | def __init__( 105 | self, 106 | manifest: Manifest, 107 | target_dir: Path, 108 | msa_dir: Path, 109 | ) -> None: 110 | """Initialize the training dataset. 111 | 112 | Parameters 113 | ---------- 114 | manifest : Manifest 115 | The manifest to load data from. 116 | target_dir : Path 117 | The path to the target directory. 118 | msa_dir : Path 119 | The path to the msa directory. 120 | 121 | """ 122 | super().__init__() 123 | self.manifest = manifest 124 | self.target_dir = target_dir 125 | self.msa_dir = msa_dir 126 | self.tokenizer = BoltzTokenizer() 127 | self.featurizer = BoltzFeaturizer() 128 | 129 | def __getitem__(self, idx: int) -> dict: 130 | """Get an item from the dataset. 131 | 132 | Returns 133 | ------- 134 | Dict[str, Tensor] 135 | The sampled data features. 136 | 137 | """ 138 | # Get a sample from the dataset 139 | record = self.manifest.records[idx] 140 | 141 | # Get the structure 142 | try: 143 | input_data = load_input(record, self.target_dir, self.msa_dir) 144 | except Exception as e: # noqa: BLE001 145 | print(f"Failed to load input for {record.id} with error {e}. Skipping.") # noqa: T201 146 | return self.__getitem__(0) 147 | 148 | # Tokenize structure 149 | try: 150 | tokenized = self.tokenizer.tokenize(input_data) 151 | except Exception as e: # noqa: BLE001 152 | print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201 153 | return self.__getitem__(0) 154 | 155 | # Compute features 156 | try: 157 | features = self.featurizer.process( 158 | tokenized, 159 | training=False, 160 | max_atoms=None, 161 | max_tokens=None, 162 | max_seqs=const.max_msa_seqs, 163 | pad_to_max_seqs=False, 164 | symmetries={}, 165 | compute_symmetries=False, 166 | ) 167 | except Exception as e: # noqa: BLE001 168 | print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201 169 | return self.__getitem__(0) 170 | 171 | features["record"] = record 172 | return features 173 | 174 | def __len__(self) -> int: 175 | """Get the length of the dataset. 176 | 177 | Returns 178 | ------- 179 | int 180 | The length of the dataset. 181 | 182 | """ 183 | return len(self.manifest.records) 184 | 185 | 186 | class BoltzInferenceDataModule(pl.LightningDataModule): 187 | """DataModule for Boltz inference.""" 188 | 189 | def __init__( 190 | self, 191 | manifest: Manifest, 192 | target_dir: Path, 193 | msa_dir: Path, 194 | num_workers: int, 195 | ) -> None: 196 | """Initialize the DataModule. 197 | 198 | Parameters 199 | ---------- 200 | config : DataConfig 201 | The data configuration. 202 | 203 | """ 204 | super().__init__() 205 | self.num_workers = num_workers 206 | self.manifest = manifest 207 | self.target_dir = target_dir 208 | self.msa_dir = msa_dir 209 | 210 | def predict_dataloader(self) -> DataLoader: 211 | """Get the training dataloader. 212 | 213 | Returns 214 | ------- 215 | DataLoader 216 | The training dataloader. 217 | 218 | """ 219 | dataset = PredictionDataset( 220 | manifest=self.manifest, 221 | target_dir=self.target_dir, 222 | msa_dir=self.msa_dir, 223 | ) 224 | return DataLoader( 225 | dataset, 226 | batch_size=1, 227 | num_workers=self.num_workers, 228 | pin_memory=True, 229 | shuffle=False, 230 | collate_fn=collate, 231 | ) 232 | 233 | def transfer_batch_to_device( 234 | self, 235 | batch: dict, 236 | device: torch.device, 237 | dataloader_idx: int, # noqa: ARG002 238 | ) -> dict: 239 | """Transfer a batch to the given device. 240 | 241 | Parameters 242 | ---------- 243 | batch : Dict 244 | The batch to transfer. 245 | device : torch.device 246 | The device to transfer to. 247 | dataloader_idx : int 248 | The dataloader index. 249 | 250 | Returns 251 | ------- 252 | np.Any 253 | The transferred batch. 254 | 255 | """ 256 | for key in batch: 257 | if key not in [ 258 | "all_coords", 259 | "all_resolved_mask", 260 | "crop_to_all_atom_map", 261 | "chain_symmetries", 262 | "amino_acids_symmetries", 263 | "ligand_symmetries", 264 | "record", 265 | ]: 266 | batch[key] = batch[key].to(device) 267 | return batch 268 | -------------------------------------------------------------------------------- /docs/prediction.md: -------------------------------------------------------------------------------- 1 | # Prediction 2 | 3 | Once you have installed `boltz`, you can start making predictions by simply running: 4 | 5 | `boltz predict ` 6 | 7 | where `` is a path to the input file or a directory. The input file can either be in fasta (enough for most use cases) or YAML format (for more complex inputs). If you specify a directory, `boltz` will run predictions on each `.yaml` or `.fasta` file in the directory. 8 | 9 | Before diving into more details about the input formats, here are the key differences in what they each support: 10 | 11 | | Feature | Fasta | YAML | 12 | | -------- |--------------------| ------- | 13 | | Polymers | :white_check_mark: | :white_check_mark: | 14 | | Smiles | :white_check_mark: | :white_check_mark: | 15 | | CCD code | :white_check_mark: | :white_check_mark: | 16 | | Custom MSA | :white_check_mark: | :white_check_mark: | 17 | | Modified Residues | :x: | :white_check_mark: | 18 | | Covalent bonds | :x: | :white_check_mark: | 19 | | Pocket conditioning | :x: | :white_check_mark: | 20 | 21 | 22 | 23 | ## Fasta format 24 | 25 | The fasta format should contain entries as follows: 26 | 27 | ``` 28 | >CHAIN_ID|ENTITY_TYPE|MSA_PATH 29 | SEQUENCE 30 | ``` 31 | 32 | Where `CHAIN_ID` is a unique identifier for each input chain, `ENTITY_TYPE` can be one of `protein`, `dna`, `rna`, `smiles`, `ccd` and `MSA_PATH` is only specified for protein entities and is the path to the `.a3m` file containing a computed MSA for the sequence of the protein. Note that we support both smiles and CCD code for ligands. 33 | 34 | For each of these cases, the corresponding `SEQUENCE` will contain an amino acid sequence (e.g. `EFKEAFSLF`), a sequence of nucleotide bases (e.g. `ATCG`), a smiles string (e.g. `CC1=CC=CC=C1`), or a CCD code (e.g. `ATP`), depending on the entity. 35 | 36 | As an example: 37 | 38 | ```yaml 39 | >A|protein|./examples/msa/seq1.a3m 40 | MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ 41 | >B|protein|./examples/msa/seq1.a3m 42 | MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ 43 | >C|ccd 44 | SAH 45 | >D|ccd 46 | SAH 47 | >E|smiles 48 | N[C@@H](Cc1ccc(O)cc1)C(=O)O 49 | >F|smiles 50 | N[C@@H](Cc1ccc(O)cc1)C(=O)O 51 | ``` 52 | 53 | 54 | ## YAML format 55 | 56 | The YAML format is more flexible and allows for more complex inputs, particularly around covalent bonds. The schema of the YAML is the following: 57 | 58 | ```yaml 59 | sequences: 60 | - ENTITY_TYPE: 61 | id: CHAIN_ID 62 | sequence: SEQUENCE # only for protein, dna, rna 63 | smiles: SMILES # only for ligand, exclusive with ccd 64 | ccd: CCD # only for ligand, exclusive with smiles 65 | msa: MSA_PATH # only for protein 66 | modifications: 67 | - position: RES_IDX # index of residue, starting from 1 68 | ccd: CCD # CCD code of the modified residue 69 | 70 | - ENTITY_TYPE: 71 | id: [CHAIN_ID, CHAIN_ID] # multiple ids in case of multiple identical entities 72 | ... 73 | constraints: 74 | - bond: 75 | atom1: [CHAIN_ID, RES_IDX, ATOM_NAME] 76 | atom2: [CHAIN_ID, RES_IDX, ATOM_NAME] 77 | - pocket: 78 | binder: CHAIN_ID 79 | contacts: [[CHAIN_ID, RES_IDX], [CHAIN_ID, RES_IDX]] 80 | ``` 81 | `sequences` has one entry for every unique chain/molecule in the input. Each polymer entity as a `ENTITY_TYPE` either `protein`, `dna` or`rna` and have a `sequence` attribute. Non-polymer entities are indicated by `ENTITY_TYPE` equal to `ligand` and have a `smiles` or `ccd` attribute. `CHAIN_ID` is the unique identifier for each chain/molecule, and it should be set as a list in case of multiple identical entities in the structure. Protein entities should also contain an `msa` attribute with `MSA_PATH` indicating the path to the `.a3m` file containing a computed MSA for the sequence of the protein. 82 | 83 | The `modifications` field is an optional field that allows you to specify modified residues in the polymer (`protein`, `dna` or`rna`). The `position` field specifies the index (starting from 1) of the residue, and `ccd` is the CCD code of the modified residue. This field is currently only supported for CCD ligands. 84 | 85 | `constraints` is an optional field that allows you to specify additional information about the input structure. Currently, we support just `bond`. The `bond` constraint specifies a covalent bonds between two atoms (`atom1` and `atom2`). It is currently only supported for CCD ligands and canonical residues, `CHAIN_ID` refers to the id of the residue set above, `RES_IDX` is the index (starting from 1) of the residue (1 for ligands), and `ATOM_NAME` is the standardized atom name (can be verified in CIF file of that component on the RCSB website). 86 | 87 | As an example: 88 | 89 | ```yaml 90 | version: 1 91 | sequences: 92 | - protein: 93 | id: [A, B] 94 | sequence: MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ 95 | msa: ./examples/msa/seq1.a3m 96 | - ligand: 97 | id: [C, D] 98 | ccd: SAH 99 | - ligand: 100 | id: [E, F] 101 | smiles: N[C@@H](Cc1ccc(O)cc1)C(=O)O 102 | ``` 103 | 104 | 105 | ## Options 106 | 107 | The following options are available for the `predict` command: 108 | 109 | boltz predict [OPTIONS] input_path 110 | 111 | | **Option** | **Type** | **Default** | **Description** | 112 | |-----------------------------|-----------------|--------------------|---------------------------------------------------------------------------------| 113 | | `--out_dir PATH` | `PATH` | `./` | The path where to save the predictions. | 114 | | `--cache PATH` | `PATH` | `~/.boltz` | The directory where to download the data and model. | 115 | | `--checkpoint PATH` | `PATH` | None | An optional checkpoint. Uses the provided Boltz-1 model by default. | 116 | | `--devices INTEGER` | `INTEGER` | `1` | The number of devices to use for prediction. | 117 | | `--accelerator` | `[gpu,cpu,tpu]` | `gpu` | The accelerator to use for prediction. | 118 | | `--recycling_steps INTEGER` | `INTEGER` | `3` | The number of recycling steps to use for prediction. | 119 | | `--sampling_steps INTEGER` | `INTEGER` | `200` | The number of sampling steps to use for prediction. | 120 | | `--diffusion_samples INTEGER` | `INTEGER` | `1` | The number of diffusion samples to use for prediction. | 121 | | `--output_format` | `[pdb,mmcif]` | `mmcif` | The output format to use for the predictions. | 122 | | `--num_workers INTEGER` | `INTEGER` | `2` | The number of dataloader workers to use for prediction. | 123 | | `--override` | `FLAG` | `False` | Whether to override existing predictions if found. | 124 | 125 | ## Output 126 | 127 | After running the model, the generated outputs are organized into the output directory following the structure below: 128 | ``` 129 | out_dir/ 130 | ├── lightning_logs/ # Logs generated during training or evaluation 131 | ├── predictions/ # Contains the model's predictions 132 | ├── [input_file1]/ 133 | ├── [input_file1]_model_0.cif # The predicted structure in CIF format 134 | ... 135 | └── [input_file1]_model_[diffusion_samples-1].cif # The predicted structure in CIF format 136 | └── [input_file2]/ 137 | ... 138 | └── processed/ # Processed data used during execution 139 | ``` 140 | The `predictions` folder contains a unique folder for each input file. The input folders contain diffusion_samples predictions saved in the output_format. The `processed` folder contains the processed input files that are used by the model during inference. 141 | -------------------------------------------------------------------------------- /src/boltz/data/sample/cluster.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterator, List 2 | 3 | import numpy as np 4 | from numpy.random import RandomState 5 | 6 | from boltz.data import const 7 | from boltz.data.types import ChainInfo, InterfaceInfo, Record 8 | from boltz.data.sample.sampler import Sample, Sampler 9 | 10 | 11 | def get_chain_cluster(chain: ChainInfo, record: Record) -> str: # noqa: ARG001 12 | """Get the cluster id for a chain. 13 | 14 | Parameters 15 | ---------- 16 | chain : ChainInfo 17 | The chain id to get the cluster id for. 18 | record : Record 19 | The record the interface is part of. 20 | 21 | Returns 22 | ------- 23 | str 24 | The cluster id of the chain. 25 | 26 | """ 27 | return chain.cluster_id 28 | 29 | 30 | def get_interface_cluster(interface: InterfaceInfo, record: Record) -> str: 31 | """Get the cluster id for an interface. 32 | 33 | Parameters 34 | ---------- 35 | interface : InterfaceInfo 36 | The interface to get the cluster id for. 37 | record : Record 38 | The record the interface is part of. 39 | 40 | Returns 41 | ------- 42 | str 43 | The cluster id of the interface. 44 | 45 | """ 46 | chain1 = record.chains[interface.chain_1] 47 | chain2 = record.chains[interface.chain_2] 48 | 49 | cluster_1 = str(chain1.cluster_id) 50 | cluster_2 = str(chain2.cluster_id) 51 | 52 | cluster_id = (cluster_1, cluster_2) 53 | cluster_id = tuple(sorted(cluster_id)) 54 | 55 | return cluster_id 56 | 57 | 58 | def get_chain_weight( 59 | chain: ChainInfo, 60 | record: Record, # noqa: ARG001 61 | clusters: Dict[str, int], 62 | beta_chain: float, 63 | alpha_prot: float, 64 | alpha_nucl: float, 65 | alpha_ligand: float, 66 | ) -> float: 67 | """Get the weight of a chain. 68 | 69 | Parameters 70 | ---------- 71 | chain : ChainInfo 72 | The chain to get the weight for. 73 | record : Record 74 | The record the chain is part of. 75 | clusters : Dict[str, int] 76 | The cluster sizes. 77 | beta_chain : float 78 | The beta value for chains. 79 | alpha_prot : float 80 | The alpha value for proteins. 81 | alpha_nucl : float 82 | The alpha value for nucleic acids. 83 | alpha_ligand : float 84 | The alpha value for ligands. 85 | 86 | Returns 87 | ------- 88 | float 89 | The weight of the chain. 90 | 91 | """ 92 | prot_id = const.chain_type_ids["PROTEIN"] 93 | rna_id = const.chain_type_ids["RNA"] 94 | dna_id = const.chain_type_ids["DNA"] 95 | ligand_id = const.chain_type_ids["NONPOLYMER"] 96 | 97 | weight = beta_chain / clusters[chain.cluster_id] 98 | if chain.mol_type == prot_id: 99 | weight *= alpha_prot 100 | elif chain.mol_type in [rna_id, dna_id]: 101 | weight *= alpha_nucl 102 | elif chain.mol_type == ligand_id: 103 | weight *= alpha_ligand 104 | 105 | return weight 106 | 107 | 108 | def get_interface_weight( 109 | interface: InterfaceInfo, 110 | record: Record, 111 | clusters: Dict[str, int], 112 | beta_interface: float, 113 | alpha_prot: float, 114 | alpha_nucl: float, 115 | alpha_ligand: float, 116 | ) -> float: 117 | """Get the weight of an interface. 118 | 119 | Parameters 120 | ---------- 121 | interface : InterfaceInfo 122 | The interface to get the weight for. 123 | record : Record 124 | The record the interface is part of. 125 | clusters : Dict[str, int] 126 | The cluster sizes. 127 | beta_interface : float 128 | The beta value for interfaces. 129 | alpha_prot : float 130 | The alpha value for proteins. 131 | alpha_nucl : float 132 | The alpha value for nucleic acids. 133 | alpha_ligand : float 134 | The alpha value for ligands. 135 | 136 | Returns 137 | ------- 138 | float 139 | The weight of the interface. 140 | 141 | """ 142 | prot_id = const.chain_type_ids["PROTEIN"] 143 | rna_id = const.chain_type_ids["RNA"] 144 | dna_id = const.chain_type_ids["DNA"] 145 | ligand_id = const.chain_type_ids["NONPOLYMER"] 146 | 147 | chain1 = record.chains[interface.chain_1] 148 | chain2 = record.chains[interface.chain_2] 149 | 150 | n_prot = (chain1.mol_type) == prot_id 151 | n_nuc = chain1.mol_type in [rna_id, dna_id] 152 | n_ligand = chain1.mol_type == ligand_id 153 | 154 | n_prot += chain2.mol_type == prot_id 155 | n_nuc += chain2.mol_type in [rna_id, dna_id] 156 | n_ligand += chain2.mol_type == ligand_id 157 | 158 | weight = beta_interface / clusters[get_interface_cluster(interface, record)] 159 | weight *= alpha_prot * n_prot + alpha_nucl * n_nuc + alpha_ligand * n_ligand 160 | return weight 161 | 162 | 163 | class ClusterSampler(Sampler): 164 | """The weighted sampling approach, as described in AF3. 165 | 166 | Each chain / interface is given a weight according 167 | to the following formula, and sampled accordingly: 168 | 169 | w = b / n_clust *(a_prot * n_prot + a_nuc * n_nuc 170 | + a_ligand * n_ligand) 171 | 172 | """ 173 | 174 | def __init__( 175 | self, 176 | alpha_prot: float = 3.0, 177 | alpha_nucl: float = 3.0, 178 | alpha_ligand: float = 1.0, 179 | beta_chain: float = 0.5, 180 | beta_interface: float = 1.0, 181 | ) -> None: 182 | """Initialize the sampler. 183 | 184 | Parameters 185 | ---------- 186 | alpha_prot : float, optional 187 | The alpha value for proteins. 188 | alpha_nucl : float, optional 189 | The alpha value for nucleic acids. 190 | alpha_ligand : float, optional 191 | The alpha value for ligands. 192 | beta_chain : float, optional 193 | The beta value for chains. 194 | beta_interface : float, optional 195 | The beta value for interfaces. 196 | 197 | """ 198 | self.alpha_prot = alpha_prot 199 | self.alpha_nucl = alpha_nucl 200 | self.alpha_ligand = alpha_ligand 201 | self.beta_chain = beta_chain 202 | self.beta_interface = beta_interface 203 | 204 | def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: # noqa: C901, PLR0912 205 | """Sample a structure from the dataset infinitely. 206 | 207 | Parameters 208 | ---------- 209 | records : List[Record] 210 | The records to sample from. 211 | random : RandomState 212 | The random state for reproducibility. 213 | 214 | Yields 215 | ------ 216 | Sample 217 | A data sample. 218 | 219 | """ 220 | # Compute chain cluster sizes 221 | chain_clusters: Dict[str, int] = {} 222 | for record in records: 223 | for chain in record.chains: 224 | if not chain.valid: 225 | continue 226 | cluster_id = get_chain_cluster(chain, record) 227 | if cluster_id not in chain_clusters: 228 | chain_clusters[cluster_id] = 0 229 | chain_clusters[cluster_id] += 1 230 | 231 | # Compute interface clusters sizes 232 | interface_clusters: Dict[str, int] = {} 233 | for record in records: 234 | for interface in record.interfaces: 235 | if not interface.valid: 236 | continue 237 | cluster_id = get_interface_cluster(interface, record) 238 | if cluster_id not in interface_clusters: 239 | interface_clusters[cluster_id] = 0 240 | interface_clusters[cluster_id] += 1 241 | 242 | # Compute weights 243 | items, weights = [], [] 244 | for record in records: 245 | for chain_id, chain in enumerate(record.chains): 246 | if not chain.valid: 247 | continue 248 | weight = get_chain_weight( 249 | chain, 250 | record, 251 | chain_clusters, 252 | self.beta_chain, 253 | self.alpha_prot, 254 | self.alpha_nucl, 255 | self.alpha_ligand, 256 | ) 257 | items.append((record, 0, chain_id)) 258 | weights.append(weight) 259 | 260 | for int_id, interface in enumerate(record.interfaces): 261 | if not interface.valid: 262 | continue 263 | weight = get_interface_weight( 264 | interface, 265 | record, 266 | interface_clusters, 267 | self.beta_interface, 268 | self.alpha_prot, 269 | self.alpha_nucl, 270 | self.alpha_ligand, 271 | ) 272 | items.append((record, 1, int_id)) 273 | weights.append(weight) 274 | 275 | # Sample infinitely 276 | weights = np.array(weights) / np.sum(weights) 277 | while True: 278 | item_idx = random.choice(len(items), p=weights) 279 | record, kind, index = items[item_idx] 280 | if kind == 0: 281 | yield Sample(record=record, chain_id=index) 282 | else: 283 | yield Sample(record=record, interface_id=index) 284 | -------------------------------------------------------------------------------- /src/boltz/model/modules/transformers.py: -------------------------------------------------------------------------------- 1 | # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang 2 | 3 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 4 | from torch import nn, sigmoid 5 | from torch.nn import ( 6 | LayerNorm, 7 | Linear, 8 | Module, 9 | ModuleList, 10 | Sequential, 11 | ) 12 | 13 | from boltz.model.layers.attention import AttentionPairBias 14 | from boltz.model.modules.utils import LinearNoBias, SwiGLU, default 15 | 16 | 17 | class AdaLN(Module): 18 | """Adaptive Layer Normalization""" 19 | 20 | def __init__(self, dim, dim_single_cond): 21 | """Initialize the adaptive layer normalization. 22 | 23 | Parameters 24 | ---------- 25 | dim : int 26 | The input dimension. 27 | dim_single_cond : int 28 | The single condition dimension. 29 | 30 | """ 31 | super().__init__() 32 | self.a_norm = LayerNorm(dim, elementwise_affine=False, bias=False) 33 | self.s_norm = LayerNorm(dim_single_cond, bias=False) 34 | self.s_scale = Linear(dim_single_cond, dim) 35 | self.s_bias = LinearNoBias(dim_single_cond, dim) 36 | 37 | def forward(self, a, s): 38 | a = self.a_norm(a) 39 | s = self.s_norm(s) 40 | a = sigmoid(self.s_scale(s)) * a + self.s_bias(s) 41 | return a 42 | 43 | 44 | class ConditionedTransitionBlock(Module): 45 | """Conditioned Transition Block""" 46 | 47 | def __init__(self, dim_single, dim_single_cond, expansion_factor=2): 48 | """Initialize the conditioned transition block. 49 | 50 | Parameters 51 | ---------- 52 | dim_single : int 53 | The single dimension. 54 | dim_single_cond : int 55 | The single condition dimension. 56 | expansion_factor : int, optional 57 | The expansion factor, by default 2 58 | 59 | """ 60 | super().__init__() 61 | 62 | self.adaln = AdaLN(dim_single, dim_single_cond) 63 | 64 | dim_inner = int(dim_single * expansion_factor) 65 | self.swish_gate = Sequential( 66 | LinearNoBias(dim_single, dim_inner * 2), 67 | SwiGLU(), 68 | ) 69 | self.a_to_b = LinearNoBias(dim_single, dim_inner) 70 | self.b_to_a = LinearNoBias(dim_inner, dim_single) 71 | 72 | output_projection_linear = Linear(dim_single_cond, dim_single) 73 | nn.init.zeros_(output_projection_linear.weight) 74 | nn.init.constant_(output_projection_linear.bias, -2.0) 75 | 76 | self.output_projection = nn.Sequential(output_projection_linear, nn.Sigmoid()) 77 | 78 | def forward( 79 | self, 80 | a, 81 | s, 82 | ): 83 | a = self.adaln(a, s) 84 | b = self.swish_gate(a) * self.a_to_b(a) 85 | a = self.output_projection(s) * self.b_to_a(b) 86 | 87 | return a 88 | 89 | 90 | class DiffusionTransformer(Module): 91 | """Diffusion Transformer""" 92 | 93 | def __init__( 94 | self, 95 | depth, 96 | heads, 97 | dim=384, 98 | dim_single_cond=None, 99 | dim_pairwise=128, 100 | activation_checkpointing=False, 101 | offload_to_cpu=False, 102 | ): 103 | """Initialize the diffusion transformer. 104 | 105 | Parameters 106 | ---------- 107 | depth : int 108 | The depth. 109 | heads : int 110 | The number of heads. 111 | dim : int, optional 112 | The dimension, by default 384 113 | dim_single_cond : int, optional 114 | The single condition dimension, by default None 115 | dim_pairwise : int, optional 116 | The pairwise dimension, by default 128 117 | activation_checkpointing : bool, optional 118 | Whether to use activation checkpointing, by default False 119 | offload_to_cpu : bool, optional 120 | Whether to offload to CPU, by default False 121 | 122 | """ 123 | super().__init__() 124 | self.activation_checkpointing = activation_checkpointing 125 | dim_single_cond = default(dim_single_cond, dim) 126 | 127 | self.layers = ModuleList() 128 | for _ in range(depth): 129 | if activation_checkpointing: 130 | self.layers.append( 131 | checkpoint_wrapper( 132 | DiffusionTransformerLayer( 133 | heads, 134 | dim, 135 | dim_single_cond, 136 | dim_pairwise, 137 | ), 138 | offload_to_cpu=offload_to_cpu, 139 | ) 140 | ) 141 | else: 142 | self.layers.append( 143 | DiffusionTransformerLayer( 144 | heads, 145 | dim, 146 | dim_single_cond, 147 | dim_pairwise, 148 | ) 149 | ) 150 | 151 | def forward( 152 | self, 153 | a, 154 | s, 155 | z, 156 | mask=None, 157 | to_keys=None, 158 | multiplicity=1, 159 | model_cache=None, 160 | ): 161 | for i, layer in enumerate(self.layers): 162 | layer_cache = None 163 | if model_cache is not None: 164 | prefix_cache = "layer_" + str(i) 165 | if prefix_cache not in model_cache: 166 | model_cache[prefix_cache] = {} 167 | layer_cache = model_cache[prefix_cache] 168 | a = layer( 169 | a, 170 | s, 171 | z, 172 | mask=mask, 173 | to_keys=to_keys, 174 | multiplicity=multiplicity, 175 | layer_cache=layer_cache, 176 | ) 177 | return a 178 | 179 | 180 | class DiffusionTransformerLayer(Module): 181 | """Diffusion Transformer Layer""" 182 | 183 | def __init__( 184 | self, 185 | heads, 186 | dim=384, 187 | dim_single_cond=None, 188 | dim_pairwise=128, 189 | ): 190 | """Initialize the diffusion transformer layer. 191 | 192 | Parameters 193 | ---------- 194 | heads : int 195 | The number of heads. 196 | dim : int, optional 197 | The dimension, by default 384 198 | dim_single_cond : int, optional 199 | The single condition dimension, by default None 200 | dim_pairwise : int, optional 201 | The pairwise dimension, by default 128 202 | 203 | """ 204 | super().__init__() 205 | 206 | dim_single_cond = default(dim_single_cond, dim) 207 | 208 | self.adaln = AdaLN(dim, dim_single_cond) 209 | 210 | self.pair_bias_attn = AttentionPairBias( 211 | c_s=dim, c_z=dim_pairwise, num_heads=heads, initial_norm=False 212 | ) 213 | 214 | self.output_projection_linear = Linear(dim_single_cond, dim) 215 | nn.init.zeros_(self.output_projection_linear.weight) 216 | nn.init.constant_(self.output_projection_linear.bias, -2.0) 217 | 218 | self.output_projection = nn.Sequential( 219 | self.output_projection_linear, nn.Sigmoid() 220 | ) 221 | self.transition = ConditionedTransitionBlock( 222 | dim_single=dim, dim_single_cond=dim_single_cond 223 | ) 224 | 225 | def forward( 226 | self, 227 | a, 228 | s, 229 | z, 230 | mask=None, 231 | to_keys=None, 232 | multiplicity=1, 233 | layer_cache=None, 234 | ): 235 | b = self.adaln(a, s) 236 | b = self.pair_bias_attn( 237 | s=b, 238 | z=z, 239 | mask=mask, 240 | multiplicity=multiplicity, 241 | to_keys=to_keys, 242 | model_cache=layer_cache, 243 | ) 244 | b = self.output_projection(s) * b 245 | 246 | # NOTE: Added residual connection! 247 | a = a + b 248 | a = a + self.transition(a, s) 249 | return a 250 | 251 | 252 | class AtomTransformer(Module): 253 | """Atom Transformer""" 254 | 255 | def __init__( 256 | self, 257 | attn_window_queries=None, 258 | attn_window_keys=None, 259 | **diffusion_transformer_kwargs, 260 | ): 261 | """Initialize the atom transformer. 262 | 263 | Parameters 264 | ---------- 265 | attn_window_queries : int, optional 266 | The attention window queries, by default None 267 | attn_window_keys : int, optional 268 | The attention window keys, by default None 269 | diffusion_transformer_kwargs : dict 270 | The diffusion transformer keyword arguments 271 | 272 | """ 273 | super().__init__() 274 | self.attn_window_queries = attn_window_queries 275 | self.attn_window_keys = attn_window_keys 276 | self.diffusion_transformer = DiffusionTransformer( 277 | **diffusion_transformer_kwargs 278 | ) 279 | 280 | def forward( 281 | self, 282 | q, 283 | c, 284 | p, 285 | to_keys=None, 286 | mask=None, 287 | multiplicity=1, 288 | model_cache=None, 289 | ): 290 | W = self.attn_window_queries 291 | H = self.attn_window_keys 292 | 293 | if W is not None: 294 | B, N, D = q.shape 295 | NW = N // W 296 | 297 | # reshape tokens 298 | q = q.view((B * NW, W, -1)) 299 | c = c.view((B * NW, W, -1)) 300 | if mask is not None: 301 | mask = mask.view(B * NW, W) 302 | p = p.view((p.shape[0] * NW, W, H, -1)) 303 | 304 | to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1) 305 | else: 306 | to_keys_new = None 307 | 308 | # main transformer 309 | q = self.diffusion_transformer( 310 | a=q, 311 | s=c, 312 | z=p, 313 | mask=mask.float(), 314 | multiplicity=multiplicity, 315 | to_keys=to_keys_new, 316 | model_cache=model_cache, 317 | ) 318 | 319 | if W is not None: 320 | q = q.view((B, NW * W, D)) 321 | 322 | return q 323 | -------------------------------------------------------------------------------- /src/boltz/data/filter/static/polymer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | import numpy as np 5 | from scipy.spatial.distance import cdist 6 | 7 | from boltz.data import const 8 | from boltz.data.types import Structure 9 | from boltz.data.filter.static.filter import StaticFilter 10 | 11 | 12 | class MinimumLengthFilter(StaticFilter): 13 | """Filter polymers based on their length. 14 | 15 | We use the number of resolved residues when considering 16 | the minimum, and the sequence length for the maximum. 17 | 18 | """ 19 | 20 | def __init__(self, min_len: int = 4, max_len: int = 5000) -> None: 21 | """Initialize the filter. 22 | 23 | Parameters 24 | ---------- 25 | min_len : float, optional 26 | The minimum allowed length. 27 | max_len : float, optional 28 | The maximum allowed length. 29 | 30 | """ 31 | self._min = min_len 32 | self._max = max_len 33 | 34 | def filter(self, structure: Structure) -> np.ndarray: 35 | """Filter a chains based on their length. 36 | 37 | Parameters 38 | ---------- 39 | structure : Structure 40 | The structure to filter chains from. 41 | 42 | Returns 43 | ------- 44 | np.ndarray 45 | The chains to keep, as a boolean mask. 46 | 47 | """ 48 | valid = np.ones(len(structure.chains), dtype=bool) 49 | 50 | for i, chain in enumerate(structure.chains): 51 | if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]: 52 | continue 53 | 54 | res_start = chain["res_idx"] 55 | res_end = res_start + chain["res_num"] 56 | residues = structure.residues[res_start:res_end] 57 | resolved = residues["is_present"].sum() 58 | 59 | if (resolved < self._min) or (resolved > self._max): 60 | valid[i] = 0 61 | 62 | return valid 63 | 64 | 65 | class UnknownFilter(StaticFilter): 66 | """Filter proteins with all unknown residues.""" 67 | 68 | def filter(self, structure: Structure) -> np.ndarray: 69 | """Filter proteins with all unknown residues. 70 | 71 | Parameters 72 | ---------- 73 | structure : Structure 74 | The structure to filter chains from. 75 | 76 | Returns 77 | ------- 78 | np.ndarray 79 | The chains to keep, as a boolean mask. 80 | 81 | """ 82 | valid = np.ones(len(structure.chains), dtype=bool) 83 | unk_toks = { 84 | const.chain_type_ids["PROTEIN"]: const.unk_token_ids["PROTEIN"], 85 | const.chain_type_ids["DNA"]: const.unk_token_ids["DNA"], 86 | const.chain_type_ids["RNA"]: const.unk_token_ids["RNA"], 87 | } 88 | 89 | for i, chain in enumerate(structure.chains): 90 | if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]: 91 | continue 92 | 93 | res_start = chain["res_idx"] 94 | res_end = res_start + chain["res_num"] 95 | residues = structure.residues[res_start:res_end] 96 | 97 | unk_id = unk_toks[chain["mol_type"]] 98 | if np.all(residues["res_type"] == unk_id): 99 | valid[i] = 0 100 | 101 | return valid 102 | 103 | 104 | class ConsecutiveCA(StaticFilter): 105 | """Filter proteins with consecutive CA atoms above a threshold.""" 106 | 107 | def __init__(self, max_dist: int = 10.0) -> None: 108 | """Initialize the filter. 109 | 110 | Parameters 111 | ---------- 112 | max_dist : float, optional 113 | The maximum allowed distance. 114 | 115 | """ 116 | self._max_dist = max_dist 117 | 118 | def filter(self, structure: Structure) -> np.ndarray: 119 | """Filter protein if consecutive CA atoms above a threshold. 120 | 121 | Parameters 122 | ---------- 123 | structure : Structure 124 | The structure to filter chains from. 125 | 126 | Returns 127 | ------- 128 | np.ndarray 129 | The chains to keep, as a boolean mask. 130 | 131 | """ 132 | valid = np.ones(len(structure.chains), dtype=bool) 133 | 134 | # Remove chain if consecutive CA atoms are above threshold 135 | for i, chain in enumerate(structure.chains): 136 | # Skip non-protein chains 137 | if chain["mol_type"] != const.chain_type_ids["PROTEIN"]: 138 | continue 139 | 140 | # Get residues 141 | res_start = chain["res_idx"] 142 | res_end = res_start + chain["res_num"] 143 | residues = structure.residues[res_start:res_end] 144 | 145 | # Get c-alphas 146 | ca_ids = residues["atom_center"] 147 | ca_atoms = structure.atoms[ca_ids] 148 | 149 | res_valid = residues["is_present"] 150 | ca_valid = ca_atoms["is_present"] & res_valid 151 | ca_coords = ca_atoms["coords"] 152 | 153 | # Compute distances between consecutive atoms 154 | dist = np.linalg.norm(ca_coords[1:] - ca_coords[:-1], axis=1) 155 | dist = dist > self._max_dist 156 | dist = dist[ca_valid[1:] & ca_valid[:-1]] 157 | 158 | # Remove the chain if any valid pair is above threshold 159 | if np.any(dist): 160 | valid[i] = 0 161 | 162 | return valid 163 | 164 | 165 | @dataclass(frozen=True) 166 | class Clash: 167 | """A clash between two chains.""" 168 | 169 | chain: int 170 | other: int 171 | num_atoms: int 172 | num_clashes: int 173 | 174 | 175 | class ClashingChainsFilter(StaticFilter): 176 | """A filter that filters clashing chains. 177 | 178 | Clashing chains are defined as those with >30% of atoms 179 | within 1.7 Å of an atom in another chain. If two chains 180 | are clashing with each other, the chain with the greater 181 | percentage of clashing atoms will be removed. If the same 182 | fraction of atoms are clashing, the chain with fewer total 183 | atoms is removed. If the chains have the same number of 184 | atoms, then the chain with the larger chain id is removed. 185 | 186 | """ 187 | 188 | def __init__(self, dist: float = 1.7, freq: float = 0.3) -> None: 189 | """Initialize the filter. 190 | 191 | Parameters 192 | ---------- 193 | dist : float, optional 194 | The maximum distance for a clash. 195 | freq : float, optional 196 | The maximum allowed frequency of clashes. 197 | 198 | """ 199 | self._dist = dist 200 | self._freq = freq 201 | 202 | def filter(self, structure: Structure) -> np.ndarray: # noqa: PLR0912, C901 203 | """Filter out clashing chains. 204 | 205 | Parameters 206 | ---------- 207 | structure : Structure 208 | The structure to filter chains from. 209 | 210 | Returns 211 | ------- 212 | np.ndarray 213 | The chains to keep, as a boolean mask. 214 | 215 | """ 216 | num_chains = len(structure.chains) 217 | if num_chains < 2: # noqa: PLR2004 218 | return np.ones(num_chains, dtype=bool) 219 | 220 | # Get unique chain pairs 221 | pairs = zip(range(num_chains), range(num_chains)) 222 | pairs = [(i, j) for i, j in pairs if i < j] 223 | 224 | # Compute clashes 225 | clashes: List[Clash] = [] 226 | for i, j in pairs: 227 | # Get the chains 228 | c1 = structure.chains[i] 229 | c2 = structure.chains[j] 230 | 231 | # Get the atoms from each chain 232 | c1_start = c1["atom_idx"] 233 | c2_start = c2["atom_idx"] 234 | c1_end = c1_start + c1["atom_num"] 235 | c2_end = c2_start + c2["atom_num"] 236 | 237 | atoms1 = structure.atoms[c1_start:c1_end] 238 | atoms2 = structure.atoms[c2_start:c2_end] 239 | atoms1 = atoms1[atoms1["is_present"]] 240 | atoms2 = atoms2[atoms2["is_present"]] 241 | 242 | # Compute the number of clashes 243 | dists = cdist(atoms1["coords"], atoms2["coords"]) 244 | clashes = dists < self._dist 245 | c1_clashes = np.any(clashes, axis=1).sum().item() 246 | c2_clashes = np.any(clashes, axis=0).sum().item() 247 | 248 | # Save results 249 | if (c1_clashes / len(atoms1)) > self._freq: 250 | clashes.append(Clash(i, j, len(atoms1), c1_clashes)) 251 | if (c2_clashes / len(atoms2)) > self._freq: 252 | clashes.append(Clash(j, i, len(atoms2), c2_clashes)) 253 | 254 | # Compute indices to clash map 255 | removed = set() 256 | ids_to_clash = {(c.chain, c.other): c for c in clashes} 257 | 258 | # Filter out chains according to ruleset 259 | for clash in clashes: 260 | # If either is already removed, skip 261 | if clash.chain in removed or clash.other in removed: 262 | continue 263 | 264 | # Check if the two chains clash with each other 265 | other_clash = ids_to_clash.get((clash.other, clash.chain)) 266 | if other_clash is not None: 267 | # Remove the chain with the most clashes 268 | clash1_freq = clash.num_clashes / clash.num_atoms 269 | clash2_freq = other_clash.num_clashes / other_clash.num_atoms 270 | if clash1_freq > clash2_freq: 271 | removed.add(clash.chain) 272 | elif clash1_freq < clash2_freq: 273 | removed.add(clash.other) 274 | 275 | # If same, remove the chain with fewer atoms 276 | elif clash.num_atoms < other_clash.num_atoms: 277 | removed.add(clash.chain) 278 | elif clash.num_atoms > other_clash.num_atoms: 279 | removed.add(clash.other) 280 | 281 | # If same, remove the chain with the larger chain id 282 | else: 283 | removed.add(max(clash.chain, clash.other)) 284 | 285 | # Otherwise, just remove the chain directly 286 | else: 287 | removed.add(clash.chain) 288 | 289 | # Remove the chains 290 | valid = np.ones(len(structure.chains), dtype=bool) 291 | for i in removed: 292 | valid[i] = 0 293 | 294 | return valid 295 | -------------------------------------------------------------------------------- /src/boltz/data/const.py: -------------------------------------------------------------------------------- 1 | #################################################################################################### 2 | # CHAINS 3 | #################################################################################################### 4 | 5 | chain_types = [ 6 | "PROTEIN", 7 | "DNA", 8 | "RNA", 9 | "NONPOLYMER", 10 | ] 11 | chain_type_ids = {chain: i for i, chain in enumerate(chain_types)} 12 | 13 | out_types = [ 14 | "dna_protein", 15 | "rna_protein", 16 | "ligand_protein", 17 | "dna_ligand", 18 | "rna_ligand", 19 | "intra_ligand", 20 | "intra_dna", 21 | "intra_rna", 22 | "intra_protein", 23 | "protein_protein", 24 | ] 25 | 26 | out_types_weights_af3 = { 27 | "dna_protein": 10.0, 28 | "rna_protein": 10.0, 29 | "ligand_protein": 10.0, 30 | "dna_ligand": 5.0, 31 | "rna_ligand": 5.0, 32 | "intra_ligand": 20.0, 33 | "intra_dna": 4.0, 34 | "intra_rna": 16.0, 35 | "intra_protein": 20.0, 36 | "protein_protein": 20.0, 37 | } 38 | 39 | out_types_weights = { 40 | "dna_protein": 5.0, 41 | "rna_protein": 5.0, 42 | "ligand_protein": 20.0, 43 | "dna_ligand": 2.0, 44 | "rna_ligand": 2.0, 45 | "intra_ligand": 20.0, 46 | "intra_dna": 2.0, 47 | "intra_rna": 8.0, 48 | "intra_protein": 20.0, 49 | "protein_protein": 20.0, 50 | } 51 | 52 | 53 | out_single_types = ["protein", "ligand", "dna", "rna"] 54 | 55 | #################################################################################################### 56 | # RESIDUES & TOKENS 57 | #################################################################################################### 58 | 59 | tokens = [ 60 | "", 61 | "-", 62 | "ALA", 63 | "ARG", 64 | "ASN", 65 | "ASP", 66 | "CYS", 67 | "GLN", 68 | "GLU", 69 | "GLY", 70 | "HIS", 71 | "ILE", 72 | "LEU", 73 | "LYS", 74 | "MET", 75 | "PHE", 76 | "PRO", 77 | "SER", 78 | "THR", 79 | "TRP", 80 | "TYR", 81 | "VAL", 82 | "UNK", # unknown protein token 83 | "A", 84 | "G", 85 | "C", 86 | "U", 87 | "N", # unknown rna token 88 | "DA", 89 | "DG", 90 | "DC", 91 | "DT", 92 | "DN", # unknown dna token 93 | ] 94 | 95 | token_ids = {token: i for i, token in enumerate(tokens)} 96 | num_tokens = len(tokens) 97 | unk_token = {"PROTEIN": "UNK", "DNA": "DN", "RNA": "N"} 98 | unk_token_ids = {m: token_ids[t] for m, t in unk_token.items()} 99 | 100 | prot_letter_to_token = { 101 | "A": "ALA", 102 | "R": "ARG", 103 | "N": "ASN", 104 | "D": "ASP", 105 | "C": "CYS", 106 | "E": "GLU", 107 | "Q": "GLN", 108 | "G": "GLY", 109 | "H": "HIS", 110 | "I": "ILE", 111 | "L": "LEU", 112 | "K": "LYS", 113 | "M": "MET", 114 | "F": "PHE", 115 | "P": "PRO", 116 | "S": "SER", 117 | "T": "THR", 118 | "W": "TRP", 119 | "Y": "TYR", 120 | "V": "VAL", 121 | "X": "UNK", 122 | "J": "UNK", 123 | "B": "UNK", 124 | "Z": "UNK", 125 | "O": "UNK", 126 | "U": "UNK", 127 | "-": "-", 128 | } 129 | 130 | prot_token_to_letter = {v: k for k, v in prot_letter_to_token.items()} 131 | prot_token_to_letter["UNK"] = "X" 132 | 133 | rna_letter_to_token = { 134 | "A": "A", 135 | "G": "G", 136 | "C": "C", 137 | "U": "U", 138 | "N": "N", 139 | } 140 | rna_token_to_letter = {v: k for k, v in rna_letter_to_token.items()} 141 | 142 | dna_letter_to_token = { 143 | "A": "DA", 144 | "G": "DG", 145 | "C": "DC", 146 | "T": "DT", 147 | "N": "DN", 148 | } 149 | dna_token_to_letter = {v: k for k, v in dna_letter_to_token.items()} 150 | 151 | #################################################################################################### 152 | # ATOMS 153 | #################################################################################################### 154 | 155 | num_elements = 128 156 | 157 | chirality_types = [ 158 | "CHI_UNSPECIFIED", 159 | "CHI_TETRAHEDRAL_CW", 160 | "CHI_TETRAHEDRAL_CCW", 161 | "CHI_OTHER", 162 | ] 163 | chirality_type_ids = {chirality: i for i, chirality in enumerate(chirality_types)} 164 | unk_chirality_type = "CHI_UNSPECIFIED" 165 | 166 | # fmt: off 167 | ref_atoms = { 168 | "PAD": [], 169 | "UNK": ["N", "CA", "C", "O", "CB"], 170 | "-": [], 171 | "ALA": ["N", "CA", "C", "O", "CB"], 172 | "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"], 173 | "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"], 174 | "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"], 175 | "CYS": ["N", "CA", "C", "O", "CB", "SG"], 176 | "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"], 177 | "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"], 178 | "GLY": ["N", "CA", "C", "O"], 179 | "HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"], 180 | "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"], 181 | "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"], 182 | "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"], 183 | "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"], 184 | "PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"], 185 | "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD"], 186 | "SER": ["N", "CA", "C", "O", "CB", "OG"], 187 | "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2"], 188 | "TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"], # noqa: E501 189 | "TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"], 190 | "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2"], 191 | "A": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N9", "C8", "N7", "C5", "C6", "N6", "N1", "C2", "N3", "C4"], # noqa: E501 192 | "G": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N9", "C8", "N7", "C5", "C6", "O6", "N1", "C2", "N2", "N3", "C4"], # noqa: E501 193 | "C": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"], # noqa: E501 194 | "U": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N1", "C2", "O2", "N3", "C4", "O4", "C5", "C6"], # noqa: E501 195 | "N": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'"], # noqa: E501 196 | "DA": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N9", "C8", "N7", "C5", "C6", "N6", "N1", "C2", "N3", "C4"], # noqa: E501 197 | "DG": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N9", "C8", "N7", "C5", "C6", "O6", "N1", "C2", "N2", "N3", "C4"], # noqa: E501 198 | "DC": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"], # noqa: E501 199 | "DT": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N1", "C2", "O2", "N3", "C4", "O4", "C5", "C7", "C6"], # noqa: E501 200 | "DN": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'"] 201 | } 202 | 203 | ref_symmetries = { 204 | "PAD": [], 205 | "ALA": [], 206 | "ARG": [], 207 | "ASN": [], 208 | "ASP": [[(6, 7), (7, 6)]], 209 | "CYS": [], 210 | "GLN": [], 211 | "GLU": [[(7, 8), (8, 7)]], 212 | "GLY": [], 213 | "HIS": [], 214 | "ILE": [], 215 | "LEU": [], 216 | "LYS": [], 217 | "MET": [], 218 | "PHE": [[(6, 7), (7, 6), (8, 9), (9, 8)]], 219 | "PRO": [], 220 | "SER": [], 221 | "THR": [], 222 | "TRP": [], 223 | "TYR": [[(6, 7), (7, 6), (8, 9), (9, 8)]], 224 | "VAL": [], 225 | "A": [[(1, 2), (2, 1)]], 226 | "G": [[(1, 2), (2, 1)]], 227 | "C": [[(1, 2), (2, 1)]], 228 | "U": [[(1, 2), (2, 1)]], 229 | "N": [[(1, 2), (2, 1)]], 230 | "DA": [[(1, 2), (2, 1)]], 231 | "DG": [[(1, 2), (2, 1)]], 232 | "DC": [[(1, 2), (2, 1)]], 233 | "DT": [[(1, 2), (2, 1)]], 234 | "DN": [[(1, 2), (2, 1)]] 235 | } 236 | 237 | 238 | res_to_center_atom = { 239 | "UNK": "CA", 240 | "ALA": "CA", 241 | "ARG": "CA", 242 | "ASN": "CA", 243 | "ASP": "CA", 244 | "CYS": "CA", 245 | "GLN": "CA", 246 | "GLU": "CA", 247 | "GLY": "CA", 248 | "HIS": "CA", 249 | "ILE": "CA", 250 | "LEU": "CA", 251 | "LYS": "CA", 252 | "MET": "CA", 253 | "PHE": "CA", 254 | "PRO": "CA", 255 | "SER": "CA", 256 | "THR": "CA", 257 | "TRP": "CA", 258 | "TYR": "CA", 259 | "VAL": "CA", 260 | "A": "C1'", 261 | "G": "C1'", 262 | "C": "C1'", 263 | "U": "C1'", 264 | "N": "C1'", 265 | "DA": "C1'", 266 | "DG": "C1'", 267 | "DC": "C1'", 268 | "DT": "C1'", 269 | "DN": "C1'" 270 | } 271 | 272 | res_to_disto_atom = { 273 | "UNK": "CB", 274 | "ALA": "CB", 275 | "ARG": "CB", 276 | "ASN": "CB", 277 | "ASP": "CB", 278 | "CYS": "CB", 279 | "GLN": "CB", 280 | "GLU": "CB", 281 | "GLY": "CA", 282 | "HIS": "CB", 283 | "ILE": "CB", 284 | "LEU": "CB", 285 | "LYS": "CB", 286 | "MET": "CB", 287 | "PHE": "CB", 288 | "PRO": "CB", 289 | "SER": "CB", 290 | "THR": "CB", 291 | "TRP": "CB", 292 | "TYR": "CB", 293 | "VAL": "CB", 294 | "A": "C4", 295 | "G": "C4", 296 | "C": "C2", 297 | "U": "C2", 298 | "N": "C1'", 299 | "DA": "C4", 300 | "DG": "C4", 301 | "DC": "C2", 302 | "DT": "C2", 303 | "DN": "C1'" 304 | } 305 | 306 | res_to_center_atom_id = { 307 | res: ref_atoms[res].index(atom) 308 | for res, atom in res_to_center_atom.items() 309 | } 310 | 311 | res_to_disto_atom_id = { 312 | res: ref_atoms[res].index(atom) 313 | for res, atom in res_to_disto_atom.items() 314 | } 315 | 316 | # fmt: on 317 | 318 | #################################################################################################### 319 | # BONDS 320 | #################################################################################################### 321 | 322 | atom_interface_cutoff = 5.0 323 | interface_cutoff = 15.0 324 | 325 | bond_types = [ 326 | "OTHER", 327 | "SINGLE", 328 | "DOUBLE", 329 | "TRIPLE", 330 | "AROMATIC", 331 | ] 332 | bond_type_ids = {bond: i for i, bond in enumerate(bond_types)} 333 | unk_bond_type = "OTHER" 334 | 335 | 336 | #################################################################################################### 337 | # Contacts 338 | #################################################################################################### 339 | 340 | 341 | pocket_contact_info = { 342 | "UNSPECIFIED": 0, 343 | "UNSELECTED": 1, 344 | "POCKET": 2, 345 | "BINDER": 3, 346 | } 347 | 348 | 349 | #################################################################################################### 350 | # MSA 351 | #################################################################################################### 352 | 353 | max_msa_seqs = 16384 354 | -------------------------------------------------------------------------------- /src/boltz/data/crop/boltz.py: -------------------------------------------------------------------------------- 1 | from dataclasses import replace 2 | from typing import Optional 3 | 4 | import numpy as np 5 | from scipy.spatial.distance import cdist 6 | 7 | from boltz.data import const 8 | from boltz.data.crop.cropper import Cropper 9 | from boltz.data.types import Tokenized 10 | 11 | 12 | def pick_random_token( 13 | tokens: np.ndarray, 14 | random: np.random.RandomState, 15 | ) -> np.ndarray: 16 | """Pick a random token from the data. 17 | 18 | Parameters 19 | ---------- 20 | tokens : np.ndarray 21 | The token data. 22 | random : np.ndarray 23 | The random state for reproducibility. 24 | 25 | Returns 26 | ------- 27 | np.ndarray 28 | The selected token. 29 | 30 | """ 31 | return tokens[random.randint(len(tokens))] 32 | 33 | 34 | def pick_chain_token( 35 | tokens: np.ndarray, 36 | chain_id: int, 37 | random: np.random.RandomState, 38 | ) -> np.ndarray: 39 | """Pick a random token from a chain. 40 | 41 | Parameters 42 | ---------- 43 | tokens : np.ndarray 44 | The token data. 45 | chain_id : int 46 | The chain ID. 47 | random : np.ndarray 48 | The random state for reproducibility. 49 | 50 | Returns 51 | ------- 52 | np.ndarray 53 | The selected token. 54 | 55 | """ 56 | # Filter to chain 57 | chain_tokens = tokens[tokens["asym_id"] == chain_id] 58 | 59 | # Pick from chain, fallback to all tokens 60 | if chain_tokens.size: 61 | query = pick_random_token(chain_tokens, random) 62 | else: 63 | query = pick_random_token(tokens, random) 64 | 65 | return query 66 | 67 | 68 | def pick_interface_token( 69 | tokens: np.ndarray, 70 | interface: np.ndarray, 71 | random: np.random.RandomState, 72 | ) -> np.ndarray: 73 | """Pick a random token from an interface. 74 | 75 | Parameters 76 | ---------- 77 | tokens : np.ndarray 78 | The token data. 79 | interface : int 80 | The interface ID. 81 | random : np.ndarray 82 | The random state for reproducibility. 83 | 84 | Returns 85 | ------- 86 | np.ndarray 87 | The selected token. 88 | 89 | """ 90 | # Sample random interface 91 | chain_1 = int(interface["chain_1"]) 92 | chain_2 = int(interface["chain_2"]) 93 | 94 | tokens_1 = tokens[tokens["asym_id"] == chain_1] 95 | tokens_2 = tokens[tokens["asym_id"] == chain_2] 96 | 97 | # If no interface, pick from the chains 98 | if tokens_1.size and (not tokens_2.size): 99 | query = pick_random_token(tokens_1, random) 100 | elif tokens_2.size and (not tokens_1.size): 101 | query = pick_random_token(tokens_2, random) 102 | elif (not tokens_1.size) and (not tokens_2.size): 103 | query = pick_random_token(tokens, random) 104 | else: 105 | # If we have tokens, compute distances 106 | tokens_1_coords = tokens_1["center_coords"] 107 | tokens_2_coords = tokens_2["center_coords"] 108 | 109 | dists = cdist(tokens_1_coords, tokens_2_coords) 110 | cuttoff = dists < const.interface_cutoff 111 | 112 | # In rare cases, the interface cuttoff is slightly 113 | # too small, then we slightly expand it if it happens 114 | if not np.any(cuttoff): 115 | cuttoff = dists < (const.interface_cutoff + 5.0) 116 | 117 | tokens_1 = tokens_1[np.any(cuttoff, axis=1)] 118 | tokens_2 = tokens_2[np.any(cuttoff, axis=0)] 119 | 120 | # Select random token 121 | candidates = np.concatenate([tokens_1, tokens_2]) 122 | query = pick_random_token(candidates, random) 123 | 124 | return query 125 | 126 | 127 | class BoltzCropper(Cropper): 128 | """Interpolate between contiguous and spatial crops.""" 129 | 130 | def __init__(self, min_neighborhood: int = 0, max_neighborhood: int = 40) -> None: 131 | """Initialize the cropper. 132 | 133 | Modulates the type of cropping to be performed. 134 | Smaller neighborhoods result in more spatial 135 | cropping. Larger neighborhoods result in more 136 | continuous cropping. A mix can be achieved by 137 | providing a range over which to sample. 138 | 139 | Parameters 140 | ---------- 141 | min_neighborhood : int 142 | The minimum neighborhood size, by default 0. 143 | max_neighborhood : int 144 | The maximum neighborhood size, by default 40. 145 | 146 | """ 147 | sizes = list(range(min_neighborhood, max_neighborhood + 1, 2)) 148 | self.neighborhood_sizes = sizes 149 | 150 | def crop( # noqa: PLR0915 151 | self, 152 | data: Tokenized, 153 | max_tokens: int, 154 | random: np.random.RandomState, 155 | max_atoms: Optional[int] = None, 156 | chain_id: Optional[int] = None, 157 | interface_id: Optional[int] = None, 158 | ) -> Tokenized: 159 | """Crop the data to a maximum number of tokens. 160 | 161 | Parameters 162 | ---------- 163 | data : Tokenized 164 | The tokenized data. 165 | max_tokens : int 166 | The maximum number of tokens to crop. 167 | random : np.random.RandomState 168 | The random state for reproducibility. 169 | max_atoms : int, optional 170 | The maximum number of atoms to consider. 171 | chain_id : int, optional 172 | The chain ID to crop. 173 | interface_id : int, optional 174 | The interface ID to crop. 175 | 176 | Returns 177 | ------- 178 | Tokenized 179 | The cropped data. 180 | 181 | """ 182 | # Check inputs 183 | if chain_id is not None and interface_id is not None: 184 | msg = "Only one of chain_id or interface_id can be provided." 185 | raise ValueError(msg) 186 | 187 | # Randomly select a neighborhood size 188 | neighborhood_size = random.choice(self.neighborhood_sizes) 189 | 190 | # Get token data 191 | token_data = data.tokens 192 | token_bonds = data.bonds 193 | mask = data.structure.mask 194 | chains = data.structure.chains 195 | interfaces = data.structure.interfaces 196 | 197 | # Filter to valid chains 198 | valid_chains = chains[mask] 199 | 200 | # Filter to valid interfaces 201 | valid_interfaces = interfaces 202 | valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_1"]]] 203 | valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_2"]]] 204 | 205 | # Filter to resolved tokens 206 | valid_tokens = token_data[token_data["resolved_mask"]] 207 | 208 | # Check if we have any valid tokens 209 | if not valid_tokens.size: 210 | msg = "No valid tokens in structure" 211 | raise ValueError(msg) 212 | 213 | # Pick a random token, chain, or interface 214 | if chain_id is not None: 215 | query = pick_chain_token(valid_tokens, chain_id, random) 216 | elif interface_id is not None: 217 | interface = interfaces[interface_id] 218 | query = pick_interface_token(valid_tokens, interface, random) 219 | elif valid_interfaces.size: 220 | idx = random.randint(len(valid_interfaces)) 221 | interface = valid_interfaces[idx] 222 | query = pick_interface_token(valid_tokens, interface, random) 223 | else: 224 | idx = random.randint(len(valid_chains)) 225 | chain_id = valid_chains[idx]["asym_id"] 226 | query = pick_chain_token(valid_tokens, chain_id, random) 227 | 228 | # Sort all tokens by distance to query_coords 229 | dists = valid_tokens["center_coords"] - query["center_coords"] 230 | indices = np.argsort(np.linalg.norm(dists, axis=1)) 231 | 232 | # Select cropped indices 233 | cropped: set[int] = set() 234 | total_atoms = 0 235 | for idx in indices: 236 | # Get the token 237 | token = valid_tokens[idx] 238 | 239 | # Get all tokens from this chain 240 | chain_tokens = token_data[token_data["asym_id"] == token["asym_id"]] 241 | 242 | # Pick the whole chain if possible, otherwise select 243 | # a contiguous subset centered at the query token 244 | if len(chain_tokens) <= neighborhood_size: 245 | new_tokens = chain_tokens 246 | else: 247 | # First limit to the maximum set of tokens, with the 248 | # neighboorhood on both sides to handle edges. This 249 | # is mostly for efficiency with the while loop below. 250 | min_idx = token["res_idx"] - neighborhood_size 251 | max_idx = token["res_idx"] + neighborhood_size 252 | 253 | max_token_set = chain_tokens 254 | max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx] 255 | max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx] 256 | 257 | # Start by adding just the query token 258 | new_tokens = max_token_set[max_token_set["res_idx"] == token["res_idx"]] 259 | 260 | # Expand the neighborhood until we have enough tokens, one 261 | # by one to handle some edge cases with non-standard chains. 262 | # We switch to the res_idx instead of the token_idx to always 263 | # include all tokens from modified residues or from ligands. 264 | min_idx = max_idx = token["res_idx"] 265 | while new_tokens.size < neighborhood_size: 266 | min_idx = min_idx - 1 267 | max_idx = max_idx + 1 268 | new_tokens = max_token_set 269 | new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx] 270 | new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx] 271 | 272 | # Compute new tokens and new atoms 273 | new_indices = set(new_tokens["token_idx"]) - cropped 274 | new_tokens = token_data[list(new_indices)] 275 | new_atoms = np.sum(new_tokens["atom_num"]) 276 | 277 | # Stop if we exceed the max number of tokens or atoms 278 | if (len(new_indices) > (max_tokens - len(cropped))) or ( 279 | (max_atoms is not None) and ((total_atoms + new_atoms) > max_atoms) 280 | ): 281 | break 282 | 283 | # Add new indices 284 | cropped.update(new_indices) 285 | total_atoms += new_atoms 286 | 287 | # Get the cropped tokens sorted by index 288 | token_data = token_data[sorted(cropped)] 289 | 290 | # Only keep bonds within the cropped tokens 291 | indices = token_data["token_idx"] 292 | token_bonds = token_bonds[np.isin(token_bonds["token_1"], indices)] 293 | token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)] 294 | 295 | # Return the cropped tokens 296 | return replace(data, tokens=token_data, bonds=token_bonds) 297 | --------------------------------------------------------------------------------