├── src └── ibex │ ├── openfold │ ├── utils │ │ ├── __init__.py │ │ ├── precision_utils.py │ │ ├── tensor_utils.py │ │ ├── data_transforms.py │ │ ├── feats.py │ │ └── protein.py │ ├── resources │ │ ├── __init__.py │ │ └── stereo_chemical_props.txt │ └── __init__.py │ ├── __init__.py │ ├── py.typed │ ├── loss │ ├── __init__.py │ ├── loss.py │ └── aligned_rmsd.py │ ├── inference.py │ ├── predict.py │ ├── utils.py │ ├── refine.py │ ├── dataloader.py │ └── model.py ├── docs ├── assets │ ├── ibex.png │ └── ibex_transparent.png ├── ibex_test.csv └── Genentech_license_weights_ibex ├── pyproject.toml ├── .gitignore ├── README.md └── LICENSE /src/ibex/openfold/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/ibex/openfold/resources/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/assets/ibex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prescient-design/ibex/HEAD/docs/assets/ibex.png -------------------------------------------------------------------------------- /src/ibex/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Ibex 2 | from .predict import inference, batch_inference -------------------------------------------------------------------------------- /docs/assets/ibex_transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prescient-design/ibex/HEAD/docs/assets/ibex_transparent.png -------------------------------------------------------------------------------- /src/ibex/openfold/__init__.py: -------------------------------------------------------------------------------- 1 | from . import model, resources, utils 2 | 3 | __all__ = ["model", "utils", "resources"] 4 | -------------------------------------------------------------------------------- /src/ibex/py.typed: -------------------------------------------------------------------------------- 1 | Marks this package as providing type-annotations. 2 | See https://www.python.org/dev/peps/pep-0561/ for details. 3 | -------------------------------------------------------------------------------- /src/ibex/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from ibex.loss.aligned_rmsd import aligned_fv_and_cdrh3_rmsd, batch_align 2 | from ibex.loss.loss import IbexLoss 3 | -------------------------------------------------------------------------------- /src/ibex/openfold/utils/precision_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # Copyright 2024 Exscientia 3 | # Copyright 2022 AlQuraishi Laboratory 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | 19 | 20 | def is_fp16_enabled(): 21 | # Autocast world 22 | fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 23 | fp16_enabled = fp16_enabled and torch.is_autocast_enabled() 24 | 25 | return fp16_enabled 26 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | 2 | [project] 3 | name = "prescient-ibex" 4 | description = "Lightweight antibody structure prediction tool (arXiv:2507.09054)" 5 | authors = [{name = "Prescient Design"}] 6 | dynamic = ["version", "readme"] 7 | requires-python = ">=3.10" 8 | dependencies = [ 9 | "click<8.2.0", 10 | "dm-tree", 11 | "einops", 12 | "esm==3.1.6", 13 | "hydra-core", 14 | "levenshtein", 15 | "lightning>=2.5.0.post0", 16 | "loguru>=0.7.3", 17 | "ml-collections>=1.0.0", 18 | "numpy>=1.26.4", 19 | "omegaconf>=2.3.0", 20 | "openmm>=8.2.0", 21 | "pandas>=2.2.3", 22 | "pdbfixer", 23 | "pyparsing>=3.2.3", 24 | "python-box>=7.3.2", 25 | "s3fs>=2025.2.0", 26 | "scipy>=1.10.1", 27 | "sentencepiece", 28 | "torch>=2.5.1", 29 | "tqdm", 30 | "transformers", 31 | "typer>=0.15.2", 32 | "wandb>=0.19.8", 33 | ] 34 | 35 | [tool.setuptools.packages.find] 36 | where = ["src"] 37 | 38 | [tool.setuptools.package-data] 39 | "ibex.configs" = ["**/*.yaml"] 40 | 41 | [tool.setuptools.dynamic] 42 | readme = {file = "README.md", content-type = "text/markdown"} 43 | 44 | [project.scripts] 45 | ibex = "ibex.inference:app" 46 | 47 | [project.optional-dependencies] 48 | analysis = [ 49 | "matplotlib>=3.10.1", 50 | "py3dmol>=2.4.2", 51 | "seaborn>=0.13.2", 52 | ] 53 | 54 | [build-system] 55 | requires = ["setuptools >= 65", "setuptools_scm[toml]>=6.2"] 56 | build-backend = 'setuptools.build_meta' 57 | 58 | [dependency-groups] 59 | dev = [ 60 | "coverage[toml]>=7.8.0", 61 | "ipykernel>=6.29.5", 62 | "ipython>=8.35.0", 63 | "pip>=25.0.1", 64 | "pre-commit>=4.2.0", 65 | "pytest>=8.3.5", 66 | "ruff>=0.11.5", 67 | ] 68 | 69 | [tool.setuptools_scm] 70 | search_parent_directories = true 71 | version_scheme = "no-guess-dev" 72 | local_scheme = "node-and-date" 73 | fallback_version = "0.0.1" 74 | 75 | [tool.uv.sources] 76 | pdbfixer = { git = "https://github.com/openmm/pdbfixer.git" } 77 | -------------------------------------------------------------------------------- /docs/ibex_test.csv: -------------------------------------------------------------------------------- 1 | pdb_id,is_vhh,is_tcr 2 | 8f5i,False,False 3 | 8byu,False,False 4 | 8f6l,False,False 5 | 8bse,False,False 6 | 8dy3,False,False 7 | 7wsl,False,False 8 | 7ycl,False,False 9 | 8ek6,False,False 10 | 7ucx,False,False 11 | 8d29,False,False 12 | 7s0j,False,False 13 | 7sts,False,False 14 | 7unb,False,False 15 | 7skz,False,False 16 | 8cwu,True,False 17 | 7t0j,False,False 18 | 7omn,True,False 19 | 7jkm,False,False 20 | 7r63,True,False 21 | 7vux,False,False 22 | 4lkx,False,False 23 | 1jps,False,False 24 | 7ue9,False,False 25 | 6xjq,False,False 26 | 7ndf,True,False 27 | 7zxu,True,False 28 | 6e65,False,False 29 | 7fau,True,False 30 | 5kzp,False,False 31 | 6l8t,False,False 32 | 7zmv,True,False 33 | 5c6w,False,False 34 | 7olz,True,False 35 | 6o26,False,False 36 | 7q6c,True,False 37 | 6cr1,False,False 38 | 4m6n,False,False 39 | 7nfr,True,False 40 | 7t0i,False,False 41 | 5chn,False,False 42 | 7zwi,False,False 43 | 6o3a,False,False 44 | 7sem,False,False 45 | 7s4g,False,False 46 | 6dcv,False,False 47 | 7o06,True,False 48 | 6o3k,False,False 49 | 7apj,True,False 50 | 6blh,False,False 51 | 7z0x,False,False 52 | 7sg6,False,False 53 | 7djx,True,False 54 | 7kpg,False,False 55 | 7rt9,False,False 56 | 3lmj,False,False 57 | 7rqr,False,False 58 | 7lcv,False,False 59 | 7ps3,False,False 60 | 7vyr,False,False 61 | 5fhb,False,False 62 | 7kfy,False,False 63 | 7raq,False,False 64 | 7sg5,False,False 65 | 7so5,False,False 66 | 7rxi,False,False 67 | 7qu1,False,False 68 | 7rxl,False,False 69 | 7urq,False,False 70 | 5ob5,False,False 71 | 1uj3,False,False 72 | 7vnb,True,False 73 | 7w1s,True,False 74 | 7o0s,True,False 75 | 3h0t,False,False 76 | 3n9g,False,False 77 | 7q0g,False,False 78 | 7np9,True,False 79 | 7neh,False,False 80 | 7ar0,True,False 81 | 1t3f,False,False 82 | 7f1g,True,False 83 | 7zf6,False,False 84 | 7sjs,False,False 85 | 6wyt,False,False 86 | 7ps4,False,False 87 | 7om5,True,False 88 | 7rqq,False,False 89 | 7vfb,True,False 90 | 7n4n,True,False 91 | 6m3b,False,False 92 | 5e8e,False,False 93 | 6dc4,False,False 94 | 7anq,True,False 95 | 7nfq,True,False 96 | 6mxs,False,False 97 | 7qf0,False,False 98 | 7s7r,True,False 99 | 7sn1,False,False 100 | 7e53,True,False 101 | 6yio,False,False 102 | 6pzg,False,False 103 | 6vjt,False,False 104 | 7omt,True,False 105 | 6e56,False,False 106 | 7r20,True,False 107 | 7f07,True,False 108 | 7vfa,True,False 109 | 7ryu,False,False 110 | 1gaf,False,False 111 | 7tp4,False,False 112 | 4jam,False,False 113 | 7rg7,True,False 114 | 7q4q,False,False 115 | 7r73,True,False 116 | 4d9q,False,False 117 | 7k8o,False,False 118 | 1h0d,False,False 119 | 3grw,False,False 120 | 5wn9,False,False 121 | 7ps6,False,False 122 | 3tje,False,False 123 | 4j6r,False,False 124 | 7rp2,False,False 125 | 6uce,False,False 126 | 7ttm,False,False 127 | 4xvs,False,False 128 | 7n3d,False,False 129 | 2a9m,False,False 130 | 4buh,False,False 131 | 6ss5,False,False 132 | 7seg,False,False 133 | 7u8c,False,False 134 | 7pbe,False,True 135 | 7rrg,False,True 136 | 7fjd,False,True 137 | 6zkz,False,True 138 | 7amp,False,True 139 | 7r7z,False,True 140 | 7fjf,False,True 141 | 7dzm,False,True 142 | 7n5c,False,True 143 | 7qpj,False,True 144 | 7na5,False,True 145 | 7fje,False,True 146 | 7pb2,False,True 147 | 7n5p,False,True 148 | 6zkx,False,True 149 | 7z50,False,True 150 | 7l1d,False,True 151 | 7f5k,False,True 152 | 6zky,False,True 153 | 7r80,False,True 154 | 7su9,False,True 155 | -------------------------------------------------------------------------------- /src/ibex/openfold/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # Copyright 2024 Exscientia 3 | # Copyright 2021 AlQuraishi Laboratory 4 | # Copyright 2021 DeepMind Technologies Limited 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from functools import partial 19 | from typing import List 20 | 21 | import torch 22 | 23 | 24 | def permute_final_dims(tensor: torch.Tensor, inds: List[int]): 25 | zero_index = -1 * len(inds) 26 | first_inds = list(range(len(tensor.shape[:zero_index]))) 27 | return tensor.permute(first_inds + [zero_index + i for i in inds]) 28 | 29 | 30 | def flatten_final_dims(t: torch.Tensor, no_dims: int): 31 | return t.reshape(t.shape[:-no_dims] + (-1,)) 32 | 33 | 34 | def masked_mean(mask, value, dim, eps=1e-4): 35 | mask = mask.expand(*value.shape) 36 | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) 37 | 38 | 39 | def dict_multimap(fn, dicts): 40 | first = dicts[0] 41 | new_dict = {} 42 | for k, v in first.items(): 43 | all_v = [d[k] for d in dicts] 44 | if type(v) is dict: 45 | new_dict[k] = dict_multimap(fn, all_v) 46 | else: 47 | new_dict[k] = fn(all_v) 48 | 49 | return new_dict 50 | 51 | 52 | def one_hot(x, v_bins): 53 | reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) 54 | diffs = x[..., None] - reshaped_bins 55 | am = torch.argmin(torch.abs(diffs), dim=-1) 56 | return nn.functional.one_hot(am, num_classes=len(v_bins)).float() 57 | 58 | 59 | def batched_gather(data, inds, dim=0, no_batch_dims=0): 60 | ranges = [] 61 | for i, s in enumerate(data.shape[:no_batch_dims]): 62 | r = torch.arange(s) 63 | r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) 64 | ranges.append(r) 65 | 66 | remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] 67 | remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds 68 | ranges.extend(remaining_dims) 69 | return data[tuple(ranges)] 70 | 71 | 72 | # With tree_map, a poor man's JAX tree_map 73 | def dict_map(fn, dic, leaf_type): 74 | new_dict = {} 75 | for k, v in dic.items(): 76 | if type(v) is dict: 77 | new_dict[k] = dict_map(fn, v, leaf_type) 78 | else: 79 | new_dict[k] = tree_map(fn, v, leaf_type) 80 | 81 | return new_dict 82 | 83 | 84 | def tree_map(fn, tree, leaf_type): 85 | if isinstance(tree, dict): 86 | return dict_map(fn, tree, leaf_type) 87 | elif isinstance(tree, list): 88 | return [tree_map(fn, x, leaf_type) for x in tree] 89 | elif isinstance(tree, tuple): 90 | return tuple([tree_map(fn, x, leaf_type) for x in tree]) 91 | elif isinstance(tree, leaf_type): 92 | return fn(tree) 93 | else: 94 | print(type(tree)) 95 | raise ValueError("Not supported") 96 | 97 | 98 | tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) 99 | -------------------------------------------------------------------------------- /src/ibex/openfold/utils/data_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # Copyright 2024 Exscientia 3 | # Copyright 2021 AlQuraishi Laboratory 4 | # Copyright 2021 DeepMind Technologies Limited 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | 20 | from ibex.openfold.utils import residue_constants as rc 21 | 22 | 23 | def make_atom14_masks(protein): 24 | """Construct denser atom positions (14 dimensions instead of 37).""" 25 | restype_atom14_to_atom37 = [] 26 | restype_atom37_to_atom14 = [] 27 | restype_atom14_mask = [] 28 | 29 | for rt in rc.restypes: 30 | atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]] 31 | restype_atom14_to_atom37.append( 32 | [(rc.atom_order[name] if name else 0) for name in atom_names] 33 | ) 34 | atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} 35 | restype_atom37_to_atom14.append( 36 | [ 37 | (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) 38 | for name in rc.atom_types 39 | ] 40 | ) 41 | 42 | restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names]) 43 | 44 | # Add dummy mapping for restype 'UNK' 45 | restype_atom14_to_atom37.append([0] * 14) 46 | restype_atom37_to_atom14.append([0] * 37) 47 | restype_atom14_mask.append([0.0] * 14) 48 | 49 | restype_atom14_to_atom37 = torch.tensor( 50 | restype_atom14_to_atom37, 51 | dtype=torch.int32, 52 | device=protein["aatype"].device, 53 | ) 54 | restype_atom37_to_atom14 = torch.tensor( 55 | restype_atom37_to_atom14, 56 | dtype=torch.int32, 57 | device=protein["aatype"].device, 58 | ) 59 | restype_atom14_mask = torch.tensor( 60 | restype_atom14_mask, 61 | dtype=torch.float32, 62 | device=protein["aatype"].device, 63 | ) 64 | protein_aatype = protein["aatype"].to(torch.long) 65 | 66 | # create the mapping for (residx, atom14) --> atom37, i.e. an array 67 | # with shape (num_res, 14) containing the atom37 indices for this protein 68 | residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype] 69 | residx_atom14_mask = restype_atom14_mask[protein_aatype] 70 | 71 | protein["atom14_atom_exists"] = residx_atom14_mask 72 | protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long() 73 | 74 | # create the gather indices for mapping back 75 | residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype] 76 | protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long() 77 | 78 | # create the corresponding mask 79 | restype_atom37_mask = torch.zeros( 80 | [21, 37], dtype=torch.float32, device=protein["aatype"].device 81 | ) 82 | for restype, restype_letter in enumerate(rc.restypes): 83 | restype_name = rc.restype_1to3[restype_letter] 84 | atom_names = rc.residue_atoms[restype_name] 85 | for atom_name in atom_names: 86 | atom_type = rc.atom_order[atom_name] 87 | restype_atom37_mask[restype, atom_type] = 1 88 | 89 | residx_atom37_mask = restype_atom37_mask[protein_aatype] 90 | protein["atom37_atom_exists"] = residx_atom37_mask 91 | 92 | return protein 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Abstra 171 | # Abstra is an AI-powered process automation framework. 172 | # Ignore directories containing user credentials, local state, and settings. 173 | # Learn more at https://abstra.io/docs 174 | .abstra/ 175 | 176 | # Visual Studio Code 177 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 178 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 179 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 180 | # you could uncomment the following to ignore the enitre vscode folder 181 | # .vscode/ 182 | 183 | # Ruff stuff: 184 | .ruff_cache/ 185 | 186 | # PyPI configuration file 187 | .pypirc 188 | 189 | # Cursor 190 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 191 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 192 | # refer to https://docs.cursor.com/context/ignore-files 193 | .cursorignore 194 | .cursorindexingignore -------------------------------------------------------------------------------- /src/ibex/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from pathlib import Path 17 | 18 | import pandas as pd 19 | from loguru import logger 20 | import typer 21 | 22 | from ibex.model import Ibex 23 | from ibex.predict import inference, batch_inference 24 | from ibex.utils import MODEL_CHECKPOINTS, ENSEMBLE_MODELS, checkpoint_path 25 | 26 | warnings.filterwarnings("ignore") 27 | 28 | 29 | def main( 30 | abodybuilder3: bool = typer.Option(False, help="Use the AbodyBuilder3 model instead of Ibex for inference."), 31 | ckpt: str = typer.Option("", help="Path to model checkpoint. This is only needed to load a user specified checkpoint."), 32 | fv_heavy: str = typer.Option("", help="Sequence of the heavy chain."), 33 | fv_light: str = typer.Option("", help="Sequence of the light chain."), 34 | csv: str = typer.Option("", help="CSV file containing sequences of heavy and light chains. Columns should be named 'fv_heavy' and 'fv_light'. Output file names can be provided in a 'id' column."), 35 | parquet: str = typer.Option("", help="Parquet file containing sequences of heavy and light chains. Columns should be named 'fv_heavy' and 'fv_light'. Output file names can be provided in a 'id' column."), 36 | output: Path = typer.Option("prediction.pdb", help="Output file for the PDB structure, or path to the output folder when a parquet or csv file is provided."), 37 | batch_size: int = typer.Option(32, help="Batch size for inference if a parquet or csv file is provided."), 38 | ensemble: bool = typer.Option(False, help="Specify if checkpoint provided is an ensemble (only needed if using the explicit --ckpt flag)."), 39 | save_all: bool = typer.Option(False, help="Save all structures of the ensemble as output files."), 40 | refine: bool = typer.Option(False, help="Refine the output structures with openMM."), 41 | refine_checks: bool = typer.Option(False, help="Additional checks to fix cis-isomers and D-stereoisomers during refinement."), 42 | apo: bool = typer.Option(False, help="Predict structures in the apo conformation."), 43 | ): 44 | """ 45 | Model can be specified by name or by providing a checkpoint path. If a CSV or parquet file is provided, the model will perform inference on all sequences in the file, otherwise it will perform inference on a single heavy and light sequence pair. 46 | """ 47 | model = "ibex" 48 | if abodybuilder3: 49 | model = "abodybuilder3" 50 | if ckpt=="" and model not in MODEL_CHECKPOINTS: 51 | typer.echo(f"Invalid model name: {model}. Valid options are: {', '.join(MODEL_CHECKPOINTS.keys())} or provide a checkpoint path with the --ckpt option.") 52 | raise typer.Exit(code=1) 53 | if ckpt == "": 54 | ckpt = checkpoint_path(model) 55 | ckpt=Path(ckpt) 56 | 57 | if ensemble or model in ENSEMBLE_MODELS: 58 | logger.info(f"Loading ensemble model from {ckpt=}") 59 | ibex_model = Ibex.load_from_ensemble_checkpoint(ckpt) 60 | else: 61 | logger.info(f"Loading single model from {ckpt=}") 62 | ibex_model = Ibex.load_from_checkpoint(ckpt) 63 | if ibex_model.plm_model is not None: 64 | ibex_model.set_plm() 65 | if fv_heavy: 66 | logger.info("Performing inference on a single heavy and light sequence pair.") 67 | inference(ibex_model, fv_heavy, fv_light, output, save_all=save_all, refine=refine, refine_checks=refine_checks, apo=apo) 68 | elif csv or parquet: 69 | if save_all: 70 | logger.warning("save_all was set to True, but ensemble output is not implemented for batched inference. Setting save_all to False.") 71 | save_all=False 72 | if output==Path("prediction.pdb"): 73 | # overwrite default for batch inference 74 | output = Path("predictions") 75 | if csv: 76 | csv = Path(csv) 77 | logger.info(f"Performing inference on sequences from {csv=}") 78 | df = pd.read_csv(csv) 79 | else: 80 | parquet = Path(parquet) 81 | logger.info(f"Performing inference on sequences from {parquet=}") 82 | df = pd.read_parquet(parquet) 83 | df['fv_light'] = df['fv_light'].fillna('') 84 | fv_heavy_list = df["fv_heavy"].tolist() 85 | fv_light_list = df["fv_light"].tolist() 86 | apo_list = None 87 | if ibex_model.conformation_aware: 88 | apo_list = [apo]*len(fv_heavy_list) 89 | if "id" in df.columns: 90 | names = df["id"].tolist() 91 | batch_inference( 92 | ibex_model, 93 | fv_heavy_list, 94 | fv_light_list, 95 | output, 96 | batch_size, 97 | names, 98 | refine=refine, 99 | refine_checks=refine_checks, 100 | apo_list=apo_list 101 | ) 102 | else: 103 | batch_inference( 104 | ibex_model, 105 | fv_heavy_list, 106 | fv_light_list, 107 | output, 108 | batch_size, 109 | refine=refine, 110 | refine_checks=refine_checks, 111 | apo_list=apo_list 112 | ) 113 | else: 114 | typer.echo("Please provide sequences of heavy and light chains or a csv/parquet file containing sequences.") 115 | raise typer.Exit(code=1) 116 | 117 | 118 | def app(): 119 | typer.run(main) 120 | -------------------------------------------------------------------------------- /src/ibex/predict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from pathlib import Path 17 | from typing import Optional 18 | import torch 19 | import tempfile 20 | from tqdm import tqdm 21 | from loguru import logger 22 | 23 | from ibex.model import Ibex, EnsembleStructureModule 24 | from ibex.refine import refine_file 25 | 26 | 27 | def process_file(pdb_string, output_file, refine, refine_checks=False): 28 | if not refine: 29 | with open(output_file, "w") as f: 30 | f.write(pdb_string) 31 | else: 32 | try: 33 | with tempfile.NamedTemporaryFile(delete=True) as tmpfile: 34 | tmpfile.write(pdb_string.encode('utf-8')) 35 | tmpfile.flush() 36 | refine_file(tmpfile.name, output_file, checks=refine_checks) 37 | except Exception as e: 38 | logger.warning(f"Refinement failed with error: {e}") 39 | with open(output_file, "w") as f: 40 | f.write(pdb_string) 41 | 42 | 43 | def inference( 44 | model: Ibex, 45 | fv_heavy: str, 46 | fv_light: str, 47 | output_file: Path, 48 | logging: bool = True, 49 | save_all = False, 50 | refine: bool = False, 51 | refine_checks: bool = False, 52 | apo: bool = False, 53 | return_pdb: bool = True 54 | ): 55 | device = "cuda" if torch.cuda.is_available() else "cpu" 56 | if device == "cpu" and logging: 57 | logger.warning("Inference is being done on CPU as GPU not found.") 58 | if save_all==True and not isinstance(model.model, EnsembleStructureModule): 59 | raise ValueError("save_all is set to True but model is not an ensemble model.") 60 | if return_pdb==False and refine: 61 | raise ValueError("Cannot return a protein object and refine at the same time. To run refinement, output format must be a PDB file (return_pdb==True).") 62 | if return_pdb==False and save_all: 63 | raise ValueError("Cannot return a protein object and save all outputs at the same time. To save all, output format must be a PDB file (return_pdb==True).") 64 | with torch.no_grad(): 65 | pdb_string_or_protein = model.predict(fv_heavy, fv_light, device=device, ensemble=save_all, pdb_string=return_pdb, apo=apo) 66 | if not return_pdb: 67 | if logging: 68 | logger.info("Inference complete. Returning a protein object.") 69 | return pdb_string_or_protein 70 | if save_all: 71 | ensemble_files = [] 72 | for i, pdb_string_current in enumerate(pdb_string_or_protein): 73 | output_file_current = output_file.parent / f"{output_file.stem}_{i+1}{output_file.suffix}" 74 | process_file(pdb_string_current, output_file_current, refine, refine_checks) 75 | ensemble_files.append(str(output_file_current)) 76 | output_file = ensemble_files 77 | else: 78 | process_file(pdb_string_or_protein, output_file, refine, refine_checks) 79 | if logging: 80 | logger.info(f"Inference complete. Wrote PDB file to {output_file=}") 81 | return output_file 82 | 83 | 84 | def batch_inference( 85 | model: Ibex, 86 | fv_heavy_list: list[str], 87 | fv_light_list: list[str], 88 | output_dir: Path, 89 | batch_size: int, 90 | output_names: Optional[list[str]] = None, 91 | logging: bool = True, 92 | refine: bool = False, 93 | refine_checks: bool = False, 94 | apo_list: bool = None, 95 | return_pdb: bool = True 96 | ): 97 | if output_names is None: 98 | output_names = [f"output_{i}" for i in range(len(fv_heavy_list))] 99 | 100 | if len(fv_heavy_list) != len(fv_light_list) or len(fv_heavy_list) != len(output_names): 101 | raise ValueError("Input lists must have the same length.") 102 | 103 | device = "cuda" if torch.cuda.is_available() else "cpu" 104 | if device == "cpu" and logging: 105 | logger.warning("Inference is being done on CPU as GPU not found.") 106 | 107 | output_dir = Path(output_dir) 108 | output_dir.mkdir(parents=True, exist_ok=True) 109 | 110 | if apo_list is not None and not model.conformation_aware: 111 | raise ValueError("Model is not conformation-aware, but apo_list was provided.") 112 | if return_pdb==False and refine: 113 | raise ValueError("Cannot return a protein object and refine at the same time. To run refinement, output format must be a PDB file (return_pdb==True).") 114 | 115 | if model.plm_model is not None: 116 | model.plm_model = model.plm_model.to(device) 117 | 118 | name_idx = 0 # Index for tracking the position in output_names 119 | result_files = [] 120 | for i in tqdm(range(0, len(fv_heavy_list), batch_size), desc="Processing batches"): 121 | fv_heavy_batch = fv_heavy_list[i:i+batch_size] 122 | fv_light_batch = fv_light_list[i:i+batch_size] if fv_light_list else None 123 | with torch.no_grad(): 124 | pdb_strings_or_proteins = model.predict_batch( 125 | fv_heavy_batch, fv_light_batch, device=device, pdb_string=return_pdb, apo_list=apo_list 126 | ) 127 | if not return_pdb: 128 | if logging: 129 | logger.warning("Inference complete. Returning a protein object.") 130 | return pdb_strings_or_proteins 131 | for pdb_string in pdb_strings_or_proteins: 132 | output_file = output_dir / f"{output_names[name_idx]}.pdb" 133 | process_file(pdb_string, output_file, refine, refine_checks) 134 | result_files.append(output_file) 135 | name_idx += 1 136 | if logging: 137 | logger.info(f"Inference complete. Wrote {name_idx} PDB files to {output_dir=}") 138 | 139 | return result_files 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15866555.svg)](https://doi.org/10.5281/zenodo.15866555) 2 | 3 | # Ibex 🐐 4 | 5 | [Ibex](https://www.tandfonline.com/doi/full/10.1080/19420862.2025.2602217) is a lightweight antibody and TCR structure prediction model. 6 | 7 |

8 | 9 |

10 | 11 | ## Installation 12 | 13 | Ibex can be installed through pip with 14 | ```bash 15 | pip install prescient-ibex 16 | ``` 17 | Alternatively, you can use `uv` and create a new virtual environment 18 | ```bash 19 | uv venv --python 3.10 20 | source .venv/bin/activate 21 | uv pip install -e . 22 | ``` 23 | 24 | ## Usage 25 | 26 | The simplest way to run inference is through the `ibex` command, e.g. 27 | 28 | ```bash 29 | ibex --fv-heavy EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSRWGGDGFYAMDYWGQGTLVTVSS --fv-light DIQMTQSPSSLSASVGDRVTITCRASQDVNTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSRSGTDFTLTISSLQPEDFATYYCQQHYTTPPTFGQGTKVEIK --output prediction.pdb 30 | ``` 31 | You can provide a csv (with the `--csv` argument) or a parquet file (with the `--parquet` argument) and run a batched inference writing the output into a specified directory with 32 | ```bash 33 | ibex --csv sequences.csv --output predictions 34 | ``` 35 | where `sequences.csv` should contain a `fv_heavy` and `fv_light` column with heavy and light chain sequences, and optionally an `id` column with a string that will be used as part of the output PDB filenames. 36 | 37 | By default, structures are predicted in the holo conformation. To predict the apo state, use the `--apo` flag. 38 | 39 | To run a refinement step on the predicted structures, use the `--refine` option. Additional checks to fix cis-isomers and D-stereoisomers during refinement can be activated with `--refine-checks`. 40 | 41 | Instead of running Ibex, you can use `--abodybuilder3` to run inference with the [ABodyBuilder3](https://academic.oup.com/bioinformatics/article/40/10/btae576/7810444) model. 42 | Below is a summary of all available options: 43 | ``` 44 | --abodybuilder3 Use the AbodyBuilder3 model instead of Ibex for inference. [default: no-abodybuilder3] 45 | --fv-heavy Sequence of the heavy chain. 46 | --fv-light Sequence of the light chain. 47 | --csv CSV file containing sequences of heavy and light chains. Columns should be named 'fv_heavy' and 'fv_light'. Output file names can 48 | be provided in a 'id' column. 49 | --parquet Parquet file containing sequences of heavy and light chains. Columns should be named 'fv_heavy' and 'fv_light'. Output file names 50 | can be provided in a 'id' column. 51 | --output Output file for the PDB structure, or path to the output folder when a parquet or csv file is provided. [default: prediction.pdb] 52 | --batch-size Batch size for inference if a parquet or csv file is provided. [default: 32] 53 | --save-all Save all structures of the ensemble as output files. [default: no-save-all] 54 | --refine Refine the output structures with openMM. [default: no-refine] 55 | --refine-checks Additional checks to fix cis-isomers and D-stereoisomers during refinement. [default: no-refine-checks] 56 | --apo Predict structures in the apo conformation. [default: no-apo] 57 | ``` 58 | 59 | To run Ibex programmatically, you can use 60 | ```python 61 | from ibex import Ibex, inference 62 | ibex_model = Ibex.from_pretrained("ibex") # or "abodybuilder3" 63 | inference(ibex_model, fv_heavy, fv_light, "prediction.pdb") 64 | ``` 65 | to predict structures for multiple sequence pairs, `batch_inference` is recommended instead of `inference`. 66 | 67 | ## Predictions on nanobodies and TCRs 68 | 69 | To predict nanobody structures, leave out the `fv_light` argument, or set it as `""` or `None` in the csv column. 70 | 71 | For inference on TCRs, you should provide the variable beta chain sequence as `fv_heavy` and the alpha chain as `fv_light`. Ibex has not been trained on gamma and delta chains. 72 | 73 | 74 | ## License 75 | The Ibex codebase is available under an [Apache 2.0 license](http://www.apache.org/licenses/LICENSE-2.0), and the [ABodyBuilder3](https://doi.org/10.5281/zenodo.11354576) model weights under a [Creative Commons Attribution 4.0 International license](https://creativecommons.org/licenses/by/4.0/legalcode), both of which allow for commercial use. 76 | 77 | The [Ibex model weights](https://doi.org/10.5281/zenodo.15866555) are available under a [Genentech Apache 2.0 Non-Commercial license](https://raw.githubusercontent.com/prescient-design/ibex/refs/heads/main/docs/Genentech_license_weights_ibex), which allows its use for non-commercial academic research purposes. 78 | 79 | Ibex uses as input representation embeddings from ESMC 300M, which is licensed under the [EvolutionaryScale Cambrian Open License Agreement](https://www.evolutionaryscale.ai/policies/cambrian-open-license-agreement). 80 | 81 | ## Citation 82 | When using Ibex in your work, please cite the following paper 83 | 84 | ```bibtex 85 | @article{ibex, 86 | author = {Frédéric A. Dreyer and Jan Ludwiczak and Karolis Martinkus and Brennan Abanades and Robert G. Alberstein and Pan Kessel and Pranav Rao and Jae Hyeon Lee and Richard Bonneau and Andrew M. Watkins and Franziska Seeger}, 87 | title = {Conformation-aware structure prediction of antigen-recognizing immune proteins}, 88 | journal = {mAbs}, 89 | volume = {18}, 90 | number = {1}, 91 | pages = {2602217}, 92 | year = {2026}, 93 | publisher = {Taylor \& Francis}, 94 | doi = {10.1080/19420862.2025.2602217}, 95 | note ={PMID: 41378904}, 96 | URL = {https://doi.org/10.1080/19420862.2025.2602217}, 97 | eprint = {https://doi.org/10.1080/19420862.2025.2602217} 98 | } 99 | ``` 100 | 101 | If you use the ABodyBuilder3 model weights, you should also cite 102 | ```bibtex 103 | @article{abodybuilder3, 104 | author = {Kenlay, Henry and Dreyer, Frédéric A and Cutting, Daniel and Nissley, Daniel and Deane, Charlotte M}, 105 | title = "{ABodyBuilder3: improved and scalable antibody structure predictions}", 106 | journal = {Bioinformatics}, 107 | volume = {40}, 108 | number = {10}, 109 | pages = {btae576}, 110 | year = {2024}, 111 | month = {10}, 112 | issn = {1367-4811}, 113 | doi = {10.1093/bioinformatics/btae576} 114 | } 115 | ``` 116 | -------------------------------------------------------------------------------- /src/ibex/loss/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # Copyright 2024 Exscientia 3 | # Copyright 2021 AlQuraishi Laboratory 4 | # Copyright 2021 DeepMind Technologies Limited 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # Based on AlphaFoldLoss class from https://github.com/aqlaboratory/openfold/blob/main/openfold/utils/loss.py#L1685 19 | 20 | import loguru 21 | import ml_collections 22 | import torch 23 | import torch.nn.functional as F 24 | from loguru import logger 25 | 26 | from ibex.loss.aligned_rmsd import aligned_fv_and_cdrh3_rmsd 27 | from ibex.openfold.utils.loss import ( 28 | compute_renamed_ground_truth, 29 | fape_loss, 30 | final_output_backbone_loss, 31 | find_structural_violations, 32 | lddt_loss, 33 | supervised_chi_loss, 34 | violation_loss_bondangle, 35 | violation_loss_bondlength, 36 | violation_loss_clash, 37 | ) 38 | 39 | 40 | def triplet_loss(positive, negative, anchor, margin): 41 | # Compute cosine similarity 42 | sim_pos = F.cosine_similarity(anchor, positive, dim=-1) 43 | sim_neg = F.cosine_similarity(anchor, negative, dim=-1) 44 | # Triplet loss 45 | loss_triplet = torch.mean(torch.clamp(margin - sim_pos + sim_neg, min=0.0)) 46 | return loss_triplet 47 | 48 | 49 | def triplet_loss_batch(single: torch.Tensor, cdr_mask: torch.Tensor, margin: float=1.0): 50 | cdr_mask_expanded = cdr_mask.unsqueeze(-1).expand_as(single) 51 | masked_embeddings = single * cdr_mask_expanded 52 | sum_embeddings = masked_embeddings.sum(dim=1) 53 | count_non_zero = cdr_mask.sum(dim=1).unsqueeze(-1) 54 | # Avoid division by zero 55 | count_non_zero = torch.where(count_non_zero == 0, torch.ones_like(count_non_zero), count_non_zero) 56 | # Compute the average 57 | average_embeddings = sum_embeddings / count_non_zero 58 | loss = triplet_loss(average_embeddings[-3], average_embeddings[-2], average_embeddings[-1], margin) 59 | loss += triplet_loss(average_embeddings[-6], average_embeddings[-5], average_embeddings[-4], margin) 60 | return loss 61 | 62 | 63 | def negative_pairs_loss(positive: torch.Tensor, negative: torch.Tensor, threshold: float): 64 | cosine_sim = F.cosine_similarity(positive, negative, dim=-1) 65 | # Compute loss 66 | loss = torch.mean(torch.clamp(threshold - cosine_sim, min=0.0)) 67 | return loss 68 | 69 | def negative_pairs_loss_batch(single: torch.Tensor, cdr_mask: torch.Tensor, threshold: float, pair_mask: torch.Tensor): 70 | if torch.tensor(pair_mask).sum() == 0: 71 | return torch.tensor(0.0, device=single.device) 72 | cdr_mask_expanded = cdr_mask.unsqueeze(-1).expand_as(single) 73 | masked_embeddings = single * cdr_mask_expanded 74 | sum_embeddings = masked_embeddings.sum(dim=1) 75 | count_non_zero = cdr_mask.sum(dim=1).unsqueeze(-1) 76 | # Avoid division by zero 77 | count_non_zero = torch.where(count_non_zero == 0, torch.ones_like(count_non_zero), count_non_zero) 78 | # Compute the average 79 | average_embeddings = sum_embeddings / count_non_zero 80 | paired_samples = average_embeddings[pair_mask] 81 | pos_samples = paired_samples[0::2] 82 | neg_samples = paired_samples[1::2] 83 | loss = negative_pairs_loss(pos_samples, neg_samples, threshold) 84 | return loss 85 | 86 | class IbexLoss(torch.nn.Module): 87 | def __init__(self, config: ml_collections.config_dict.ConfigDict): 88 | super().__init__() 89 | self.config = config 90 | self.dist_and_angle_annealing = 0.0 91 | 92 | def forward(self, output: dict, batch: dict, finetune: bool = False, contrastive: bool = False): 93 | if finetune: 94 | output["violation"] = find_structural_violations( 95 | batch, 96 | output["positions"][-1], 97 | **self.config.violation, 98 | ) 99 | 100 | if "renamed_atom14_gt_positions" not in output.keys(): 101 | batch.update( 102 | compute_renamed_ground_truth( 103 | batch, 104 | output["positions"][-1], 105 | ) 106 | ) 107 | 108 | loss_fns = { 109 | "fape": lambda: fape_loss( 110 | {"sm": output}, 111 | batch, 112 | self.config.fape, 113 | ), 114 | "supervised_chi": lambda: supervised_chi_loss( 115 | output["angles"], 116 | output["unnormalized_angles"], 117 | **{**batch, **self.config.supervised_chi}, 118 | ), 119 | "final_output_backbone_loss": lambda: final_output_backbone_loss( 120 | output, batch 121 | ), 122 | } 123 | if "plddt" in output: 124 | loss_fns.update( 125 | { 126 | "plddt": lambda: lddt_loss( 127 | output["plddt"], 128 | output["positions"][-1], 129 | batch["atom14_gt_positions"], 130 | batch["atom14_atom_exists"], 131 | batch["resolution"], 132 | ), 133 | } 134 | ) 135 | 136 | if finetune: 137 | loss_fns.update( 138 | { 139 | "violation_loss_bondlength": lambda: violation_loss_bondlength( 140 | output["violation"] 141 | ), 142 | "violation_loss_bondangle": lambda: violation_loss_bondangle( 143 | output["violation"] 144 | ), 145 | "violation_loss_clash": lambda: violation_loss_clash( 146 | output["violation"], **batch 147 | ), 148 | } 149 | ) 150 | if contrastive: 151 | loss_fns.update( 152 | { 153 | "contrastive": lambda: negative_pairs_loss_batch( 154 | output["single"], 155 | batch["cdr_mask"], 156 | self.config.contrastive.margin, 157 | batch["is_matched"], 158 | ), 159 | # "contrastive": lambda: triplet_loss_batch( 160 | # output["single"], 161 | # batch["cdr_mask"], 162 | # self.config.contrastive.margin, 163 | # ), 164 | } 165 | ) 166 | cum_loss = 0.0 167 | losses = {} 168 | for loss_name, loss_fn in loss_fns.items(): 169 | weight = self.config[loss_name].weight 170 | loss = loss_fn() 171 | if torch.isnan(loss) or torch.isinf(loss): 172 | logger.warning(f"{loss_name} loss is NaN. Skipping...") 173 | loss = loss.new_tensor(0.0, requires_grad=True) 174 | if loss_name in ["violation_loss_bondlength", "violation_loss_bondangle"]: 175 | weight *= min(self.dist_and_angle_annealing / 50, 1) 176 | cum_loss = cum_loss + weight * loss 177 | losses[loss_name] = loss.detach().clone() 178 | losses[f"{loss_name}_weighted"] = weight * losses[loss_name] 179 | losses[f"{loss_name}_weight"] = weight 180 | 181 | # aligned_rmsd (not added to cum_loss) 182 | with torch.no_grad(): 183 | losses.update( 184 | aligned_fv_and_cdrh3_rmsd( 185 | coords_truth=batch["atom14_gt_positions"], 186 | coords_prediction=output["positions"][-1], 187 | sequence_mask=batch["seq_mask"], 188 | cdrh3_mask=batch["region_numeric"] == 2, 189 | ) 190 | ) 191 | 192 | losses["loss"] = cum_loss.detach().clone() 193 | 194 | return cum_loss, losses 195 | -------------------------------------------------------------------------------- /docs/Genentech_license_weights_ibex: -------------------------------------------------------------------------------- 1 | ACADEMIC SOFTWARE LICENSE AGREEMENT FOR END-USERS AT PUBLICLY FUNDED ACADEMIC, EDUCATION OR RESEARCH INSTITUTIONS FOR THE USE OF IBEX 2 | 3 | By downloading, installing, or using the Licensed Software you consent to be bound by and become a party to this agreement (hereinafter "Agreement") as a "LICENSEE". If you do not agree to all of the terms of this Agreement, you must not download, install, or use the Licensed Software, and you do not become a LICENSEE under this Agreement. 4 | 5 | If you are not a member of an academic research institution, you must obtain a commercial license; please send requests via email to dreyer.frederic@gene.com. This Agreement is entered into by and between Genentech, Inc. (hereinafter "GENENTECH") and the LICENSEE. 6 | WHEREAS GENENTECH has the right to license all copyrights and other property rights in the Licensed Software identified as Ibex and developed by GENENTECH and GENENTECH desires to license the Licensed Software so that it becomes available for internal teaching and non-commercial academic research purposes only. 7 | WHEREAS LICENSEE is a publicly funded academic and/or education and/or research institution. 8 | WHEREAS LICENSEE desires to acquire a free non-exclusive license to use the Licensed Software for internal teaching and non-commercial academic research purposes only. 9 | NOW, THEREFORE, in consideration of the mutual promises and covenants contained herein, the parties agree as follows: 10 | 11 | 1. Definitions 12 | "Licensed Software" means the specific version of Ibex pursuant to this Agreement. 13 | 14 | 2. License 15 | Subject to the terms and conditions of this Agreement a non-exclusive, non-transferable license to use and copy the Licensed Software is made available free of charge for the LICENSEE which is a non-profit educational, academic and/or research institution. The license is only granted for internal teaching and non-commercial academic research purposes at one Site, where a Site is defined as a set of contiguous buildings in one location. The Licensed Software will be used at only one location of LICENSEE. 16 | This license does not entitle Licensee to receive from GENENTECH copies of the Licensed Software on disks, tapes or CD's, hard-copy documentation, technical support, telephone assistance, or enhancements or updates to the Licensed Software. 17 | The user and any research assistants or co-workers who may use the Licensed Software agree to not give the program to third parties or grant licenses on software, which include the Licensed Software, alone or integrated into other software, to third parties. Modification of the source code is prohibited without the prior written consent of GENENTECH. 18 | 19 | 3. Ownership 20 | Except as expressly licensed in this Agreement, GENENTECH shall retain title to the Licensed Software, and any upgrades and modifications created by GENENTECH. 21 | 22 | 4. Consideration 23 | In consideration for the license rights granted by GENENTECH, LICENSEE will obtain this license free of charge. 24 | 25 | 5. Copies 26 | LICENSEE shall have the right to make copies of the Licensed Software for personal and internal teaching and non-commercial academic research purposes at the Site and for back-up purposes under this Agreement, but agrees that all such copies shall contain a copy of this Agreement, copyright notices, and all other reasonable and appropriate proprietary markings or confidential legends that appear on the Licensed Software. 27 | 28 | 6. Support 29 | GENENTECH shall have no obligation to offer support services to LICENSEE, and nothing contained herein shall be interpreted as to require GENENTECH to provide maintenance, installation services, version updates, debugging, consultation or end-user support of any kind. 30 | 31 | 7. Software Protection 32 | LICENSEE acknowledges that the Licensed Software is proprietary to GENENTECH. 33 | Except as otherwise expressly permitted in this Agreement, Licensee may not (i) modify or create any derivative works of the Licensed Software or documentation, including customization, translation or localization; (ii) decompile, disassemble, reverse engineer, or otherwise attempt to derive the source code or model weights of the Licensed Software; (iii) redistribute, encumber, sell, rent, lease, sublicense, or otherwise transfer rights to the Licensed Software; (iv) remove or alter any trademark, logo, copyright or other proprietary notices, legends, symbols or labels in the Licensed Software; or (v) publish any results of benchmark tests run on the Licensed Software to a third party without GENENTECH's prior written consent. 34 | 35 | 8. Representations of GENENTECH to LICENSEE 36 | GENENTECH represents to LICENSEE that GENENTECH has the right to grant the License and to enter into this Agreement. 37 | 38 | 9. Indemnity and Disclaimer of Warranties 39 | GENENTECH makes no representations or warranties, express or implied, of any kind. 40 | The Licensed Software is provided free of charge, and, therefore, on an "as is" basis, without warranty of any kind, express or implied, including without limitation the warranties that it is free of defects, virus free, able to operate on an uninterrupted basis, merchantable, fit for a particular purpose or non-interfering. The entire risk as to the quality and performance of the Licensed Software is borne by LICENSEE. 41 | By way of example, but not limitation, GENENTECH makes no representations or warranties of merchantability or fitness for any particular application or that the use of the Licensed Software will not infringe any patents, copyrights or trademarks or other rights of third parties. The entire risk as to the quality and performance of the Licensed Software is borne by LICENSEE. GENENTECH shall not be liable for any liability or damages with respect to any claim by LICENSEE or any third party on account of, or arising from the license or use of the Licensed Software. 42 | Should the Licensed Software prove defective in any respect, LICENSEE and neither GENENTECH should assume the entire cost of any service and repair. This disclaimer of warranty constitutes an essential part of this Agreement. No use of the Licensed Software is authorized hereunder except under this disclaimer. 43 | In no event will GENENTECH be liable for any indirect, special, incidental or consequential damages arising out of the use of or inability to use the Licensed Software, including, without limitation, damages for lost profits, loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses, even if advised of the possibility thereof, and regardless of the legal or equitable theory (contract, tort or otherwise) upon which the claim is based. 44 | 45 | 10. Promotional Advertising & References 46 | LICENSEE may not use the name of the Licensed Software in its promotional advertising, product literature, and other similar promotional materials to be disseminated to the public or any portion thereof. LICENSEE agrees not to identify GENENTECH in any promotional advertising or other promotional materials to be disseminated to the public, or any portion thereof without GENENTECH's prior written consent. 47 | GENENTECH shall not use LICENSEE's name in publicity or advertising involving this Agreement or otherwise without LICENSEE's prior written consent which may be withheld at LICENSEE's sole discretion. 48 | 49 | 11. Term 50 | This Agreement and the license rights granted herein shall become effective as of the date this Agreement is executed by both parties and shall be perpetual unless terminated in accordance with this Section. 51 | GENENTECH may terminate this Agreement at any time. 52 | Either party may terminate this Agreement at any time effective upon the other party's breach of any agreement, covenant, or representation made in this Agreement, such breach remaining uncorrected sixty (60) days after written notice thereof. 53 | LICENSEE shall have the right, at any time, to terminate this Agreement without cause by written notice to GENENTECH specifying the date of termination. 54 | Upon termination, LICENSEE shall destroy all full and partial copies of the Licensed Software. 55 | 56 | 12. Governing Law 57 | This Agreement shall be construed in accordance with the laws of California. 58 | 59 | 13. General 60 | The parties agree that this Agreement is the complete and exclusive agreement among the parties and supersedes all proposals and prior agreements whether written or oral, and all other communications among the parties relating to the subject matter of this Agreement. 61 | This Agreement cannot be modified except in writing and signed by both parties. Failure by either party at any time to enforce any of the provisions of this Agreement shall not constitute a waiver by such party of such provision nor in any way affect the validity of this Agreement. 62 | The invalidity of singular provisions does not affect the validity of the entire understanding. The parties are obligated, however, to replace the invalid provisions by a regulation which comes closest to the economic intent of the invalid provision. The same shall apply mutatis mutandis in case of a gap. 63 | -------------------------------------------------------------------------------- /src/ibex/openfold/utils/feats.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # Copyright 2024 Exscientia 3 | # Copyright 2021 AlQuraishi Laboratory 4 | # Copyright 2021 DeepMind Technologies Limited 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from typing import Dict 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | import ibex.openfold.utils.residue_constants as rc 24 | from ibex.openfold.utils.rigid_utils import Rigid, Rotation 25 | from ibex.openfold.utils.tensor_utils import batched_gather 26 | 27 | 28 | def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): 29 | is_gly = aatype == rc.restype_order["G"] 30 | ca_idx = rc.atom_order["CA"] 31 | cb_idx = rc.atom_order["CB"] 32 | pseudo_beta = torch.where( 33 | is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3), 34 | all_atom_positions[..., ca_idx, :], 35 | all_atom_positions[..., cb_idx, :], 36 | ) 37 | 38 | if all_atom_masks is not None: 39 | pseudo_beta_mask = torch.where( 40 | is_gly, 41 | all_atom_masks[..., ca_idx], 42 | all_atom_masks[..., cb_idx], 43 | ) 44 | return pseudo_beta, pseudo_beta_mask 45 | else: 46 | return pseudo_beta 47 | 48 | 49 | def atom14_to_atom37(atom14, batch): 50 | atom37_data = batched_gather( 51 | atom14, 52 | batch["residx_atom37_to_atom14"], 53 | dim=-2, 54 | no_batch_dims=len(atom14.shape[:-2]), 55 | ) 56 | 57 | atom37_data = atom37_data * batch["atom37_atom_exists"][..., None] 58 | 59 | return atom37_data 60 | 61 | 62 | def build_template_angle_feat(template_feats): 63 | template_aatype = template_feats["template_aatype"] 64 | torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"] 65 | alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"] 66 | torsion_angles_mask = template_feats["template_torsion_angles_mask"] 67 | template_angle_feat = torch.cat( 68 | [ 69 | nn.functional.one_hot(template_aatype, 22), 70 | torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14), 71 | alt_torsion_angles_sin_cos.reshape( 72 | *alt_torsion_angles_sin_cos.shape[:-2], 14 73 | ), 74 | torsion_angles_mask, 75 | ], 76 | dim=-1, 77 | ) 78 | 79 | return template_angle_feat 80 | 81 | 82 | def build_template_pair_feat( 83 | batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8 84 | ): 85 | template_mask = batch["template_pseudo_beta_mask"] 86 | template_mask_2d = template_mask[..., None] * template_mask[..., None, :] 87 | 88 | # Compute distogram (this seems to differ slightly from Alg. 5) 89 | tpb = batch["template_pseudo_beta"] 90 | dgram = torch.sum( 91 | (tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True 92 | ) 93 | lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2 94 | upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) 95 | dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) 96 | 97 | to_concat = [dgram, template_mask_2d[..., None]] 98 | 99 | aatype_one_hot = nn.functional.one_hot( 100 | batch["template_aatype"], 101 | rc.restype_num + 2, 102 | ) 103 | 104 | n_res = batch["template_aatype"].shape[-1] 105 | to_concat.append( 106 | aatype_one_hot[..., None, :, :].expand( 107 | *aatype_one_hot.shape[:-2], n_res, -1, -1 108 | ) 109 | ) 110 | to_concat.append( 111 | aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1) 112 | ) 113 | 114 | n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] 115 | rigids = Rigid.make_transform_from_reference( 116 | n_xyz=batch["template_all_atom_positions"][..., n, :], 117 | ca_xyz=batch["template_all_atom_positions"][..., ca, :], 118 | c_xyz=batch["template_all_atom_positions"][..., c, :], 119 | eps=eps, 120 | ) 121 | points = rigids.get_trans()[..., None, :, :] 122 | rigid_vec = rigids[..., None].invert_apply(points) 123 | 124 | inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) 125 | 126 | t_aa_masks = batch["template_all_atom_mask"] 127 | template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] 128 | template_mask_2d = template_mask[..., None] * template_mask[..., None, :] 129 | 130 | inv_distance_scalar = inv_distance_scalar * template_mask_2d 131 | unit_vector = rigid_vec * inv_distance_scalar[..., None] 132 | 133 | if not use_unit_vector: 134 | unit_vector = unit_vector * 0.0 135 | 136 | to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) 137 | to_concat.append(template_mask_2d[..., None]) 138 | 139 | act = torch.cat(to_concat, dim=-1) 140 | act = act * template_mask_2d[..., None] 141 | 142 | return act 143 | 144 | 145 | def build_extra_msa_feat(batch): 146 | msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23) 147 | msa_feat = [ 148 | msa_1hot, 149 | batch["extra_has_deletion"].unsqueeze(-1), 150 | batch["extra_deletion_value"].unsqueeze(-1), 151 | ] 152 | return torch.cat(msa_feat, dim=-1) 153 | 154 | 155 | def torsion_angles_to_frames( 156 | r: Rigid, 157 | alpha: torch.Tensor, 158 | aatype: torch.Tensor, 159 | rrgdf: torch.Tensor, 160 | ): 161 | # [*, N, 8, 4, 4] 162 | default_4x4 = rrgdf[aatype, ...] 163 | 164 | # [*, N, 8] transformations, i.e. 165 | # One [*, N, 8, 3, 3] rotation matrix and 166 | # One [*, N, 8, 3] translation matrix 167 | default_r = r.from_tensor_4x4(default_4x4) 168 | 169 | bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) 170 | bb_rot[..., 1] = 1 171 | 172 | # [*, N, 8, 2] 173 | alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2) 174 | 175 | # [*, N, 8, 3, 3] 176 | # Produces rotation matrices of the form: 177 | # [ 178 | # [1, 0 , 0 ], 179 | # [0, a_2,-a_1], 180 | # [0, a_1, a_2] 181 | # ] 182 | # This follows the original code rather than the supplement, which uses 183 | # different indices. 184 | 185 | all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) 186 | all_rots[..., 0, 0] = 1 187 | all_rots[..., 1, 1] = alpha[..., 1] 188 | all_rots[..., 1, 2] = -alpha[..., 0] 189 | all_rots[..., 2, 1:] = alpha 190 | 191 | all_rots = Rigid(Rotation(rot_mats=all_rots), None) 192 | 193 | all_frames = default_r.compose(all_rots) 194 | 195 | chi2_frame_to_frame = all_frames[..., 5] 196 | chi3_frame_to_frame = all_frames[..., 6] 197 | chi4_frame_to_frame = all_frames[..., 7] 198 | 199 | chi1_frame_to_bb = all_frames[..., 4] 200 | chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) 201 | chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) 202 | chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) 203 | 204 | all_frames_to_bb = Rigid.cat( 205 | [ 206 | all_frames[..., :5], 207 | chi2_frame_to_bb.unsqueeze(-1), 208 | chi3_frame_to_bb.unsqueeze(-1), 209 | chi4_frame_to_bb.unsqueeze(-1), 210 | ], 211 | dim=-1, 212 | ) 213 | 214 | all_frames_to_global = r[..., None].compose(all_frames_to_bb) 215 | 216 | return all_frames_to_global 217 | 218 | 219 | def frames_and_literature_positions_to_atom14_pos( 220 | r: Rigid, 221 | aatype: torch.Tensor, 222 | default_frames, 223 | group_idx, 224 | atom_mask, 225 | lit_positions, 226 | ): 227 | # [*, N, 14, 4, 4] 228 | default_4x4 = default_frames[aatype, ...] 229 | 230 | # [*, N, 14] 231 | group_mask = group_idx[aatype, ...] 232 | 233 | # [*, N, 14, 8] 234 | group_mask = nn.functional.one_hot( 235 | group_mask, 236 | num_classes=default_frames.shape[-3], 237 | ) 238 | 239 | # [*, N, 14, 8] 240 | t_atoms_to_global = r[..., None, :] * group_mask 241 | 242 | # [*, N, 14] 243 | t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) 244 | 245 | # [*, N, 14, 1] 246 | atom_mask = atom_mask[aatype, ...].unsqueeze(-1) 247 | 248 | # [*, N, 14, 3] 249 | lit_positions = lit_positions[aatype, ...] 250 | pred_positions = t_atoms_to_global.apply(lit_positions) 251 | pred_positions = pred_positions * atom_mask 252 | 253 | return pred_positions 254 | -------------------------------------------------------------------------------- /src/ibex/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import io 15 | import os 16 | import requests 17 | import hashlib 18 | from pathlib import Path 19 | 20 | from tqdm import tqdm 21 | from loguru import logger 22 | import numpy as np 23 | import torch 24 | from typing import MutableMapping 25 | from omegaconf import DictConfig 26 | 27 | from ibex.openfold.utils.data_transforms import make_atom14_masks 28 | from ibex.openfold.utils.protein import Protein, to_pdb 29 | from ibex.openfold.utils.feats import atom14_to_atom37 30 | from ibex.loss.aligned_rmsd import positions_to_backbone_dihedrals, region_mapping, CDR_RANGES_AHO 31 | 32 | 33 | MODEL_CHECKPOINTS = { 34 | "ibex": "https://zenodo.org/records/15866556/files/ibex_v1.ckpt", 35 | "abodybuilder3": "https://zenodo.org/records/15866556/files/abb3.ckpt", 36 | } 37 | 38 | ENSEMBLE_MODELS = ["ibex"] 39 | 40 | 41 | def dihedral_distance_per_loop( 42 | positions_predicted: torch.Tensor, 43 | positions_reference: torch.Tensor, 44 | region_mask: torch.Tensor, 45 | mask: torch.Tensor | None = None, 46 | residue_index: torch.Tensor | None = None, 47 | chain_index: torch.Tensor | None = None 48 | ) -> dict[str, float]: 49 | """Computes the dihedral distance between predicted and reference backbone coordinates for each loop region. 50 | 51 | Args: 52 | positions_predicted (torch.Tensor): Predicted atomic coordinates of the protein backbone. 53 | positions_reference (torch.Tensor): Reference (ground truth) atomic coordinates of the protein backbone. 54 | region_mask (torch.Tensor): Boolean mask specifying residues belonging to each loop region (e.g., CDRs). 55 | mask (torch.Tensor | None, optional): Boolean mask indicating valid residues in the sequence. 56 | residue_index (torch.Tensor | None, optional): Indices of residues in the sequence. Defaults to None. 57 | chain_index (torch.Tensor | None, optional): Indices of chains in the structure. Defaults to None. 58 | Returns: 59 | dict[str, float]: Dictionary mapping each loop region to its computed dihedral distance between predicted and reference structures. 60 | 61 | """ 62 | if mask is None: 63 | mask = torch.ones(positions_reference.shape[:-1]).to(device=positions_reference.device) 64 | dihedral_angles_predicted, dihedral_mask = positions_to_backbone_dihedrals(positions_predicted, mask, residue_index, chain_index) 65 | dihedral_angles_reference, _ = positions_to_backbone_dihedrals(positions_reference, mask, residue_index, chain_index) 66 | results = {} 67 | dihedral_differences = 2 * (1 - torch.cos(dihedral_angles_predicted - dihedral_angles_reference)) 68 | for region_name, region_idx in region_mapping.items(): 69 | if region_name.startswith("cdr") and (region_mask == region_idx).any(): 70 | results[region_name] = torch.mean(dihedral_differences[region_mask == region_idx]).item() 71 | return results 72 | 73 | def region_mask_from_aho(fv_heavy_aho: str, fv_light_aho: str = "") -> torch.Tensor: 74 | """Return a tensor with CDR and framework identifiers for a given aho string of heavy and light chains. 75 | Args: 76 | fv_heavy_aho (str): Heavy chain aho string 77 | fv_light_aho (str): Light chain aho string 78 | Returns: 79 | torch.Tensor: Tensor of shape [N] 80 | """ 81 | region_len = {} 82 | region_len["fwh1"] = len(fv_heavy_aho[:CDR_RANGES_AHO["H1"][0]].replace('-', '')) 83 | region_len["cdrh1"] = len(fv_heavy_aho[CDR_RANGES_AHO["H1"][0]:CDR_RANGES_AHO["H1"][1]].replace('-', '')) 84 | region_len["fwh2"] = len(fv_heavy_aho[CDR_RANGES_AHO["H1"][1]:CDR_RANGES_AHO["H2"][0]].replace('-', '')) 85 | region_len["cdrh2"] = len(fv_heavy_aho[CDR_RANGES_AHO["H2"][0]:CDR_RANGES_AHO["H2"][1]].replace('-', '')) 86 | region_len["fwh3"] = len(fv_heavy_aho[CDR_RANGES_AHO["H2"][1]:CDR_RANGES_AHO["H3"][0]].replace('-', '')) 87 | region_len["cdrh3"] = len(fv_heavy_aho[CDR_RANGES_AHO["H3"][0]:CDR_RANGES_AHO["H3"][1]].replace('-', '')) 88 | region_len["fwh4"] = len(fv_heavy_aho[CDR_RANGES_AHO["H3"][1]:].replace('-', '')) 89 | 90 | if fv_light_aho: 91 | region_len["fwl1"] = len(fv_light_aho[:CDR_RANGES_AHO["L1"][0]].replace('-', '')) 92 | region_len["cdrl1"] = len(fv_light_aho[CDR_RANGES_AHO["L1"][0]:CDR_RANGES_AHO["L1"][1]].replace('-', '')) 93 | region_len["fwl2"] = len(fv_light_aho[CDR_RANGES_AHO["L1"][1]:CDR_RANGES_AHO["L2"][0]].replace('-', '')) 94 | region_len["cdrl2"] = len(fv_light_aho[CDR_RANGES_AHO["L2"][0]:CDR_RANGES_AHO["L2"][1]].replace('-', '')) 95 | region_len["fwl3"] = len(fv_light_aho[CDR_RANGES_AHO["L2"][1]:CDR_RANGES_AHO["L3"][0]].replace('-', '')) 96 | region_len["cdrl3"] = len(fv_light_aho[CDR_RANGES_AHO["L3"][0]:CDR_RANGES_AHO["L3"][1]].replace('-', '')) 97 | region_len["fwl4"] = len(fv_light_aho[CDR_RANGES_AHO["L3"][1]:].replace('-', '')) 98 | 99 | 100 | res = [] 101 | for region in region_len: 102 | res.append(torch.ones(region_len[region], dtype=torch.int) * region_mapping[region]) 103 | return torch.cat(res) 104 | 105 | def compute_plddt(plddt: torch.Tensor) -> torch.Tensor: 106 | """Computes plddt from the model output. The output is a histogram of unnormalised 107 | plddt. 108 | 109 | Args: 110 | plddt (torch.Tensor): (B, n, 50) output from the model 111 | 112 | Returns: 113 | torch.Tensor: (B, n) plddt scores 114 | """ 115 | pdf = torch.nn.functional.softmax(plddt, dim=-1) 116 | vbins = torch.arange(1, 101, 2).to(plddt.device).float() 117 | output = pdf @ vbins # (B, n) 118 | return output 119 | 120 | 121 | def add_atom37_to_output(output: dict, aatype: torch.Tensor): 122 | """Adds atom37 coordinates to an output dictionary containing atom14 coordinates.""" 123 | atom14 = output["positions"][-1, 0] 124 | batch = make_atom14_masks({"aatype": aatype.squeeze()}) 125 | atom37 = atom14_to_atom37(atom14, batch) 126 | output["atom37"] = atom37 127 | output["atom37_atom_exists"] = batch["atom37_atom_exists"] 128 | return output 129 | 130 | 131 | def output_to_protein(output: dict, model_input: dict) -> Protein: 132 | """Generates a Protein object from Ibex predictions. 133 | 134 | Args: 135 | output (dict): Ibex output dictionary 136 | model_input (dict): Ibex input dictionary 137 | 138 | Returns: 139 | str: the contents of a pdb file in string format. 140 | """ 141 | aatype = model_input["aatype"].squeeze().cpu().numpy().astype(int) 142 | atom37 = output["atom37"] 143 | chain_index = 1 - model_input["is_heavy"].cpu().numpy().astype(int) 144 | atom_mask = output["atom37_atom_exists"].cpu().numpy().astype(int) 145 | residue_index = np.arange(len(atom37)) 146 | if "plddt" in output: 147 | plddt = compute_plddt(output["plddt"]).squeeze().detach().cpu().numpy() 148 | b_factors = np.expand_dims(plddt, 1).repeat(37, 1) 149 | else: 150 | b_factors = np.zeros_like(atom_mask) 151 | protein = Protein( 152 | aatype=aatype, 153 | atom_positions=atom37, 154 | atom_mask=atom_mask, 155 | residue_index=residue_index, 156 | b_factors=b_factors, 157 | chain_index=chain_index, 158 | ) 159 | 160 | return protein 161 | 162 | def output_to_pdb(output: dict, model_input: dict) -> str: 163 | """Generates a pdb file from Ibex predictions. 164 | 165 | Args: 166 | output (dict): Ibex output dictionary 167 | model_input (dict): Ibex input dictionary 168 | 169 | Returns: 170 | str: the contents of a pdb file in string format. 171 | """ 172 | return to_pdb(output_to_protein(output, model_input)) 173 | 174 | def download_from_url(url, local_path): 175 | """Downloads a file from a URL with a progress bar.""" 176 | try: 177 | with requests.get(url, stream=True) as r: 178 | r.raise_for_status() 179 | total_size_in_bytes = int(r.headers.get('content-length', 0)) 180 | block_size = 1024 # 1 Kibibyte 181 | 182 | with tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc=f"Downloading {os.path.basename(local_path)}") as progress_bar: 183 | with open(local_path, 'wb') as f: 184 | for chunk in r.iter_content(block_size): 185 | progress_bar.update(len(chunk)) 186 | f.write(chunk) 187 | 188 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 189 | raise IOError("ERROR: Something went wrong during download") 190 | 191 | except requests.exceptions.RequestException as e: 192 | logger.error(f"Failed to download {url}: {e}") 193 | # Clean up partially downloaded file 194 | if os.path.exists(local_path): 195 | os.remove(local_path) 196 | raise 197 | 198 | def get_checkpoint_path(checkpoint_url, cache_dir=None): 199 | """Generates a unique local path for a given URL.""" 200 | cache_dir = cache_dir or os.path.join(Path.home(), ".cache/ibex") 201 | os.makedirs(cache_dir, exist_ok=True) 202 | url_filename = os.path.basename(checkpoint_url) 203 | url_hash = hashlib.md5(checkpoint_url.encode()).hexdigest()[:8] 204 | filename = f"{url_hash}_{url_filename}" 205 | return os.path.join(cache_dir, filename) 206 | 207 | 208 | def checkpoint_path(model_name, cache_dir=None): 209 | """ 210 | Ensures the model checkpoint exists locally, downloading it if necessary. 211 | """ 212 | if model_name not in MODEL_CHECKPOINTS: 213 | raise ValueError(f"Invalid model name: {model_name}") 214 | 215 | checkpoint_url = MODEL_CHECKPOINTS[model_name] 216 | local_path = get_checkpoint_path(checkpoint_url, cache_dir) 217 | 218 | if not os.path.exists(local_path): 219 | logger.info(f"Downloading checkpoint from {checkpoint_url} to {local_path}") 220 | # Call the new download function 221 | download_from_url(checkpoint_url, local_path) 222 | 223 | return local_path 224 | -------------------------------------------------------------------------------- /src/ibex/openfold/resources/stereo_chemical_props.txt: -------------------------------------------------------------------------------- 1 | Bond Residue Mean StdDev 2 | CA-CB ALA 1.520 0.021 3 | N-CA ALA 1.459 0.020 4 | CA-C ALA 1.525 0.026 5 | C-O ALA 1.229 0.019 6 | CA-CB ARG 1.535 0.022 7 | CB-CG ARG 1.521 0.027 8 | CG-CD ARG 1.515 0.025 9 | CD-NE ARG 1.460 0.017 10 | NE-CZ ARG 1.326 0.013 11 | CZ-NH1 ARG 1.326 0.013 12 | CZ-NH2 ARG 1.326 0.013 13 | N-CA ARG 1.459 0.020 14 | CA-C ARG 1.525 0.026 15 | C-O ARG 1.229 0.019 16 | CA-CB ASN 1.527 0.026 17 | CB-CG ASN 1.506 0.023 18 | CG-OD1 ASN 1.235 0.022 19 | CG-ND2 ASN 1.324 0.025 20 | N-CA ASN 1.459 0.020 21 | CA-C ASN 1.525 0.026 22 | C-O ASN 1.229 0.019 23 | CA-CB ASP 1.535 0.022 24 | CB-CG ASP 1.513 0.021 25 | CG-OD1 ASP 1.249 0.023 26 | CG-OD2 ASP 1.249 0.023 27 | N-CA ASP 1.459 0.020 28 | CA-C ASP 1.525 0.026 29 | C-O ASP 1.229 0.019 30 | CA-CB CYS 1.526 0.013 31 | CB-SG CYS 1.812 0.016 32 | N-CA CYS 1.459 0.020 33 | CA-C CYS 1.525 0.026 34 | C-O CYS 1.229 0.019 35 | CA-CB GLU 1.535 0.022 36 | CB-CG GLU 1.517 0.019 37 | CG-CD GLU 1.515 0.015 38 | CD-OE1 GLU 1.252 0.011 39 | CD-OE2 GLU 1.252 0.011 40 | N-CA GLU 1.459 0.020 41 | CA-C GLU 1.525 0.026 42 | C-O GLU 1.229 0.019 43 | CA-CB GLN 1.535 0.022 44 | CB-CG GLN 1.521 0.027 45 | CG-CD GLN 1.506 0.023 46 | CD-OE1 GLN 1.235 0.022 47 | CD-NE2 GLN 1.324 0.025 48 | N-CA GLN 1.459 0.020 49 | CA-C GLN 1.525 0.026 50 | C-O GLN 1.229 0.019 51 | N-CA GLY 1.456 0.015 52 | CA-C GLY 1.514 0.016 53 | C-O GLY 1.232 0.016 54 | CA-CB HIS 1.535 0.022 55 | CB-CG HIS 1.492 0.016 56 | CG-ND1 HIS 1.369 0.015 57 | CG-CD2 HIS 1.353 0.017 58 | ND1-CE1 HIS 1.343 0.025 59 | CD2-NE2 HIS 1.415 0.021 60 | CE1-NE2 HIS 1.322 0.023 61 | N-CA HIS 1.459 0.020 62 | CA-C HIS 1.525 0.026 63 | C-O HIS 1.229 0.019 64 | CA-CB ILE 1.544 0.023 65 | CB-CG1 ILE 1.536 0.028 66 | CB-CG2 ILE 1.524 0.031 67 | CG1-CD1 ILE 1.500 0.069 68 | N-CA ILE 1.459 0.020 69 | CA-C ILE 1.525 0.026 70 | C-O ILE 1.229 0.019 71 | CA-CB LEU 1.533 0.023 72 | CB-CG LEU 1.521 0.029 73 | CG-CD1 LEU 1.514 0.037 74 | CG-CD2 LEU 1.514 0.037 75 | N-CA LEU 1.459 0.020 76 | CA-C LEU 1.525 0.026 77 | C-O LEU 1.229 0.019 78 | CA-CB LYS 1.535 0.022 79 | CB-CG LYS 1.521 0.027 80 | CG-CD LYS 1.520 0.034 81 | CD-CE LYS 1.508 0.025 82 | CE-NZ LYS 1.486 0.025 83 | N-CA LYS 1.459 0.020 84 | CA-C LYS 1.525 0.026 85 | C-O LYS 1.229 0.019 86 | CA-CB MET 1.535 0.022 87 | CB-CG MET 1.509 0.032 88 | CG-SD MET 1.807 0.026 89 | SD-CE MET 1.774 0.056 90 | N-CA MET 1.459 0.020 91 | CA-C MET 1.525 0.026 92 | C-O MET 1.229 0.019 93 | CA-CB PHE 1.535 0.022 94 | CB-CG PHE 1.509 0.017 95 | CG-CD1 PHE 1.383 0.015 96 | CG-CD2 PHE 1.383 0.015 97 | CD1-CE1 PHE 1.388 0.020 98 | CD2-CE2 PHE 1.388 0.020 99 | CE1-CZ PHE 1.369 0.019 100 | CE2-CZ PHE 1.369 0.019 101 | N-CA PHE 1.459 0.020 102 | CA-C PHE 1.525 0.026 103 | C-O PHE 1.229 0.019 104 | CA-CB PRO 1.531 0.020 105 | CB-CG PRO 1.495 0.050 106 | CG-CD PRO 1.502 0.033 107 | CD-N PRO 1.474 0.014 108 | N-CA PRO 1.468 0.017 109 | CA-C PRO 1.524 0.020 110 | C-O PRO 1.228 0.020 111 | CA-CB SER 1.525 0.015 112 | CB-OG SER 1.418 0.013 113 | N-CA SER 1.459 0.020 114 | CA-C SER 1.525 0.026 115 | C-O SER 1.229 0.019 116 | CA-CB THR 1.529 0.026 117 | CB-OG1 THR 1.428 0.020 118 | CB-CG2 THR 1.519 0.033 119 | N-CA THR 1.459 0.020 120 | CA-C THR 1.525 0.026 121 | C-O THR 1.229 0.019 122 | CA-CB TRP 1.535 0.022 123 | CB-CG TRP 1.498 0.018 124 | CG-CD1 TRP 1.363 0.014 125 | CG-CD2 TRP 1.432 0.017 126 | CD1-NE1 TRP 1.375 0.017 127 | NE1-CE2 TRP 1.371 0.013 128 | CD2-CE2 TRP 1.409 0.012 129 | CD2-CE3 TRP 1.399 0.015 130 | CE2-CZ2 TRP 1.393 0.017 131 | CE3-CZ3 TRP 1.380 0.017 132 | CZ2-CH2 TRP 1.369 0.019 133 | CZ3-CH2 TRP 1.396 0.016 134 | N-CA TRP 1.459 0.020 135 | CA-C TRP 1.525 0.026 136 | C-O TRP 1.229 0.019 137 | CA-CB TYR 1.535 0.022 138 | CB-CG TYR 1.512 0.015 139 | CG-CD1 TYR 1.387 0.013 140 | CG-CD2 TYR 1.387 0.013 141 | CD1-CE1 TYR 1.389 0.015 142 | CD2-CE2 TYR 1.389 0.015 143 | CE1-CZ TYR 1.381 0.013 144 | CE2-CZ TYR 1.381 0.013 145 | CZ-OH TYR 1.374 0.017 146 | N-CA TYR 1.459 0.020 147 | CA-C TYR 1.525 0.026 148 | C-O TYR 1.229 0.019 149 | CA-CB VAL 1.543 0.021 150 | CB-CG1 VAL 1.524 0.021 151 | CB-CG2 VAL 1.524 0.021 152 | N-CA VAL 1.459 0.020 153 | CA-C VAL 1.525 0.026 154 | C-O VAL 1.229 0.019 155 | - 156 | 157 | Angle Residue Mean StdDev 158 | N-CA-CB ALA 110.1 1.4 159 | CB-CA-C ALA 110.1 1.5 160 | N-CA-C ALA 111.0 2.7 161 | CA-C-O ALA 120.1 2.1 162 | N-CA-CB ARG 110.6 1.8 163 | CB-CA-C ARG 110.4 2.0 164 | CA-CB-CG ARG 113.4 2.2 165 | CB-CG-CD ARG 111.6 2.6 166 | CG-CD-NE ARG 111.8 2.1 167 | CD-NE-CZ ARG 123.6 1.4 168 | NE-CZ-NH1 ARG 120.3 0.5 169 | NE-CZ-NH2 ARG 120.3 0.5 170 | NH1-CZ-NH2 ARG 119.4 1.1 171 | N-CA-C ARG 111.0 2.7 172 | CA-C-O ARG 120.1 2.1 173 | N-CA-CB ASN 110.6 1.8 174 | CB-CA-C ASN 110.4 2.0 175 | CA-CB-CG ASN 113.4 2.2 176 | CB-CG-ND2 ASN 116.7 2.4 177 | CB-CG-OD1 ASN 121.6 2.0 178 | ND2-CG-OD1 ASN 121.9 2.3 179 | N-CA-C ASN 111.0 2.7 180 | CA-C-O ASN 120.1 2.1 181 | N-CA-CB ASP 110.6 1.8 182 | CB-CA-C ASP 110.4 2.0 183 | CA-CB-CG ASP 113.4 2.2 184 | CB-CG-OD1 ASP 118.3 0.9 185 | CB-CG-OD2 ASP 118.3 0.9 186 | OD1-CG-OD2 ASP 123.3 1.9 187 | N-CA-C ASP 111.0 2.7 188 | CA-C-O ASP 120.1 2.1 189 | N-CA-CB CYS 110.8 1.5 190 | CB-CA-C CYS 111.5 1.2 191 | CA-CB-SG CYS 114.2 1.1 192 | N-CA-C CYS 111.0 2.7 193 | CA-C-O CYS 120.1 2.1 194 | N-CA-CB GLU 110.6 1.8 195 | CB-CA-C GLU 110.4 2.0 196 | CA-CB-CG GLU 113.4 2.2 197 | CB-CG-CD GLU 114.2 2.7 198 | CG-CD-OE1 GLU 118.3 2.0 199 | CG-CD-OE2 GLU 118.3 2.0 200 | OE1-CD-OE2 GLU 123.3 1.2 201 | N-CA-C GLU 111.0 2.7 202 | CA-C-O GLU 120.1 2.1 203 | N-CA-CB GLN 110.6 1.8 204 | CB-CA-C GLN 110.4 2.0 205 | CA-CB-CG GLN 113.4 2.2 206 | CB-CG-CD GLN 111.6 2.6 207 | CG-CD-OE1 GLN 121.6 2.0 208 | CG-CD-NE2 GLN 116.7 2.4 209 | OE1-CD-NE2 GLN 121.9 2.3 210 | N-CA-C GLN 111.0 2.7 211 | CA-C-O GLN 120.1 2.1 212 | N-CA-C GLY 113.1 2.5 213 | CA-C-O GLY 120.6 1.8 214 | N-CA-CB HIS 110.6 1.8 215 | CB-CA-C HIS 110.4 2.0 216 | CA-CB-CG HIS 113.6 1.7 217 | CB-CG-ND1 HIS 123.2 2.5 218 | CB-CG-CD2 HIS 130.8 3.1 219 | CG-ND1-CE1 HIS 108.2 1.4 220 | ND1-CE1-NE2 HIS 109.9 2.2 221 | CE1-NE2-CD2 HIS 106.6 2.5 222 | NE2-CD2-CG HIS 109.2 1.9 223 | CD2-CG-ND1 HIS 106.0 1.4 224 | N-CA-C HIS 111.0 2.7 225 | CA-C-O HIS 120.1 2.1 226 | N-CA-CB ILE 110.8 2.3 227 | CB-CA-C ILE 111.6 2.0 228 | CA-CB-CG1 ILE 111.0 1.9 229 | CB-CG1-CD1 ILE 113.9 2.8 230 | CA-CB-CG2 ILE 110.9 2.0 231 | CG1-CB-CG2 ILE 111.4 2.2 232 | N-CA-C ILE 111.0 2.7 233 | CA-C-O ILE 120.1 2.1 234 | N-CA-CB LEU 110.4 2.0 235 | CB-CA-C LEU 110.2 1.9 236 | CA-CB-CG LEU 115.3 2.3 237 | CB-CG-CD1 LEU 111.0 1.7 238 | CB-CG-CD2 LEU 111.0 1.7 239 | CD1-CG-CD2 LEU 110.5 3.0 240 | N-CA-C LEU 111.0 2.7 241 | CA-C-O LEU 120.1 2.1 242 | N-CA-CB LYS 110.6 1.8 243 | CB-CA-C LYS 110.4 2.0 244 | CA-CB-CG LYS 113.4 2.2 245 | CB-CG-CD LYS 111.6 2.6 246 | CG-CD-CE LYS 111.9 3.0 247 | CD-CE-NZ LYS 111.7 2.3 248 | N-CA-C LYS 111.0 2.7 249 | CA-C-O LYS 120.1 2.1 250 | N-CA-CB MET 110.6 1.8 251 | CB-CA-C MET 110.4 2.0 252 | CA-CB-CG MET 113.3 1.7 253 | CB-CG-SD MET 112.4 3.0 254 | CG-SD-CE MET 100.2 1.6 255 | N-CA-C MET 111.0 2.7 256 | CA-C-O MET 120.1 2.1 257 | N-CA-CB PHE 110.6 1.8 258 | CB-CA-C PHE 110.4 2.0 259 | CA-CB-CG PHE 113.9 2.4 260 | CB-CG-CD1 PHE 120.8 0.7 261 | CB-CG-CD2 PHE 120.8 0.7 262 | CD1-CG-CD2 PHE 118.3 1.3 263 | CG-CD1-CE1 PHE 120.8 1.1 264 | CG-CD2-CE2 PHE 120.8 1.1 265 | CD1-CE1-CZ PHE 120.1 1.2 266 | CD2-CE2-CZ PHE 120.1 1.2 267 | CE1-CZ-CE2 PHE 120.0 1.8 268 | N-CA-C PHE 111.0 2.7 269 | CA-C-O PHE 120.1 2.1 270 | N-CA-CB PRO 103.3 1.2 271 | CB-CA-C PRO 111.7 2.1 272 | CA-CB-CG PRO 104.8 1.9 273 | CB-CG-CD PRO 106.5 3.9 274 | CG-CD-N PRO 103.2 1.5 275 | CA-N-CD PRO 111.7 1.4 276 | N-CA-C PRO 112.1 2.6 277 | CA-C-O PRO 120.2 2.4 278 | N-CA-CB SER 110.5 1.5 279 | CB-CA-C SER 110.1 1.9 280 | CA-CB-OG SER 111.2 2.7 281 | N-CA-C SER 111.0 2.7 282 | CA-C-O SER 120.1 2.1 283 | N-CA-CB THR 110.3 1.9 284 | CB-CA-C THR 111.6 2.7 285 | CA-CB-OG1 THR 109.0 2.1 286 | CA-CB-CG2 THR 112.4 1.4 287 | OG1-CB-CG2 THR 110.0 2.3 288 | N-CA-C THR 111.0 2.7 289 | CA-C-O THR 120.1 2.1 290 | N-CA-CB TRP 110.6 1.8 291 | CB-CA-C TRP 110.4 2.0 292 | CA-CB-CG TRP 113.7 1.9 293 | CB-CG-CD1 TRP 127.0 1.3 294 | CB-CG-CD2 TRP 126.6 1.3 295 | CD1-CG-CD2 TRP 106.3 0.8 296 | CG-CD1-NE1 TRP 110.1 1.0 297 | CD1-NE1-CE2 TRP 109.0 0.9 298 | NE1-CE2-CD2 TRP 107.3 1.0 299 | CE2-CD2-CG TRP 107.3 0.8 300 | CG-CD2-CE3 TRP 133.9 0.9 301 | NE1-CE2-CZ2 TRP 130.4 1.1 302 | CE3-CD2-CE2 TRP 118.7 1.2 303 | CD2-CE2-CZ2 TRP 122.3 1.2 304 | CE2-CZ2-CH2 TRP 117.4 1.0 305 | CZ2-CH2-CZ3 TRP 121.6 1.2 306 | CH2-CZ3-CE3 TRP 121.2 1.1 307 | CZ3-CE3-CD2 TRP 118.8 1.3 308 | N-CA-C TRP 111.0 2.7 309 | CA-C-O TRP 120.1 2.1 310 | N-CA-CB TYR 110.6 1.8 311 | CB-CA-C TYR 110.4 2.0 312 | CA-CB-CG TYR 113.4 1.9 313 | CB-CG-CD1 TYR 121.0 0.6 314 | CB-CG-CD2 TYR 121.0 0.6 315 | CD1-CG-CD2 TYR 117.9 1.1 316 | CG-CD1-CE1 TYR 121.3 0.8 317 | CG-CD2-CE2 TYR 121.3 0.8 318 | CD1-CE1-CZ TYR 119.8 0.9 319 | CD2-CE2-CZ TYR 119.8 0.9 320 | CE1-CZ-CE2 TYR 119.8 1.6 321 | CE1-CZ-OH TYR 120.1 2.7 322 | CE2-CZ-OH TYR 120.1 2.7 323 | N-CA-C TYR 111.0 2.7 324 | CA-C-O TYR 120.1 2.1 325 | N-CA-CB VAL 111.5 2.2 326 | CB-CA-C VAL 111.4 1.9 327 | CA-CB-CG1 VAL 110.9 1.5 328 | CA-CB-CG2 VAL 110.9 1.5 329 | CG1-CB-CG2 VAL 110.9 1.6 330 | N-CA-C VAL 111.0 2.7 331 | CA-C-O VAL 120.1 2.1 332 | - 333 | 334 | Non-bonded distance Minimum Dist Tolerance 335 | C-C 3.4 1.5 336 | C-N 3.25 1.5 337 | C-S 3.5 1.5 338 | C-O 3.22 1.5 339 | N-N 3.1 1.5 340 | N-S 3.35 1.5 341 | N-O 3.07 1.5 342 | O-S 3.32 1.5 343 | O-O 3.04 1.5 344 | S-S 2.03 1.0 345 | - 346 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/ibex/refine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # Copyright (c) 2022, Brennan Abanades Kenyon 3 | # All rights reserved. 4 | 5 | # 6 | # BSD 3-Clause License 7 | # 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, are permitted provided that the following conditions are met: 10 | # 11 | # 1. Redistributions of source code must retain the above copyright notice, this 12 | # list of conditions and the following disclaimer. 13 | # 14 | # 2. Redistributions in binary form must reproduce the above copyright notice, 15 | # this list of conditions and the following disclaimer in the documentation 16 | # and/or other materials provided with the distribution. 17 | # 18 | # 3. Neither the name of the copyright holder nor the names of its 19 | # contributors may be used to endorse or promote products derived from 20 | # this software without specific prior written permission. 21 | # 22 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 23 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 24 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 25 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 26 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 27 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 28 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 29 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 30 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | # 33 | # Implementation taken from https://github.com/brennanaba/ImmuneBuilder/blob/main/ImmuneBuilder/refine.py 34 | 35 | import os 36 | import numpy as np 37 | from openmm import app, LangevinIntegrator, CustomExternalForce, CustomTorsionForce, OpenMMException, Platform, unit 38 | from scipy import spatial 39 | import logging 40 | import pdbfixer 41 | logging.disable() 42 | 43 | ENERGY = unit.kilocalories_per_mole 44 | LENGTH = unit.angstroms 45 | spring_unit = ENERGY / (LENGTH ** 2) 46 | 47 | CLASH_CUTOFF = 0.63 48 | 49 | # Atomic radii for various atom types. 50 | atom_radii = {"C": 1.70, "N": 1.55, 'O': 1.52, 'S': 1.80} 51 | 52 | # Sum of van-der-waals radii 53 | radii_sums = dict( 54 | [(i + j, (atom_radii[i] + atom_radii[j])) for i in list(atom_radii.keys()) for j in list(atom_radii.keys())]) 55 | # Clash_cutoff-based radii values 56 | cutoffs = dict( 57 | [(i + j, CLASH_CUTOFF * (radii_sums[i + j])) for i in list(atom_radii.keys()) for j in list(atom_radii.keys())]) 58 | 59 | # Using amber14 recommended protein force field 60 | forcefield = app.ForceField("amber14/protein.ff14SB.xml") 61 | 62 | 63 | def refine_file(input_file, output_file, checks=True, tries=3, n=6, n_threads=-1): 64 | for _ in range(tries): 65 | if refine_once(input_file, output_file, checks=checks, n=n, n_threads=n_threads): 66 | return True 67 | return False 68 | 69 | 70 | def refine_once(input_file, output_file, checks=True, n=6, n_threads=-1): 71 | k1s = [2.5,1,0.5,0.25,0.1,0.001] 72 | k2s = [2.5,5,7.5,15,25,50] 73 | success = False 74 | 75 | fixer = pdbfixer.PDBFixer(input_file) 76 | 77 | fixer.findMissingResidues() 78 | fixer.findMissingAtoms() 79 | fixer.addMissingAtoms() 80 | 81 | k1 = k1s[0] 82 | if checks: 83 | k2 = -1 if (cis_check(fixer.topology, fixer.positions)) else k2s[0] 84 | else: 85 | k2 = -1 86 | 87 | topology, positions = fixer.topology, fixer.positions 88 | 89 | for i in range(n): 90 | try: 91 | simulation = minimize_energy(topology, positions, k1=k1, k2 = k2, n_threads=n_threads) 92 | topology, positions = simulation.topology, simulation.context.getState(getPositions=True).getPositions() 93 | if checks: 94 | acceptable_bonds, trans_peptide_bonds = bond_check(topology, positions), cis_check(topology, positions) 95 | except OpenMMException as e: 96 | if (i == n-1) and ("positions" not in locals()): 97 | print("OpenMM failed to refine {}".format(input_file), flush=True) 98 | raise e 99 | else: 100 | topology, positions = fixer.topology, fixer.positions 101 | continue 102 | 103 | if checks: 104 | 105 | # If peptide bonds are the wrong length, decrease the strength of the positional restraint 106 | if not acceptable_bonds: 107 | k1 = k1s[min(i, len(k1s)-1)] 108 | 109 | # If there are still cis isomers in the model, increase the force to fix these 110 | if not trans_peptide_bonds: 111 | k2 = k2s[min(i, len(k2s)-1)] 112 | else: 113 | k2 = -1 114 | 115 | if acceptable_bonds and trans_peptide_bonds: 116 | # If peptide bond lengths and torsions are okay, check and fix the chirality. 117 | try: 118 | simulation = chirality_fixer(simulation) 119 | topology, positions = simulation.topology, simulation.context.getState(getPositions=True).getPositions() 120 | except OpenMMException as e: 121 | topology, positions = fixer.topology, fixer.positions 122 | continue 123 | 124 | # If all other checks pass, check and fix strained sidechain bonds: 125 | try: 126 | strained_bonds = strained_sidechain_bonds_check(topology, positions) 127 | if len(strained_bonds) > 0: 128 | needs_recheck = True 129 | topology, positions = strained_sidechain_bonds_fixer(strained_bonds, topology, positions, n_threads=n_threads) 130 | else: 131 | needs_recheck = False 132 | except OpenMMException as e: 133 | topology, positions = fixer.topology, fixer.positions 134 | continue 135 | 136 | # If it passes all the tests, we are done 137 | tests = bond_check(topology, positions) and cis_check(topology, positions) 138 | if needs_recheck: 139 | tests = tests and strained_sidechain_bonds_check(topology, positions) 140 | if tests and stereo_check(topology, positions) and clash_check(topology, positions): 141 | success = True 142 | break 143 | 144 | else: 145 | success = True 146 | break 147 | 148 | with open(output_file, "w") as out_handle: 149 | app.PDBFile.writeFile(topology, positions, out_handle, keepIds=True) 150 | 151 | return success 152 | 153 | 154 | def minimize_energy(topology, positions, k1=2.5, k2=2.5, n_threads=-1): 155 | # Fill in the gaps with OpenMM Modeller 156 | modeller = app.Modeller(topology, positions) 157 | modeller.addHydrogens(forcefield) 158 | 159 | # Set up force field 160 | system = forcefield.createSystem(modeller.topology) 161 | 162 | # Keep atoms close to initial prediction 163 | force = CustomExternalForce("k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)") 164 | force.addGlobalParameter("k", k1 * spring_unit) 165 | for p in ["x0", "y0", "z0"]: 166 | force.addPerParticleParameter(p) 167 | 168 | for residue in modeller.topology.residues(): 169 | for atom in residue.atoms(): 170 | if atom.name in ["CA", "CB", "N", "C"]: 171 | force.addParticle(atom.index, modeller.positions[atom.index]) 172 | 173 | system.addForce(force) 174 | 175 | if k2 > 0.0: 176 | cis_force = CustomTorsionForce("10*k2*(1+cos(theta))^2") 177 | cis_force.addGlobalParameter("k2", k2 * ENERGY) 178 | 179 | for chain in modeller.topology.chains(): 180 | residues = [res for res in chain.residues()] 181 | relevant_atoms = [{atom.name:atom.index for atom in res.atoms() if atom.name in ["N", "CA", "C"]} for res in residues] 182 | for i in range(1,len(residues)): 183 | if residues[i].name == "PRO": 184 | continue 185 | 186 | resi = relevant_atoms[i-1] 187 | n_resi = relevant_atoms[i] 188 | cis_force.addTorsion(resi["CA"], resi["C"], n_resi["N"], n_resi["CA"]) 189 | 190 | system.addForce(cis_force) 191 | 192 | # Set up integrator 193 | integrator = LangevinIntegrator(0, 0.01, 0.0) 194 | 195 | # Set up the simulation 196 | if n_threads > 0: 197 | # Set number of threads used by OpenMM 198 | platform = Platform.getPlatformByName('CPU') 199 | simulation = app.Simulation(modeller.topology, system, integrator, platform, {'Threads': str(n_threads)}) 200 | else: 201 | simulation = app.Simulation(modeller.topology, system, integrator) 202 | simulation.context.setPositions(modeller.positions) 203 | 204 | # Minimize the energy 205 | simulation.minimizeEnergy() 206 | 207 | return simulation 208 | 209 | 210 | def chirality_fixer(simulation): 211 | topology = simulation.topology 212 | positions = simulation.context.getState(getPositions=True).getPositions() 213 | 214 | d_stereoisomers = [] 215 | for residue in topology.residues(): 216 | if residue.name == "GLY": 217 | continue 218 | 219 | atom_indices = {atom.name:atom.index for atom in residue.atoms() if atom.name in ["N", "CA", "C", "CB"]} 220 | vectors = [positions[atom_indices[i]] - positions[atom_indices["CA"]] for i in ["N", "C", "CB"]] 221 | 222 | if np.dot(np.cross(vectors[0], vectors[1]), vectors[2]) < .0*LENGTH**3: 223 | # If it is a D-stereoisomer then flip its H atom 224 | indices = {x.name:x.index for x in residue.atoms() if x.name in ["HA", "CA"]} 225 | positions[indices["HA"]] = 2*positions[indices["CA"]] - positions[indices["HA"]] 226 | 227 | # Fix the H atom in place 228 | particle_mass = simulation.system.getParticleMass(indices["HA"]) 229 | simulation.system.setParticleMass(indices["HA"], 0.0) 230 | d_stereoisomers.append((indices["HA"], particle_mass)) 231 | 232 | if len(d_stereoisomers) > 0: 233 | simulation.context.setPositions(positions) 234 | 235 | # Minimize the energy with the evil hydrogens fixed 236 | simulation.minimizeEnergy() 237 | 238 | # Minimize the energy letting the hydrogens move 239 | for atom in d_stereoisomers: 240 | simulation.system.setParticleMass(*atom) 241 | simulation.minimizeEnergy() 242 | 243 | return simulation 244 | 245 | 246 | def bond_check(topology, positions): 247 | for chain in topology.chains(): 248 | residues = [{atom.name:atom.index for atom in res.atoms() if atom.name in ["N", "C"]} for res in chain.residues()] 249 | for i in range(len(residues)-1): 250 | # For simplicity we only check the peptide bond length as the rest should be correct as they are hard coded 251 | v = np.linalg.norm(positions[residues[i]["C"]] - positions[residues[i+1]["N"]]) 252 | if abs(v - 1.329*LENGTH) > 0.1*LENGTH: 253 | return False 254 | return True 255 | 256 | 257 | def cis_bond(p0,p1,p2,p3): 258 | ab = p1-p0 259 | cd = p2-p1 260 | db = p3-p2 261 | 262 | u = np.cross(-ab, cd) 263 | v = np.cross(db, cd) 264 | return np.dot(u,v) > 0 265 | 266 | 267 | def cis_check(topology, positions): 268 | pos = np.array(positions.value_in_unit(LENGTH)) 269 | for chain in topology.chains(): 270 | residues = [res for res in chain.residues()] 271 | relevant_atoms = [{atom.name:atom.index for atom in res.atoms() if atom.name in ["N", "CA", "C"]} for res in residues] 272 | for i in range(1,len(residues)): 273 | if residues[i].name == "PRO": 274 | continue 275 | 276 | resi = relevant_atoms[i-1] 277 | n_resi = relevant_atoms[i] 278 | p0,p1,p2,p3 = pos[resi["CA"]],pos[resi["C"]],pos[n_resi["N"]],pos[n_resi["CA"]] 279 | if cis_bond(p0,p1,p2,p3): 280 | return False 281 | return True 282 | 283 | 284 | def stereo_check(topology, positions): 285 | pos = np.array(positions.value_in_unit(LENGTH)) 286 | for residue in topology.residues(): 287 | if residue.name == "GLY": 288 | continue 289 | 290 | atom_indices = {atom.name:atom.index for atom in residue.atoms() if atom.name in ["N", "CA", "C", "CB"]} 291 | vectors = pos[[atom_indices[i] for i in ["N", "C", "CB"]]] - pos[atom_indices["CA"]] 292 | 293 | if np.linalg.det(vectors) < 0: 294 | return False 295 | return True 296 | 297 | 298 | def clash_check(topology, positions): 299 | heavies = [x for x in topology.atoms() if x.element.symbol != "H"] 300 | pos = np.array(positions.value_in_unit(LENGTH))[[x.index for x in heavies]] 301 | 302 | tree = spatial.KDTree(pos) 303 | pairs = tree.query_pairs(r=max(cutoffs.values())) 304 | 305 | for pair in pairs: 306 | atom_i, atom_j = heavies[pair[0]], heavies[pair[1]] 307 | 308 | if atom_i.residue.index == atom_j.residue.index: 309 | continue 310 | elif (atom_i.name == "C" and atom_j.name == "N") or (atom_i.name == "N" and atom_j.name == "C"): 311 | continue 312 | 313 | atom_distance = np.linalg.norm(pos[pair[0]] - pos[pair[1]]) 314 | 315 | if (atom_i.name == "SG" and atom_j.name == "SG") and atom_distance > 1.88: 316 | continue 317 | 318 | elif atom_distance < (cutoffs[atom_i.element.symbol + atom_j.element.symbol]): 319 | return False 320 | return True 321 | 322 | 323 | def strained_sidechain_bonds_check(topology, positions): 324 | atoms = list(topology.atoms()) 325 | pos = np.array(positions.value_in_unit(LENGTH)) 326 | 327 | system = forcefield.createSystem(topology) 328 | bonds = [x for x in system.getForces() if type(x).__name__ == "HarmonicBondForce"][0] 329 | 330 | # Initialise arrays for bond details 331 | n_bonds = bonds.getNumBonds() 332 | i = np.empty(n_bonds, dtype=int) 333 | j = np.empty(n_bonds, dtype=int) 334 | k = np.empty(n_bonds) 335 | x0 = np.empty(n_bonds) 336 | 337 | # Extract bond details to arrays 338 | for n in range(n_bonds): 339 | i[n],j[n],_x0,_k = bonds.getBondParameters(n) 340 | k[n] = _k.value_in_unit(spring_unit) 341 | x0[n] = _x0.value_in_unit(LENGTH) 342 | 343 | # Check if there are any abnormally strained bond 344 | distance = np.linalg.norm(pos[i] - pos[j], axis=-1) 345 | check = k*(distance - x0)**2 > 100 346 | 347 | # Return residues with strained bonds if any 348 | return [atoms[x].residue for x in i[check]] 349 | 350 | 351 | def strained_sidechain_bonds_fixer(strained_residues, topology, positions, n_threads=-1): 352 | # Delete all atoms except the main chain for badly refined residues. 353 | bb_atoms = ["N","CA","C"] 354 | bad_side_chains = sum([[atom for atom in residue.atoms() if atom.name not in bb_atoms] for residue in strained_residues],[]) 355 | modeller = app.Modeller(topology, positions) 356 | modeller.delete(bad_side_chains) 357 | 358 | # Save model with deleted side chains to temporary file. 359 | random_number = str(int(np.random.rand()*10**8)) 360 | tmp_file = f"side_chain_fix_tmp_{random_number}.pdb" 361 | with open(tmp_file,"w") as handle: 362 | app.PDBFile.writeFile(modeller.topology, modeller.positions, handle, keepIds=True) 363 | 364 | # Load model into pdbfixer 365 | fixer = pdbfixer.PDBFixer(tmp_file) 366 | os.remove(tmp_file) 367 | 368 | # Repair deleted side chains 369 | fixer.findMissingResidues() 370 | fixer.findMissingAtoms() 371 | fixer.addMissingAtoms() 372 | 373 | # Fill in the gaps with OpenMM Modeller 374 | modeller = app.Modeller(fixer.topology, fixer.positions) 375 | modeller.addHydrogens(forcefield) 376 | 377 | # Set up force field 378 | system = forcefield.createSystem(modeller.topology) 379 | 380 | # Set up integrator 381 | integrator = LangevinIntegrator(0, 0.01, 0.0) 382 | 383 | # Set up the simulation 384 | if n_threads > 0: 385 | # Set number of threads used by OpenMM 386 | platform = Platform.getPlatformByName('CPU') 387 | simulation = app.Simulation(modeller.topology, system, integrator, platform, {'Threads', str(n_threads)}) 388 | else: 389 | simulation = app.Simulation(modeller.topology, system, integrator) 390 | simulation.context.setPositions(modeller.positions) 391 | 392 | # Minimize the energy 393 | simulation.minimizeEnergy() 394 | 395 | return simulation.topology, simulation.context.getState(getPositions=True).getPositions() 396 | -------------------------------------------------------------------------------- /src/ibex/openfold/utils/protein.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # Copyright 2024 Exscientia 3 | # Copyright 2021 AlQuraishi Laboratory 4 | # Copyright 2021 DeepMind Technologies Limited 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """Protein data type.""" 19 | import dataclasses 20 | import io 21 | import re 22 | import string 23 | from typing import Any, Mapping, Optional, Sequence 24 | 25 | import numpy as np 26 | from Bio.PDB import PDBParser 27 | 28 | from ibex.openfold.utils import residue_constants 29 | 30 | FeatureDict = Mapping[str, np.ndarray] 31 | ModelOutput = Mapping[str, Any] # Is a nested dict. 32 | PICO_TO_ANGSTROM = 0.01 33 | 34 | 35 | @dataclasses.dataclass(frozen=True) 36 | class Protein: 37 | """Protein structure representation.""" 38 | 39 | # Cartesian coordinates of atoms in angstroms. The atom types correspond to 40 | # residue_constants.atom_types, i.e. the first three are N, CA, CB. 41 | atom_positions: np.ndarray # [num_res, num_atom_type, 3] 42 | 43 | # Amino-acid type for each residue represented as an integer between 0 and 44 | # 20, where 20 is 'X'. 45 | aatype: np.ndarray # [num_res] 46 | 47 | # Binary float mask to indicate presence of a particular atom. 1.0 if an atom 48 | # is present and 0.0 if not. This should be used for loss masking. 49 | atom_mask: np.ndarray # [num_res, num_atom_type] 50 | 51 | # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. 52 | residue_index: np.ndarray # [num_res] 53 | 54 | # B-factors, or temperature factors, of each residue (in sq. angstroms units), 55 | # representing the displacement of the residue from its ground truth mean 56 | # value. 57 | b_factors: np.ndarray # [num_res, num_atom_type] 58 | 59 | # Chain indices for multi-chain predictions 60 | chain_index: Optional[np.ndarray] = None 61 | 62 | # Optional remark about the protein. Included as a comment in output PDB 63 | # files 64 | remark: Optional[str] = None 65 | 66 | # Templates used to generate this protein (prediction-only) 67 | parents: Optional[Sequence[str]] = None 68 | 69 | # Chain corresponding to each parent 70 | parents_chain_index: Optional[Sequence[int]] = None 71 | 72 | 73 | def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: 74 | """Takes a PDB string and constructs a Protein object. 75 | 76 | WARNING: All non-standard residue types will be converted into UNK. All 77 | non-standard atoms will be ignored. 78 | 79 | Args: 80 | pdb_str: The contents of the pdb file 81 | chain_id: If None, then the pdb file must contain a single chain (which 82 | will be parsed). If chain_id is specified (e.g. A), then only that chain 83 | is parsed. 84 | 85 | Returns: 86 | A new `Protein` parsed from the pdb contents. 87 | """ 88 | pdb_fh = io.StringIO(pdb_str) 89 | parser = PDBParser(QUIET=True) 90 | structure = parser.get_structure("none", pdb_fh) 91 | models = list(structure.get_models()) 92 | if len(models) != 1: 93 | raise ValueError( 94 | f"Only single model PDBs are supported. Found {len(models)} models." 95 | ) 96 | model = models[0] 97 | 98 | atom_positions = [] 99 | aatype = [] 100 | atom_mask = [] 101 | residue_index = [] 102 | chain_ids = [] 103 | b_factors = [] 104 | 105 | for chain in model: 106 | if chain_id is not None and chain.id != chain_id: 107 | continue 108 | for res in chain: 109 | if res.id[2] != " ": 110 | raise ValueError( 111 | f"PDB contains an insertion code at chain {chain.id} and residue " 112 | f"index {res.id[1]}. These are not supported." 113 | ) 114 | res_shortname = residue_constants.restype_3to1.get(res.resname, "X") 115 | restype_idx = residue_constants.restype_order.get( 116 | res_shortname, residue_constants.restype_num 117 | ) 118 | pos = np.zeros((residue_constants.atom_type_num, 3)) 119 | mask = np.zeros((residue_constants.atom_type_num,)) 120 | res_b_factors = np.zeros((residue_constants.atom_type_num,)) 121 | for atom in res: 122 | if atom.name not in residue_constants.atom_types: 123 | continue 124 | pos[residue_constants.atom_order[atom.name]] = atom.coord 125 | mask[residue_constants.atom_order[atom.name]] = 1.0 126 | res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor 127 | if np.sum(mask) < 0.5: 128 | # If no known atom positions are reported for the residue then skip it. 129 | continue 130 | aatype.append(restype_idx) 131 | atom_positions.append(pos) 132 | atom_mask.append(mask) 133 | residue_index.append(res.id[1]) 134 | chain_ids.append(chain.id) 135 | b_factors.append(res_b_factors) 136 | 137 | parents = None 138 | parents_chain_index = None 139 | if "PARENT" in pdb_str: 140 | parents = [] 141 | parents_chain_index = [] 142 | chain_id = 0 143 | for l in pdb_str.split("\n"): 144 | if "PARENT" in l: 145 | if not "N/A" in l: 146 | parent_names = l.split()[1:] 147 | parents.extend(parent_names) 148 | parents_chain_index.extend([chain_id for _ in parent_names]) 149 | chain_id += 1 150 | 151 | unique_chain_ids = np.unique(chain_ids) 152 | chain_id_mapping = {cid: n for n, cid in enumerate(string.ascii_uppercase)} 153 | chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) 154 | 155 | return Protein( 156 | atom_positions=np.array(atom_positions), 157 | atom_mask=np.array(atom_mask), 158 | aatype=np.array(aatype), 159 | residue_index=np.array(residue_index), 160 | chain_index=chain_index, 161 | b_factors=np.array(b_factors), 162 | parents=parents, 163 | parents_chain_index=parents_chain_index, 164 | ) 165 | 166 | 167 | def from_proteinnet_string(proteinnet_str: str) -> Protein: 168 | tag_re = r"(\[[A-Z]+\]\n)" 169 | tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0] 170 | groups = zip(tags[0::2], [l.split("\n") for l in tags[1::2]]) 171 | 172 | atoms = ["N", "CA", "C"] 173 | aatype = None 174 | atom_positions = None 175 | atom_mask = None 176 | for g in groups: 177 | if "[PRIMARY]" == g[0]: 178 | seq = g[1][0].strip() 179 | for i in range(len(seq)): 180 | if seq[i] not in residue_constants.restypes: 181 | seq[i] = "X" 182 | aatype = np.array( 183 | [ 184 | residue_constants.restype_order.get( 185 | res_symbol, residue_constants.restype_num 186 | ) 187 | for res_symbol in seq 188 | ] 189 | ) 190 | elif "[TERTIARY]" == g[0]: 191 | tertiary = [] 192 | for axis in range(3): 193 | tertiary.append(list(map(float, g[1][axis].split()))) 194 | tertiary_np = np.array(tertiary) 195 | atom_positions = np.zeros( 196 | (len(tertiary[0]) // 3, residue_constants.atom_type_num, 3) 197 | ).astype(np.float32) 198 | for i, atom in enumerate(atoms): 199 | atom_positions[:, residue_constants.atom_order[atom], :] = np.transpose( 200 | tertiary_np[:, i::3] 201 | ) 202 | atom_positions *= PICO_TO_ANGSTROM 203 | elif "[MASK]" == g[0]: 204 | mask = np.array(list(map({"-": 0, "+": 1}.get, g[1][0].strip()))) 205 | atom_mask = np.zeros( 206 | ( 207 | len(mask), 208 | residue_constants.atom_type_num, 209 | ) 210 | ).astype(np.float32) 211 | for i, atom in enumerate(atoms): 212 | atom_mask[:, residue_constants.atom_order[atom]] = 1 213 | atom_mask *= mask[..., None] 214 | 215 | return Protein( 216 | atom_positions=atom_positions, 217 | atom_mask=atom_mask, 218 | aatype=aatype, 219 | residue_index=np.arange(len(aatype)), 220 | b_factors=None, 221 | ) 222 | 223 | 224 | def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]: 225 | pdb_headers = [] 226 | 227 | remark = prot.remark 228 | if remark is not None: 229 | pdb_headers.append(f"REMARK {remark}") 230 | 231 | parents = prot.parents 232 | parents_chain_index = prot.parents_chain_index 233 | if parents_chain_index is not None: 234 | parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id] 235 | 236 | if parents is None or len(parents) == 0: 237 | parents = ["N/A"] 238 | 239 | pdb_headers.append(f"PARENT {' '.join(parents)}") 240 | 241 | return pdb_headers 242 | 243 | 244 | def add_pdb_headers(prot: Protein, pdb_str: str) -> str: 245 | """Add pdb headers to an existing PDB string. Useful during multi-chain 246 | recycling 247 | """ 248 | out_pdb_lines = [] 249 | lines = pdb_str.split("\n") 250 | 251 | remark = prot.remark 252 | if remark is not None: 253 | out_pdb_lines.append(f"REMARK {remark}") 254 | 255 | parents_per_chain = None 256 | if prot.parents is not None and len(prot.parents) > 0: 257 | parents_per_chain = [] 258 | if prot.parents_chain_index is not None: 259 | cur_chain = prot.parents_chain_index[0] 260 | parent_dict = {} 261 | for p, i in zip(prot.parents, prot.parents_chain_index): 262 | parent_dict.setdefault(str(i), []) 263 | parent_dict[str(i)].append(p) 264 | 265 | max_idx = max([int(chain_idx) for chain_idx in parent_dict]) 266 | for i in range(max_idx + 1): 267 | chain_parents = parent_dict.get(str(i), ["N/A"]) 268 | parents_per_chain.append(chain_parents) 269 | else: 270 | parents_per_chain.append(prot.parents) 271 | else: 272 | parents_per_chain = [["N/A"]] 273 | 274 | make_parent_line = lambda p: f"PARENT {' '.join(p)}" 275 | 276 | out_pdb_lines.append(make_parent_line(parents_per_chain[0])) 277 | 278 | chain_counter = 0 279 | for i, l in enumerate(lines): 280 | if "PARENT" not in l and "REMARK" not in l: 281 | out_pdb_lines.append(l) 282 | if "TER" in l and not "END" in lines[i + 1]: 283 | chain_counter += 1 284 | if not chain_counter >= len(parents_per_chain): 285 | chain_parents = parents_per_chain[chain_counter] 286 | else: 287 | chain_parents = ["N/A"] 288 | 289 | out_pdb_lines.append(make_parent_line(chain_parents)) 290 | 291 | return "\n".join(out_pdb_lines) 292 | 293 | 294 | def to_pdb(prot: Protein) -> str: 295 | """Converts a `Protein` instance to a PDB string. 296 | 297 | Args: 298 | prot: The protein to convert to PDB. 299 | 300 | Returns: 301 | PDB string. 302 | """ 303 | restypes = residue_constants.restypes + ["X"] 304 | res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK") 305 | atom_types = residue_constants.atom_types 306 | 307 | pdb_lines = [] 308 | 309 | atom_mask = prot.atom_mask 310 | aatype = prot.aatype 311 | atom_positions = prot.atom_positions 312 | residue_index = prot.residue_index.astype(np.int32) 313 | b_factors = prot.b_factors 314 | chain_index = prot.chain_index 315 | 316 | if np.any(aatype > residue_constants.restype_num): 317 | raise ValueError("Invalid aatypes.") 318 | 319 | headers = get_pdb_headers(prot) 320 | if len(headers) > 0: 321 | pdb_lines.extend(headers) 322 | 323 | n = aatype.shape[0] 324 | atom_index = 1 325 | prev_chain_index = 0 326 | chain_tags = ["H", "L"] # string.ascii_uppercase 327 | # Add all atom sites. 328 | for i in range(n): 329 | res_name_3 = res_1to3(aatype[i]) 330 | for atom_name, pos, mask, b_factor in zip( 331 | atom_types, atom_positions[i], atom_mask[i], b_factors[i] 332 | ): 333 | if mask < 0.5: 334 | continue 335 | 336 | record_type = "ATOM" 337 | name = atom_name if len(atom_name) == 4 else f" {atom_name}" 338 | alt_loc = "" 339 | insertion_code = "" 340 | occupancy = 1.00 341 | element = atom_name[0] # Protein supports only C, N, O, S, this works. 342 | charge = "" 343 | 344 | chain_tag = "A" 345 | if chain_index is not None: 346 | chain_tag = chain_tags[chain_index[i]] 347 | 348 | # PDB is a columnar format, every space matters here! 349 | atom_line = ( 350 | f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" 351 | f"{res_name_3:>3} {chain_tag:>1}" 352 | f"{residue_index[i]:>4}{insertion_code:>1} " 353 | f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" 354 | f"{occupancy:>6.2f}{b_factor:>6.2f} " 355 | f"{element:>2}{charge:>2}" 356 | ) 357 | pdb_lines.append(atom_line) 358 | atom_index += 1 359 | 360 | should_terminate = i == n - 1 361 | if chain_index is not None: 362 | if i != n - 1 and chain_index[i + 1] != prev_chain_index: 363 | should_terminate = True 364 | prev_chain_index = chain_index[i + 1] 365 | 366 | if should_terminate: 367 | # Close the chain. 368 | chain_end = "TER" 369 | chain_termination_line = ( 370 | f"{chain_end:<6}{atom_index:>5} " 371 | f"{res_1to3(aatype[i]):>3} " 372 | f"{chain_tag:>1}{residue_index[i]:>4}" 373 | ) 374 | pdb_lines.append(chain_termination_line) 375 | atom_index += 1 376 | 377 | if i != n - 1: 378 | # "prev" is a misnomer here. This happens at the beginning of 379 | # each new chain. 380 | pdb_lines.extend(get_pdb_headers(prot, prev_chain_index)) 381 | 382 | pdb_lines.append("END") 383 | pdb_lines.append("") 384 | return "\n".join(pdb_lines) 385 | 386 | 387 | def ideal_atom_mask(prot: Protein) -> np.ndarray: 388 | """Computes an ideal atom mask. 389 | 390 | `Protein.atom_mask` typically is defined according to the atoms that are 391 | reported in the PDB. This function computes a mask according to heavy atoms 392 | that should be present in the given sequence of amino acids. 393 | 394 | Args: 395 | prot: `Protein` whose fields are `numpy.ndarray` objects. 396 | 397 | Returns: 398 | An ideal atom mask. 399 | """ 400 | return residue_constants.STANDARD_ATOM_MASK[prot.aatype] 401 | 402 | 403 | def from_prediction( 404 | features: FeatureDict, 405 | result: ModelOutput, 406 | b_factors: Optional[np.ndarray] = None, 407 | chain_index: Optional[np.ndarray] = None, 408 | remark: Optional[str] = None, 409 | parents: Optional[Sequence[str]] = None, 410 | parents_chain_index: Optional[Sequence[int]] = None, 411 | ) -> Protein: 412 | """Assembles a protein from a prediction. 413 | 414 | Args: 415 | features: Dictionary holding model inputs. 416 | result: Dictionary holding model outputs. 417 | b_factors: (Optional) B-factors to use for the protein. 418 | chain_index: (Optional) Chain indices for multi-chain predictions 419 | remark: (Optional) Remark about the prediction 420 | parents: (Optional) List of template names 421 | Returns: 422 | A protein instance. 423 | """ 424 | if b_factors is None: 425 | b_factors = np.zeros_like(result["final_atom_mask"]) 426 | 427 | return Protein( 428 | aatype=features["aatype"], 429 | atom_positions=result["final_atom_positions"], 430 | atom_mask=result["final_atom_mask"], 431 | residue_index=features["residue_index"] + 1, 432 | b_factors=b_factors, 433 | chain_index=chain_index, 434 | remark=remark, 435 | parents=parents, 436 | parents_chain_index=parents_chain_index, 437 | ) 438 | -------------------------------------------------------------------------------- /src/ibex/dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # Copyright 2024 Exscientia 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | import pandas as pd 18 | import torch 19 | from tqdm import tqdm 20 | import itertools 21 | from math import ceil 22 | from torch import Tensor 23 | from typing import Iterator, List, Dict, Optional, Union 24 | from pathlib import Path 25 | from loguru import logger 26 | from torch.nn.utils.rnn import pad_sequence 27 | from torch.utils.data import Dataset, BatchSampler, Sampler 28 | from collections.abc import Iterable 29 | from torch.utils.data.distributed import DistributedSampler 30 | 31 | from ibex.openfold.utils.residue_constants import restype_order_with_x 32 | 33 | # Fileids from the test and validation set that are not present in our version of the dataset, or 34 | # where there are mismatches in the lengths of the chains. 35 | invalid_fileids = [ 36 | # "8aon_BC", # valid 37 | # "4hjj_HL", # valid 38 | # "7pa6_KKKKKK", # valid 39 | # "7ce2_ZB", # valid 40 | # "7r8u_HHHLLL", # valid 41 | # "6o89_HL", # valid 42 | # "5d6c_HL", # valid 43 | # "7u8c_HL", # test 44 | # "7seg_HL", # test 45 | # "2a9m_IM", # test 46 | # "4buh_AA", # test 47 | # "6ss5_HHHLLL" # test 48 | "7sgm_HL", # mismatching lengths 49 | "7sgm_XY", # mismatching lengths 50 | "7sgm_KM", # mismatching lengths 51 | "6w9g_HL", # mismatching lengths 52 | "6w9g_KM", # mismatching lengths 53 | "6w9g_XY", # mismatching lengths 54 | "5iwl_AB", # mismatching lengths 55 | "5iwl_BA", # mismatching lengths 56 | "7sen_HL", # mismatching lengths 57 | "6w5a_HL", # mismatching lengths 58 | "6bjz_HL", # mismatching lengths 59 | "7t74_HL", # mismatching lengths 60 | "6dfv_BA", # mismatching lengths TCR 61 | "3tyf_DC", # mismatching lengths TCR 62 | ] 63 | 64 | class ABDataset(Dataset): 65 | def __init__( 66 | self, 67 | split: str, 68 | split_file: str, 69 | data_dir: str, 70 | limit: Optional[int] = None, 71 | use_vhh: bool = False, 72 | use_tcr: bool = False, 73 | rel_pos_dim: int = 16, 74 | edge_chain_feature: bool = False, 75 | use_plm_embeddings: bool = False, 76 | use_public_only: bool = False, 77 | use_private_only: bool = False, 78 | cluster_column: str = "cluster", 79 | weight_temperature: float = 1.0, 80 | weight_clip_quantile: float = 0.99, 81 | vhh_weight: float = 1.0, 82 | tcr_weight: float = 1.0, 83 | matched_weight: float = 1.0, 84 | conformation_node_feature: bool = False, 85 | use_contrastive: bool = False, 86 | is_matched: bool = False, 87 | use_boltz_only: bool = False, 88 | use_weights: bool = True, 89 | ) -> None: 90 | """Dataset of antibodies suitable for openfold structuremodules and loss functions. 91 | 92 | Args: 93 | path (str): root data folder 94 | split (str): "all", "train", "valid" or "test" 95 | """ 96 | super().__init__() 97 | self.split = split 98 | self.limit = limit 99 | self.split_file = split_file 100 | self.data_dir = data_dir 101 | self.rel_pos_dim = rel_pos_dim 102 | self.edge_chain_feature = edge_chain_feature 103 | self.use_plm_embeddings = use_plm_embeddings 104 | self.conformation_node_feature = conformation_node_feature 105 | self.use_contrastive = use_contrastive 106 | self.is_matched = is_matched 107 | self.cluster_column = cluster_column 108 | self.weight_temperature = weight_temperature 109 | self.weight_clip_quantile = weight_clip_quantile 110 | 111 | self.df = pd.read_csv(split_file, index_col=0, dtype={'file_id': str}) 112 | if not use_vhh: 113 | if "is_vhh" in self.df.columns: 114 | self.df = self.df[self.df["is_vhh"] == False] 115 | if not use_tcr: 116 | if "is_tcr" in self.df.columns: 117 | self.df = self.df[self.df["is_tcr"] == False] 118 | if split != "all": 119 | self.df = self.df.query(f"split=='{split}'") 120 | if use_public_only: 121 | self.df = self.df[self.df["is_internal"]==False] 122 | if use_private_only: 123 | self.df = self.df[self.df["is_internal"]] 124 | if 'is_boltz' in self.df.columns and use_boltz_only: 125 | self.df = self.df[self.df["is_boltz"]] 126 | if use_contrastive: 127 | self.df = self.df[self.df["is_matched"]==is_matched] 128 | 129 | self.df = self.df[~self.df.index.isin(invalid_fileids)] 130 | 131 | if limit is not None: 132 | self.df = self.df.iloc[:limit] 133 | 134 | if cluster_column in self.df.columns and use_weights: 135 | # Get basic weight that is 1/cluster size 136 | probs = self.df[cluster_column].map(self.df.groupby(cluster_column).size()).values 137 | probs = torch.tensor(probs, dtype=torch.float32) 138 | probs = probs / probs.sum() 139 | self.weights = 1 / (probs ** self.weight_temperature) 140 | # Cap the weight to e.g. 0.99 quantile, to avoid bad outliers 141 | self.weights = torch.minimum(self.weights, torch.quantile(self.weights, self.weight_clip_quantile)) 142 | # Make the weight sum to 1 143 | self.weights = self.weights / self.weights.sum() 144 | else: 145 | logger.info(f"Using uniform weights for {split_file}") 146 | self.weights = torch.ones(len(self.df)) / len(self.df) 147 | 148 | # Upweight TCRs and VHHs 149 | if use_vhh: 150 | self.weights[self.df["is_vhh"]] *= vhh_weight 151 | if use_tcr: 152 | self.weights[self.df["is_tcr"]] *= tcr_weight 153 | if 'is_matched' in self.df.columns: 154 | self.weights[self.df["is_matched"]] *= matched_weight 155 | self.weights = self.weights / self.weights.sum() 156 | 157 | def __len__(self): 158 | return len(self.df) 159 | 160 | def __getitem__(self, idx: int): 161 | if self.use_contrastive and self.is_matched: 162 | datapoint_pos = self._load_single_sample(self.df.iloc[idx].name) 163 | datapoint_neg = self._load_single_sample(self.df.iloc[idx].matched_index) 164 | 165 | datapoint_pos["is_matched"] = torch.tensor(True, dtype=torch.bool) 166 | datapoint_neg["is_matched"] = torch.tensor(True, dtype=torch.bool) 167 | 168 | return datapoint_pos, datapoint_neg 169 | # return datapoint_pos, datapoint_neg, datapoint_anchor 170 | elif self.split=='train' or self.split=='all': 171 | datapoint = self._load_single_sample(self.df.iloc[idx].name) 172 | # Get a random pairing respecting the weights 173 | idx_other = random.choices( 174 | range(len(self.df)), 175 | weights=self.weights, 176 | k=1, 177 | )[0] 178 | datapoint_other = self._load_single_sample(self.df.iloc[idx_other].name) 179 | 180 | datapoint["is_matched"] = torch.tensor(False, dtype=torch.bool) 181 | datapoint_other["is_matched"] = torch.tensor(False, dtype=torch.bool) 182 | 183 | return datapoint, datapoint_other 184 | else: 185 | datapoint = self._load_single_sample(self.df.iloc[idx].name) 186 | return datapoint 187 | 188 | def _load_single_sample(self, file_id): 189 | fname = Path(self.data_dir) / f"{file_id}.pt" 190 | datapoint = torch.load(fname, weights_only=False) 191 | datapoint.update( 192 | self.single_and_double_from_datapoint( 193 | datapoint, 194 | self.rel_pos_dim, 195 | self.edge_chain_feature, 196 | self.use_plm_embeddings, 197 | self.conformation_node_feature, 198 | ) 199 | ) 200 | for key in datapoint: 201 | if isinstance(datapoint[key], torch.Tensor): 202 | datapoint[key] = datapoint[key].detach() 203 | return datapoint 204 | 205 | @staticmethod 206 | def single_and_double_from_datapoint( 207 | datapoint: dict, 208 | rel_pos_dim: int, 209 | edge_chain_feature: bool = False, 210 | use_plm_embeddings: bool = False, 211 | conformation_node_feature: bool = False, 212 | ): 213 | """ 214 | datapoint is a dict containing: 215 | aatype - [n,] tensor of ints for the amino acid (including unknown) 216 | is_heavy - [n,] tensor of ints where 1 is heavy chain and 0 is light chain. 217 | residue_index - [n,] tensor of ints assinging integer to each residue 218 | 219 | rel_pos_dim: integer determining edge feature dimension 220 | 221 | edge_chain_feature: boolean to add an edge feature z_ij that encodes what chain i and j are in. 222 | 223 | returns: 224 | A dictionary containing single a tensor of size (n, 23) and pair a tensor of size (n, n, 2 * rel_pos_dim + 1 + x) where x is 3 if edge_chain_feature and 0 otherwise. 225 | """ 226 | single_aa = torch.nn.functional.one_hot(datapoint["aatype"], 21) 227 | single_chain = torch.nn.functional.one_hot(datapoint["is_heavy"].long(), 2) 228 | if conformation_node_feature: 229 | single_conformation = torch.nn.functional.one_hot(datapoint["is_apo"].long(), 2) 230 | single = torch.cat((single_aa, single_chain, single_conformation), dim=-1) 231 | else: 232 | single = torch.cat((single_aa, single_chain), dim=-1) 233 | pair = datapoint["residue_index"] 234 | pair = pair[None] - pair[:, None] 235 | pair = pair.clamp(-rel_pos_dim, rel_pos_dim) + rel_pos_dim 236 | pair = torch.nn.functional.one_hot(pair, 2 * rel_pos_dim + 1) 237 | if edge_chain_feature: 238 | is_heavy = datapoint["is_heavy"] 239 | is_heavy = 2 * is_heavy.outer(is_heavy) + ( 240 | (1 - is_heavy).outer(1 - is_heavy) 241 | ) 242 | is_heavy = torch.nn.functional.one_hot(is_heavy.long()) 243 | pair = torch.cat((is_heavy, pair), dim=-1) 244 | if use_plm_embeddings: 245 | {"single": single.float(), "pair": pair.float(), "plm_embedding": datapoint["plm_embedding"]} 246 | return {"single": single.float(), "pair": pair.float()} 247 | 248 | def pad_square_tensors(tensors: list[torch.tensor]) -> torch.tensor: 249 | """Pads a list of tensors in the first two dimensions. 250 | 251 | Args: 252 | tensors (list[torch.tensor]): Input tensor are of shape (n_1, n_1, ...), (n_2, n_2, ...). where shape matches in the ... dimensions 253 | 254 | Returns: 255 | torch.tensor: A tensor of size (len(tensor), max(n_1,...), max(n_1,...), ...) 256 | """ 257 | max_len = max(map(len, tensors)) 258 | output = torch.zeros((len(tensors), max_len, max_len, *tensors[0].shape[2:])) 259 | for i, tensor in enumerate(tensors): 260 | output[i, : tensor.size(0), : tensor.size(1)] = tensor 261 | return output 262 | 263 | 264 | pad_first_dim_keys = [ 265 | "atom14_gt_positions", 266 | "atom14_alt_gt_positions", 267 | "atom14_atom_is_ambiguous", 268 | "atom14_gt_exists", 269 | "atom14_alt_gt_exists", 270 | "atom14_atom_exists", 271 | "single", 272 | "plm_embedding", 273 | "seq_mask", 274 | "aatype", 275 | "backbone_rigid_tensor", 276 | "backbone_rigid_mask", 277 | "rigidgroups_gt_frames", 278 | "rigidgroups_alt_gt_frames", 279 | "rigidgroups_gt_exists", 280 | "cdr_mask", 281 | "chi_mask", 282 | "chi_angles_sin_cos", 283 | "residue_index", 284 | "residx_atom14_to_atom37", 285 | "region_numeric", 286 | ] 287 | 288 | pad_first_two_dim_keys = ["pair"] 289 | 290 | 291 | def string_to_input(heavy: str, light: str, apo: bool = False, conformation_aware: bool = False, embedding=None, device: str = "cpu") -> dict: 292 | """Generates an input formatted for an Ibex model from heavy and light chain 293 | strings. 294 | 295 | Args: 296 | heavy (str): heavy chain 297 | light (str): light chain 298 | apo (bool): whether the structure is apo or holo (optional) 299 | 300 | Returns: 301 | dict: A dictionary containing 302 | aatype: an (n,) tensor of integers encoding the amino acid string 303 | is_heavy: an (n,) tensor where is_heavy[i] = 1 means residue i is heavy and 304 | is_heavy[i] = 0 means residue i is light 305 | residue_index: an (n,) tensor with indices for each residue. There is a gap 306 | of 500 between the last heavy residue and the first light residue 307 | single: a (1, n, 23) tensor of node features 308 | pair: a (1, n, n, 132) tensor of edge features 309 | """ 310 | aatype = [] 311 | is_heavy = [] 312 | for character in heavy: 313 | is_heavy.append(1) 314 | aatype.append(restype_order_with_x[character]) 315 | if light is not None: 316 | for character in light: 317 | is_heavy.append(0) 318 | aatype.append(restype_order_with_x[character]) 319 | is_heavy = torch.tensor(is_heavy) 320 | aatype = torch.tensor(aatype) 321 | if light is None: 322 | residue_index = torch.arange(len(heavy)) 323 | else: 324 | residue_index = torch.cat( 325 | (torch.arange(len(heavy)), torch.arange(len(light)) + 500) 326 | ) 327 | 328 | model_input = { 329 | "is_heavy": is_heavy, 330 | "aatype": aatype, 331 | "residue_index": residue_index, 332 | } 333 | 334 | if apo and conformation_aware: 335 | model_input["is_apo"] = torch.ones(len(heavy + (light if light is not None else '')),dtype=torch.int64) 336 | elif conformation_aware: 337 | model_input["is_apo"] = torch.zeros(len(heavy + (light if light is not None else '')),dtype=torch.int64) 338 | if embedding is not None: 339 | model_input["plm_embedding"] = embedding 340 | model_input.update( 341 | ABDataset.single_and_double_from_datapoint( 342 | model_input, 64, edge_chain_feature=True, conformation_node_feature=conformation_aware 343 | ) 344 | ) 345 | if "plm_embedding" in model_input: 346 | model_input["plm_embedding"] = model_input["plm_embedding"].unsqueeze(0) 347 | model_input["single"] = model_input["single"].unsqueeze(0) 348 | model_input["pair"] = model_input["pair"].unsqueeze(0) 349 | model_input = {k: v.to(device) for k, v in model_input.items()} 350 | return model_input 351 | 352 | 353 | def collate_fn(batch: dict): 354 | """A collate function so the ABDataset can be used in a torch dataloader. 355 | 356 | Args: 357 | batch (dict): A list of datapoints from ABDataset 358 | 359 | Returns: 360 | dict: A dictionary where the keys are the same as batch but map to a batched tensor where the batch is on the leading dimension. 361 | """ 362 | # Flatten the batch structure (needed for contrastive examples, where some of the batch elements are tuples of objects) 363 | flattened_batch = [] 364 | for item in batch: 365 | if isinstance(item, tuple): 366 | flattened_batch.extend(item) 367 | else: 368 | flattened_batch.append(item) 369 | batch = {key: [d[key] for d in flattened_batch] for key in flattened_batch[0]} 370 | output = {} 371 | for key in batch: 372 | if key in pad_first_dim_keys: 373 | output[key] = pad_sequence(batch[key], batch_first=True) 374 | elif key in pad_first_two_dim_keys: 375 | output[key] = pad_square_tensors(batch[key]) 376 | elif key == "resolution": 377 | output[key] = torch.Tensor(batch["resolution"]) 378 | elif key == "is_matched": 379 | output[key] = torch.stack(batch[key]) 380 | return output 381 | 382 | 383 | class SequenceDataset(Dataset): 384 | def __init__(self, fv_heavy_list, fv_light_list, apo_list=None, plm_model=None, batch_size=2, num_workers=0): 385 | self.fv_heavy_list = fv_heavy_list 386 | self.fv_light_list = fv_light_list 387 | self.apo_list = apo_list 388 | if plm_model is not None: 389 | with torch.no_grad(): 390 | self.fv_heavy_embedding = plm_model.embed_dataset( 391 | fv_heavy_list, 392 | batch_size=batch_size, 393 | max_len=max([len(x) for x in fv_heavy_list]), 394 | full_embeddings=True, 395 | full_precision=False, 396 | pooling_type="mean", 397 | num_workers=num_workers, 398 | sql=False 399 | ) 400 | self.fv_light_embedding = plm_model.embed_dataset( 401 | [light for light in fv_light_list if light is not None and light!=''], 402 | batch_size=batch_size, 403 | max_len=max([len(x) for x in fv_light_list]), 404 | full_embeddings=True, 405 | full_precision=False, 406 | pooling_type="mean", 407 | num_workers=num_workers, 408 | sql=False 409 | ) if fv_light_list is not None else None 410 | else: 411 | self.fv_heavy_embedding = None 412 | self.fv_light_embedding = None 413 | 414 | def __len__(self): 415 | return len(self.fv_heavy_list) 416 | 417 | def __getitem__(self, idx): 418 | heavy = self.fv_heavy_list[idx] 419 | light = self.fv_light_list[idx] if self.fv_light_list is not None else None 420 | if self.fv_heavy_embedding is not None and self.fv_light_embedding is not None: 421 | if light is None: 422 | embedding = self.fv_heavy_embedding[heavy] 423 | else: 424 | embedding = torch.concat([self.fv_heavy_embedding[heavy], self.fv_light_embedding[light]]) 425 | else: 426 | embedding = None 427 | if self.apo_list is not None: 428 | apo = self.apo_list[idx] 429 | return string_to_input(heavy, light, apo=apo, conformation_aware=True, embedding=embedding) 430 | return string_to_input(heavy, light, embedding=embedding) 431 | 432 | def sequence_collate_fn(batch): 433 | # Separating each component of the batch 434 | batch_dict = {key: [d[key] if key not in ["single","pair","plm_embedding"] else d[key].squeeze(0) for d in batch] for key in batch[0]} 435 | # Prepare output dictionary 436 | output = {} 437 | # Pad the "single" features 438 | output["single"] = pad_sequence(batch_dict["single"], batch_first=True) 439 | # Pad the "pair" features 440 | output["pair"] = pad_square_tensors(batch_dict["pair"]) 441 | # Copy other keys directly 442 | for key in ["aatype", "residue_index", "is_heavy"]: 443 | output[key] = pad_sequence(batch_dict[key], batch_first=True) 444 | if "is_apo" in batch_dict: 445 | output["is_apo"] = pad_sequence(batch_dict["is_apo"], batch_first=True) 446 | if "plm_embedding" in batch_dict: 447 | output["plm_embedding"] = pad_sequence(batch_dict["plm_embedding"], batch_first=True) 448 | # Create mask based on the lengths of "aatype" 449 | mask = [torch.ones(len(aatype)) for aatype in batch_dict["aatype"]] 450 | output["mask"] = pad_sequence(mask, batch_first=True) 451 | return output 452 | 453 | if __name__ == "__main__": 454 | from torch.utils.data import DataLoader 455 | 456 | 457 | dataset = ABDataset("test", split_file="/data/dreyerf1/ibex/split.csv", data_dir="/data/dreyerf1/ibex/structures", edge_chain_feature=True) 458 | dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn) 459 | for batch in dataloader: 460 | break 461 | for key in batch: 462 | print(key, batch[key].shape) 463 | -------------------------------------------------------------------------------- /src/ibex/loss/aligned_rmsd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | CDR_RANGES_AHO = { 18 | "L1": (23,42), 19 | "L2": (56,72), 20 | "L3": (106,138), 21 | "H1": (23,42), 22 | "H2": (56,69), 23 | "H3": (106,138), 24 | } 25 | 26 | region_mapping = { 27 | "cdrh1": 0, 28 | "cdrh2": 1, 29 | "cdrh3": 2, 30 | "cdrl1": 3, 31 | "cdrl2": 4, 32 | "cdrl3": 5, 33 | "fwh1": 6, 34 | "fwh2": 7, 35 | "fwh3": 8, 36 | "fwh4": 9, 37 | "fwl1": 10, 38 | "fwl2": 11, 39 | "fwl3": 12, 40 | "fwl4": 13, 41 | } 42 | 43 | heavy_chain_regions = {0, 1, 2, 6, 7, 8, 9} 44 | light_chain_regions = {3, 4, 5, 10, 11, 12, 13} 45 | 46 | heavy_framework_regions = {6, 7, 8, 9} 47 | light_framework_regions = {10, 11, 12, 13} 48 | 49 | heavy_cdr_regions = {0, 1, 2} 50 | light_cdr_regions = {3, 4, 5} 51 | 52 | 53 | def apply_transformation(coords, R, t): 54 | # Apply inverse rotation and translation 55 | coords_transformed = torch.bmm(R.transpose(-1, -2), (coords - t.unsqueeze(1)).transpose(-1, -2)).transpose(-1, -2) 56 | return coords_transformed 57 | 58 | 59 | def coordinates_to_dihedral(input: torch.Tensor) -> torch.Tensor: 60 | """Compute dihedral angle from a set of four points. 61 | 62 | Given an input tensor with shape (*, 4, 3) representing points (p1, p2, p3, p4) 63 | compute the dihedral angle between the plane defined by (p1, p2, p3) and (p2, p3, p4). 64 | 65 | Parameters 66 | ---------- 67 | input: torch.Tensor 68 | Shape (*, 4, 3) 69 | 70 | Returns 71 | ------- 72 | torsion: torch.Tensor 73 | Shape (*,) 74 | """ 75 | assert input.ndim >= 3 76 | assert input.shape[-2] == 4 77 | assert input.shape[-1] == 3 78 | 79 | # difference vectors: [a = p2 - p1, b = p3 - p2, c = p4 - p3] 80 | delta = input[..., 1:, :] - input[..., :-1, :] 81 | a, b, c = torch.unbind(delta, dim=-2) 82 | 83 | # torsion angle is angle from axb to bxc counterclockwise around b 84 | # see https://www.math.fsu.edu/~quine/MB_10/6_torsion.pdf 85 | 86 | axb = torch.cross(a, b, dim=-1) 87 | bxc = torch.cross(b, c, dim=-1) 88 | 89 | # orthogonal basis in plane perpendicular to b 90 | # NOTE v1 and v2 are not unit but have the same magnitude 91 | v1 = axb 92 | v2 = torch.cross(torch.nn.functional.normalize(b, dim=-1), axb, dim=-1) 93 | 94 | x = torch.sum(bxc * v1, dim=-1) 95 | y = torch.sum(bxc * v2, dim=-1) 96 | phi = torch.atan2(y, x) 97 | 98 | return phi 99 | 100 | 101 | def _positions_to_phi( 102 | positions: torch.Tensor, 103 | mask: torch.Tensor, 104 | residue_index: torch.Tensor, 105 | chain_index: torch.Tensor, 106 | ): 107 | chain_boundary_mask = torch.cat( 108 | [torch.zeros_like(chain_index[..., :1], dtype=torch.bool), torch.diff(chain_index, n=1, dim=-1) == 0], dim=-1 109 | ) 110 | 111 | chain_break_mask = torch.cat( 112 | [torch.zeros_like(residue_index[..., :1], dtype=torch.bool), torch.diff(residue_index, n=1, dim=-1) == 1], 113 | dim=-1, 114 | ) 115 | 116 | (input, mask) = ( 117 | torch.stack( 118 | [ 119 | x[..., :-1, 2, :], # C(i-1) 120 | x[..., 1:, 0, :], # N(i) 121 | x[..., 1:, 1, :], # CA(i) 122 | x[..., 1:, 2, :], # C(i) 123 | ], 124 | dim=-2, 125 | ) 126 | for x in (positions, mask) 127 | ) # [..., L, 4, 3] 128 | 129 | mask = mask.all(dim=-1).all(dim=-1) # [..., L] 130 | 131 | angles = coordinates_to_dihedral(input) 132 | nan_tensor = torch.full_like(angles[..., :1], float("nan")) 133 | false_tensor = torch.zeros_like(mask[..., :1]) 134 | 135 | angles = torch.cat([nan_tensor, angles], dim=-1) 136 | mask = torch.cat([false_tensor, mask], dim=-1) 137 | mask = mask & chain_boundary_mask & chain_break_mask 138 | 139 | return angles, mask 140 | 141 | 142 | def _positions_to_psi( 143 | positions: torch.Tensor, 144 | mask: torch.Tensor, 145 | residue_index: torch.Tensor, 146 | chain_index: torch.Tensor, 147 | ): 148 | chain_boundary_mask = torch.cat( 149 | [torch.diff(chain_index, n=1, dim=-1) == 0, torch.zeros_like(chain_index[..., :1], dtype=torch.bool)], dim=-1 150 | ) 151 | 152 | chain_break_mask = torch.cat( 153 | [torch.diff(residue_index, n=1, dim=-1) == 1, torch.zeros_like(residue_index[..., :1], dtype=torch.bool)], 154 | dim=-1, 155 | ) 156 | 157 | (input, mask) = ( 158 | torch.stack( 159 | [ 160 | x[..., :-1, 0, :], # N(i) 161 | x[..., :-1, 1, :], # CA(i) 162 | x[..., :-1, 2, :], # C(i) 163 | x[..., 1:, 0, :], # N(i+1) 164 | ], 165 | dim=-2, 166 | ) 167 | for x in (positions, mask) 168 | ) # [..., L, 4, 3] 169 | 170 | mask = mask.all(dim=-1).all(dim=-1) # [..., L] 171 | 172 | angles = coordinates_to_dihedral(input) 173 | nan_tensor = torch.full_like(angles[..., :1], float("nan")) 174 | false_tensor = torch.zeros_like(mask[..., :1]) 175 | 176 | angles = torch.cat([angles, nan_tensor], dim=-1) 177 | mask = torch.cat([mask, false_tensor], dim=-1) 178 | mask = mask & chain_boundary_mask & chain_break_mask 179 | 180 | return angles, mask 181 | 182 | 183 | def _positions_to_omega( 184 | positions: torch.Tensor, 185 | mask: torch.Tensor, 186 | residue_index: torch.Tensor, 187 | chain_index: torch.Tensor, 188 | ): 189 | chain_boundary_mask = torch.cat( 190 | [torch.diff(chain_index, n=1, dim=-1) == 0, torch.zeros_like(chain_index[..., :1], dtype=torch.bool)], dim=-1 191 | ) 192 | 193 | chain_break_mask = torch.cat( 194 | [torch.diff(residue_index, n=1, dim=-1) == 1, torch.zeros_like(residue_index[..., :1], dtype=torch.bool)], 195 | dim=-1, 196 | ) 197 | 198 | (input, mask) = ( 199 | torch.stack( 200 | [ 201 | x[..., :-1, 1, :], # CA(i) 202 | x[..., :-1, 2, :], # C(i) 203 | x[..., 1:, 0, :], # N(i+1) 204 | x[..., 1:, 1, :], # CA(i+1) 205 | ], 206 | dim=-2, 207 | ) 208 | for x in (positions, mask) 209 | ) # [..., L, 4, 3] 210 | 211 | mask = mask.all(dim=-1).all(dim=-1) # [..., L] 212 | 213 | angles = coordinates_to_dihedral(input) 214 | nan_tensor = torch.full_like(angles[..., :1], float("nan")) 215 | false_tensor = torch.zeros_like(mask[..., :1]) 216 | 217 | angles = torch.cat([angles, nan_tensor], dim=-1) 218 | mask = torch.cat([mask, false_tensor], dim=-1) 219 | mask = mask & chain_boundary_mask & chain_break_mask 220 | 221 | return angles, mask 222 | 223 | 224 | def positions_to_backbone_dihedrals( 225 | positions: torch.Tensor, mask: torch.Tensor, residue_index: torch.Tensor | None = None, chain_index: torch.Tensor | None = None 226 | ) -> tuple[torch.Tensor, torch.Tensor]: 227 | """Compute Backbone dihedral angles (phi, psi, omega) from the atom-wise coordinates. 228 | 229 | Parameters 230 | ---------- 231 | positions: Tensor 232 | Shape (..., L, 37, 3) tensor of the atom-wise coordinates 233 | mask: BoolTensor 234 | Shape (..., L, 37) boolean tensor indicating which atoms are present 235 | residue_index: Tensor | None = None 236 | Optional shape (..., L) tensor specifying the index of each residue along its chain. 237 | If supplied this is used to mask dihedrals that cross a chain break. 238 | chain_index: Tensor | None = None 239 | Optional shape (..., L) tensor specifying the chain index of each residue. 240 | If supplied this is used to mask dihedrals that cross a chain boundary. 241 | 242 | where `NAW` is the number of atoms in the atom wide representation. 243 | 244 | Returns 245 | ------- 246 | dihedrals: Tensor 247 | Shape (..., L, 3) tensor of the dihedral angles (phi, psi, omega) 248 | dihedrals_mask: BoolTensor 249 | Shape (..., L, 3) boolean tensor indicating which dihedrals are present 250 | """ 251 | assert positions.ndim >= 3 252 | L = positions.shape[-3] 253 | device = positions.device 254 | 255 | if residue_index is None: 256 | residue_index = torch.arange(L).expand(*positions.shape[:-3], -1) # [..., L] 257 | residue_index = residue_index.to(device) 258 | 259 | if chain_index is None: 260 | chain_index = torch.zeros_like(positions[..., :, 0, 0], dtype=torch.int64) # [..., L] 261 | chain_index = chain_index.to(device) 262 | 263 | mask = mask.unsqueeze(-1).expand(*mask.shape,3) 264 | phi, phi_mask = _positions_to_phi(positions, mask, residue_index=residue_index, chain_index=chain_index) 265 | psi, psi_mask = _positions_to_psi(positions, mask, residue_index=residue_index, chain_index=chain_index) 266 | omega, omega_mask = _positions_to_omega(positions, mask, residue_index=residue_index, chain_index=chain_index) 267 | 268 | dihedrals = torch.stack([phi, psi, omega], dim=-1) 269 | dihedrals_mask = torch.stack([phi_mask, psi_mask, omega_mask], dim=-1) 270 | 271 | return dihedrals, dihedrals_mask 272 | 273 | 274 | def rmsd_summary_calculation( 275 | coords_truth: torch.Tensor, 276 | coords_prediction: torch.Tensor, 277 | sequence_mask: torch.Tensor, 278 | region_mask: torch.Tensor, 279 | chain_mask: torch.Tensor, 280 | batch_average: bool = True, 281 | ) -> dict[str, torch.Tensor]: 282 | """Computes RMSD summary for different regions and chains. 283 | 284 | Args: 285 | coords_truth (torch.Tensor): (B, n, 14/37, 3) ground truth coordinates 286 | coords_prediction (torch.Tensor): (B, n, 14/37, 3) predicted coordinates 287 | sequence_mask (torch.Tensor): (B, n) where [i, j] = 1 if a coordinate for sequence i at residue j exists. 288 | region_mask (torch.Tensor): (B, n) region mask indicating the region of each residue 289 | chain_mask (torch.Tensor): (B, n) chain mask indicating the chain of each residue (0 for light chain, 1 for heavy chain) 290 | batch_average (bool): if True, average along the batch dimensions 291 | 292 | Returns: 293 | dict[str, torch.Tensor]: RMSD values for each region and chain 294 | """ 295 | results = {} 296 | 297 | # Align and compute RMSD for heavy chain regions 298 | heavy_chain_mask = chain_mask == 1 299 | 300 | heavy_chain_backbone_truth = extract_backbone_coordinates( 301 | coords_truth * heavy_chain_mask.unsqueeze(-1).unsqueeze(-1) 302 | ) 303 | heavy_chain_sequence_mask = extract_backbone_mask(sequence_mask * heavy_chain_mask) 304 | 305 | # Mask for framework regions only 306 | heavy_framework_mask = (region_mask.unsqueeze(-1) == torch.tensor(list(heavy_framework_regions), device=region_mask.device)).any(-1) * heavy_chain_mask 307 | heavy_framework_backbone_truth = extract_backbone_coordinates( 308 | coords_truth * heavy_framework_mask.unsqueeze(-1).unsqueeze(-1) 309 | ) 310 | heavy_framework_backbone_prediction = extract_backbone_coordinates( 311 | coords_prediction * heavy_framework_mask.unsqueeze(-1).unsqueeze(-1) 312 | ) 313 | heavy_framework_sequence_mask = extract_backbone_mask(sequence_mask * heavy_framework_mask) 314 | 315 | # Align framework regions 316 | heavy_framework_backbone_truth, R, t = batch_align( 317 | heavy_framework_backbone_truth, heavy_framework_backbone_prediction, heavy_framework_sequence_mask, return_transform=True 318 | ) 319 | 320 | # Compute RMSD for heavy chain framework as a whole 321 | square_distance = ( 322 | torch.linalg.norm( 323 | heavy_framework_backbone_prediction - heavy_framework_backbone_truth, dim=-1 324 | ) 325 | ** 2 326 | ) 327 | square_distance = square_distance * heavy_framework_sequence_mask 328 | 329 | heavy_framework_msd = torch.sum(square_distance, dim=-1) / heavy_framework_sequence_mask.sum(dim=-1) 330 | heavy_framework_rmsd = torch.sqrt(heavy_framework_msd) 331 | 332 | if batch_average: 333 | heavy_framework_rmsd = heavy_framework_rmsd.mean() 334 | 335 | results["fwh_rmsd"] = heavy_framework_rmsd 336 | 337 | # Apply the same transformation to the CDR regions 338 | heavy_cdr_mask = (region_mask.unsqueeze(-1) == torch.tensor(list(heavy_cdr_regions), device=region_mask.device)).any(-1) * heavy_chain_mask 339 | heavy_cdr_backbone_prediction = extract_backbone_coordinates( 340 | coords_prediction * heavy_cdr_mask.unsqueeze(-1).unsqueeze(-1) 341 | ) 342 | heavy_cdr_backbone_prediction_aligned = apply_transformation(heavy_cdr_backbone_prediction, R, t) 343 | 344 | for region_name, region_idx in region_mapping.items(): 345 | if region_idx in heavy_cdr_regions: 346 | region_mask_region = region_mask == region_idx 347 | region_mask_backbone = extract_backbone_mask(region_mask_region) 348 | 349 | heavy_chain_region_mask = region_mask_backbone * heavy_chain_sequence_mask 350 | square_distance = ( 351 | torch.linalg.norm( 352 | heavy_cdr_backbone_prediction_aligned - heavy_chain_backbone_truth, dim=-1 353 | ) 354 | ** 2 355 | ) 356 | square_distance = square_distance * heavy_chain_region_mask 357 | 358 | region_msd = torch.sum(square_distance, dim=-1) / heavy_chain_region_mask.sum(dim=-1) 359 | region_rmsd = torch.sqrt(region_msd) 360 | 361 | if batch_average: 362 | region_rmsd = region_rmsd.mean() 363 | 364 | results[f"{region_name}_rmsd"] = region_rmsd 365 | 366 | # Align and compute RMSD for light chain regions 367 | light_chain_mask = chain_mask == 0 368 | 369 | light_chain_backbone_truth = extract_backbone_coordinates( 370 | coords_truth * light_chain_mask.unsqueeze(-1).unsqueeze(-1) 371 | ) 372 | light_chain_sequence_mask = extract_backbone_mask(sequence_mask * light_chain_mask) 373 | 374 | # Mask for framework regions only 375 | light_framework_mask = (region_mask.unsqueeze(-1) == torch.tensor(list(light_framework_regions), device=region_mask.device)).any(-1) * light_chain_mask 376 | light_framework_backbone_truth = extract_backbone_coordinates( 377 | coords_truth * light_framework_mask.unsqueeze(-1).unsqueeze(-1) 378 | ) 379 | light_framework_backbone_prediction = extract_backbone_coordinates( 380 | coords_prediction * light_framework_mask.unsqueeze(-1).unsqueeze(-1) 381 | ) 382 | light_framework_sequence_mask = extract_backbone_mask(sequence_mask * light_framework_mask) 383 | 384 | # Align framework regions 385 | light_framework_backbone_truth, R, t = batch_align( 386 | light_framework_backbone_truth, light_framework_backbone_prediction, light_framework_sequence_mask, return_transform=True 387 | ) 388 | 389 | # Compute RMSD for light chain framework as a whole 390 | square_distance = ( 391 | torch.linalg.norm( 392 | light_framework_backbone_prediction - light_framework_backbone_truth, dim=-1 393 | ) 394 | ** 2 395 | ) 396 | square_distance = square_distance * light_framework_sequence_mask 397 | 398 | light_framework_msd = torch.sum(square_distance, dim=-1) / light_framework_sequence_mask.sum(dim=-1) 399 | light_framework_rmsd = torch.sqrt(light_framework_msd) 400 | 401 | if batch_average: 402 | light_framework_rmsd = light_framework_rmsd.mean() 403 | 404 | results["fwl_rmsd"] = light_framework_rmsd 405 | 406 | # Apply the same transformation to the CDR regions 407 | light_cdr_mask = (region_mask.unsqueeze(-1) == torch.tensor(list(light_cdr_regions), device=region_mask.device)).any(-1) * light_chain_mask 408 | light_cdr_backbone_prediction = extract_backbone_coordinates( 409 | coords_prediction * light_cdr_mask.unsqueeze(-1).unsqueeze(-1) 410 | ) 411 | light_cdr_backbone_prediction_aligned = apply_transformation(light_cdr_backbone_prediction, R, t) 412 | 413 | for region_name, region_idx in region_mapping.items(): 414 | if region_idx in light_cdr_regions: 415 | region_mask_region = region_mask == region_idx 416 | region_mask_backbone = extract_backbone_mask(region_mask_region) 417 | 418 | light_chain_region_mask = region_mask_backbone * light_chain_sequence_mask 419 | square_distance = ( 420 | torch.linalg.norm( 421 | light_cdr_backbone_prediction_aligned - light_chain_backbone_truth, dim=-1 422 | ) 423 | ** 2 424 | ) 425 | square_distance = square_distance * light_chain_region_mask 426 | 427 | region_msd = torch.sum(square_distance, dim=-1) / light_chain_region_mask.sum(dim=-1) 428 | region_rmsd = torch.sqrt(region_msd) 429 | 430 | if batch_average: 431 | region_rmsd = region_rmsd.mean() 432 | 433 | results[f"{region_name}_rmsd"] = region_rmsd 434 | 435 | return results 436 | 437 | 438 | def aligned_fv_and_cdrh3_rmsd( 439 | coords_truth: torch.Tensor, 440 | coords_prediction: torch.Tensor, 441 | sequence_mask: torch.Tensor, 442 | cdrh3_mask: torch.Tensor, 443 | batch_average: bool = True, 444 | ) -> dict[str, torch.Tensor]: 445 | """Aligns positions_truth to positions_prediction in a batched way. 446 | 447 | Args: 448 | positions_truth (torch.Tensor): (B, n, 14/37, 3) ground truth coordinates 449 | positions_prediction (torch.Tensor): (B, n, 14/37, 3) predicted coordinates 450 | sequence_mask (torch.Tensor): (B, n) where [i, j] = 1 if a coordinate for sequence i at residue j exists. 451 | cdrh3_mask (torch.Tensor): (B, n) where [i, j] = 1 if a coordinate for sequence i at residue j is part of the cdrh3 loop. 452 | batch_average (bool): if True, average along the batch dimensions 453 | 454 | Returns: 455 | A dictionary[str, torch.Tensor] containing 456 | seq_rmsd: the RMSD of the backbone after backbone alignment 457 | cdrh3_rmsd: the RMSD of the CDRH3 backbone after backbone alignment 458 | """ 459 | 460 | # extract backbones and mask and put in 3d point cloud shape 461 | backbone_truth = extract_backbone_coordinates(coords_truth) 462 | backbone_prediction = extract_backbone_coordinates(coords_prediction) 463 | backbone_sequence_mask = extract_backbone_mask(sequence_mask) 464 | 465 | # align backbones 466 | backbone_truth = batch_align( 467 | backbone_truth, backbone_prediction, backbone_sequence_mask 468 | ) 469 | 470 | square_distance = ( 471 | torch.linalg.norm(backbone_prediction - backbone_truth, dim=-1) ** 2 472 | ) 473 | square_distance = square_distance * backbone_sequence_mask 474 | 475 | seq_msd = square_distance.sum(dim=-1) / backbone_sequence_mask.sum(dim=-1) 476 | seq_rmsd = torch.sqrt(seq_msd) 477 | 478 | backbone_cdrh3_mask = extract_backbone_mask(cdrh3_mask) 479 | square_distance = square_distance * (backbone_cdrh3_mask * backbone_sequence_mask) 480 | cdrh3_msd = torch.sum(square_distance, dim=-1) / backbone_cdrh3_mask.sum(dim=-1) 481 | cdrh3_rmsd = torch.sqrt(cdrh3_msd) 482 | 483 | if batch_average: 484 | seq_rmsd = seq_rmsd.mean() 485 | cdrh3_rmsd = cdrh3_rmsd.mean() 486 | 487 | return {"seq_rmsd": seq_rmsd, "cdrh3_rmsd": cdrh3_rmsd} 488 | 489 | 490 | def extract_backbone_coordinates(positions: torch.Tensor) -> torch.Tensor: 491 | """(B, n, 14/37, 3) -> (B, n * 4, 3)""" 492 | batch_size = positions.size(0) 493 | backbone_positions = positions[:, :, :4, :] # (B, n, 4, 3) 494 | backbone_positions_flat = backbone_positions.reshape( 495 | batch_size, -1, 3 496 | ) # (B, n * 4, 3) 497 | return backbone_positions_flat 498 | 499 | 500 | def extract_backbone_mask(sequence_mask: torch.Tensor) -> torch.Tensor: 501 | """(B, n) -> (B, n * 4)""" 502 | batch_size = sequence_mask.size(0) 503 | return sequence_mask.unsqueeze(-1).repeat(1, 1, 4).view(batch_size, -1) 504 | 505 | 506 | def batch_align(x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor, return_transform=False): 507 | """Aligns 3-dimensional point clouds. Based on section 4 of https://igl.ethz.ch/projects/ARAP/svd_rot.pdf. 508 | 509 | Args: 510 | x (torch.Tensor): A tensor shape (B, n, 3) 511 | y (torch.Tensor): A tensor shape (B, n, 3) 512 | mask (torch.Tensor): A mask of shape (B, n) were mask[i, j]=1 indicates the presence of a point in sample i at location j of both sequences. 513 | return_transform (bool): If True, return rotation and translation matrices. 514 | 515 | Returns: 516 | torch.Tensor: a rototranslated x aligned to y. 517 | torch.Tensor: rotation matrix used for alignment (if return_transform is True). 518 | torch.Tensor: translation matrix used for alignment (if return_transform is True). 519 | """ 520 | 521 | # check inputs 522 | if x.ndim != 3: 523 | raise ValueError(f"Expected x.ndim=3. Instead got {x.ndim=}") 524 | if y.ndim != 3: 525 | raise ValueError(f"Expected y.ndim=3. Instead got {x.ndim=}") 526 | if mask.ndim != 2: 527 | raise ValueError(f"Expected mask.ndim=2. Instead got {mask.ndim=}") 528 | if x.size(-1) != 3: 529 | raise ValueError(f"Expected last dim of x to be 3. Instead got {x.size(-1)=}") 530 | if y.size(-1) != 3: 531 | raise ValueError(f"Expected last dim of y to be 3. Instead got {y.size(-1)=}") 532 | 533 | # (B, n) -> (B, n, 1) 534 | mask = mask.unsqueeze(-1) 535 | 536 | # zero masked coordinates (the below centroids computation relies on it). 537 | x = x * mask 538 | y = y * mask 539 | 540 | # centroids (B, 3) 541 | p_bar = x.sum(dim=1) / mask.sum(dim=1) 542 | q_bar = y.sum(dim=1) / mask.sum(dim=1) 543 | 544 | # centered points (B, n, 3) 545 | x_centered = x - p_bar.unsqueeze(1) 546 | y_centered = y - q_bar.unsqueeze(1) 547 | 548 | # compute covariance matrices (B, 3, 3) 549 | num_valid_points = mask.sum(dim=1, keepdim=True).sum(dim=2, keepdim=True) 550 | S = torch.bmm(x_centered.transpose(-1, -2), y_centered * mask) / num_valid_points 551 | S = S + 10e-6 * torch.eye(S.size(-1)).unsqueeze(0).to(S.device) 552 | 553 | # Compute U, V from SVD (B, 3, 3) 554 | U, _, Vh = torch.linalg.svd(S) 555 | V = Vh.transpose(-1, -2) 556 | Uh = U.transpose(-1, -2) 557 | 558 | # correction that accounts for reflection (B, 3, 3) 559 | correction = torch.eye(x.size(-1)).unsqueeze(0).repeat(x.size(0), 1, 1).to(x.device) 560 | correction[:, -1, -1] = torch.det(torch.bmm(V, Uh).float()) 561 | 562 | # rotation (B, 3, 3) 563 | R = V.bmm(correction).bmm(Uh) 564 | 565 | # translation (B, 3) 566 | t = q_bar - R.bmm(p_bar.unsqueeze(-1)).squeeze() 567 | 568 | # translate x to align with y 569 | x_rotated = torch.bmm(R, x.transpose(-1, -2)).transpose(-1, -2) 570 | x_aligned = x_rotated + t.unsqueeze(1) 571 | 572 | if return_transform: 573 | return x_aligned, R, t 574 | else: 575 | return x_aligned 576 | 577 | if __name__ == "__main__": 578 | from torch.utils.data import DataLoader 579 | from ibex.dataloader import ABDataset, collate_fn 580 | dataset = ABDataset("test", split_file="/data/dreyerf1/ibex/split.csv", data_dir="/data/dreyerf1/ibex/structures", edge_chain_feature=True) 581 | dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn, shuffle=False) 582 | for batch in dataloader: 583 | coords = batch["atom14_gt_positions"] 584 | preds = coords + 10 585 | mask = batch["seq_mask"] 586 | cdrh3_mask = batch["region_numeric"] == 2 587 | print(aligned_fv_and_cdrh3_rmsd(coords, preds, mask, cdrh3_mask)) 588 | break 589 | 590 | -------------------------------------------------------------------------------- /src/ibex/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Genentech 2 | # Copyright 2024 Exscientia 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import lightning.pytorch as pl 17 | import torch 18 | from torch.optim import AdamW 19 | from torch.optim.lr_scheduler import LambdaLR, LinearLR 20 | from torch.utils.data import DataLoader, WeightedRandomSampler 21 | from torch.utils.data import ConcatDataset 22 | from esm.models.esmc import ESMC 23 | from esm.sdk.api import ESMProtein, LogitsConfig 24 | 25 | from ibex.dataloader import ABDataset, collate_fn, string_to_input, sequence_collate_fn 26 | from ibex.loss import IbexLoss 27 | from ibex.utils import output_to_pdb, add_atom37_to_output, output_to_protein, ENSEMBLE_MODELS, checkpoint_path 28 | from ibex.openfold.model import StructureModule 29 | 30 | class IbexDataModule(pl.LightningDataModule): 31 | def __init__( 32 | self, 33 | batch_size: int, 34 | data_dir_sabdab: str, 35 | split_file_sabdab: str, 36 | data_dir_predicted: str, 37 | split_file_predicted: str, 38 | fraction_predicted: float, 39 | data_dir_ig: str, 40 | split_file_ig: str, 41 | fraction_ig: float, 42 | fraction_matched: float, 43 | use_vhh: bool = False, 44 | use_tcr: bool = False, 45 | public_only_data: bool = False, 46 | rel_pos_dim: int = 64, 47 | edge_chain_feature: bool = False, 48 | num_workers: int = 0, 49 | pin_memory: bool = False, 50 | use_plm_embeddings: bool = False, 51 | use_weighted_sampler: bool = False, 52 | cluster_column: str = "cluster", 53 | weight_temperature: float = 1.0, 54 | weight_clip_quantile: float = 0.99, 55 | vhh_weight: float = 1.0, 56 | tcr_weight: float = 1.0, 57 | matched_weight: float = 1.0, 58 | conformation_node_feature: bool = False, 59 | use_contrastive: bool = False, 60 | use_boltz_only: bool = False, 61 | ): 62 | super().__init__() 63 | self.data_dir_sabdab = data_dir_sabdab 64 | self.data_dir_predicted = data_dir_predicted 65 | self.data_dir_ig = data_dir_ig 66 | self.split_file_sabdab = split_file_sabdab 67 | self.split_file_predicted = split_file_predicted 68 | self.split_file_ig = split_file_ig 69 | self.fraction_predicted = fraction_predicted 70 | self.fraction_ig = fraction_ig 71 | self.fraction_matched = fraction_matched if use_contrastive else 0.0 72 | self.batch_size = batch_size 73 | self.use_vhh = use_vhh 74 | self.use_tcr = use_tcr 75 | self.public_only_data = public_only_data 76 | self.rel_pos_dim = rel_pos_dim 77 | self.edge_chain_feature = edge_chain_feature 78 | self.num_workers = num_workers 79 | self.pin_memory = pin_memory 80 | self.use_plm_embeddings = use_plm_embeddings 81 | self.use_weighted_sampler = use_weighted_sampler 82 | self.cluster_column = cluster_column 83 | self.weight_temperature = weight_temperature 84 | self.weight_clip_quantile = weight_clip_quantile 85 | self.vhh_weight = vhh_weight 86 | self.tcr_weight = tcr_weight 87 | self.matched_weight = matched_weight 88 | self.conformation_node_feature = conformation_node_feature 89 | self.use_contrastive = use_contrastive 90 | self.use_boltz_only = use_boltz_only 91 | 92 | def setup(self, stage: str): 93 | sabdab_dataset = ABDataset( 94 | "train", 95 | split_file=self.split_file_sabdab, 96 | data_dir=self.data_dir_sabdab, 97 | use_vhh=self.use_vhh, 98 | use_tcr=self.use_tcr, 99 | rel_pos_dim=self.rel_pos_dim, 100 | edge_chain_feature=self.edge_chain_feature, 101 | use_plm_embeddings=self.use_plm_embeddings, 102 | use_public_only=self.public_only_data, 103 | cluster_column=self.cluster_column, 104 | weight_temperature=self.weight_temperature, 105 | weight_clip_quantile=self.weight_clip_quantile, 106 | vhh_weight=self.vhh_weight, 107 | tcr_weight=self.tcr_weight, 108 | matched_weight=self.matched_weight, 109 | conformation_node_feature=self.conformation_node_feature, 110 | use_contrastive=self.use_contrastive, 111 | is_matched=False, 112 | use_weights=self.use_weighted_sampler, 113 | ) 114 | sabdab_dataset_matched = ABDataset( 115 | "train", 116 | split_file=self.split_file_sabdab, 117 | data_dir=self.data_dir_sabdab, 118 | use_vhh=self.use_vhh, 119 | use_tcr=self.use_tcr, 120 | rel_pos_dim=self.rel_pos_dim, 121 | edge_chain_feature=self.edge_chain_feature, 122 | use_plm_embeddings=self.use_plm_embeddings, 123 | use_public_only=self.public_only_data, 124 | cluster_column=self.cluster_column, 125 | weight_temperature=self.weight_temperature, 126 | weight_clip_quantile=self.weight_clip_quantile, 127 | vhh_weight=self.vhh_weight, 128 | tcr_weight=self.tcr_weight, 129 | conformation_node_feature=self.conformation_node_feature, 130 | use_contrastive=self.use_contrastive, 131 | is_matched=True, 132 | use_weights=self.use_weighted_sampler, 133 | ) if self.use_contrastive else None 134 | predicted_dataset = ABDataset( 135 | "all", 136 | split_file=self.split_file_predicted, 137 | data_dir=self.data_dir_predicted, 138 | use_vhh=False, 139 | use_tcr=False, 140 | rel_pos_dim=self.rel_pos_dim, 141 | edge_chain_feature=self.edge_chain_feature, 142 | use_plm_embeddings=self.use_plm_embeddings, 143 | use_public_only=self.public_only_data, 144 | cluster_column=self.cluster_column, 145 | weight_temperature=self.weight_temperature, 146 | weight_clip_quantile=self.weight_clip_quantile, 147 | vhh_weight=self.vhh_weight, 148 | tcr_weight=self.tcr_weight, 149 | conformation_node_feature=self.conformation_node_feature, 150 | use_boltz_only=self.use_boltz_only, 151 | use_weights=self.use_weighted_sampler, 152 | ) if self.fraction_predicted > 0.0 else None 153 | ig_dataset = ABDataset( 154 | "all", 155 | split_file=self.split_file_ig, 156 | data_dir=self.data_dir_ig, 157 | use_vhh=True, 158 | use_tcr=False, 159 | rel_pos_dim=self.rel_pos_dim, 160 | edge_chain_feature=self.edge_chain_feature, 161 | use_plm_embeddings=self.use_plm_embeddings, 162 | use_public_only=self.public_only_data, 163 | cluster_column=self.cluster_column, 164 | weight_temperature=self.weight_temperature, 165 | weight_clip_quantile=self.weight_clip_quantile, 166 | vhh_weight=self.vhh_weight, 167 | tcr_weight=self.tcr_weight, 168 | conformation_node_feature=self.conformation_node_feature, 169 | use_weights=self.use_weighted_sampler, 170 | ) if self.fraction_ig > 0.0 else None 171 | self.train_dataset = ConcatDataset([d for d in [sabdab_dataset, predicted_dataset, ig_dataset, sabdab_dataset_matched] if d is not None]) 172 | self.len_train_dataset = len(sabdab_dataset) 173 | sabdab_fraction = 1.0 - self.fraction_predicted - self.fraction_ig - self.fraction_matched 174 | self.train_dataset_weights = torch.cat( 175 | [d.weights * frac for d, frac in [(sabdab_dataset, sabdab_fraction), (predicted_dataset, self.fraction_predicted), (ig_dataset, self.fraction_ig), (sabdab_dataset_matched, self.fraction_matched)] if d is not None] 176 | ) 177 | self.train_dataset_weights = self.train_dataset_weights / self.train_dataset_weights.sum() 178 | self.valid_dataset = ABDataset( 179 | "valid", 180 | split_file=self.split_file_sabdab, 181 | data_dir=self.data_dir_sabdab, 182 | use_vhh=False, 183 | use_tcr=False, 184 | rel_pos_dim=self.rel_pos_dim, 185 | edge_chain_feature=self.edge_chain_feature, 186 | use_plm_embeddings=self.use_plm_embeddings, 187 | use_public_only=self.public_only_data, 188 | conformation_node_feature=self.conformation_node_feature, 189 | ) 190 | self.test_dataset = ABDataset( 191 | "test", 192 | split_file=self.split_file_sabdab, 193 | data_dir=self.data_dir_sabdab, 194 | use_vhh=False, 195 | use_tcr=False, 196 | rel_pos_dim=self.rel_pos_dim, 197 | edge_chain_feature=self.edge_chain_feature, 198 | use_plm_embeddings=self.use_plm_embeddings, 199 | use_public_only=self.public_only_data, 200 | conformation_node_feature=self.conformation_node_feature, 201 | ) 202 | if stage == "test": 203 | self.test_public_dataset = ABDataset( 204 | "test", 205 | split_file=self.split_file_sabdab, 206 | data_dir=self.data_dir_sabdab, 207 | use_vhh=False, 208 | use_tcr=False, 209 | rel_pos_dim=self.rel_pos_dim, 210 | edge_chain_feature=self.edge_chain_feature, 211 | use_plm_embeddings=self.use_plm_embeddings, 212 | use_public_only=True, 213 | conformation_node_feature=self.conformation_node_feature, 214 | ) 215 | self.test_public_vhh_dataset = ABDataset( 216 | "test_vhh", 217 | split_file=self.split_file_sabdab, 218 | data_dir=self.data_dir_sabdab, 219 | use_vhh=True, 220 | use_tcr=False, 221 | rel_pos_dim=self.rel_pos_dim, 222 | edge_chain_feature=self.edge_chain_feature, 223 | use_plm_embeddings=self.use_plm_embeddings, 224 | use_public_only=True, 225 | conformation_node_feature=self.conformation_node_feature, 226 | ) 227 | self.test_public_tcr_dataset = ABDataset( 228 | "test_tcr", 229 | split_file=self.split_file_sabdab, 230 | data_dir=self.data_dir_sabdab, 231 | use_vhh=False, 232 | use_tcr=True, 233 | rel_pos_dim=self.rel_pos_dim, 234 | edge_chain_feature=self.edge_chain_feature, 235 | use_plm_embeddings=self.use_plm_embeddings, 236 | use_public_only=True, 237 | conformation_node_feature=self.conformation_node_feature, 238 | ) 239 | 240 | def train_dataloader(self): 241 | return DataLoader( 242 | self.train_dataset, 243 | collate_fn=collate_fn, 244 | sampler=WeightedRandomSampler(self.train_dataset_weights, num_samples=self.len_train_dataset // 2, replacement=self.use_weighted_sampler), 245 | batch_size=self.batch_size // 2, # Each element in our dataset actually contains 2 samples 246 | num_workers=self.num_workers, 247 | pin_memory=self.pin_memory, 248 | ) 249 | 250 | def val_dataloader(self): 251 | return DataLoader( 252 | self.valid_dataset, 253 | collate_fn=collate_fn, 254 | batch_size=self.batch_size, 255 | shuffle=False, 256 | num_workers=self.num_workers, 257 | pin_memory=self.pin_memory, 258 | ) 259 | 260 | def test_dataloader(self): 261 | return DataLoader( 262 | self.test_dataset, 263 | collate_fn=collate_fn, 264 | batch_size=self.batch_size, 265 | shuffle=False, 266 | num_workers=self.num_workers, 267 | pin_memory=self.pin_memory, 268 | ) 269 | 270 | def test_public_dataloader(self): 271 | return DataLoader( 272 | self.test_public_dataset, 273 | collate_fn=collate_fn, 274 | batch_size=self.batch_size, 275 | shuffle=False, 276 | num_workers=self.num_workers, 277 | pin_memory=self.pin_memory, 278 | ) 279 | 280 | def test_public_vhh_dataloader(self): 281 | return DataLoader( 282 | self.test_public_vhh_dataset, 283 | collate_fn=collate_fn, 284 | batch_size=self.batch_size, 285 | shuffle=False, 286 | num_workers=self.num_workers, 287 | pin_memory=self.pin_memory, 288 | ) 289 | 290 | def test_public_tcr_dataloader(self): 291 | return DataLoader( 292 | self.test_public_tcr_dataset, 293 | collate_fn=collate_fn, 294 | batch_size=self.batch_size, 295 | shuffle=False, 296 | num_workers=self.num_workers, 297 | pin_memory=self.pin_memory, 298 | ) 299 | 300 | class Ibex(pl.LightningModule): 301 | def __init__(self, model_config, loss_config, optim_config, conformation_aware=False, use_plm=False, ensemble=False, models=None, stage=None, init_plm=False): 302 | super().__init__() 303 | model_config["use_plddt"] = loss_config.plddt.weight > 0 304 | self.save_hyperparameters() 305 | self.model_config = model_config 306 | self.loss_config = loss_config 307 | self.optim_config = optim_config 308 | self.loss = IbexLoss(loss_config) 309 | self.finetune = False 310 | self.contrastive = False 311 | self.ensemble = ensemble 312 | self.conformation_aware = conformation_aware 313 | self.stage = stage 314 | 315 | if self.ensemble: 316 | if models is None: 317 | raise ValueError("Models must be provided for ensemble mode.") 318 | # Use the provided models to create the EnsembleStructureModule 319 | self.model = EnsembleStructureModule(models) 320 | else: 321 | self.model = StructureModule(**model_config) 322 | 323 | if use_plm: 324 | # # self.plm_model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True) 325 | # self.plm_model = ESMC.from_pretrained("esmc_300m") 326 | # self.plm_model.eval() 327 | # for param in self.plm_model.parameters(): 328 | # param.requires_grad = False 329 | self.plm_model = True 330 | if init_plm: 331 | self.set_plm() 332 | # self.plm_model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True) 333 | else: 334 | self.plm_model = None 335 | 336 | def set_plm(self): 337 | if self.plm_model is not None: 338 | self.plm_model = ESMC.from_pretrained("esmc_300m") 339 | self.plm_model.eval() 340 | for param in self.plm_model.parameters(): 341 | param.requires_grad = False 342 | 343 | @classmethod 344 | def load_from_ensemble_checkpoint(cls, checkpoint_path, map_location=None): 345 | checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False) 346 | state_dicts = checkpoint['state_dicts'] 347 | model_config = checkpoint['model_config'] 348 | loss_config = checkpoint['loss_config'] 349 | optim_config = checkpoint['optim_config'] 350 | if 'conformation_aware' in checkpoint: 351 | conformation_aware = checkpoint['conformation_aware'] 352 | else: 353 | conformation_aware = False 354 | if 'language' in checkpoint: 355 | init_plm = checkpoint['language'] 356 | else: 357 | init_plm = False 358 | # Initialize the models from state dicts 359 | models = [StructureModule(**model_config) for _ in state_dicts] 360 | for model, state_dict in zip(models, state_dicts): 361 | model.load_state_dict(state_dict) 362 | 363 | return cls(model_config, loss_config, optim_config, conformation_aware=conformation_aware, ensemble=True, models=models, use_plm=init_plm, init_plm=init_plm) 364 | 365 | @classmethod 366 | def from_pretrained(cls, model="ibex", map_location=None, cache_dir=None): 367 | ckpt = checkpoint_path(model, cache_dir=cache_dir) 368 | 369 | if model in ENSEMBLE_MODELS: 370 | ibex_model = Ibex.load_from_ensemble_checkpoint(ckpt, map_location=map_location) 371 | else: 372 | ibex_model = Ibex.load_from_checkpoint(ckpt, map_location=map_location) 373 | if ibex_model.plm_model is not None: 374 | ibex_model.set_plm() 375 | return ibex_model 376 | 377 | def training_step(self, batch, batch_idx): 378 | stage=f"_{self.stage}" if self.stage else "" 379 | loss = self._step(batch, "train"+stage) 380 | grad_norm = torch.nn.utils.clip_grad_norm_( 381 | self.parameters(), max_norm=self.optim_config.max_grad_norm 382 | ) 383 | self.log(f"monitor{stage}/grad_norm", grad_norm) 384 | # Log custom epoch and learning rate 385 | if stage: 386 | self.log( 387 | f"monitor{stage}/epoch", 388 | self.current_epoch 389 | ) 390 | self.log( 391 | f"monitor{stage}/learning_rate", 392 | self.trainer.optimizers[0].param_groups[0]['lr'] 393 | ) 394 | return loss 395 | 396 | def validation_step(self, batch, batch_idx): 397 | stage=f"_{self.stage}" if self.stage else "" 398 | self._step(batch, "valid"+stage) 399 | self.log( 400 | f"monitor{stage}/finetune", 401 | float(self.finetune), 402 | on_step=False, 403 | on_epoch=True, 404 | sync_dist=True, 405 | ) 406 | 407 | def _step(self, batch, split): 408 | if self.plm_model is not None: 409 | output = self.model( 410 | { 411 | "single": batch["single"], 412 | "pair": batch["pair"], 413 | "plm_embedding": batch["plm_embedding"], 414 | }, 415 | batch["aatype"], 416 | batch["seq_mask"], 417 | ) 418 | else: 419 | output = self.model( 420 | { 421 | "single": batch["single"], 422 | "pair": batch["pair"], 423 | }, 424 | batch["aatype"], 425 | batch["seq_mask"], 426 | ) 427 | loss, loss_dict = self.loss(output, batch, self.finetune, (self.contrastive and "train" in split)) 428 | for loss_name in loss_dict: 429 | self.log( 430 | f"{split}/{loss_name}", 431 | loss_dict[loss_name], 432 | prog_bar=loss_name == "loss", 433 | on_epoch=True, 434 | on_step=False, 435 | sync_dist=True, 436 | ) 437 | return loss 438 | 439 | def configure_optimizers(self): 440 | if self.optim_config.optimizer == "AdamW": 441 | if self.stage=='first_stage': 442 | optimizer = AdamW( 443 | self.parameters(), 444 | lr=self.optim_config.lr, 445 | betas=(0.9, 0.99), 446 | eps=1e-6, 447 | ) 448 | scheduler = LinearLR(optimizer, start_factor=1e-3, total_iters=self.optim_config.linear_iters) 449 | else: 450 | optimizer = AdamW( 451 | self.parameters(), 452 | lr=self.optim_config.lr, 453 | betas=(0.9, 0.99), 454 | eps=1e-6, 455 | ) 456 | scheduler = LambdaLR( 457 | optimizer, lambda epoch: max(self.optim_config.lambda_lr ** (epoch-self.optim_config.lambda_iters), self.optim_config.lambda_min_factor) 458 | if epoch >= self.optim_config.lambda_iters else 1 459 | ) 460 | 461 | lr_scheduler = { 462 | "scheduler": scheduler, 463 | "interval": "step", 464 | "frequency": 1, 465 | "name": "learning_rate", 466 | } 467 | return [optimizer], [lr_scheduler] 468 | else: 469 | raise ValueError( 470 | "Expected AdamW as optimizer. Instead got" 471 | f" {self.optim_config.optimizer=}." 472 | ) 473 | 474 | def predict(self, fv_heavy, fv_light, device, ensemble=False, pdb_string=True, apo=False): 475 | self.model.eval() 476 | self.to(device) 477 | if self.plm_model is not None: 478 | self.plm_model.to(device) 479 | if fv_light is None or fv_light=="": 480 | # for nanobodies, embed the heavy chain 481 | prot_h = ESMProtein(sequence=fv_heavy) 482 | prot_tensor_h = self.plm_model.encode(prot_h) 483 | logits_output_h = self.plm_model.logits( 484 | prot_tensor_h, LogitsConfig(sequence=True, return_embeddings=True) 485 | ) 486 | embedding = logits_output_h.embeddings[0,1:-1] 487 | # tokenized = self.plm_model.tokenizer([fv_heavy], padding=True, return_tensors='pt') 488 | # output = self.plm_model(**tokenized) 489 | # embedding = output.last_hidden_state[0][1:-1] 490 | else: 491 | # for antibodies, embed both chains 492 | prot_h = ESMProtein(sequence=fv_heavy) 493 | prot_tensor_h = self.plm_model.encode(prot_h) 494 | logits_output_h = self.plm_model.logits( 495 | prot_tensor_h, LogitsConfig(sequence=True, return_embeddings=True) 496 | ) 497 | embed_vh = logits_output_h.embeddings[0,1:-1] 498 | 499 | prot_l = ESMProtein(sequence=fv_light) 500 | prot_tensor_l = self.plm_model.encode(prot_l) 501 | logits_output_l = self.plm_model.logits( 502 | prot_tensor_l, LogitsConfig(sequence=True, return_embeddings=True) 503 | ) 504 | embed_vl = logits_output_l.embeddings[0,1:-1] 505 | # tokenized = self.plm_model.tokenizer([fv_heavy,fv_light], padding=True, return_tensors='pt') 506 | # output = self.plm_model(**tokenized) 507 | # embed_vh=output.last_hidden_state[0][1:len(fv_heavy)+1] 508 | # embed_vl=output.last_hidden_state[1][1:len(fv_light)+1] 509 | embedding = torch.concat([embed_vh, embed_vl]) 510 | else: 511 | # if PLM is not used, set embedding to None and use one-hot encoding instead 512 | embedding = None 513 | ab_input = string_to_input(heavy=fv_heavy, light=fv_light, apo=apo, conformation_aware=self.conformation_aware, embedding=embedding, device=device) 514 | ab_input_batch = { 515 | key: (value.unsqueeze(0) if key not in ["single", "pair", "plm_embedding"] else value) 516 | for key, value in ab_input.items() 517 | } 518 | # Forward pass with model 519 | result = self.model(ab_input_batch, ab_input_batch["aatype"], return_all=ensemble) 520 | if ensemble: 521 | predictions = [] 522 | for i, output in enumerate(result): 523 | result[i] = add_atom37_to_output(output, ab_input_batch["aatype"]) 524 | if pdb_string: 525 | predictions.append(output_to_pdb(result[i], ab_input)) 526 | else: 527 | predictions.append(output_to_protein(result[i], ab_input)) 528 | return predictions 529 | 530 | result = add_atom37_to_output(result, ab_input["aatype"].to(device)) 531 | # collate the results 532 | if pdb_string: 533 | return output_to_pdb(result, ab_input) 534 | return output_to_protein(result, ab_input) 535 | 536 | 537 | def predict_batch(self, fv_heavy_batch, fv_light_batch, device, ensemble=False, pdb_string=True, apo_list=None, num_workers=0): 538 | self.model.eval() 539 | self.to(device) 540 | 541 | # if self.plm_model is not None: 542 | # fv_heavy_embedding = self.plm_model.embed_dataset( 543 | # fv_heavy_batch, 544 | # batch_size=batch_size, 545 | # max_len=max([len(x) for x in fv_heavy_batch]), 546 | # full_embeddings=True, 547 | # full_precision=False, 548 | # pooling_type="mean", 549 | # num_workers=num_workers, 550 | # sql=False 551 | # ) 552 | # fv_light_embedding = self.plm_model.embed_dataset( 553 | # [light for light in fv_light_batch if light is not None and light!=''], 554 | # batch_size=batch_size, 555 | # max_len=max([len(x) for x in fv_light_batch]), 556 | # full_embeddings=True, 557 | # full_precision=False, 558 | # pooling_type="mean", 559 | # num_workers=num_workers, 560 | # sql=False 561 | # ) if fv_light_batch is not None else None 562 | # else: 563 | # fv_heavy_embedding = None 564 | # fv_light_embedding = None 565 | 566 | batch = [] 567 | for i, fv_heavy in enumerate(fv_heavy_batch): 568 | fv_light = fv_light_batch[i] if fv_light_batch is not None else None 569 | apo = apo_list[i] if apo_list is not None else False 570 | # if fv_light_embedding is None and fv_heavy_embedding is not None: 571 | # embedding = fv_heavy_embedding[fv_heavy][1:len(fv_heavy)+1] 572 | # elif fv_light_embedding is not None: 573 | # embedding = torch.concat([ 574 | # fv_heavy_embedding[fv_heavy][1:len(fv_heavy)+1], 575 | # fv_light_embedding[fv_light][1:len(fv_light)+1] 576 | # ]) 577 | # else: 578 | # embedding = None 579 | if self.plm_model is not None: 580 | prot_h = ESMProtein(sequence=fv_heavy) 581 | prot_tensor_h = self.plm_model.encode(prot_h) 582 | logits_output_h = self.plm_model.logits( 583 | prot_tensor_h, LogitsConfig(sequence=True, return_embeddings=True) 584 | ) 585 | embed_vh = logits_output_h.embeddings[0,1:-1] 586 | if fv_light is None or fv_light=="": 587 | # for nanobodies, embed the heavy chain 588 | embedding = embed_vh 589 | else: 590 | # for antibodies, embed both chains 591 | prot_l = ESMProtein(sequence=fv_light) 592 | prot_tensor_l = self.plm_model.encode(prot_l) 593 | logits_output_l = self.plm_model.logits( 594 | prot_tensor_l, LogitsConfig(sequence=True, return_embeddings=True) 595 | ) 596 | embed_vl = logits_output_l.embeddings[0,1:-1] 597 | embedding = torch.concat([embed_vh, embed_vl]) 598 | else: 599 | embedding = None 600 | batch.append(string_to_input(heavy=fv_heavy, light=fv_light, apo=apo, conformation_aware=self.conformation_aware, embedding=embedding, device=device)) 601 | 602 | ab_input_batch = sequence_collate_fn(batch) 603 | # Move inputs to the device 604 | ab_input_batch = {key: value.to(device) for key, value in ab_input_batch.items()} 605 | 606 | # Forward pass with mask 607 | results = self.model(ab_input_batch, ab_input_batch["aatype"], mask=ab_input_batch["mask"], return_all=ensemble) 608 | 609 | predictions = [] 610 | batch_size = ab_input_batch["aatype"].size(0) 611 | if ensemble: 612 | for i in range(batch_size): 613 | ensemble_preds = [] 614 | for result in results: 615 | masked_result = { 616 | "positions": result["positions"][-1,i][ab_input_batch["mask"][i]==1].unsqueeze(0).unsqueeze(0), 617 | } 618 | if "plddt" in result: 619 | masked_result["plddt"]=result["plddt"][i][ab_input_batch["mask"][i]==1].unsqueeze(0) 620 | masked_input = { 621 | "aatype": ab_input_batch["aatype"][i][ab_input_batch["mask"][i]==1], 622 | "is_heavy": ab_input_batch["is_heavy"][i][ab_input_batch["mask"][i]==1], 623 | } 624 | # Add atom37 coordinates to the output 625 | masked_result = add_atom37_to_output(masked_result, masked_input["aatype"].unsqueeze(0)) 626 | if pdb_string: 627 | ensemble_preds.append(output_to_pdb(masked_result, masked_input)) 628 | else: 629 | ensemble_preds.append(output_to_protein(masked_result, masked_input)) 630 | predictions.append(ensemble_preds) 631 | return predictions 632 | 633 | # Iterate over each item in the batch 634 | for i in range(batch_size): 635 | masked_result = { 636 | "positions": results["positions"][-1,i][ab_input_batch["mask"][i]==1].unsqueeze(0).unsqueeze(0), 637 | } 638 | if "plddt" in results: 639 | masked_result["plddt"]=results["plddt"][i][ab_input_batch["mask"][i]==1].unsqueeze(0) 640 | masked_input = { 641 | "aatype": ab_input_batch["aatype"][i][ab_input_batch["mask"][i]==1], 642 | "is_heavy": ab_input_batch["is_heavy"][i][ab_input_batch["mask"][i]==1], 643 | } 644 | 645 | # Add atom37 coordinates to the output 646 | masked_result = add_atom37_to_output(masked_result, masked_input["aatype"].unsqueeze(0)) 647 | if pdb_string: 648 | # Generate a PDB string for the single result 649 | predictions.append(output_to_pdb(masked_result, masked_input)) 650 | else: 651 | # Generate a Protein object for the single result 652 | predictions.append(output_to_protein(masked_result, masked_input)) 653 | 654 | return predictions 655 | 656 | 657 | class EnsembleStructureModule(torch.nn.Module): 658 | def __init__(self, models): 659 | super(EnsembleStructureModule, self).__init__() 660 | if not all(isinstance(model, StructureModule) for model in models): 661 | raise ValueError("All models must be instances of StructureModule") 662 | self.models = torch.nn.ModuleList(models) 663 | 664 | def forward(self, inputs, aatype, mask=None, return_all=False, plddt_select=False): 665 | outputs = [] 666 | for model in self.models: 667 | output = model(inputs, aatype, mask) 668 | outputs.append(output) 669 | 670 | if return_all: 671 | # if all outputs are requested, no need for alignment, just return everything 672 | return outputs 673 | 674 | if aatype.shape[0] == 1: 675 | # if batch size is one, then we can use the original ABB2 implementation 676 | if plddt_select: 677 | from ibex.utils import compute_plddt 678 | plddts = torch.stack([compute_plddt(output["plddt"]).squeeze() for output in outputs]) 679 | plddts = torch.stack([x for x in plddts]) # [E, N, 3] 680 | closest_index = torch.argmax(torch.quantile(plddts, 0.1, dim=1)) 681 | print(torch.quantile(plddts, 0.1, dim=1), closest_index) 682 | return outputs[closest_index] 683 | else: 684 | # Stack positions along a new axis for batch processing 685 | positions = [output['positions'][-1].squeeze() for output in outputs] 686 | traces = torch.stack([x[:,0] for x in positions]) # [E, N, 3] 687 | # find the rotation and translation that aligns the traces 688 | R,t = find_alignment_transform(traces) 689 | aligned_traces = (traces-t) @ R 690 | # compute rmsd to the mean and return the prediction closest to the mean 691 | rmsd_values = (aligned_traces - aligned_traces.mean(0)).square().sum(-1).mean(-1) 692 | closest_index = torch.argmin(rmsd_values) 693 | return outputs[closest_index] 694 | 695 | # for batch size > 1 we need to perform a more sophisticated batching operation 696 | # and keep track of the mask 697 | positions = [output['positions'][-1] for output in outputs] 698 | traces = torch.stack([x[:,:,0] for x in positions]) # [E, B, N, 3] 699 | # Permute dimensions to get [B, E, N, 3] 700 | traces = traces.permute(1, 0, 2, 3) # [B, E, N, 3] 701 | 702 | if mask is not None: 703 | # Expand the mask to [B, E, N, 3] for element-wise operations 704 | mask = mask.unsqueeze(1).unsqueeze(-1).expand(-1, traces.size(1), -1, traces.size(-1)) 705 | R,t = find_alignment_transform_batch(traces, mask) 706 | aligned_traces = (traces-t) @ R 707 | 708 | if mask is not None: 709 | # Compute the mean of aligned traces along the ensemble dimension, considering the mask 710 | masked_aligned_traces = aligned_traces * mask # Mask application 711 | mask_sum = mask.sum(1) 712 | mask_sum[mask_sum == 0] = 1 713 | mean_aligned_traces = masked_aligned_traces.sum(1) / mask_sum # [B, N, 3] 714 | # Compute RMSD, taking the mask into account -> [B, E, N] 715 | rmsd_values = ((aligned_traces - mean_aligned_traces.unsqueeze(1)) * mask).square().sum(-1) 716 | # Normalize by the number of valid elements per sequence 717 | rmsd_values = (rmsd_values * mask[:,:,:,0]).sum(-1) / mask[:,:,:,0].sum(-1) 718 | else: 719 | # rmsd_values = (aligned_traces - aligned_traces.mean(0)).square().sum(-1).mean(-1) 720 | rmsd_values = (aligned_traces - aligned_traces.mean(1).unsqueeze(1)).square().sum(-1).mean(-1) 721 | 722 | # Find the prediction with the minimum RMSD 723 | closest_index = torch.argmin(rmsd_values, dim=-1) 724 | result = outputs[-1] 725 | # now iterate over the batch and select the best prediction for each sequence 726 | for ibatch in range(len(closest_index)): 727 | for k in result: 728 | if k in ['single', 'plddt']: 729 | result[k][ibatch] = outputs[closest_index[ibatch]][k][ibatch] 730 | else: 731 | result[k][:, ibatch] = outputs[closest_index[ibatch]][k][:, ibatch] 732 | return result 733 | 734 | 735 | def find_alignment_transform(traces): 736 | centers = traces.mean(-2, keepdim=True) 737 | traces = traces - centers 738 | 739 | p1, p2 = traces[0], traces[1:] 740 | C = torch.einsum("i j k, j l -> i k l", p2, p1) 741 | V, _, W = torch.linalg.svd(C) 742 | U = torch.matmul(V, W) 743 | U = torch.matmul(torch.stack([torch.ones(len(p2), device=U.device),torch.ones(len(p2), device=U.device),torch.linalg.det(U)], dim=1)[:,:,None] * V, W) 744 | 745 | return torch.cat([torch.eye(3, device=U.device)[None], U]), centers 746 | 747 | 748 | def find_alignment_transform_batch(traces, mask=None): 749 | # traces: [B, E, N, 3], mask: [B, E, N, 3] 750 | if mask is not None: 751 | centers = (traces * mask).sum(dim=-2, keepdim=True) / mask.sum(dim=-2, keepdim=True) 752 | else: 753 | centers = traces.mean(dim=-2, keepdim=True) 754 | 755 | traces = traces - centers 756 | 757 | p1 = traces[:, 0, :, :] # [B, N, 3] 758 | p2 = traces[:, 1:, :, :] # [B, E-1, N, 3] 759 | 760 | if mask is not None: 761 | C = torch.einsum("...ni,...nj->...ij", p2 * mask[:, 1:], p1.unsqueeze(1) * mask[:, :1]) 762 | else: 763 | C = torch.einsum("...ni,...nj->...ij", p2, p1.unsqueeze(1)) 764 | 765 | V, _, W = torch.linalg.svd(C) 766 | 767 | # Compute the rotation matrix U 768 | U = torch.matmul(V, W) 769 | 770 | # Ensure U is a proper rotation matrix by checking its determinant 771 | det_U = torch.linalg.det(U) 772 | V[..., -1] *= torch.sign(det_U).unsqueeze(-1) 773 | 774 | U = torch.matmul(V, W) 775 | 776 | # Prepare identity matrices for the reference trace 777 | identity_matrices = torch.eye(3, device=U.device).expand(traces.size(0), 1, 3, 3) # [B, 1, 3, 3] 778 | 779 | # Concatenate identity matrix for reference with U for ensemble members 780 | all_transforms = torch.cat([identity_matrices, U], dim=1) # [B, E, 3, 3] 781 | 782 | return all_transforms, centers 783 | --------------------------------------------------------------------------------