├── src
└── ibex
│ ├── openfold
│ ├── utils
│ │ ├── __init__.py
│ │ ├── precision_utils.py
│ │ ├── tensor_utils.py
│ │ ├── data_transforms.py
│ │ ├── feats.py
│ │ └── protein.py
│ ├── resources
│ │ ├── __init__.py
│ │ └── stereo_chemical_props.txt
│ └── __init__.py
│ ├── __init__.py
│ ├── py.typed
│ ├── loss
│ ├── __init__.py
│ ├── loss.py
│ └── aligned_rmsd.py
│ ├── inference.py
│ ├── predict.py
│ ├── utils.py
│ ├── refine.py
│ ├── dataloader.py
│ └── model.py
├── docs
├── assets
│ ├── ibex.png
│ └── ibex_transparent.png
├── ibex_test.csv
└── Genentech_license_weights_ibex
├── pyproject.toml
├── .gitignore
├── README.md
└── LICENSE
/src/ibex/openfold/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/ibex/openfold/resources/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/assets/ibex.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prescient-design/ibex/HEAD/docs/assets/ibex.png
--------------------------------------------------------------------------------
/src/ibex/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Ibex
2 | from .predict import inference, batch_inference
--------------------------------------------------------------------------------
/docs/assets/ibex_transparent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prescient-design/ibex/HEAD/docs/assets/ibex_transparent.png
--------------------------------------------------------------------------------
/src/ibex/openfold/__init__.py:
--------------------------------------------------------------------------------
1 | from . import model, resources, utils
2 |
3 | __all__ = ["model", "utils", "resources"]
4 |
--------------------------------------------------------------------------------
/src/ibex/py.typed:
--------------------------------------------------------------------------------
1 | Marks this package as providing type-annotations.
2 | See https://www.python.org/dev/peps/pep-0561/ for details.
3 |
--------------------------------------------------------------------------------
/src/ibex/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from ibex.loss.aligned_rmsd import aligned_fv_and_cdrh3_rmsd, batch_align
2 | from ibex.loss.loss import IbexLoss
3 |
--------------------------------------------------------------------------------
/src/ibex/openfold/utils/precision_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | # Copyright 2024 Exscientia
3 | # Copyright 2022 AlQuraishi Laboratory
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 |
19 |
20 | def is_fp16_enabled():
21 | # Autocast world
22 | fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
23 | fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
24 |
25 | return fp16_enabled
26 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 |
2 | [project]
3 | name = "prescient-ibex"
4 | description = "Lightweight antibody structure prediction tool (arXiv:2507.09054)"
5 | authors = [{name = "Prescient Design"}]
6 | dynamic = ["version", "readme"]
7 | requires-python = ">=3.10"
8 | dependencies = [
9 | "click<8.2.0",
10 | "dm-tree",
11 | "einops",
12 | "esm==3.1.6",
13 | "hydra-core",
14 | "levenshtein",
15 | "lightning>=2.5.0.post0",
16 | "loguru>=0.7.3",
17 | "ml-collections>=1.0.0",
18 | "numpy>=1.26.4",
19 | "omegaconf>=2.3.0",
20 | "openmm>=8.2.0",
21 | "pandas>=2.2.3",
22 | "pdbfixer",
23 | "pyparsing>=3.2.3",
24 | "python-box>=7.3.2",
25 | "s3fs>=2025.2.0",
26 | "scipy>=1.10.1",
27 | "sentencepiece",
28 | "torch>=2.5.1",
29 | "tqdm",
30 | "transformers",
31 | "typer>=0.15.2",
32 | "wandb>=0.19.8",
33 | ]
34 |
35 | [tool.setuptools.packages.find]
36 | where = ["src"]
37 |
38 | [tool.setuptools.package-data]
39 | "ibex.configs" = ["**/*.yaml"]
40 |
41 | [tool.setuptools.dynamic]
42 | readme = {file = "README.md", content-type = "text/markdown"}
43 |
44 | [project.scripts]
45 | ibex = "ibex.inference:app"
46 |
47 | [project.optional-dependencies]
48 | analysis = [
49 | "matplotlib>=3.10.1",
50 | "py3dmol>=2.4.2",
51 | "seaborn>=0.13.2",
52 | ]
53 |
54 | [build-system]
55 | requires = ["setuptools >= 65", "setuptools_scm[toml]>=6.2"]
56 | build-backend = 'setuptools.build_meta'
57 |
58 | [dependency-groups]
59 | dev = [
60 | "coverage[toml]>=7.8.0",
61 | "ipykernel>=6.29.5",
62 | "ipython>=8.35.0",
63 | "pip>=25.0.1",
64 | "pre-commit>=4.2.0",
65 | "pytest>=8.3.5",
66 | "ruff>=0.11.5",
67 | ]
68 |
69 | [tool.setuptools_scm]
70 | search_parent_directories = true
71 | version_scheme = "no-guess-dev"
72 | local_scheme = "node-and-date"
73 | fallback_version = "0.0.1"
74 |
75 | [tool.uv.sources]
76 | pdbfixer = { git = "https://github.com/openmm/pdbfixer.git" }
77 |
--------------------------------------------------------------------------------
/docs/ibex_test.csv:
--------------------------------------------------------------------------------
1 | pdb_id,is_vhh,is_tcr
2 | 8f5i,False,False
3 | 8byu,False,False
4 | 8f6l,False,False
5 | 8bse,False,False
6 | 8dy3,False,False
7 | 7wsl,False,False
8 | 7ycl,False,False
9 | 8ek6,False,False
10 | 7ucx,False,False
11 | 8d29,False,False
12 | 7s0j,False,False
13 | 7sts,False,False
14 | 7unb,False,False
15 | 7skz,False,False
16 | 8cwu,True,False
17 | 7t0j,False,False
18 | 7omn,True,False
19 | 7jkm,False,False
20 | 7r63,True,False
21 | 7vux,False,False
22 | 4lkx,False,False
23 | 1jps,False,False
24 | 7ue9,False,False
25 | 6xjq,False,False
26 | 7ndf,True,False
27 | 7zxu,True,False
28 | 6e65,False,False
29 | 7fau,True,False
30 | 5kzp,False,False
31 | 6l8t,False,False
32 | 7zmv,True,False
33 | 5c6w,False,False
34 | 7olz,True,False
35 | 6o26,False,False
36 | 7q6c,True,False
37 | 6cr1,False,False
38 | 4m6n,False,False
39 | 7nfr,True,False
40 | 7t0i,False,False
41 | 5chn,False,False
42 | 7zwi,False,False
43 | 6o3a,False,False
44 | 7sem,False,False
45 | 7s4g,False,False
46 | 6dcv,False,False
47 | 7o06,True,False
48 | 6o3k,False,False
49 | 7apj,True,False
50 | 6blh,False,False
51 | 7z0x,False,False
52 | 7sg6,False,False
53 | 7djx,True,False
54 | 7kpg,False,False
55 | 7rt9,False,False
56 | 3lmj,False,False
57 | 7rqr,False,False
58 | 7lcv,False,False
59 | 7ps3,False,False
60 | 7vyr,False,False
61 | 5fhb,False,False
62 | 7kfy,False,False
63 | 7raq,False,False
64 | 7sg5,False,False
65 | 7so5,False,False
66 | 7rxi,False,False
67 | 7qu1,False,False
68 | 7rxl,False,False
69 | 7urq,False,False
70 | 5ob5,False,False
71 | 1uj3,False,False
72 | 7vnb,True,False
73 | 7w1s,True,False
74 | 7o0s,True,False
75 | 3h0t,False,False
76 | 3n9g,False,False
77 | 7q0g,False,False
78 | 7np9,True,False
79 | 7neh,False,False
80 | 7ar0,True,False
81 | 1t3f,False,False
82 | 7f1g,True,False
83 | 7zf6,False,False
84 | 7sjs,False,False
85 | 6wyt,False,False
86 | 7ps4,False,False
87 | 7om5,True,False
88 | 7rqq,False,False
89 | 7vfb,True,False
90 | 7n4n,True,False
91 | 6m3b,False,False
92 | 5e8e,False,False
93 | 6dc4,False,False
94 | 7anq,True,False
95 | 7nfq,True,False
96 | 6mxs,False,False
97 | 7qf0,False,False
98 | 7s7r,True,False
99 | 7sn1,False,False
100 | 7e53,True,False
101 | 6yio,False,False
102 | 6pzg,False,False
103 | 6vjt,False,False
104 | 7omt,True,False
105 | 6e56,False,False
106 | 7r20,True,False
107 | 7f07,True,False
108 | 7vfa,True,False
109 | 7ryu,False,False
110 | 1gaf,False,False
111 | 7tp4,False,False
112 | 4jam,False,False
113 | 7rg7,True,False
114 | 7q4q,False,False
115 | 7r73,True,False
116 | 4d9q,False,False
117 | 7k8o,False,False
118 | 1h0d,False,False
119 | 3grw,False,False
120 | 5wn9,False,False
121 | 7ps6,False,False
122 | 3tje,False,False
123 | 4j6r,False,False
124 | 7rp2,False,False
125 | 6uce,False,False
126 | 7ttm,False,False
127 | 4xvs,False,False
128 | 7n3d,False,False
129 | 2a9m,False,False
130 | 4buh,False,False
131 | 6ss5,False,False
132 | 7seg,False,False
133 | 7u8c,False,False
134 | 7pbe,False,True
135 | 7rrg,False,True
136 | 7fjd,False,True
137 | 6zkz,False,True
138 | 7amp,False,True
139 | 7r7z,False,True
140 | 7fjf,False,True
141 | 7dzm,False,True
142 | 7n5c,False,True
143 | 7qpj,False,True
144 | 7na5,False,True
145 | 7fje,False,True
146 | 7pb2,False,True
147 | 7n5p,False,True
148 | 6zkx,False,True
149 | 7z50,False,True
150 | 7l1d,False,True
151 | 7f5k,False,True
152 | 6zky,False,True
153 | 7r80,False,True
154 | 7su9,False,True
155 |
--------------------------------------------------------------------------------
/src/ibex/openfold/utils/tensor_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | # Copyright 2024 Exscientia
3 | # Copyright 2021 AlQuraishi Laboratory
4 | # Copyright 2021 DeepMind Technologies Limited
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | from functools import partial
19 | from typing import List
20 |
21 | import torch
22 |
23 |
24 | def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
25 | zero_index = -1 * len(inds)
26 | first_inds = list(range(len(tensor.shape[:zero_index])))
27 | return tensor.permute(first_inds + [zero_index + i for i in inds])
28 |
29 |
30 | def flatten_final_dims(t: torch.Tensor, no_dims: int):
31 | return t.reshape(t.shape[:-no_dims] + (-1,))
32 |
33 |
34 | def masked_mean(mask, value, dim, eps=1e-4):
35 | mask = mask.expand(*value.shape)
36 | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
37 |
38 |
39 | def dict_multimap(fn, dicts):
40 | first = dicts[0]
41 | new_dict = {}
42 | for k, v in first.items():
43 | all_v = [d[k] for d in dicts]
44 | if type(v) is dict:
45 | new_dict[k] = dict_multimap(fn, all_v)
46 | else:
47 | new_dict[k] = fn(all_v)
48 |
49 | return new_dict
50 |
51 |
52 | def one_hot(x, v_bins):
53 | reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
54 | diffs = x[..., None] - reshaped_bins
55 | am = torch.argmin(torch.abs(diffs), dim=-1)
56 | return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
57 |
58 |
59 | def batched_gather(data, inds, dim=0, no_batch_dims=0):
60 | ranges = []
61 | for i, s in enumerate(data.shape[:no_batch_dims]):
62 | r = torch.arange(s)
63 | r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
64 | ranges.append(r)
65 |
66 | remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
67 | remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
68 | ranges.extend(remaining_dims)
69 | return data[tuple(ranges)]
70 |
71 |
72 | # With tree_map, a poor man's JAX tree_map
73 | def dict_map(fn, dic, leaf_type):
74 | new_dict = {}
75 | for k, v in dic.items():
76 | if type(v) is dict:
77 | new_dict[k] = dict_map(fn, v, leaf_type)
78 | else:
79 | new_dict[k] = tree_map(fn, v, leaf_type)
80 |
81 | return new_dict
82 |
83 |
84 | def tree_map(fn, tree, leaf_type):
85 | if isinstance(tree, dict):
86 | return dict_map(fn, tree, leaf_type)
87 | elif isinstance(tree, list):
88 | return [tree_map(fn, x, leaf_type) for x in tree]
89 | elif isinstance(tree, tuple):
90 | return tuple([tree_map(fn, x, leaf_type) for x in tree])
91 | elif isinstance(tree, leaf_type):
92 | return fn(tree)
93 | else:
94 | print(type(tree))
95 | raise ValueError("Not supported")
96 |
97 |
98 | tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
99 |
--------------------------------------------------------------------------------
/src/ibex/openfold/utils/data_transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | # Copyright 2024 Exscientia
3 | # Copyright 2021 AlQuraishi Laboratory
4 | # Copyright 2021 DeepMind Technologies Limited
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import torch
19 |
20 | from ibex.openfold.utils import residue_constants as rc
21 |
22 |
23 | def make_atom14_masks(protein):
24 | """Construct denser atom positions (14 dimensions instead of 37)."""
25 | restype_atom14_to_atom37 = []
26 | restype_atom37_to_atom14 = []
27 | restype_atom14_mask = []
28 |
29 | for rt in rc.restypes:
30 | atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
31 | restype_atom14_to_atom37.append(
32 | [(rc.atom_order[name] if name else 0) for name in atom_names]
33 | )
34 | atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
35 | restype_atom37_to_atom14.append(
36 | [
37 | (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
38 | for name in rc.atom_types
39 | ]
40 | )
41 |
42 | restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names])
43 |
44 | # Add dummy mapping for restype 'UNK'
45 | restype_atom14_to_atom37.append([0] * 14)
46 | restype_atom37_to_atom14.append([0] * 37)
47 | restype_atom14_mask.append([0.0] * 14)
48 |
49 | restype_atom14_to_atom37 = torch.tensor(
50 | restype_atom14_to_atom37,
51 | dtype=torch.int32,
52 | device=protein["aatype"].device,
53 | )
54 | restype_atom37_to_atom14 = torch.tensor(
55 | restype_atom37_to_atom14,
56 | dtype=torch.int32,
57 | device=protein["aatype"].device,
58 | )
59 | restype_atom14_mask = torch.tensor(
60 | restype_atom14_mask,
61 | dtype=torch.float32,
62 | device=protein["aatype"].device,
63 | )
64 | protein_aatype = protein["aatype"].to(torch.long)
65 |
66 | # create the mapping for (residx, atom14) --> atom37, i.e. an array
67 | # with shape (num_res, 14) containing the atom37 indices for this protein
68 | residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
69 | residx_atom14_mask = restype_atom14_mask[protein_aatype]
70 |
71 | protein["atom14_atom_exists"] = residx_atom14_mask
72 | protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
73 |
74 | # create the gather indices for mapping back
75 | residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
76 | protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
77 |
78 | # create the corresponding mask
79 | restype_atom37_mask = torch.zeros(
80 | [21, 37], dtype=torch.float32, device=protein["aatype"].device
81 | )
82 | for restype, restype_letter in enumerate(rc.restypes):
83 | restype_name = rc.restype_1to3[restype_letter]
84 | atom_names = rc.residue_atoms[restype_name]
85 | for atom_name in atom_names:
86 | atom_type = rc.atom_order[atom_name]
87 | restype_atom37_mask[restype, atom_type] = 1
88 |
89 | residx_atom37_mask = restype_atom37_mask[protein_aatype]
90 | protein["atom37_atom_exists"] = residx_atom37_mask
91 |
92 | return protein
93 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # Abstra
171 | # Abstra is an AI-powered process automation framework.
172 | # Ignore directories containing user credentials, local state, and settings.
173 | # Learn more at https://abstra.io/docs
174 | .abstra/
175 |
176 | # Visual Studio Code
177 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
178 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
179 | # and can be added to the global gitignore or merged into this file. However, if you prefer,
180 | # you could uncomment the following to ignore the enitre vscode folder
181 | # .vscode/
182 |
183 | # Ruff stuff:
184 | .ruff_cache/
185 |
186 | # PyPI configuration file
187 | .pypirc
188 |
189 | # Cursor
190 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
191 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
192 | # refer to https://docs.cursor.com/context/ignore-files
193 | .cursorignore
194 | .cursorindexingignore
--------------------------------------------------------------------------------
/src/ibex/inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from pathlib import Path
17 |
18 | import pandas as pd
19 | from loguru import logger
20 | import typer
21 |
22 | from ibex.model import Ibex
23 | from ibex.predict import inference, batch_inference
24 | from ibex.utils import MODEL_CHECKPOINTS, ENSEMBLE_MODELS, checkpoint_path
25 |
26 | warnings.filterwarnings("ignore")
27 |
28 |
29 | def main(
30 | abodybuilder3: bool = typer.Option(False, help="Use the AbodyBuilder3 model instead of Ibex for inference."),
31 | ckpt: str = typer.Option("", help="Path to model checkpoint. This is only needed to load a user specified checkpoint."),
32 | fv_heavy: str = typer.Option("", help="Sequence of the heavy chain."),
33 | fv_light: str = typer.Option("", help="Sequence of the light chain."),
34 | csv: str = typer.Option("", help="CSV file containing sequences of heavy and light chains. Columns should be named 'fv_heavy' and 'fv_light'. Output file names can be provided in a 'id' column."),
35 | parquet: str = typer.Option("", help="Parquet file containing sequences of heavy and light chains. Columns should be named 'fv_heavy' and 'fv_light'. Output file names can be provided in a 'id' column."),
36 | output: Path = typer.Option("prediction.pdb", help="Output file for the PDB structure, or path to the output folder when a parquet or csv file is provided."),
37 | batch_size: int = typer.Option(32, help="Batch size for inference if a parquet or csv file is provided."),
38 | ensemble: bool = typer.Option(False, help="Specify if checkpoint provided is an ensemble (only needed if using the explicit --ckpt flag)."),
39 | save_all: bool = typer.Option(False, help="Save all structures of the ensemble as output files."),
40 | refine: bool = typer.Option(False, help="Refine the output structures with openMM."),
41 | refine_checks: bool = typer.Option(False, help="Additional checks to fix cis-isomers and D-stereoisomers during refinement."),
42 | apo: bool = typer.Option(False, help="Predict structures in the apo conformation."),
43 | ):
44 | """
45 | Model can be specified by name or by providing a checkpoint path. If a CSV or parquet file is provided, the model will perform inference on all sequences in the file, otherwise it will perform inference on a single heavy and light sequence pair.
46 | """
47 | model = "ibex"
48 | if abodybuilder3:
49 | model = "abodybuilder3"
50 | if ckpt=="" and model not in MODEL_CHECKPOINTS:
51 | typer.echo(f"Invalid model name: {model}. Valid options are: {', '.join(MODEL_CHECKPOINTS.keys())} or provide a checkpoint path with the --ckpt option.")
52 | raise typer.Exit(code=1)
53 | if ckpt == "":
54 | ckpt = checkpoint_path(model)
55 | ckpt=Path(ckpt)
56 |
57 | if ensemble or model in ENSEMBLE_MODELS:
58 | logger.info(f"Loading ensemble model from {ckpt=}")
59 | ibex_model = Ibex.load_from_ensemble_checkpoint(ckpt)
60 | else:
61 | logger.info(f"Loading single model from {ckpt=}")
62 | ibex_model = Ibex.load_from_checkpoint(ckpt)
63 | if ibex_model.plm_model is not None:
64 | ibex_model.set_plm()
65 | if fv_heavy:
66 | logger.info("Performing inference on a single heavy and light sequence pair.")
67 | inference(ibex_model, fv_heavy, fv_light, output, save_all=save_all, refine=refine, refine_checks=refine_checks, apo=apo)
68 | elif csv or parquet:
69 | if save_all:
70 | logger.warning("save_all was set to True, but ensemble output is not implemented for batched inference. Setting save_all to False.")
71 | save_all=False
72 | if output==Path("prediction.pdb"):
73 | # overwrite default for batch inference
74 | output = Path("predictions")
75 | if csv:
76 | csv = Path(csv)
77 | logger.info(f"Performing inference on sequences from {csv=}")
78 | df = pd.read_csv(csv)
79 | else:
80 | parquet = Path(parquet)
81 | logger.info(f"Performing inference on sequences from {parquet=}")
82 | df = pd.read_parquet(parquet)
83 | df['fv_light'] = df['fv_light'].fillna('')
84 | fv_heavy_list = df["fv_heavy"].tolist()
85 | fv_light_list = df["fv_light"].tolist()
86 | apo_list = None
87 | if ibex_model.conformation_aware:
88 | apo_list = [apo]*len(fv_heavy_list)
89 | if "id" in df.columns:
90 | names = df["id"].tolist()
91 | batch_inference(
92 | ibex_model,
93 | fv_heavy_list,
94 | fv_light_list,
95 | output,
96 | batch_size,
97 | names,
98 | refine=refine,
99 | refine_checks=refine_checks,
100 | apo_list=apo_list
101 | )
102 | else:
103 | batch_inference(
104 | ibex_model,
105 | fv_heavy_list,
106 | fv_light_list,
107 | output,
108 | batch_size,
109 | refine=refine,
110 | refine_checks=refine_checks,
111 | apo_list=apo_list
112 | )
113 | else:
114 | typer.echo("Please provide sequences of heavy and light chains or a csv/parquet file containing sequences.")
115 | raise typer.Exit(code=1)
116 |
117 |
118 | def app():
119 | typer.run(main)
120 |
--------------------------------------------------------------------------------
/src/ibex/predict.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from pathlib import Path
17 | from typing import Optional
18 | import torch
19 | import tempfile
20 | from tqdm import tqdm
21 | from loguru import logger
22 |
23 | from ibex.model import Ibex, EnsembleStructureModule
24 | from ibex.refine import refine_file
25 |
26 |
27 | def process_file(pdb_string, output_file, refine, refine_checks=False):
28 | if not refine:
29 | with open(output_file, "w") as f:
30 | f.write(pdb_string)
31 | else:
32 | try:
33 | with tempfile.NamedTemporaryFile(delete=True) as tmpfile:
34 | tmpfile.write(pdb_string.encode('utf-8'))
35 | tmpfile.flush()
36 | refine_file(tmpfile.name, output_file, checks=refine_checks)
37 | except Exception as e:
38 | logger.warning(f"Refinement failed with error: {e}")
39 | with open(output_file, "w") as f:
40 | f.write(pdb_string)
41 |
42 |
43 | def inference(
44 | model: Ibex,
45 | fv_heavy: str,
46 | fv_light: str,
47 | output_file: Path,
48 | logging: bool = True,
49 | save_all = False,
50 | refine: bool = False,
51 | refine_checks: bool = False,
52 | apo: bool = False,
53 | return_pdb: bool = True
54 | ):
55 | device = "cuda" if torch.cuda.is_available() else "cpu"
56 | if device == "cpu" and logging:
57 | logger.warning("Inference is being done on CPU as GPU not found.")
58 | if save_all==True and not isinstance(model.model, EnsembleStructureModule):
59 | raise ValueError("save_all is set to True but model is not an ensemble model.")
60 | if return_pdb==False and refine:
61 | raise ValueError("Cannot return a protein object and refine at the same time. To run refinement, output format must be a PDB file (return_pdb==True).")
62 | if return_pdb==False and save_all:
63 | raise ValueError("Cannot return a protein object and save all outputs at the same time. To save all, output format must be a PDB file (return_pdb==True).")
64 | with torch.no_grad():
65 | pdb_string_or_protein = model.predict(fv_heavy, fv_light, device=device, ensemble=save_all, pdb_string=return_pdb, apo=apo)
66 | if not return_pdb:
67 | if logging:
68 | logger.info("Inference complete. Returning a protein object.")
69 | return pdb_string_or_protein
70 | if save_all:
71 | ensemble_files = []
72 | for i, pdb_string_current in enumerate(pdb_string_or_protein):
73 | output_file_current = output_file.parent / f"{output_file.stem}_{i+1}{output_file.suffix}"
74 | process_file(pdb_string_current, output_file_current, refine, refine_checks)
75 | ensemble_files.append(str(output_file_current))
76 | output_file = ensemble_files
77 | else:
78 | process_file(pdb_string_or_protein, output_file, refine, refine_checks)
79 | if logging:
80 | logger.info(f"Inference complete. Wrote PDB file to {output_file=}")
81 | return output_file
82 |
83 |
84 | def batch_inference(
85 | model: Ibex,
86 | fv_heavy_list: list[str],
87 | fv_light_list: list[str],
88 | output_dir: Path,
89 | batch_size: int,
90 | output_names: Optional[list[str]] = None,
91 | logging: bool = True,
92 | refine: bool = False,
93 | refine_checks: bool = False,
94 | apo_list: bool = None,
95 | return_pdb: bool = True
96 | ):
97 | if output_names is None:
98 | output_names = [f"output_{i}" for i in range(len(fv_heavy_list))]
99 |
100 | if len(fv_heavy_list) != len(fv_light_list) or len(fv_heavy_list) != len(output_names):
101 | raise ValueError("Input lists must have the same length.")
102 |
103 | device = "cuda" if torch.cuda.is_available() else "cpu"
104 | if device == "cpu" and logging:
105 | logger.warning("Inference is being done on CPU as GPU not found.")
106 |
107 | output_dir = Path(output_dir)
108 | output_dir.mkdir(parents=True, exist_ok=True)
109 |
110 | if apo_list is not None and not model.conformation_aware:
111 | raise ValueError("Model is not conformation-aware, but apo_list was provided.")
112 | if return_pdb==False and refine:
113 | raise ValueError("Cannot return a protein object and refine at the same time. To run refinement, output format must be a PDB file (return_pdb==True).")
114 |
115 | if model.plm_model is not None:
116 | model.plm_model = model.plm_model.to(device)
117 |
118 | name_idx = 0 # Index for tracking the position in output_names
119 | result_files = []
120 | for i in tqdm(range(0, len(fv_heavy_list), batch_size), desc="Processing batches"):
121 | fv_heavy_batch = fv_heavy_list[i:i+batch_size]
122 | fv_light_batch = fv_light_list[i:i+batch_size] if fv_light_list else None
123 | with torch.no_grad():
124 | pdb_strings_or_proteins = model.predict_batch(
125 | fv_heavy_batch, fv_light_batch, device=device, pdb_string=return_pdb, apo_list=apo_list
126 | )
127 | if not return_pdb:
128 | if logging:
129 | logger.warning("Inference complete. Returning a protein object.")
130 | return pdb_strings_or_proteins
131 | for pdb_string in pdb_strings_or_proteins:
132 | output_file = output_dir / f"{output_names[name_idx]}.pdb"
133 | process_file(pdb_string, output_file, refine, refine_checks)
134 | result_files.append(output_file)
135 | name_idx += 1
136 | if logging:
137 | logger.info(f"Inference complete. Wrote {name_idx} PDB files to {output_dir=}")
138 |
139 | return result_files
140 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://doi.org/10.5281/zenodo.15866555)
2 |
3 | # Ibex 🐐
4 |
5 | [Ibex](https://www.tandfonline.com/doi/full/10.1080/19420862.2025.2602217) is a lightweight antibody and TCR structure prediction model.
6 |
7 |
8 |
9 |
10 |
11 | ## Installation
12 |
13 | Ibex can be installed through pip with
14 | ```bash
15 | pip install prescient-ibex
16 | ```
17 | Alternatively, you can use `uv` and create a new virtual environment
18 | ```bash
19 | uv venv --python 3.10
20 | source .venv/bin/activate
21 | uv pip install -e .
22 | ```
23 |
24 | ## Usage
25 |
26 | The simplest way to run inference is through the `ibex` command, e.g.
27 |
28 | ```bash
29 | ibex --fv-heavy EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSRWGGDGFYAMDYWGQGTLVTVSS --fv-light DIQMTQSPSSLSASVGDRVTITCRASQDVNTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSRSGTDFTLTISSLQPEDFATYYCQQHYTTPPTFGQGTKVEIK --output prediction.pdb
30 | ```
31 | You can provide a csv (with the `--csv` argument) or a parquet file (with the `--parquet` argument) and run a batched inference writing the output into a specified directory with
32 | ```bash
33 | ibex --csv sequences.csv --output predictions
34 | ```
35 | where `sequences.csv` should contain a `fv_heavy` and `fv_light` column with heavy and light chain sequences, and optionally an `id` column with a string that will be used as part of the output PDB filenames.
36 |
37 | By default, structures are predicted in the holo conformation. To predict the apo state, use the `--apo` flag.
38 |
39 | To run a refinement step on the predicted structures, use the `--refine` option. Additional checks to fix cis-isomers and D-stereoisomers during refinement can be activated with `--refine-checks`.
40 |
41 | Instead of running Ibex, you can use `--abodybuilder3` to run inference with the [ABodyBuilder3](https://academic.oup.com/bioinformatics/article/40/10/btae576/7810444) model.
42 | Below is a summary of all available options:
43 | ```
44 | --abodybuilder3 Use the AbodyBuilder3 model instead of Ibex for inference. [default: no-abodybuilder3]
45 | --fv-heavy Sequence of the heavy chain.
46 | --fv-light Sequence of the light chain.
47 | --csv CSV file containing sequences of heavy and light chains. Columns should be named 'fv_heavy' and 'fv_light'. Output file names can
48 | be provided in a 'id' column.
49 | --parquet Parquet file containing sequences of heavy and light chains. Columns should be named 'fv_heavy' and 'fv_light'. Output file names
50 | can be provided in a 'id' column.
51 | --output Output file for the PDB structure, or path to the output folder when a parquet or csv file is provided. [default: prediction.pdb]
52 | --batch-size Batch size for inference if a parquet or csv file is provided. [default: 32]
53 | --save-all Save all structures of the ensemble as output files. [default: no-save-all]
54 | --refine Refine the output structures with openMM. [default: no-refine]
55 | --refine-checks Additional checks to fix cis-isomers and D-stereoisomers during refinement. [default: no-refine-checks]
56 | --apo Predict structures in the apo conformation. [default: no-apo]
57 | ```
58 |
59 | To run Ibex programmatically, you can use
60 | ```python
61 | from ibex import Ibex, inference
62 | ibex_model = Ibex.from_pretrained("ibex") # or "abodybuilder3"
63 | inference(ibex_model, fv_heavy, fv_light, "prediction.pdb")
64 | ```
65 | to predict structures for multiple sequence pairs, `batch_inference` is recommended instead of `inference`.
66 |
67 | ## Predictions on nanobodies and TCRs
68 |
69 | To predict nanobody structures, leave out the `fv_light` argument, or set it as `""` or `None` in the csv column.
70 |
71 | For inference on TCRs, you should provide the variable beta chain sequence as `fv_heavy` and the alpha chain as `fv_light`. Ibex has not been trained on gamma and delta chains.
72 |
73 |
74 | ## License
75 | The Ibex codebase is available under an [Apache 2.0 license](http://www.apache.org/licenses/LICENSE-2.0), and the [ABodyBuilder3](https://doi.org/10.5281/zenodo.11354576) model weights under a [Creative Commons Attribution 4.0 International license](https://creativecommons.org/licenses/by/4.0/legalcode), both of which allow for commercial use.
76 |
77 | The [Ibex model weights](https://doi.org/10.5281/zenodo.15866555) are available under a [Genentech Apache 2.0 Non-Commercial license](https://raw.githubusercontent.com/prescient-design/ibex/refs/heads/main/docs/Genentech_license_weights_ibex), which allows its use for non-commercial academic research purposes.
78 |
79 | Ibex uses as input representation embeddings from ESMC 300M, which is licensed under the [EvolutionaryScale Cambrian Open License Agreement](https://www.evolutionaryscale.ai/policies/cambrian-open-license-agreement).
80 |
81 | ## Citation
82 | When using Ibex in your work, please cite the following paper
83 |
84 | ```bibtex
85 | @article{ibex,
86 | author = {Frédéric A. Dreyer and Jan Ludwiczak and Karolis Martinkus and Brennan Abanades and Robert G. Alberstein and Pan Kessel and Pranav Rao and Jae Hyeon Lee and Richard Bonneau and Andrew M. Watkins and Franziska Seeger},
87 | title = {Conformation-aware structure prediction of antigen-recognizing immune proteins},
88 | journal = {mAbs},
89 | volume = {18},
90 | number = {1},
91 | pages = {2602217},
92 | year = {2026},
93 | publisher = {Taylor \& Francis},
94 | doi = {10.1080/19420862.2025.2602217},
95 | note ={PMID: 41378904},
96 | URL = {https://doi.org/10.1080/19420862.2025.2602217},
97 | eprint = {https://doi.org/10.1080/19420862.2025.2602217}
98 | }
99 | ```
100 |
101 | If you use the ABodyBuilder3 model weights, you should also cite
102 | ```bibtex
103 | @article{abodybuilder3,
104 | author = {Kenlay, Henry and Dreyer, Frédéric A and Cutting, Daniel and Nissley, Daniel and Deane, Charlotte M},
105 | title = "{ABodyBuilder3: improved and scalable antibody structure predictions}",
106 | journal = {Bioinformatics},
107 | volume = {40},
108 | number = {10},
109 | pages = {btae576},
110 | year = {2024},
111 | month = {10},
112 | issn = {1367-4811},
113 | doi = {10.1093/bioinformatics/btae576}
114 | }
115 | ```
116 |
--------------------------------------------------------------------------------
/src/ibex/loss/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | # Copyright 2024 Exscientia
3 | # Copyright 2021 AlQuraishi Laboratory
4 | # Copyright 2021 DeepMind Technologies Limited
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | # Based on AlphaFoldLoss class from https://github.com/aqlaboratory/openfold/blob/main/openfold/utils/loss.py#L1685
19 |
20 | import loguru
21 | import ml_collections
22 | import torch
23 | import torch.nn.functional as F
24 | from loguru import logger
25 |
26 | from ibex.loss.aligned_rmsd import aligned_fv_and_cdrh3_rmsd
27 | from ibex.openfold.utils.loss import (
28 | compute_renamed_ground_truth,
29 | fape_loss,
30 | final_output_backbone_loss,
31 | find_structural_violations,
32 | lddt_loss,
33 | supervised_chi_loss,
34 | violation_loss_bondangle,
35 | violation_loss_bondlength,
36 | violation_loss_clash,
37 | )
38 |
39 |
40 | def triplet_loss(positive, negative, anchor, margin):
41 | # Compute cosine similarity
42 | sim_pos = F.cosine_similarity(anchor, positive, dim=-1)
43 | sim_neg = F.cosine_similarity(anchor, negative, dim=-1)
44 | # Triplet loss
45 | loss_triplet = torch.mean(torch.clamp(margin - sim_pos + sim_neg, min=0.0))
46 | return loss_triplet
47 |
48 |
49 | def triplet_loss_batch(single: torch.Tensor, cdr_mask: torch.Tensor, margin: float=1.0):
50 | cdr_mask_expanded = cdr_mask.unsqueeze(-1).expand_as(single)
51 | masked_embeddings = single * cdr_mask_expanded
52 | sum_embeddings = masked_embeddings.sum(dim=1)
53 | count_non_zero = cdr_mask.sum(dim=1).unsqueeze(-1)
54 | # Avoid division by zero
55 | count_non_zero = torch.where(count_non_zero == 0, torch.ones_like(count_non_zero), count_non_zero)
56 | # Compute the average
57 | average_embeddings = sum_embeddings / count_non_zero
58 | loss = triplet_loss(average_embeddings[-3], average_embeddings[-2], average_embeddings[-1], margin)
59 | loss += triplet_loss(average_embeddings[-6], average_embeddings[-5], average_embeddings[-4], margin)
60 | return loss
61 |
62 |
63 | def negative_pairs_loss(positive: torch.Tensor, negative: torch.Tensor, threshold: float):
64 | cosine_sim = F.cosine_similarity(positive, negative, dim=-1)
65 | # Compute loss
66 | loss = torch.mean(torch.clamp(threshold - cosine_sim, min=0.0))
67 | return loss
68 |
69 | def negative_pairs_loss_batch(single: torch.Tensor, cdr_mask: torch.Tensor, threshold: float, pair_mask: torch.Tensor):
70 | if torch.tensor(pair_mask).sum() == 0:
71 | return torch.tensor(0.0, device=single.device)
72 | cdr_mask_expanded = cdr_mask.unsqueeze(-1).expand_as(single)
73 | masked_embeddings = single * cdr_mask_expanded
74 | sum_embeddings = masked_embeddings.sum(dim=1)
75 | count_non_zero = cdr_mask.sum(dim=1).unsqueeze(-1)
76 | # Avoid division by zero
77 | count_non_zero = torch.where(count_non_zero == 0, torch.ones_like(count_non_zero), count_non_zero)
78 | # Compute the average
79 | average_embeddings = sum_embeddings / count_non_zero
80 | paired_samples = average_embeddings[pair_mask]
81 | pos_samples = paired_samples[0::2]
82 | neg_samples = paired_samples[1::2]
83 | loss = negative_pairs_loss(pos_samples, neg_samples, threshold)
84 | return loss
85 |
86 | class IbexLoss(torch.nn.Module):
87 | def __init__(self, config: ml_collections.config_dict.ConfigDict):
88 | super().__init__()
89 | self.config = config
90 | self.dist_and_angle_annealing = 0.0
91 |
92 | def forward(self, output: dict, batch: dict, finetune: bool = False, contrastive: bool = False):
93 | if finetune:
94 | output["violation"] = find_structural_violations(
95 | batch,
96 | output["positions"][-1],
97 | **self.config.violation,
98 | )
99 |
100 | if "renamed_atom14_gt_positions" not in output.keys():
101 | batch.update(
102 | compute_renamed_ground_truth(
103 | batch,
104 | output["positions"][-1],
105 | )
106 | )
107 |
108 | loss_fns = {
109 | "fape": lambda: fape_loss(
110 | {"sm": output},
111 | batch,
112 | self.config.fape,
113 | ),
114 | "supervised_chi": lambda: supervised_chi_loss(
115 | output["angles"],
116 | output["unnormalized_angles"],
117 | **{**batch, **self.config.supervised_chi},
118 | ),
119 | "final_output_backbone_loss": lambda: final_output_backbone_loss(
120 | output, batch
121 | ),
122 | }
123 | if "plddt" in output:
124 | loss_fns.update(
125 | {
126 | "plddt": lambda: lddt_loss(
127 | output["plddt"],
128 | output["positions"][-1],
129 | batch["atom14_gt_positions"],
130 | batch["atom14_atom_exists"],
131 | batch["resolution"],
132 | ),
133 | }
134 | )
135 |
136 | if finetune:
137 | loss_fns.update(
138 | {
139 | "violation_loss_bondlength": lambda: violation_loss_bondlength(
140 | output["violation"]
141 | ),
142 | "violation_loss_bondangle": lambda: violation_loss_bondangle(
143 | output["violation"]
144 | ),
145 | "violation_loss_clash": lambda: violation_loss_clash(
146 | output["violation"], **batch
147 | ),
148 | }
149 | )
150 | if contrastive:
151 | loss_fns.update(
152 | {
153 | "contrastive": lambda: negative_pairs_loss_batch(
154 | output["single"],
155 | batch["cdr_mask"],
156 | self.config.contrastive.margin,
157 | batch["is_matched"],
158 | ),
159 | # "contrastive": lambda: triplet_loss_batch(
160 | # output["single"],
161 | # batch["cdr_mask"],
162 | # self.config.contrastive.margin,
163 | # ),
164 | }
165 | )
166 | cum_loss = 0.0
167 | losses = {}
168 | for loss_name, loss_fn in loss_fns.items():
169 | weight = self.config[loss_name].weight
170 | loss = loss_fn()
171 | if torch.isnan(loss) or torch.isinf(loss):
172 | logger.warning(f"{loss_name} loss is NaN. Skipping...")
173 | loss = loss.new_tensor(0.0, requires_grad=True)
174 | if loss_name in ["violation_loss_bondlength", "violation_loss_bondangle"]:
175 | weight *= min(self.dist_and_angle_annealing / 50, 1)
176 | cum_loss = cum_loss + weight * loss
177 | losses[loss_name] = loss.detach().clone()
178 | losses[f"{loss_name}_weighted"] = weight * losses[loss_name]
179 | losses[f"{loss_name}_weight"] = weight
180 |
181 | # aligned_rmsd (not added to cum_loss)
182 | with torch.no_grad():
183 | losses.update(
184 | aligned_fv_and_cdrh3_rmsd(
185 | coords_truth=batch["atom14_gt_positions"],
186 | coords_prediction=output["positions"][-1],
187 | sequence_mask=batch["seq_mask"],
188 | cdrh3_mask=batch["region_numeric"] == 2,
189 | )
190 | )
191 |
192 | losses["loss"] = cum_loss.detach().clone()
193 |
194 | return cum_loss, losses
195 |
--------------------------------------------------------------------------------
/docs/Genentech_license_weights_ibex:
--------------------------------------------------------------------------------
1 | ACADEMIC SOFTWARE LICENSE AGREEMENT FOR END-USERS AT PUBLICLY FUNDED ACADEMIC, EDUCATION OR RESEARCH INSTITUTIONS FOR THE USE OF IBEX
2 |
3 | By downloading, installing, or using the Licensed Software you consent to be bound by and become a party to this agreement (hereinafter "Agreement") as a "LICENSEE". If you do not agree to all of the terms of this Agreement, you must not download, install, or use the Licensed Software, and you do not become a LICENSEE under this Agreement.
4 |
5 | If you are not a member of an academic research institution, you must obtain a commercial license; please send requests via email to dreyer.frederic@gene.com. This Agreement is entered into by and between Genentech, Inc. (hereinafter "GENENTECH") and the LICENSEE.
6 | WHEREAS GENENTECH has the right to license all copyrights and other property rights in the Licensed Software identified as Ibex and developed by GENENTECH and GENENTECH desires to license the Licensed Software so that it becomes available for internal teaching and non-commercial academic research purposes only.
7 | WHEREAS LICENSEE is a publicly funded academic and/or education and/or research institution.
8 | WHEREAS LICENSEE desires to acquire a free non-exclusive license to use the Licensed Software for internal teaching and non-commercial academic research purposes only.
9 | NOW, THEREFORE, in consideration of the mutual promises and covenants contained herein, the parties agree as follows:
10 |
11 | 1. Definitions
12 | "Licensed Software" means the specific version of Ibex pursuant to this Agreement.
13 |
14 | 2. License
15 | Subject to the terms and conditions of this Agreement a non-exclusive, non-transferable license to use and copy the Licensed Software is made available free of charge for the LICENSEE which is a non-profit educational, academic and/or research institution. The license is only granted for internal teaching and non-commercial academic research purposes at one Site, where a Site is defined as a set of contiguous buildings in one location. The Licensed Software will be used at only one location of LICENSEE.
16 | This license does not entitle Licensee to receive from GENENTECH copies of the Licensed Software on disks, tapes or CD's, hard-copy documentation, technical support, telephone assistance, or enhancements or updates to the Licensed Software.
17 | The user and any research assistants or co-workers who may use the Licensed Software agree to not give the program to third parties or grant licenses on software, which include the Licensed Software, alone or integrated into other software, to third parties. Modification of the source code is prohibited without the prior written consent of GENENTECH.
18 |
19 | 3. Ownership
20 | Except as expressly licensed in this Agreement, GENENTECH shall retain title to the Licensed Software, and any upgrades and modifications created by GENENTECH.
21 |
22 | 4. Consideration
23 | In consideration for the license rights granted by GENENTECH, LICENSEE will obtain this license free of charge.
24 |
25 | 5. Copies
26 | LICENSEE shall have the right to make copies of the Licensed Software for personal and internal teaching and non-commercial academic research purposes at the Site and for back-up purposes under this Agreement, but agrees that all such copies shall contain a copy of this Agreement, copyright notices, and all other reasonable and appropriate proprietary markings or confidential legends that appear on the Licensed Software.
27 |
28 | 6. Support
29 | GENENTECH shall have no obligation to offer support services to LICENSEE, and nothing contained herein shall be interpreted as to require GENENTECH to provide maintenance, installation services, version updates, debugging, consultation or end-user support of any kind.
30 |
31 | 7. Software Protection
32 | LICENSEE acknowledges that the Licensed Software is proprietary to GENENTECH.
33 | Except as otherwise expressly permitted in this Agreement, Licensee may not (i) modify or create any derivative works of the Licensed Software or documentation, including customization, translation or localization; (ii) decompile, disassemble, reverse engineer, or otherwise attempt to derive the source code or model weights of the Licensed Software; (iii) redistribute, encumber, sell, rent, lease, sublicense, or otherwise transfer rights to the Licensed Software; (iv) remove or alter any trademark, logo, copyright or other proprietary notices, legends, symbols or labels in the Licensed Software; or (v) publish any results of benchmark tests run on the Licensed Software to a third party without GENENTECH's prior written consent.
34 |
35 | 8. Representations of GENENTECH to LICENSEE
36 | GENENTECH represents to LICENSEE that GENENTECH has the right to grant the License and to enter into this Agreement.
37 |
38 | 9. Indemnity and Disclaimer of Warranties
39 | GENENTECH makes no representations or warranties, express or implied, of any kind.
40 | The Licensed Software is provided free of charge, and, therefore, on an "as is" basis, without warranty of any kind, express or implied, including without limitation the warranties that it is free of defects, virus free, able to operate on an uninterrupted basis, merchantable, fit for a particular purpose or non-interfering. The entire risk as to the quality and performance of the Licensed Software is borne by LICENSEE.
41 | By way of example, but not limitation, GENENTECH makes no representations or warranties of merchantability or fitness for any particular application or that the use of the Licensed Software will not infringe any patents, copyrights or trademarks or other rights of third parties. The entire risk as to the quality and performance of the Licensed Software is borne by LICENSEE. GENENTECH shall not be liable for any liability or damages with respect to any claim by LICENSEE or any third party on account of, or arising from the license or use of the Licensed Software.
42 | Should the Licensed Software prove defective in any respect, LICENSEE and neither GENENTECH should assume the entire cost of any service and repair. This disclaimer of warranty constitutes an essential part of this Agreement. No use of the Licensed Software is authorized hereunder except under this disclaimer.
43 | In no event will GENENTECH be liable for any indirect, special, incidental or consequential damages arising out of the use of or inability to use the Licensed Software, including, without limitation, damages for lost profits, loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses, even if advised of the possibility thereof, and regardless of the legal or equitable theory (contract, tort or otherwise) upon which the claim is based.
44 |
45 | 10. Promotional Advertising & References
46 | LICENSEE may not use the name of the Licensed Software in its promotional advertising, product literature, and other similar promotional materials to be disseminated to the public or any portion thereof. LICENSEE agrees not to identify GENENTECH in any promotional advertising or other promotional materials to be disseminated to the public, or any portion thereof without GENENTECH's prior written consent.
47 | GENENTECH shall not use LICENSEE's name in publicity or advertising involving this Agreement or otherwise without LICENSEE's prior written consent which may be withheld at LICENSEE's sole discretion.
48 |
49 | 11. Term
50 | This Agreement and the license rights granted herein shall become effective as of the date this Agreement is executed by both parties and shall be perpetual unless terminated in accordance with this Section.
51 | GENENTECH may terminate this Agreement at any time.
52 | Either party may terminate this Agreement at any time effective upon the other party's breach of any agreement, covenant, or representation made in this Agreement, such breach remaining uncorrected sixty (60) days after written notice thereof.
53 | LICENSEE shall have the right, at any time, to terminate this Agreement without cause by written notice to GENENTECH specifying the date of termination.
54 | Upon termination, LICENSEE shall destroy all full and partial copies of the Licensed Software.
55 |
56 | 12. Governing Law
57 | This Agreement shall be construed in accordance with the laws of California.
58 |
59 | 13. General
60 | The parties agree that this Agreement is the complete and exclusive agreement among the parties and supersedes all proposals and prior agreements whether written or oral, and all other communications among the parties relating to the subject matter of this Agreement.
61 | This Agreement cannot be modified except in writing and signed by both parties. Failure by either party at any time to enforce any of the provisions of this Agreement shall not constitute a waiver by such party of such provision nor in any way affect the validity of this Agreement.
62 | The invalidity of singular provisions does not affect the validity of the entire understanding. The parties are obligated, however, to replace the invalid provisions by a regulation which comes closest to the economic intent of the invalid provision. The same shall apply mutatis mutandis in case of a gap.
63 |
--------------------------------------------------------------------------------
/src/ibex/openfold/utils/feats.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | # Copyright 2024 Exscientia
3 | # Copyright 2021 AlQuraishi Laboratory
4 | # Copyright 2021 DeepMind Technologies Limited
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | from typing import Dict
19 |
20 | import torch
21 | import torch.nn as nn
22 |
23 | import ibex.openfold.utils.residue_constants as rc
24 | from ibex.openfold.utils.rigid_utils import Rigid, Rotation
25 | from ibex.openfold.utils.tensor_utils import batched_gather
26 |
27 |
28 | def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
29 | is_gly = aatype == rc.restype_order["G"]
30 | ca_idx = rc.atom_order["CA"]
31 | cb_idx = rc.atom_order["CB"]
32 | pseudo_beta = torch.where(
33 | is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
34 | all_atom_positions[..., ca_idx, :],
35 | all_atom_positions[..., cb_idx, :],
36 | )
37 |
38 | if all_atom_masks is not None:
39 | pseudo_beta_mask = torch.where(
40 | is_gly,
41 | all_atom_masks[..., ca_idx],
42 | all_atom_masks[..., cb_idx],
43 | )
44 | return pseudo_beta, pseudo_beta_mask
45 | else:
46 | return pseudo_beta
47 |
48 |
49 | def atom14_to_atom37(atom14, batch):
50 | atom37_data = batched_gather(
51 | atom14,
52 | batch["residx_atom37_to_atom14"],
53 | dim=-2,
54 | no_batch_dims=len(atom14.shape[:-2]),
55 | )
56 |
57 | atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
58 |
59 | return atom37_data
60 |
61 |
62 | def build_template_angle_feat(template_feats):
63 | template_aatype = template_feats["template_aatype"]
64 | torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
65 | alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
66 | torsion_angles_mask = template_feats["template_torsion_angles_mask"]
67 | template_angle_feat = torch.cat(
68 | [
69 | nn.functional.one_hot(template_aatype, 22),
70 | torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14),
71 | alt_torsion_angles_sin_cos.reshape(
72 | *alt_torsion_angles_sin_cos.shape[:-2], 14
73 | ),
74 | torsion_angles_mask,
75 | ],
76 | dim=-1,
77 | )
78 |
79 | return template_angle_feat
80 |
81 |
82 | def build_template_pair_feat(
83 | batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8
84 | ):
85 | template_mask = batch["template_pseudo_beta_mask"]
86 | template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
87 |
88 | # Compute distogram (this seems to differ slightly from Alg. 5)
89 | tpb = batch["template_pseudo_beta"]
90 | dgram = torch.sum(
91 | (tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
92 | )
93 | lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
94 | upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
95 | dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
96 |
97 | to_concat = [dgram, template_mask_2d[..., None]]
98 |
99 | aatype_one_hot = nn.functional.one_hot(
100 | batch["template_aatype"],
101 | rc.restype_num + 2,
102 | )
103 |
104 | n_res = batch["template_aatype"].shape[-1]
105 | to_concat.append(
106 | aatype_one_hot[..., None, :, :].expand(
107 | *aatype_one_hot.shape[:-2], n_res, -1, -1
108 | )
109 | )
110 | to_concat.append(
111 | aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1)
112 | )
113 |
114 | n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
115 | rigids = Rigid.make_transform_from_reference(
116 | n_xyz=batch["template_all_atom_positions"][..., n, :],
117 | ca_xyz=batch["template_all_atom_positions"][..., ca, :],
118 | c_xyz=batch["template_all_atom_positions"][..., c, :],
119 | eps=eps,
120 | )
121 | points = rigids.get_trans()[..., None, :, :]
122 | rigid_vec = rigids[..., None].invert_apply(points)
123 |
124 | inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
125 |
126 | t_aa_masks = batch["template_all_atom_mask"]
127 | template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
128 | template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
129 |
130 | inv_distance_scalar = inv_distance_scalar * template_mask_2d
131 | unit_vector = rigid_vec * inv_distance_scalar[..., None]
132 |
133 | if not use_unit_vector:
134 | unit_vector = unit_vector * 0.0
135 |
136 | to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
137 | to_concat.append(template_mask_2d[..., None])
138 |
139 | act = torch.cat(to_concat, dim=-1)
140 | act = act * template_mask_2d[..., None]
141 |
142 | return act
143 |
144 |
145 | def build_extra_msa_feat(batch):
146 | msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
147 | msa_feat = [
148 | msa_1hot,
149 | batch["extra_has_deletion"].unsqueeze(-1),
150 | batch["extra_deletion_value"].unsqueeze(-1),
151 | ]
152 | return torch.cat(msa_feat, dim=-1)
153 |
154 |
155 | def torsion_angles_to_frames(
156 | r: Rigid,
157 | alpha: torch.Tensor,
158 | aatype: torch.Tensor,
159 | rrgdf: torch.Tensor,
160 | ):
161 | # [*, N, 8, 4, 4]
162 | default_4x4 = rrgdf[aatype, ...]
163 |
164 | # [*, N, 8] transformations, i.e.
165 | # One [*, N, 8, 3, 3] rotation matrix and
166 | # One [*, N, 8, 3] translation matrix
167 | default_r = r.from_tensor_4x4(default_4x4)
168 |
169 | bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
170 | bb_rot[..., 1] = 1
171 |
172 | # [*, N, 8, 2]
173 | alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
174 |
175 | # [*, N, 8, 3, 3]
176 | # Produces rotation matrices of the form:
177 | # [
178 | # [1, 0 , 0 ],
179 | # [0, a_2,-a_1],
180 | # [0, a_1, a_2]
181 | # ]
182 | # This follows the original code rather than the supplement, which uses
183 | # different indices.
184 |
185 | all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
186 | all_rots[..., 0, 0] = 1
187 | all_rots[..., 1, 1] = alpha[..., 1]
188 | all_rots[..., 1, 2] = -alpha[..., 0]
189 | all_rots[..., 2, 1:] = alpha
190 |
191 | all_rots = Rigid(Rotation(rot_mats=all_rots), None)
192 |
193 | all_frames = default_r.compose(all_rots)
194 |
195 | chi2_frame_to_frame = all_frames[..., 5]
196 | chi3_frame_to_frame = all_frames[..., 6]
197 | chi4_frame_to_frame = all_frames[..., 7]
198 |
199 | chi1_frame_to_bb = all_frames[..., 4]
200 | chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
201 | chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
202 | chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
203 |
204 | all_frames_to_bb = Rigid.cat(
205 | [
206 | all_frames[..., :5],
207 | chi2_frame_to_bb.unsqueeze(-1),
208 | chi3_frame_to_bb.unsqueeze(-1),
209 | chi4_frame_to_bb.unsqueeze(-1),
210 | ],
211 | dim=-1,
212 | )
213 |
214 | all_frames_to_global = r[..., None].compose(all_frames_to_bb)
215 |
216 | return all_frames_to_global
217 |
218 |
219 | def frames_and_literature_positions_to_atom14_pos(
220 | r: Rigid,
221 | aatype: torch.Tensor,
222 | default_frames,
223 | group_idx,
224 | atom_mask,
225 | lit_positions,
226 | ):
227 | # [*, N, 14, 4, 4]
228 | default_4x4 = default_frames[aatype, ...]
229 |
230 | # [*, N, 14]
231 | group_mask = group_idx[aatype, ...]
232 |
233 | # [*, N, 14, 8]
234 | group_mask = nn.functional.one_hot(
235 | group_mask,
236 | num_classes=default_frames.shape[-3],
237 | )
238 |
239 | # [*, N, 14, 8]
240 | t_atoms_to_global = r[..., None, :] * group_mask
241 |
242 | # [*, N, 14]
243 | t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
244 |
245 | # [*, N, 14, 1]
246 | atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
247 |
248 | # [*, N, 14, 3]
249 | lit_positions = lit_positions[aatype, ...]
250 | pred_positions = t_atoms_to_global.apply(lit_positions)
251 | pred_positions = pred_positions * atom_mask
252 |
253 | return pred_positions
254 |
--------------------------------------------------------------------------------
/src/ibex/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import io
15 | import os
16 | import requests
17 | import hashlib
18 | from pathlib import Path
19 |
20 | from tqdm import tqdm
21 | from loguru import logger
22 | import numpy as np
23 | import torch
24 | from typing import MutableMapping
25 | from omegaconf import DictConfig
26 |
27 | from ibex.openfold.utils.data_transforms import make_atom14_masks
28 | from ibex.openfold.utils.protein import Protein, to_pdb
29 | from ibex.openfold.utils.feats import atom14_to_atom37
30 | from ibex.loss.aligned_rmsd import positions_to_backbone_dihedrals, region_mapping, CDR_RANGES_AHO
31 |
32 |
33 | MODEL_CHECKPOINTS = {
34 | "ibex": "https://zenodo.org/records/15866556/files/ibex_v1.ckpt",
35 | "abodybuilder3": "https://zenodo.org/records/15866556/files/abb3.ckpt",
36 | }
37 |
38 | ENSEMBLE_MODELS = ["ibex"]
39 |
40 |
41 | def dihedral_distance_per_loop(
42 | positions_predicted: torch.Tensor,
43 | positions_reference: torch.Tensor,
44 | region_mask: torch.Tensor,
45 | mask: torch.Tensor | None = None,
46 | residue_index: torch.Tensor | None = None,
47 | chain_index: torch.Tensor | None = None
48 | ) -> dict[str, float]:
49 | """Computes the dihedral distance between predicted and reference backbone coordinates for each loop region.
50 |
51 | Args:
52 | positions_predicted (torch.Tensor): Predicted atomic coordinates of the protein backbone.
53 | positions_reference (torch.Tensor): Reference (ground truth) atomic coordinates of the protein backbone.
54 | region_mask (torch.Tensor): Boolean mask specifying residues belonging to each loop region (e.g., CDRs).
55 | mask (torch.Tensor | None, optional): Boolean mask indicating valid residues in the sequence.
56 | residue_index (torch.Tensor | None, optional): Indices of residues in the sequence. Defaults to None.
57 | chain_index (torch.Tensor | None, optional): Indices of chains in the structure. Defaults to None.
58 | Returns:
59 | dict[str, float]: Dictionary mapping each loop region to its computed dihedral distance between predicted and reference structures.
60 |
61 | """
62 | if mask is None:
63 | mask = torch.ones(positions_reference.shape[:-1]).to(device=positions_reference.device)
64 | dihedral_angles_predicted, dihedral_mask = positions_to_backbone_dihedrals(positions_predicted, mask, residue_index, chain_index)
65 | dihedral_angles_reference, _ = positions_to_backbone_dihedrals(positions_reference, mask, residue_index, chain_index)
66 | results = {}
67 | dihedral_differences = 2 * (1 - torch.cos(dihedral_angles_predicted - dihedral_angles_reference))
68 | for region_name, region_idx in region_mapping.items():
69 | if region_name.startswith("cdr") and (region_mask == region_idx).any():
70 | results[region_name] = torch.mean(dihedral_differences[region_mask == region_idx]).item()
71 | return results
72 |
73 | def region_mask_from_aho(fv_heavy_aho: str, fv_light_aho: str = "") -> torch.Tensor:
74 | """Return a tensor with CDR and framework identifiers for a given aho string of heavy and light chains.
75 | Args:
76 | fv_heavy_aho (str): Heavy chain aho string
77 | fv_light_aho (str): Light chain aho string
78 | Returns:
79 | torch.Tensor: Tensor of shape [N]
80 | """
81 | region_len = {}
82 | region_len["fwh1"] = len(fv_heavy_aho[:CDR_RANGES_AHO["H1"][0]].replace('-', ''))
83 | region_len["cdrh1"] = len(fv_heavy_aho[CDR_RANGES_AHO["H1"][0]:CDR_RANGES_AHO["H1"][1]].replace('-', ''))
84 | region_len["fwh2"] = len(fv_heavy_aho[CDR_RANGES_AHO["H1"][1]:CDR_RANGES_AHO["H2"][0]].replace('-', ''))
85 | region_len["cdrh2"] = len(fv_heavy_aho[CDR_RANGES_AHO["H2"][0]:CDR_RANGES_AHO["H2"][1]].replace('-', ''))
86 | region_len["fwh3"] = len(fv_heavy_aho[CDR_RANGES_AHO["H2"][1]:CDR_RANGES_AHO["H3"][0]].replace('-', ''))
87 | region_len["cdrh3"] = len(fv_heavy_aho[CDR_RANGES_AHO["H3"][0]:CDR_RANGES_AHO["H3"][1]].replace('-', ''))
88 | region_len["fwh4"] = len(fv_heavy_aho[CDR_RANGES_AHO["H3"][1]:].replace('-', ''))
89 |
90 | if fv_light_aho:
91 | region_len["fwl1"] = len(fv_light_aho[:CDR_RANGES_AHO["L1"][0]].replace('-', ''))
92 | region_len["cdrl1"] = len(fv_light_aho[CDR_RANGES_AHO["L1"][0]:CDR_RANGES_AHO["L1"][1]].replace('-', ''))
93 | region_len["fwl2"] = len(fv_light_aho[CDR_RANGES_AHO["L1"][1]:CDR_RANGES_AHO["L2"][0]].replace('-', ''))
94 | region_len["cdrl2"] = len(fv_light_aho[CDR_RANGES_AHO["L2"][0]:CDR_RANGES_AHO["L2"][1]].replace('-', ''))
95 | region_len["fwl3"] = len(fv_light_aho[CDR_RANGES_AHO["L2"][1]:CDR_RANGES_AHO["L3"][0]].replace('-', ''))
96 | region_len["cdrl3"] = len(fv_light_aho[CDR_RANGES_AHO["L3"][0]:CDR_RANGES_AHO["L3"][1]].replace('-', ''))
97 | region_len["fwl4"] = len(fv_light_aho[CDR_RANGES_AHO["L3"][1]:].replace('-', ''))
98 |
99 |
100 | res = []
101 | for region in region_len:
102 | res.append(torch.ones(region_len[region], dtype=torch.int) * region_mapping[region])
103 | return torch.cat(res)
104 |
105 | def compute_plddt(plddt: torch.Tensor) -> torch.Tensor:
106 | """Computes plddt from the model output. The output is a histogram of unnormalised
107 | plddt.
108 |
109 | Args:
110 | plddt (torch.Tensor): (B, n, 50) output from the model
111 |
112 | Returns:
113 | torch.Tensor: (B, n) plddt scores
114 | """
115 | pdf = torch.nn.functional.softmax(plddt, dim=-1)
116 | vbins = torch.arange(1, 101, 2).to(plddt.device).float()
117 | output = pdf @ vbins # (B, n)
118 | return output
119 |
120 |
121 | def add_atom37_to_output(output: dict, aatype: torch.Tensor):
122 | """Adds atom37 coordinates to an output dictionary containing atom14 coordinates."""
123 | atom14 = output["positions"][-1, 0]
124 | batch = make_atom14_masks({"aatype": aatype.squeeze()})
125 | atom37 = atom14_to_atom37(atom14, batch)
126 | output["atom37"] = atom37
127 | output["atom37_atom_exists"] = batch["atom37_atom_exists"]
128 | return output
129 |
130 |
131 | def output_to_protein(output: dict, model_input: dict) -> Protein:
132 | """Generates a Protein object from Ibex predictions.
133 |
134 | Args:
135 | output (dict): Ibex output dictionary
136 | model_input (dict): Ibex input dictionary
137 |
138 | Returns:
139 | str: the contents of a pdb file in string format.
140 | """
141 | aatype = model_input["aatype"].squeeze().cpu().numpy().astype(int)
142 | atom37 = output["atom37"]
143 | chain_index = 1 - model_input["is_heavy"].cpu().numpy().astype(int)
144 | atom_mask = output["atom37_atom_exists"].cpu().numpy().astype(int)
145 | residue_index = np.arange(len(atom37))
146 | if "plddt" in output:
147 | plddt = compute_plddt(output["plddt"]).squeeze().detach().cpu().numpy()
148 | b_factors = np.expand_dims(plddt, 1).repeat(37, 1)
149 | else:
150 | b_factors = np.zeros_like(atom_mask)
151 | protein = Protein(
152 | aatype=aatype,
153 | atom_positions=atom37,
154 | atom_mask=atom_mask,
155 | residue_index=residue_index,
156 | b_factors=b_factors,
157 | chain_index=chain_index,
158 | )
159 |
160 | return protein
161 |
162 | def output_to_pdb(output: dict, model_input: dict) -> str:
163 | """Generates a pdb file from Ibex predictions.
164 |
165 | Args:
166 | output (dict): Ibex output dictionary
167 | model_input (dict): Ibex input dictionary
168 |
169 | Returns:
170 | str: the contents of a pdb file in string format.
171 | """
172 | return to_pdb(output_to_protein(output, model_input))
173 |
174 | def download_from_url(url, local_path):
175 | """Downloads a file from a URL with a progress bar."""
176 | try:
177 | with requests.get(url, stream=True) as r:
178 | r.raise_for_status()
179 | total_size_in_bytes = int(r.headers.get('content-length', 0))
180 | block_size = 1024 # 1 Kibibyte
181 |
182 | with tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc=f"Downloading {os.path.basename(local_path)}") as progress_bar:
183 | with open(local_path, 'wb') as f:
184 | for chunk in r.iter_content(block_size):
185 | progress_bar.update(len(chunk))
186 | f.write(chunk)
187 |
188 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
189 | raise IOError("ERROR: Something went wrong during download")
190 |
191 | except requests.exceptions.RequestException as e:
192 | logger.error(f"Failed to download {url}: {e}")
193 | # Clean up partially downloaded file
194 | if os.path.exists(local_path):
195 | os.remove(local_path)
196 | raise
197 |
198 | def get_checkpoint_path(checkpoint_url, cache_dir=None):
199 | """Generates a unique local path for a given URL."""
200 | cache_dir = cache_dir or os.path.join(Path.home(), ".cache/ibex")
201 | os.makedirs(cache_dir, exist_ok=True)
202 | url_filename = os.path.basename(checkpoint_url)
203 | url_hash = hashlib.md5(checkpoint_url.encode()).hexdigest()[:8]
204 | filename = f"{url_hash}_{url_filename}"
205 | return os.path.join(cache_dir, filename)
206 |
207 |
208 | def checkpoint_path(model_name, cache_dir=None):
209 | """
210 | Ensures the model checkpoint exists locally, downloading it if necessary.
211 | """
212 | if model_name not in MODEL_CHECKPOINTS:
213 | raise ValueError(f"Invalid model name: {model_name}")
214 |
215 | checkpoint_url = MODEL_CHECKPOINTS[model_name]
216 | local_path = get_checkpoint_path(checkpoint_url, cache_dir)
217 |
218 | if not os.path.exists(local_path):
219 | logger.info(f"Downloading checkpoint from {checkpoint_url} to {local_path}")
220 | # Call the new download function
221 | download_from_url(checkpoint_url, local_path)
222 |
223 | return local_path
224 |
--------------------------------------------------------------------------------
/src/ibex/openfold/resources/stereo_chemical_props.txt:
--------------------------------------------------------------------------------
1 | Bond Residue Mean StdDev
2 | CA-CB ALA 1.520 0.021
3 | N-CA ALA 1.459 0.020
4 | CA-C ALA 1.525 0.026
5 | C-O ALA 1.229 0.019
6 | CA-CB ARG 1.535 0.022
7 | CB-CG ARG 1.521 0.027
8 | CG-CD ARG 1.515 0.025
9 | CD-NE ARG 1.460 0.017
10 | NE-CZ ARG 1.326 0.013
11 | CZ-NH1 ARG 1.326 0.013
12 | CZ-NH2 ARG 1.326 0.013
13 | N-CA ARG 1.459 0.020
14 | CA-C ARG 1.525 0.026
15 | C-O ARG 1.229 0.019
16 | CA-CB ASN 1.527 0.026
17 | CB-CG ASN 1.506 0.023
18 | CG-OD1 ASN 1.235 0.022
19 | CG-ND2 ASN 1.324 0.025
20 | N-CA ASN 1.459 0.020
21 | CA-C ASN 1.525 0.026
22 | C-O ASN 1.229 0.019
23 | CA-CB ASP 1.535 0.022
24 | CB-CG ASP 1.513 0.021
25 | CG-OD1 ASP 1.249 0.023
26 | CG-OD2 ASP 1.249 0.023
27 | N-CA ASP 1.459 0.020
28 | CA-C ASP 1.525 0.026
29 | C-O ASP 1.229 0.019
30 | CA-CB CYS 1.526 0.013
31 | CB-SG CYS 1.812 0.016
32 | N-CA CYS 1.459 0.020
33 | CA-C CYS 1.525 0.026
34 | C-O CYS 1.229 0.019
35 | CA-CB GLU 1.535 0.022
36 | CB-CG GLU 1.517 0.019
37 | CG-CD GLU 1.515 0.015
38 | CD-OE1 GLU 1.252 0.011
39 | CD-OE2 GLU 1.252 0.011
40 | N-CA GLU 1.459 0.020
41 | CA-C GLU 1.525 0.026
42 | C-O GLU 1.229 0.019
43 | CA-CB GLN 1.535 0.022
44 | CB-CG GLN 1.521 0.027
45 | CG-CD GLN 1.506 0.023
46 | CD-OE1 GLN 1.235 0.022
47 | CD-NE2 GLN 1.324 0.025
48 | N-CA GLN 1.459 0.020
49 | CA-C GLN 1.525 0.026
50 | C-O GLN 1.229 0.019
51 | N-CA GLY 1.456 0.015
52 | CA-C GLY 1.514 0.016
53 | C-O GLY 1.232 0.016
54 | CA-CB HIS 1.535 0.022
55 | CB-CG HIS 1.492 0.016
56 | CG-ND1 HIS 1.369 0.015
57 | CG-CD2 HIS 1.353 0.017
58 | ND1-CE1 HIS 1.343 0.025
59 | CD2-NE2 HIS 1.415 0.021
60 | CE1-NE2 HIS 1.322 0.023
61 | N-CA HIS 1.459 0.020
62 | CA-C HIS 1.525 0.026
63 | C-O HIS 1.229 0.019
64 | CA-CB ILE 1.544 0.023
65 | CB-CG1 ILE 1.536 0.028
66 | CB-CG2 ILE 1.524 0.031
67 | CG1-CD1 ILE 1.500 0.069
68 | N-CA ILE 1.459 0.020
69 | CA-C ILE 1.525 0.026
70 | C-O ILE 1.229 0.019
71 | CA-CB LEU 1.533 0.023
72 | CB-CG LEU 1.521 0.029
73 | CG-CD1 LEU 1.514 0.037
74 | CG-CD2 LEU 1.514 0.037
75 | N-CA LEU 1.459 0.020
76 | CA-C LEU 1.525 0.026
77 | C-O LEU 1.229 0.019
78 | CA-CB LYS 1.535 0.022
79 | CB-CG LYS 1.521 0.027
80 | CG-CD LYS 1.520 0.034
81 | CD-CE LYS 1.508 0.025
82 | CE-NZ LYS 1.486 0.025
83 | N-CA LYS 1.459 0.020
84 | CA-C LYS 1.525 0.026
85 | C-O LYS 1.229 0.019
86 | CA-CB MET 1.535 0.022
87 | CB-CG MET 1.509 0.032
88 | CG-SD MET 1.807 0.026
89 | SD-CE MET 1.774 0.056
90 | N-CA MET 1.459 0.020
91 | CA-C MET 1.525 0.026
92 | C-O MET 1.229 0.019
93 | CA-CB PHE 1.535 0.022
94 | CB-CG PHE 1.509 0.017
95 | CG-CD1 PHE 1.383 0.015
96 | CG-CD2 PHE 1.383 0.015
97 | CD1-CE1 PHE 1.388 0.020
98 | CD2-CE2 PHE 1.388 0.020
99 | CE1-CZ PHE 1.369 0.019
100 | CE2-CZ PHE 1.369 0.019
101 | N-CA PHE 1.459 0.020
102 | CA-C PHE 1.525 0.026
103 | C-O PHE 1.229 0.019
104 | CA-CB PRO 1.531 0.020
105 | CB-CG PRO 1.495 0.050
106 | CG-CD PRO 1.502 0.033
107 | CD-N PRO 1.474 0.014
108 | N-CA PRO 1.468 0.017
109 | CA-C PRO 1.524 0.020
110 | C-O PRO 1.228 0.020
111 | CA-CB SER 1.525 0.015
112 | CB-OG SER 1.418 0.013
113 | N-CA SER 1.459 0.020
114 | CA-C SER 1.525 0.026
115 | C-O SER 1.229 0.019
116 | CA-CB THR 1.529 0.026
117 | CB-OG1 THR 1.428 0.020
118 | CB-CG2 THR 1.519 0.033
119 | N-CA THR 1.459 0.020
120 | CA-C THR 1.525 0.026
121 | C-O THR 1.229 0.019
122 | CA-CB TRP 1.535 0.022
123 | CB-CG TRP 1.498 0.018
124 | CG-CD1 TRP 1.363 0.014
125 | CG-CD2 TRP 1.432 0.017
126 | CD1-NE1 TRP 1.375 0.017
127 | NE1-CE2 TRP 1.371 0.013
128 | CD2-CE2 TRP 1.409 0.012
129 | CD2-CE3 TRP 1.399 0.015
130 | CE2-CZ2 TRP 1.393 0.017
131 | CE3-CZ3 TRP 1.380 0.017
132 | CZ2-CH2 TRP 1.369 0.019
133 | CZ3-CH2 TRP 1.396 0.016
134 | N-CA TRP 1.459 0.020
135 | CA-C TRP 1.525 0.026
136 | C-O TRP 1.229 0.019
137 | CA-CB TYR 1.535 0.022
138 | CB-CG TYR 1.512 0.015
139 | CG-CD1 TYR 1.387 0.013
140 | CG-CD2 TYR 1.387 0.013
141 | CD1-CE1 TYR 1.389 0.015
142 | CD2-CE2 TYR 1.389 0.015
143 | CE1-CZ TYR 1.381 0.013
144 | CE2-CZ TYR 1.381 0.013
145 | CZ-OH TYR 1.374 0.017
146 | N-CA TYR 1.459 0.020
147 | CA-C TYR 1.525 0.026
148 | C-O TYR 1.229 0.019
149 | CA-CB VAL 1.543 0.021
150 | CB-CG1 VAL 1.524 0.021
151 | CB-CG2 VAL 1.524 0.021
152 | N-CA VAL 1.459 0.020
153 | CA-C VAL 1.525 0.026
154 | C-O VAL 1.229 0.019
155 | -
156 |
157 | Angle Residue Mean StdDev
158 | N-CA-CB ALA 110.1 1.4
159 | CB-CA-C ALA 110.1 1.5
160 | N-CA-C ALA 111.0 2.7
161 | CA-C-O ALA 120.1 2.1
162 | N-CA-CB ARG 110.6 1.8
163 | CB-CA-C ARG 110.4 2.0
164 | CA-CB-CG ARG 113.4 2.2
165 | CB-CG-CD ARG 111.6 2.6
166 | CG-CD-NE ARG 111.8 2.1
167 | CD-NE-CZ ARG 123.6 1.4
168 | NE-CZ-NH1 ARG 120.3 0.5
169 | NE-CZ-NH2 ARG 120.3 0.5
170 | NH1-CZ-NH2 ARG 119.4 1.1
171 | N-CA-C ARG 111.0 2.7
172 | CA-C-O ARG 120.1 2.1
173 | N-CA-CB ASN 110.6 1.8
174 | CB-CA-C ASN 110.4 2.0
175 | CA-CB-CG ASN 113.4 2.2
176 | CB-CG-ND2 ASN 116.7 2.4
177 | CB-CG-OD1 ASN 121.6 2.0
178 | ND2-CG-OD1 ASN 121.9 2.3
179 | N-CA-C ASN 111.0 2.7
180 | CA-C-O ASN 120.1 2.1
181 | N-CA-CB ASP 110.6 1.8
182 | CB-CA-C ASP 110.4 2.0
183 | CA-CB-CG ASP 113.4 2.2
184 | CB-CG-OD1 ASP 118.3 0.9
185 | CB-CG-OD2 ASP 118.3 0.9
186 | OD1-CG-OD2 ASP 123.3 1.9
187 | N-CA-C ASP 111.0 2.7
188 | CA-C-O ASP 120.1 2.1
189 | N-CA-CB CYS 110.8 1.5
190 | CB-CA-C CYS 111.5 1.2
191 | CA-CB-SG CYS 114.2 1.1
192 | N-CA-C CYS 111.0 2.7
193 | CA-C-O CYS 120.1 2.1
194 | N-CA-CB GLU 110.6 1.8
195 | CB-CA-C GLU 110.4 2.0
196 | CA-CB-CG GLU 113.4 2.2
197 | CB-CG-CD GLU 114.2 2.7
198 | CG-CD-OE1 GLU 118.3 2.0
199 | CG-CD-OE2 GLU 118.3 2.0
200 | OE1-CD-OE2 GLU 123.3 1.2
201 | N-CA-C GLU 111.0 2.7
202 | CA-C-O GLU 120.1 2.1
203 | N-CA-CB GLN 110.6 1.8
204 | CB-CA-C GLN 110.4 2.0
205 | CA-CB-CG GLN 113.4 2.2
206 | CB-CG-CD GLN 111.6 2.6
207 | CG-CD-OE1 GLN 121.6 2.0
208 | CG-CD-NE2 GLN 116.7 2.4
209 | OE1-CD-NE2 GLN 121.9 2.3
210 | N-CA-C GLN 111.0 2.7
211 | CA-C-O GLN 120.1 2.1
212 | N-CA-C GLY 113.1 2.5
213 | CA-C-O GLY 120.6 1.8
214 | N-CA-CB HIS 110.6 1.8
215 | CB-CA-C HIS 110.4 2.0
216 | CA-CB-CG HIS 113.6 1.7
217 | CB-CG-ND1 HIS 123.2 2.5
218 | CB-CG-CD2 HIS 130.8 3.1
219 | CG-ND1-CE1 HIS 108.2 1.4
220 | ND1-CE1-NE2 HIS 109.9 2.2
221 | CE1-NE2-CD2 HIS 106.6 2.5
222 | NE2-CD2-CG HIS 109.2 1.9
223 | CD2-CG-ND1 HIS 106.0 1.4
224 | N-CA-C HIS 111.0 2.7
225 | CA-C-O HIS 120.1 2.1
226 | N-CA-CB ILE 110.8 2.3
227 | CB-CA-C ILE 111.6 2.0
228 | CA-CB-CG1 ILE 111.0 1.9
229 | CB-CG1-CD1 ILE 113.9 2.8
230 | CA-CB-CG2 ILE 110.9 2.0
231 | CG1-CB-CG2 ILE 111.4 2.2
232 | N-CA-C ILE 111.0 2.7
233 | CA-C-O ILE 120.1 2.1
234 | N-CA-CB LEU 110.4 2.0
235 | CB-CA-C LEU 110.2 1.9
236 | CA-CB-CG LEU 115.3 2.3
237 | CB-CG-CD1 LEU 111.0 1.7
238 | CB-CG-CD2 LEU 111.0 1.7
239 | CD1-CG-CD2 LEU 110.5 3.0
240 | N-CA-C LEU 111.0 2.7
241 | CA-C-O LEU 120.1 2.1
242 | N-CA-CB LYS 110.6 1.8
243 | CB-CA-C LYS 110.4 2.0
244 | CA-CB-CG LYS 113.4 2.2
245 | CB-CG-CD LYS 111.6 2.6
246 | CG-CD-CE LYS 111.9 3.0
247 | CD-CE-NZ LYS 111.7 2.3
248 | N-CA-C LYS 111.0 2.7
249 | CA-C-O LYS 120.1 2.1
250 | N-CA-CB MET 110.6 1.8
251 | CB-CA-C MET 110.4 2.0
252 | CA-CB-CG MET 113.3 1.7
253 | CB-CG-SD MET 112.4 3.0
254 | CG-SD-CE MET 100.2 1.6
255 | N-CA-C MET 111.0 2.7
256 | CA-C-O MET 120.1 2.1
257 | N-CA-CB PHE 110.6 1.8
258 | CB-CA-C PHE 110.4 2.0
259 | CA-CB-CG PHE 113.9 2.4
260 | CB-CG-CD1 PHE 120.8 0.7
261 | CB-CG-CD2 PHE 120.8 0.7
262 | CD1-CG-CD2 PHE 118.3 1.3
263 | CG-CD1-CE1 PHE 120.8 1.1
264 | CG-CD2-CE2 PHE 120.8 1.1
265 | CD1-CE1-CZ PHE 120.1 1.2
266 | CD2-CE2-CZ PHE 120.1 1.2
267 | CE1-CZ-CE2 PHE 120.0 1.8
268 | N-CA-C PHE 111.0 2.7
269 | CA-C-O PHE 120.1 2.1
270 | N-CA-CB PRO 103.3 1.2
271 | CB-CA-C PRO 111.7 2.1
272 | CA-CB-CG PRO 104.8 1.9
273 | CB-CG-CD PRO 106.5 3.9
274 | CG-CD-N PRO 103.2 1.5
275 | CA-N-CD PRO 111.7 1.4
276 | N-CA-C PRO 112.1 2.6
277 | CA-C-O PRO 120.2 2.4
278 | N-CA-CB SER 110.5 1.5
279 | CB-CA-C SER 110.1 1.9
280 | CA-CB-OG SER 111.2 2.7
281 | N-CA-C SER 111.0 2.7
282 | CA-C-O SER 120.1 2.1
283 | N-CA-CB THR 110.3 1.9
284 | CB-CA-C THR 111.6 2.7
285 | CA-CB-OG1 THR 109.0 2.1
286 | CA-CB-CG2 THR 112.4 1.4
287 | OG1-CB-CG2 THR 110.0 2.3
288 | N-CA-C THR 111.0 2.7
289 | CA-C-O THR 120.1 2.1
290 | N-CA-CB TRP 110.6 1.8
291 | CB-CA-C TRP 110.4 2.0
292 | CA-CB-CG TRP 113.7 1.9
293 | CB-CG-CD1 TRP 127.0 1.3
294 | CB-CG-CD2 TRP 126.6 1.3
295 | CD1-CG-CD2 TRP 106.3 0.8
296 | CG-CD1-NE1 TRP 110.1 1.0
297 | CD1-NE1-CE2 TRP 109.0 0.9
298 | NE1-CE2-CD2 TRP 107.3 1.0
299 | CE2-CD2-CG TRP 107.3 0.8
300 | CG-CD2-CE3 TRP 133.9 0.9
301 | NE1-CE2-CZ2 TRP 130.4 1.1
302 | CE3-CD2-CE2 TRP 118.7 1.2
303 | CD2-CE2-CZ2 TRP 122.3 1.2
304 | CE2-CZ2-CH2 TRP 117.4 1.0
305 | CZ2-CH2-CZ3 TRP 121.6 1.2
306 | CH2-CZ3-CE3 TRP 121.2 1.1
307 | CZ3-CE3-CD2 TRP 118.8 1.3
308 | N-CA-C TRP 111.0 2.7
309 | CA-C-O TRP 120.1 2.1
310 | N-CA-CB TYR 110.6 1.8
311 | CB-CA-C TYR 110.4 2.0
312 | CA-CB-CG TYR 113.4 1.9
313 | CB-CG-CD1 TYR 121.0 0.6
314 | CB-CG-CD2 TYR 121.0 0.6
315 | CD1-CG-CD2 TYR 117.9 1.1
316 | CG-CD1-CE1 TYR 121.3 0.8
317 | CG-CD2-CE2 TYR 121.3 0.8
318 | CD1-CE1-CZ TYR 119.8 0.9
319 | CD2-CE2-CZ TYR 119.8 0.9
320 | CE1-CZ-CE2 TYR 119.8 1.6
321 | CE1-CZ-OH TYR 120.1 2.7
322 | CE2-CZ-OH TYR 120.1 2.7
323 | N-CA-C TYR 111.0 2.7
324 | CA-C-O TYR 120.1 2.1
325 | N-CA-CB VAL 111.5 2.2
326 | CB-CA-C VAL 111.4 1.9
327 | CA-CB-CG1 VAL 110.9 1.5
328 | CA-CB-CG2 VAL 110.9 1.5
329 | CG1-CB-CG2 VAL 110.9 1.6
330 | N-CA-C VAL 111.0 2.7
331 | CA-C-O VAL 120.1 2.1
332 | -
333 |
334 | Non-bonded distance Minimum Dist Tolerance
335 | C-C 3.4 1.5
336 | C-N 3.25 1.5
337 | C-S 3.5 1.5
338 | C-O 3.22 1.5
339 | N-N 3.1 1.5
340 | N-S 3.35 1.5
341 | N-O 3.07 1.5
342 | O-S 3.32 1.5
343 | O-O 3.04 1.5
344 | S-S 2.03 1.0
345 | -
346 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/src/ibex/refine.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | # Copyright (c) 2022, Brennan Abanades Kenyon
3 | # All rights reserved.
4 |
5 | #
6 | # BSD 3-Clause License
7 | #
8 | # Redistribution and use in source and binary forms, with or without
9 | # modification, are permitted provided that the following conditions are met:
10 | #
11 | # 1. Redistributions of source code must retain the above copyright notice, this
12 | # list of conditions and the following disclaimer.
13 | #
14 | # 2. Redistributions in binary form must reproduce the above copyright notice,
15 | # this list of conditions and the following disclaimer in the documentation
16 | # and/or other materials provided with the distribution.
17 | #
18 | # 3. Neither the name of the copyright holder nor the names of its
19 | # contributors may be used to endorse or promote products derived from
20 | # this software without specific prior written permission.
21 | #
22 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
23 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
26 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
27 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
28 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 | #
33 | # Implementation taken from https://github.com/brennanaba/ImmuneBuilder/blob/main/ImmuneBuilder/refine.py
34 |
35 | import os
36 | import numpy as np
37 | from openmm import app, LangevinIntegrator, CustomExternalForce, CustomTorsionForce, OpenMMException, Platform, unit
38 | from scipy import spatial
39 | import logging
40 | import pdbfixer
41 | logging.disable()
42 |
43 | ENERGY = unit.kilocalories_per_mole
44 | LENGTH = unit.angstroms
45 | spring_unit = ENERGY / (LENGTH ** 2)
46 |
47 | CLASH_CUTOFF = 0.63
48 |
49 | # Atomic radii for various atom types.
50 | atom_radii = {"C": 1.70, "N": 1.55, 'O': 1.52, 'S': 1.80}
51 |
52 | # Sum of van-der-waals radii
53 | radii_sums = dict(
54 | [(i + j, (atom_radii[i] + atom_radii[j])) for i in list(atom_radii.keys()) for j in list(atom_radii.keys())])
55 | # Clash_cutoff-based radii values
56 | cutoffs = dict(
57 | [(i + j, CLASH_CUTOFF * (radii_sums[i + j])) for i in list(atom_radii.keys()) for j in list(atom_radii.keys())])
58 |
59 | # Using amber14 recommended protein force field
60 | forcefield = app.ForceField("amber14/protein.ff14SB.xml")
61 |
62 |
63 | def refine_file(input_file, output_file, checks=True, tries=3, n=6, n_threads=-1):
64 | for _ in range(tries):
65 | if refine_once(input_file, output_file, checks=checks, n=n, n_threads=n_threads):
66 | return True
67 | return False
68 |
69 |
70 | def refine_once(input_file, output_file, checks=True, n=6, n_threads=-1):
71 | k1s = [2.5,1,0.5,0.25,0.1,0.001]
72 | k2s = [2.5,5,7.5,15,25,50]
73 | success = False
74 |
75 | fixer = pdbfixer.PDBFixer(input_file)
76 |
77 | fixer.findMissingResidues()
78 | fixer.findMissingAtoms()
79 | fixer.addMissingAtoms()
80 |
81 | k1 = k1s[0]
82 | if checks:
83 | k2 = -1 if (cis_check(fixer.topology, fixer.positions)) else k2s[0]
84 | else:
85 | k2 = -1
86 |
87 | topology, positions = fixer.topology, fixer.positions
88 |
89 | for i in range(n):
90 | try:
91 | simulation = minimize_energy(topology, positions, k1=k1, k2 = k2, n_threads=n_threads)
92 | topology, positions = simulation.topology, simulation.context.getState(getPositions=True).getPositions()
93 | if checks:
94 | acceptable_bonds, trans_peptide_bonds = bond_check(topology, positions), cis_check(topology, positions)
95 | except OpenMMException as e:
96 | if (i == n-1) and ("positions" not in locals()):
97 | print("OpenMM failed to refine {}".format(input_file), flush=True)
98 | raise e
99 | else:
100 | topology, positions = fixer.topology, fixer.positions
101 | continue
102 |
103 | if checks:
104 |
105 | # If peptide bonds are the wrong length, decrease the strength of the positional restraint
106 | if not acceptable_bonds:
107 | k1 = k1s[min(i, len(k1s)-1)]
108 |
109 | # If there are still cis isomers in the model, increase the force to fix these
110 | if not trans_peptide_bonds:
111 | k2 = k2s[min(i, len(k2s)-1)]
112 | else:
113 | k2 = -1
114 |
115 | if acceptable_bonds and trans_peptide_bonds:
116 | # If peptide bond lengths and torsions are okay, check and fix the chirality.
117 | try:
118 | simulation = chirality_fixer(simulation)
119 | topology, positions = simulation.topology, simulation.context.getState(getPositions=True).getPositions()
120 | except OpenMMException as e:
121 | topology, positions = fixer.topology, fixer.positions
122 | continue
123 |
124 | # If all other checks pass, check and fix strained sidechain bonds:
125 | try:
126 | strained_bonds = strained_sidechain_bonds_check(topology, positions)
127 | if len(strained_bonds) > 0:
128 | needs_recheck = True
129 | topology, positions = strained_sidechain_bonds_fixer(strained_bonds, topology, positions, n_threads=n_threads)
130 | else:
131 | needs_recheck = False
132 | except OpenMMException as e:
133 | topology, positions = fixer.topology, fixer.positions
134 | continue
135 |
136 | # If it passes all the tests, we are done
137 | tests = bond_check(topology, positions) and cis_check(topology, positions)
138 | if needs_recheck:
139 | tests = tests and strained_sidechain_bonds_check(topology, positions)
140 | if tests and stereo_check(topology, positions) and clash_check(topology, positions):
141 | success = True
142 | break
143 |
144 | else:
145 | success = True
146 | break
147 |
148 | with open(output_file, "w") as out_handle:
149 | app.PDBFile.writeFile(topology, positions, out_handle, keepIds=True)
150 |
151 | return success
152 |
153 |
154 | def minimize_energy(topology, positions, k1=2.5, k2=2.5, n_threads=-1):
155 | # Fill in the gaps with OpenMM Modeller
156 | modeller = app.Modeller(topology, positions)
157 | modeller.addHydrogens(forcefield)
158 |
159 | # Set up force field
160 | system = forcefield.createSystem(modeller.topology)
161 |
162 | # Keep atoms close to initial prediction
163 | force = CustomExternalForce("k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
164 | force.addGlobalParameter("k", k1 * spring_unit)
165 | for p in ["x0", "y0", "z0"]:
166 | force.addPerParticleParameter(p)
167 |
168 | for residue in modeller.topology.residues():
169 | for atom in residue.atoms():
170 | if atom.name in ["CA", "CB", "N", "C"]:
171 | force.addParticle(atom.index, modeller.positions[atom.index])
172 |
173 | system.addForce(force)
174 |
175 | if k2 > 0.0:
176 | cis_force = CustomTorsionForce("10*k2*(1+cos(theta))^2")
177 | cis_force.addGlobalParameter("k2", k2 * ENERGY)
178 |
179 | for chain in modeller.topology.chains():
180 | residues = [res for res in chain.residues()]
181 | relevant_atoms = [{atom.name:atom.index for atom in res.atoms() if atom.name in ["N", "CA", "C"]} for res in residues]
182 | for i in range(1,len(residues)):
183 | if residues[i].name == "PRO":
184 | continue
185 |
186 | resi = relevant_atoms[i-1]
187 | n_resi = relevant_atoms[i]
188 | cis_force.addTorsion(resi["CA"], resi["C"], n_resi["N"], n_resi["CA"])
189 |
190 | system.addForce(cis_force)
191 |
192 | # Set up integrator
193 | integrator = LangevinIntegrator(0, 0.01, 0.0)
194 |
195 | # Set up the simulation
196 | if n_threads > 0:
197 | # Set number of threads used by OpenMM
198 | platform = Platform.getPlatformByName('CPU')
199 | simulation = app.Simulation(modeller.topology, system, integrator, platform, {'Threads': str(n_threads)})
200 | else:
201 | simulation = app.Simulation(modeller.topology, system, integrator)
202 | simulation.context.setPositions(modeller.positions)
203 |
204 | # Minimize the energy
205 | simulation.minimizeEnergy()
206 |
207 | return simulation
208 |
209 |
210 | def chirality_fixer(simulation):
211 | topology = simulation.topology
212 | positions = simulation.context.getState(getPositions=True).getPositions()
213 |
214 | d_stereoisomers = []
215 | for residue in topology.residues():
216 | if residue.name == "GLY":
217 | continue
218 |
219 | atom_indices = {atom.name:atom.index for atom in residue.atoms() if atom.name in ["N", "CA", "C", "CB"]}
220 | vectors = [positions[atom_indices[i]] - positions[atom_indices["CA"]] for i in ["N", "C", "CB"]]
221 |
222 | if np.dot(np.cross(vectors[0], vectors[1]), vectors[2]) < .0*LENGTH**3:
223 | # If it is a D-stereoisomer then flip its H atom
224 | indices = {x.name:x.index for x in residue.atoms() if x.name in ["HA", "CA"]}
225 | positions[indices["HA"]] = 2*positions[indices["CA"]] - positions[indices["HA"]]
226 |
227 | # Fix the H atom in place
228 | particle_mass = simulation.system.getParticleMass(indices["HA"])
229 | simulation.system.setParticleMass(indices["HA"], 0.0)
230 | d_stereoisomers.append((indices["HA"], particle_mass))
231 |
232 | if len(d_stereoisomers) > 0:
233 | simulation.context.setPositions(positions)
234 |
235 | # Minimize the energy with the evil hydrogens fixed
236 | simulation.minimizeEnergy()
237 |
238 | # Minimize the energy letting the hydrogens move
239 | for atom in d_stereoisomers:
240 | simulation.system.setParticleMass(*atom)
241 | simulation.minimizeEnergy()
242 |
243 | return simulation
244 |
245 |
246 | def bond_check(topology, positions):
247 | for chain in topology.chains():
248 | residues = [{atom.name:atom.index for atom in res.atoms() if atom.name in ["N", "C"]} for res in chain.residues()]
249 | for i in range(len(residues)-1):
250 | # For simplicity we only check the peptide bond length as the rest should be correct as they are hard coded
251 | v = np.linalg.norm(positions[residues[i]["C"]] - positions[residues[i+1]["N"]])
252 | if abs(v - 1.329*LENGTH) > 0.1*LENGTH:
253 | return False
254 | return True
255 |
256 |
257 | def cis_bond(p0,p1,p2,p3):
258 | ab = p1-p0
259 | cd = p2-p1
260 | db = p3-p2
261 |
262 | u = np.cross(-ab, cd)
263 | v = np.cross(db, cd)
264 | return np.dot(u,v) > 0
265 |
266 |
267 | def cis_check(topology, positions):
268 | pos = np.array(positions.value_in_unit(LENGTH))
269 | for chain in topology.chains():
270 | residues = [res for res in chain.residues()]
271 | relevant_atoms = [{atom.name:atom.index for atom in res.atoms() if atom.name in ["N", "CA", "C"]} for res in residues]
272 | for i in range(1,len(residues)):
273 | if residues[i].name == "PRO":
274 | continue
275 |
276 | resi = relevant_atoms[i-1]
277 | n_resi = relevant_atoms[i]
278 | p0,p1,p2,p3 = pos[resi["CA"]],pos[resi["C"]],pos[n_resi["N"]],pos[n_resi["CA"]]
279 | if cis_bond(p0,p1,p2,p3):
280 | return False
281 | return True
282 |
283 |
284 | def stereo_check(topology, positions):
285 | pos = np.array(positions.value_in_unit(LENGTH))
286 | for residue in topology.residues():
287 | if residue.name == "GLY":
288 | continue
289 |
290 | atom_indices = {atom.name:atom.index for atom in residue.atoms() if atom.name in ["N", "CA", "C", "CB"]}
291 | vectors = pos[[atom_indices[i] for i in ["N", "C", "CB"]]] - pos[atom_indices["CA"]]
292 |
293 | if np.linalg.det(vectors) < 0:
294 | return False
295 | return True
296 |
297 |
298 | def clash_check(topology, positions):
299 | heavies = [x for x in topology.atoms() if x.element.symbol != "H"]
300 | pos = np.array(positions.value_in_unit(LENGTH))[[x.index for x in heavies]]
301 |
302 | tree = spatial.KDTree(pos)
303 | pairs = tree.query_pairs(r=max(cutoffs.values()))
304 |
305 | for pair in pairs:
306 | atom_i, atom_j = heavies[pair[0]], heavies[pair[1]]
307 |
308 | if atom_i.residue.index == atom_j.residue.index:
309 | continue
310 | elif (atom_i.name == "C" and atom_j.name == "N") or (atom_i.name == "N" and atom_j.name == "C"):
311 | continue
312 |
313 | atom_distance = np.linalg.norm(pos[pair[0]] - pos[pair[1]])
314 |
315 | if (atom_i.name == "SG" and atom_j.name == "SG") and atom_distance > 1.88:
316 | continue
317 |
318 | elif atom_distance < (cutoffs[atom_i.element.symbol + atom_j.element.symbol]):
319 | return False
320 | return True
321 |
322 |
323 | def strained_sidechain_bonds_check(topology, positions):
324 | atoms = list(topology.atoms())
325 | pos = np.array(positions.value_in_unit(LENGTH))
326 |
327 | system = forcefield.createSystem(topology)
328 | bonds = [x for x in system.getForces() if type(x).__name__ == "HarmonicBondForce"][0]
329 |
330 | # Initialise arrays for bond details
331 | n_bonds = bonds.getNumBonds()
332 | i = np.empty(n_bonds, dtype=int)
333 | j = np.empty(n_bonds, dtype=int)
334 | k = np.empty(n_bonds)
335 | x0 = np.empty(n_bonds)
336 |
337 | # Extract bond details to arrays
338 | for n in range(n_bonds):
339 | i[n],j[n],_x0,_k = bonds.getBondParameters(n)
340 | k[n] = _k.value_in_unit(spring_unit)
341 | x0[n] = _x0.value_in_unit(LENGTH)
342 |
343 | # Check if there are any abnormally strained bond
344 | distance = np.linalg.norm(pos[i] - pos[j], axis=-1)
345 | check = k*(distance - x0)**2 > 100
346 |
347 | # Return residues with strained bonds if any
348 | return [atoms[x].residue for x in i[check]]
349 |
350 |
351 | def strained_sidechain_bonds_fixer(strained_residues, topology, positions, n_threads=-1):
352 | # Delete all atoms except the main chain for badly refined residues.
353 | bb_atoms = ["N","CA","C"]
354 | bad_side_chains = sum([[atom for atom in residue.atoms() if atom.name not in bb_atoms] for residue in strained_residues],[])
355 | modeller = app.Modeller(topology, positions)
356 | modeller.delete(bad_side_chains)
357 |
358 | # Save model with deleted side chains to temporary file.
359 | random_number = str(int(np.random.rand()*10**8))
360 | tmp_file = f"side_chain_fix_tmp_{random_number}.pdb"
361 | with open(tmp_file,"w") as handle:
362 | app.PDBFile.writeFile(modeller.topology, modeller.positions, handle, keepIds=True)
363 |
364 | # Load model into pdbfixer
365 | fixer = pdbfixer.PDBFixer(tmp_file)
366 | os.remove(tmp_file)
367 |
368 | # Repair deleted side chains
369 | fixer.findMissingResidues()
370 | fixer.findMissingAtoms()
371 | fixer.addMissingAtoms()
372 |
373 | # Fill in the gaps with OpenMM Modeller
374 | modeller = app.Modeller(fixer.topology, fixer.positions)
375 | modeller.addHydrogens(forcefield)
376 |
377 | # Set up force field
378 | system = forcefield.createSystem(modeller.topology)
379 |
380 | # Set up integrator
381 | integrator = LangevinIntegrator(0, 0.01, 0.0)
382 |
383 | # Set up the simulation
384 | if n_threads > 0:
385 | # Set number of threads used by OpenMM
386 | platform = Platform.getPlatformByName('CPU')
387 | simulation = app.Simulation(modeller.topology, system, integrator, platform, {'Threads', str(n_threads)})
388 | else:
389 | simulation = app.Simulation(modeller.topology, system, integrator)
390 | simulation.context.setPositions(modeller.positions)
391 |
392 | # Minimize the energy
393 | simulation.minimizeEnergy()
394 |
395 | return simulation.topology, simulation.context.getState(getPositions=True).getPositions()
396 |
--------------------------------------------------------------------------------
/src/ibex/openfold/utils/protein.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | # Copyright 2024 Exscientia
3 | # Copyright 2021 AlQuraishi Laboratory
4 | # Copyright 2021 DeepMind Technologies Limited
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | """Protein data type."""
19 | import dataclasses
20 | import io
21 | import re
22 | import string
23 | from typing import Any, Mapping, Optional, Sequence
24 |
25 | import numpy as np
26 | from Bio.PDB import PDBParser
27 |
28 | from ibex.openfold.utils import residue_constants
29 |
30 | FeatureDict = Mapping[str, np.ndarray]
31 | ModelOutput = Mapping[str, Any] # Is a nested dict.
32 | PICO_TO_ANGSTROM = 0.01
33 |
34 |
35 | @dataclasses.dataclass(frozen=True)
36 | class Protein:
37 | """Protein structure representation."""
38 |
39 | # Cartesian coordinates of atoms in angstroms. The atom types correspond to
40 | # residue_constants.atom_types, i.e. the first three are N, CA, CB.
41 | atom_positions: np.ndarray # [num_res, num_atom_type, 3]
42 |
43 | # Amino-acid type for each residue represented as an integer between 0 and
44 | # 20, where 20 is 'X'.
45 | aatype: np.ndarray # [num_res]
46 |
47 | # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
48 | # is present and 0.0 if not. This should be used for loss masking.
49 | atom_mask: np.ndarray # [num_res, num_atom_type]
50 |
51 | # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
52 | residue_index: np.ndarray # [num_res]
53 |
54 | # B-factors, or temperature factors, of each residue (in sq. angstroms units),
55 | # representing the displacement of the residue from its ground truth mean
56 | # value.
57 | b_factors: np.ndarray # [num_res, num_atom_type]
58 |
59 | # Chain indices for multi-chain predictions
60 | chain_index: Optional[np.ndarray] = None
61 |
62 | # Optional remark about the protein. Included as a comment in output PDB
63 | # files
64 | remark: Optional[str] = None
65 |
66 | # Templates used to generate this protein (prediction-only)
67 | parents: Optional[Sequence[str]] = None
68 |
69 | # Chain corresponding to each parent
70 | parents_chain_index: Optional[Sequence[int]] = None
71 |
72 |
73 | def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
74 | """Takes a PDB string and constructs a Protein object.
75 |
76 | WARNING: All non-standard residue types will be converted into UNK. All
77 | non-standard atoms will be ignored.
78 |
79 | Args:
80 | pdb_str: The contents of the pdb file
81 | chain_id: If None, then the pdb file must contain a single chain (which
82 | will be parsed). If chain_id is specified (e.g. A), then only that chain
83 | is parsed.
84 |
85 | Returns:
86 | A new `Protein` parsed from the pdb contents.
87 | """
88 | pdb_fh = io.StringIO(pdb_str)
89 | parser = PDBParser(QUIET=True)
90 | structure = parser.get_structure("none", pdb_fh)
91 | models = list(structure.get_models())
92 | if len(models) != 1:
93 | raise ValueError(
94 | f"Only single model PDBs are supported. Found {len(models)} models."
95 | )
96 | model = models[0]
97 |
98 | atom_positions = []
99 | aatype = []
100 | atom_mask = []
101 | residue_index = []
102 | chain_ids = []
103 | b_factors = []
104 |
105 | for chain in model:
106 | if chain_id is not None and chain.id != chain_id:
107 | continue
108 | for res in chain:
109 | if res.id[2] != " ":
110 | raise ValueError(
111 | f"PDB contains an insertion code at chain {chain.id} and residue "
112 | f"index {res.id[1]}. These are not supported."
113 | )
114 | res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
115 | restype_idx = residue_constants.restype_order.get(
116 | res_shortname, residue_constants.restype_num
117 | )
118 | pos = np.zeros((residue_constants.atom_type_num, 3))
119 | mask = np.zeros((residue_constants.atom_type_num,))
120 | res_b_factors = np.zeros((residue_constants.atom_type_num,))
121 | for atom in res:
122 | if atom.name not in residue_constants.atom_types:
123 | continue
124 | pos[residue_constants.atom_order[atom.name]] = atom.coord
125 | mask[residue_constants.atom_order[atom.name]] = 1.0
126 | res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
127 | if np.sum(mask) < 0.5:
128 | # If no known atom positions are reported for the residue then skip it.
129 | continue
130 | aatype.append(restype_idx)
131 | atom_positions.append(pos)
132 | atom_mask.append(mask)
133 | residue_index.append(res.id[1])
134 | chain_ids.append(chain.id)
135 | b_factors.append(res_b_factors)
136 |
137 | parents = None
138 | parents_chain_index = None
139 | if "PARENT" in pdb_str:
140 | parents = []
141 | parents_chain_index = []
142 | chain_id = 0
143 | for l in pdb_str.split("\n"):
144 | if "PARENT" in l:
145 | if not "N/A" in l:
146 | parent_names = l.split()[1:]
147 | parents.extend(parent_names)
148 | parents_chain_index.extend([chain_id for _ in parent_names])
149 | chain_id += 1
150 |
151 | unique_chain_ids = np.unique(chain_ids)
152 | chain_id_mapping = {cid: n for n, cid in enumerate(string.ascii_uppercase)}
153 | chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
154 |
155 | return Protein(
156 | atom_positions=np.array(atom_positions),
157 | atom_mask=np.array(atom_mask),
158 | aatype=np.array(aatype),
159 | residue_index=np.array(residue_index),
160 | chain_index=chain_index,
161 | b_factors=np.array(b_factors),
162 | parents=parents,
163 | parents_chain_index=parents_chain_index,
164 | )
165 |
166 |
167 | def from_proteinnet_string(proteinnet_str: str) -> Protein:
168 | tag_re = r"(\[[A-Z]+\]\n)"
169 | tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
170 | groups = zip(tags[0::2], [l.split("\n") for l in tags[1::2]])
171 |
172 | atoms = ["N", "CA", "C"]
173 | aatype = None
174 | atom_positions = None
175 | atom_mask = None
176 | for g in groups:
177 | if "[PRIMARY]" == g[0]:
178 | seq = g[1][0].strip()
179 | for i in range(len(seq)):
180 | if seq[i] not in residue_constants.restypes:
181 | seq[i] = "X"
182 | aatype = np.array(
183 | [
184 | residue_constants.restype_order.get(
185 | res_symbol, residue_constants.restype_num
186 | )
187 | for res_symbol in seq
188 | ]
189 | )
190 | elif "[TERTIARY]" == g[0]:
191 | tertiary = []
192 | for axis in range(3):
193 | tertiary.append(list(map(float, g[1][axis].split())))
194 | tertiary_np = np.array(tertiary)
195 | atom_positions = np.zeros(
196 | (len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)
197 | ).astype(np.float32)
198 | for i, atom in enumerate(atoms):
199 | atom_positions[:, residue_constants.atom_order[atom], :] = np.transpose(
200 | tertiary_np[:, i::3]
201 | )
202 | atom_positions *= PICO_TO_ANGSTROM
203 | elif "[MASK]" == g[0]:
204 | mask = np.array(list(map({"-": 0, "+": 1}.get, g[1][0].strip())))
205 | atom_mask = np.zeros(
206 | (
207 | len(mask),
208 | residue_constants.atom_type_num,
209 | )
210 | ).astype(np.float32)
211 | for i, atom in enumerate(atoms):
212 | atom_mask[:, residue_constants.atom_order[atom]] = 1
213 | atom_mask *= mask[..., None]
214 |
215 | return Protein(
216 | atom_positions=atom_positions,
217 | atom_mask=atom_mask,
218 | aatype=aatype,
219 | residue_index=np.arange(len(aatype)),
220 | b_factors=None,
221 | )
222 |
223 |
224 | def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
225 | pdb_headers = []
226 |
227 | remark = prot.remark
228 | if remark is not None:
229 | pdb_headers.append(f"REMARK {remark}")
230 |
231 | parents = prot.parents
232 | parents_chain_index = prot.parents_chain_index
233 | if parents_chain_index is not None:
234 | parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id]
235 |
236 | if parents is None or len(parents) == 0:
237 | parents = ["N/A"]
238 |
239 | pdb_headers.append(f"PARENT {' '.join(parents)}")
240 |
241 | return pdb_headers
242 |
243 |
244 | def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
245 | """Add pdb headers to an existing PDB string. Useful during multi-chain
246 | recycling
247 | """
248 | out_pdb_lines = []
249 | lines = pdb_str.split("\n")
250 |
251 | remark = prot.remark
252 | if remark is not None:
253 | out_pdb_lines.append(f"REMARK {remark}")
254 |
255 | parents_per_chain = None
256 | if prot.parents is not None and len(prot.parents) > 0:
257 | parents_per_chain = []
258 | if prot.parents_chain_index is not None:
259 | cur_chain = prot.parents_chain_index[0]
260 | parent_dict = {}
261 | for p, i in zip(prot.parents, prot.parents_chain_index):
262 | parent_dict.setdefault(str(i), [])
263 | parent_dict[str(i)].append(p)
264 |
265 | max_idx = max([int(chain_idx) for chain_idx in parent_dict])
266 | for i in range(max_idx + 1):
267 | chain_parents = parent_dict.get(str(i), ["N/A"])
268 | parents_per_chain.append(chain_parents)
269 | else:
270 | parents_per_chain.append(prot.parents)
271 | else:
272 | parents_per_chain = [["N/A"]]
273 |
274 | make_parent_line = lambda p: f"PARENT {' '.join(p)}"
275 |
276 | out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
277 |
278 | chain_counter = 0
279 | for i, l in enumerate(lines):
280 | if "PARENT" not in l and "REMARK" not in l:
281 | out_pdb_lines.append(l)
282 | if "TER" in l and not "END" in lines[i + 1]:
283 | chain_counter += 1
284 | if not chain_counter >= len(parents_per_chain):
285 | chain_parents = parents_per_chain[chain_counter]
286 | else:
287 | chain_parents = ["N/A"]
288 |
289 | out_pdb_lines.append(make_parent_line(chain_parents))
290 |
291 | return "\n".join(out_pdb_lines)
292 |
293 |
294 | def to_pdb(prot: Protein) -> str:
295 | """Converts a `Protein` instance to a PDB string.
296 |
297 | Args:
298 | prot: The protein to convert to PDB.
299 |
300 | Returns:
301 | PDB string.
302 | """
303 | restypes = residue_constants.restypes + ["X"]
304 | res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
305 | atom_types = residue_constants.atom_types
306 |
307 | pdb_lines = []
308 |
309 | atom_mask = prot.atom_mask
310 | aatype = prot.aatype
311 | atom_positions = prot.atom_positions
312 | residue_index = prot.residue_index.astype(np.int32)
313 | b_factors = prot.b_factors
314 | chain_index = prot.chain_index
315 |
316 | if np.any(aatype > residue_constants.restype_num):
317 | raise ValueError("Invalid aatypes.")
318 |
319 | headers = get_pdb_headers(prot)
320 | if len(headers) > 0:
321 | pdb_lines.extend(headers)
322 |
323 | n = aatype.shape[0]
324 | atom_index = 1
325 | prev_chain_index = 0
326 | chain_tags = ["H", "L"] # string.ascii_uppercase
327 | # Add all atom sites.
328 | for i in range(n):
329 | res_name_3 = res_1to3(aatype[i])
330 | for atom_name, pos, mask, b_factor in zip(
331 | atom_types, atom_positions[i], atom_mask[i], b_factors[i]
332 | ):
333 | if mask < 0.5:
334 | continue
335 |
336 | record_type = "ATOM"
337 | name = atom_name if len(atom_name) == 4 else f" {atom_name}"
338 | alt_loc = ""
339 | insertion_code = ""
340 | occupancy = 1.00
341 | element = atom_name[0] # Protein supports only C, N, O, S, this works.
342 | charge = ""
343 |
344 | chain_tag = "A"
345 | if chain_index is not None:
346 | chain_tag = chain_tags[chain_index[i]]
347 |
348 | # PDB is a columnar format, every space matters here!
349 | atom_line = (
350 | f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
351 | f"{res_name_3:>3} {chain_tag:>1}"
352 | f"{residue_index[i]:>4}{insertion_code:>1} "
353 | f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
354 | f"{occupancy:>6.2f}{b_factor:>6.2f} "
355 | f"{element:>2}{charge:>2}"
356 | )
357 | pdb_lines.append(atom_line)
358 | atom_index += 1
359 |
360 | should_terminate = i == n - 1
361 | if chain_index is not None:
362 | if i != n - 1 and chain_index[i + 1] != prev_chain_index:
363 | should_terminate = True
364 | prev_chain_index = chain_index[i + 1]
365 |
366 | if should_terminate:
367 | # Close the chain.
368 | chain_end = "TER"
369 | chain_termination_line = (
370 | f"{chain_end:<6}{atom_index:>5} "
371 | f"{res_1to3(aatype[i]):>3} "
372 | f"{chain_tag:>1}{residue_index[i]:>4}"
373 | )
374 | pdb_lines.append(chain_termination_line)
375 | atom_index += 1
376 |
377 | if i != n - 1:
378 | # "prev" is a misnomer here. This happens at the beginning of
379 | # each new chain.
380 | pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
381 |
382 | pdb_lines.append("END")
383 | pdb_lines.append("")
384 | return "\n".join(pdb_lines)
385 |
386 |
387 | def ideal_atom_mask(prot: Protein) -> np.ndarray:
388 | """Computes an ideal atom mask.
389 |
390 | `Protein.atom_mask` typically is defined according to the atoms that are
391 | reported in the PDB. This function computes a mask according to heavy atoms
392 | that should be present in the given sequence of amino acids.
393 |
394 | Args:
395 | prot: `Protein` whose fields are `numpy.ndarray` objects.
396 |
397 | Returns:
398 | An ideal atom mask.
399 | """
400 | return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
401 |
402 |
403 | def from_prediction(
404 | features: FeatureDict,
405 | result: ModelOutput,
406 | b_factors: Optional[np.ndarray] = None,
407 | chain_index: Optional[np.ndarray] = None,
408 | remark: Optional[str] = None,
409 | parents: Optional[Sequence[str]] = None,
410 | parents_chain_index: Optional[Sequence[int]] = None,
411 | ) -> Protein:
412 | """Assembles a protein from a prediction.
413 |
414 | Args:
415 | features: Dictionary holding model inputs.
416 | result: Dictionary holding model outputs.
417 | b_factors: (Optional) B-factors to use for the protein.
418 | chain_index: (Optional) Chain indices for multi-chain predictions
419 | remark: (Optional) Remark about the prediction
420 | parents: (Optional) List of template names
421 | Returns:
422 | A protein instance.
423 | """
424 | if b_factors is None:
425 | b_factors = np.zeros_like(result["final_atom_mask"])
426 |
427 | return Protein(
428 | aatype=features["aatype"],
429 | atom_positions=result["final_atom_positions"],
430 | atom_mask=result["final_atom_mask"],
431 | residue_index=features["residue_index"] + 1,
432 | b_factors=b_factors,
433 | chain_index=chain_index,
434 | remark=remark,
435 | parents=parents,
436 | parents_chain_index=parents_chain_index,
437 | )
438 |
--------------------------------------------------------------------------------
/src/ibex/dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | # Copyright 2024 Exscientia
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import random
17 | import pandas as pd
18 | import torch
19 | from tqdm import tqdm
20 | import itertools
21 | from math import ceil
22 | from torch import Tensor
23 | from typing import Iterator, List, Dict, Optional, Union
24 | from pathlib import Path
25 | from loguru import logger
26 | from torch.nn.utils.rnn import pad_sequence
27 | from torch.utils.data import Dataset, BatchSampler, Sampler
28 | from collections.abc import Iterable
29 | from torch.utils.data.distributed import DistributedSampler
30 |
31 | from ibex.openfold.utils.residue_constants import restype_order_with_x
32 |
33 | # Fileids from the test and validation set that are not present in our version of the dataset, or
34 | # where there are mismatches in the lengths of the chains.
35 | invalid_fileids = [
36 | # "8aon_BC", # valid
37 | # "4hjj_HL", # valid
38 | # "7pa6_KKKKKK", # valid
39 | # "7ce2_ZB", # valid
40 | # "7r8u_HHHLLL", # valid
41 | # "6o89_HL", # valid
42 | # "5d6c_HL", # valid
43 | # "7u8c_HL", # test
44 | # "7seg_HL", # test
45 | # "2a9m_IM", # test
46 | # "4buh_AA", # test
47 | # "6ss5_HHHLLL" # test
48 | "7sgm_HL", # mismatching lengths
49 | "7sgm_XY", # mismatching lengths
50 | "7sgm_KM", # mismatching lengths
51 | "6w9g_HL", # mismatching lengths
52 | "6w9g_KM", # mismatching lengths
53 | "6w9g_XY", # mismatching lengths
54 | "5iwl_AB", # mismatching lengths
55 | "5iwl_BA", # mismatching lengths
56 | "7sen_HL", # mismatching lengths
57 | "6w5a_HL", # mismatching lengths
58 | "6bjz_HL", # mismatching lengths
59 | "7t74_HL", # mismatching lengths
60 | "6dfv_BA", # mismatching lengths TCR
61 | "3tyf_DC", # mismatching lengths TCR
62 | ]
63 |
64 | class ABDataset(Dataset):
65 | def __init__(
66 | self,
67 | split: str,
68 | split_file: str,
69 | data_dir: str,
70 | limit: Optional[int] = None,
71 | use_vhh: bool = False,
72 | use_tcr: bool = False,
73 | rel_pos_dim: int = 16,
74 | edge_chain_feature: bool = False,
75 | use_plm_embeddings: bool = False,
76 | use_public_only: bool = False,
77 | use_private_only: bool = False,
78 | cluster_column: str = "cluster",
79 | weight_temperature: float = 1.0,
80 | weight_clip_quantile: float = 0.99,
81 | vhh_weight: float = 1.0,
82 | tcr_weight: float = 1.0,
83 | matched_weight: float = 1.0,
84 | conformation_node_feature: bool = False,
85 | use_contrastive: bool = False,
86 | is_matched: bool = False,
87 | use_boltz_only: bool = False,
88 | use_weights: bool = True,
89 | ) -> None:
90 | """Dataset of antibodies suitable for openfold structuremodules and loss functions.
91 |
92 | Args:
93 | path (str): root data folder
94 | split (str): "all", "train", "valid" or "test"
95 | """
96 | super().__init__()
97 | self.split = split
98 | self.limit = limit
99 | self.split_file = split_file
100 | self.data_dir = data_dir
101 | self.rel_pos_dim = rel_pos_dim
102 | self.edge_chain_feature = edge_chain_feature
103 | self.use_plm_embeddings = use_plm_embeddings
104 | self.conformation_node_feature = conformation_node_feature
105 | self.use_contrastive = use_contrastive
106 | self.is_matched = is_matched
107 | self.cluster_column = cluster_column
108 | self.weight_temperature = weight_temperature
109 | self.weight_clip_quantile = weight_clip_quantile
110 |
111 | self.df = pd.read_csv(split_file, index_col=0, dtype={'file_id': str})
112 | if not use_vhh:
113 | if "is_vhh" in self.df.columns:
114 | self.df = self.df[self.df["is_vhh"] == False]
115 | if not use_tcr:
116 | if "is_tcr" in self.df.columns:
117 | self.df = self.df[self.df["is_tcr"] == False]
118 | if split != "all":
119 | self.df = self.df.query(f"split=='{split}'")
120 | if use_public_only:
121 | self.df = self.df[self.df["is_internal"]==False]
122 | if use_private_only:
123 | self.df = self.df[self.df["is_internal"]]
124 | if 'is_boltz' in self.df.columns and use_boltz_only:
125 | self.df = self.df[self.df["is_boltz"]]
126 | if use_contrastive:
127 | self.df = self.df[self.df["is_matched"]==is_matched]
128 |
129 | self.df = self.df[~self.df.index.isin(invalid_fileids)]
130 |
131 | if limit is not None:
132 | self.df = self.df.iloc[:limit]
133 |
134 | if cluster_column in self.df.columns and use_weights:
135 | # Get basic weight that is 1/cluster size
136 | probs = self.df[cluster_column].map(self.df.groupby(cluster_column).size()).values
137 | probs = torch.tensor(probs, dtype=torch.float32)
138 | probs = probs / probs.sum()
139 | self.weights = 1 / (probs ** self.weight_temperature)
140 | # Cap the weight to e.g. 0.99 quantile, to avoid bad outliers
141 | self.weights = torch.minimum(self.weights, torch.quantile(self.weights, self.weight_clip_quantile))
142 | # Make the weight sum to 1
143 | self.weights = self.weights / self.weights.sum()
144 | else:
145 | logger.info(f"Using uniform weights for {split_file}")
146 | self.weights = torch.ones(len(self.df)) / len(self.df)
147 |
148 | # Upweight TCRs and VHHs
149 | if use_vhh:
150 | self.weights[self.df["is_vhh"]] *= vhh_weight
151 | if use_tcr:
152 | self.weights[self.df["is_tcr"]] *= tcr_weight
153 | if 'is_matched' in self.df.columns:
154 | self.weights[self.df["is_matched"]] *= matched_weight
155 | self.weights = self.weights / self.weights.sum()
156 |
157 | def __len__(self):
158 | return len(self.df)
159 |
160 | def __getitem__(self, idx: int):
161 | if self.use_contrastive and self.is_matched:
162 | datapoint_pos = self._load_single_sample(self.df.iloc[idx].name)
163 | datapoint_neg = self._load_single_sample(self.df.iloc[idx].matched_index)
164 |
165 | datapoint_pos["is_matched"] = torch.tensor(True, dtype=torch.bool)
166 | datapoint_neg["is_matched"] = torch.tensor(True, dtype=torch.bool)
167 |
168 | return datapoint_pos, datapoint_neg
169 | # return datapoint_pos, datapoint_neg, datapoint_anchor
170 | elif self.split=='train' or self.split=='all':
171 | datapoint = self._load_single_sample(self.df.iloc[idx].name)
172 | # Get a random pairing respecting the weights
173 | idx_other = random.choices(
174 | range(len(self.df)),
175 | weights=self.weights,
176 | k=1,
177 | )[0]
178 | datapoint_other = self._load_single_sample(self.df.iloc[idx_other].name)
179 |
180 | datapoint["is_matched"] = torch.tensor(False, dtype=torch.bool)
181 | datapoint_other["is_matched"] = torch.tensor(False, dtype=torch.bool)
182 |
183 | return datapoint, datapoint_other
184 | else:
185 | datapoint = self._load_single_sample(self.df.iloc[idx].name)
186 | return datapoint
187 |
188 | def _load_single_sample(self, file_id):
189 | fname = Path(self.data_dir) / f"{file_id}.pt"
190 | datapoint = torch.load(fname, weights_only=False)
191 | datapoint.update(
192 | self.single_and_double_from_datapoint(
193 | datapoint,
194 | self.rel_pos_dim,
195 | self.edge_chain_feature,
196 | self.use_plm_embeddings,
197 | self.conformation_node_feature,
198 | )
199 | )
200 | for key in datapoint:
201 | if isinstance(datapoint[key], torch.Tensor):
202 | datapoint[key] = datapoint[key].detach()
203 | return datapoint
204 |
205 | @staticmethod
206 | def single_and_double_from_datapoint(
207 | datapoint: dict,
208 | rel_pos_dim: int,
209 | edge_chain_feature: bool = False,
210 | use_plm_embeddings: bool = False,
211 | conformation_node_feature: bool = False,
212 | ):
213 | """
214 | datapoint is a dict containing:
215 | aatype - [n,] tensor of ints for the amino acid (including unknown)
216 | is_heavy - [n,] tensor of ints where 1 is heavy chain and 0 is light chain.
217 | residue_index - [n,] tensor of ints assinging integer to each residue
218 |
219 | rel_pos_dim: integer determining edge feature dimension
220 |
221 | edge_chain_feature: boolean to add an edge feature z_ij that encodes what chain i and j are in.
222 |
223 | returns:
224 | A dictionary containing single a tensor of size (n, 23) and pair a tensor of size (n, n, 2 * rel_pos_dim + 1 + x) where x is 3 if edge_chain_feature and 0 otherwise.
225 | """
226 | single_aa = torch.nn.functional.one_hot(datapoint["aatype"], 21)
227 | single_chain = torch.nn.functional.one_hot(datapoint["is_heavy"].long(), 2)
228 | if conformation_node_feature:
229 | single_conformation = torch.nn.functional.one_hot(datapoint["is_apo"].long(), 2)
230 | single = torch.cat((single_aa, single_chain, single_conformation), dim=-1)
231 | else:
232 | single = torch.cat((single_aa, single_chain), dim=-1)
233 | pair = datapoint["residue_index"]
234 | pair = pair[None] - pair[:, None]
235 | pair = pair.clamp(-rel_pos_dim, rel_pos_dim) + rel_pos_dim
236 | pair = torch.nn.functional.one_hot(pair, 2 * rel_pos_dim + 1)
237 | if edge_chain_feature:
238 | is_heavy = datapoint["is_heavy"]
239 | is_heavy = 2 * is_heavy.outer(is_heavy) + (
240 | (1 - is_heavy).outer(1 - is_heavy)
241 | )
242 | is_heavy = torch.nn.functional.one_hot(is_heavy.long())
243 | pair = torch.cat((is_heavy, pair), dim=-1)
244 | if use_plm_embeddings:
245 | {"single": single.float(), "pair": pair.float(), "plm_embedding": datapoint["plm_embedding"]}
246 | return {"single": single.float(), "pair": pair.float()}
247 |
248 | def pad_square_tensors(tensors: list[torch.tensor]) -> torch.tensor:
249 | """Pads a list of tensors in the first two dimensions.
250 |
251 | Args:
252 | tensors (list[torch.tensor]): Input tensor are of shape (n_1, n_1, ...), (n_2, n_2, ...). where shape matches in the ... dimensions
253 |
254 | Returns:
255 | torch.tensor: A tensor of size (len(tensor), max(n_1,...), max(n_1,...), ...)
256 | """
257 | max_len = max(map(len, tensors))
258 | output = torch.zeros((len(tensors), max_len, max_len, *tensors[0].shape[2:]))
259 | for i, tensor in enumerate(tensors):
260 | output[i, : tensor.size(0), : tensor.size(1)] = tensor
261 | return output
262 |
263 |
264 | pad_first_dim_keys = [
265 | "atom14_gt_positions",
266 | "atom14_alt_gt_positions",
267 | "atom14_atom_is_ambiguous",
268 | "atom14_gt_exists",
269 | "atom14_alt_gt_exists",
270 | "atom14_atom_exists",
271 | "single",
272 | "plm_embedding",
273 | "seq_mask",
274 | "aatype",
275 | "backbone_rigid_tensor",
276 | "backbone_rigid_mask",
277 | "rigidgroups_gt_frames",
278 | "rigidgroups_alt_gt_frames",
279 | "rigidgroups_gt_exists",
280 | "cdr_mask",
281 | "chi_mask",
282 | "chi_angles_sin_cos",
283 | "residue_index",
284 | "residx_atom14_to_atom37",
285 | "region_numeric",
286 | ]
287 |
288 | pad_first_two_dim_keys = ["pair"]
289 |
290 |
291 | def string_to_input(heavy: str, light: str, apo: bool = False, conformation_aware: bool = False, embedding=None, device: str = "cpu") -> dict:
292 | """Generates an input formatted for an Ibex model from heavy and light chain
293 | strings.
294 |
295 | Args:
296 | heavy (str): heavy chain
297 | light (str): light chain
298 | apo (bool): whether the structure is apo or holo (optional)
299 |
300 | Returns:
301 | dict: A dictionary containing
302 | aatype: an (n,) tensor of integers encoding the amino acid string
303 | is_heavy: an (n,) tensor where is_heavy[i] = 1 means residue i is heavy and
304 | is_heavy[i] = 0 means residue i is light
305 | residue_index: an (n,) tensor with indices for each residue. There is a gap
306 | of 500 between the last heavy residue and the first light residue
307 | single: a (1, n, 23) tensor of node features
308 | pair: a (1, n, n, 132) tensor of edge features
309 | """
310 | aatype = []
311 | is_heavy = []
312 | for character in heavy:
313 | is_heavy.append(1)
314 | aatype.append(restype_order_with_x[character])
315 | if light is not None:
316 | for character in light:
317 | is_heavy.append(0)
318 | aatype.append(restype_order_with_x[character])
319 | is_heavy = torch.tensor(is_heavy)
320 | aatype = torch.tensor(aatype)
321 | if light is None:
322 | residue_index = torch.arange(len(heavy))
323 | else:
324 | residue_index = torch.cat(
325 | (torch.arange(len(heavy)), torch.arange(len(light)) + 500)
326 | )
327 |
328 | model_input = {
329 | "is_heavy": is_heavy,
330 | "aatype": aatype,
331 | "residue_index": residue_index,
332 | }
333 |
334 | if apo and conformation_aware:
335 | model_input["is_apo"] = torch.ones(len(heavy + (light if light is not None else '')),dtype=torch.int64)
336 | elif conformation_aware:
337 | model_input["is_apo"] = torch.zeros(len(heavy + (light if light is not None else '')),dtype=torch.int64)
338 | if embedding is not None:
339 | model_input["plm_embedding"] = embedding
340 | model_input.update(
341 | ABDataset.single_and_double_from_datapoint(
342 | model_input, 64, edge_chain_feature=True, conformation_node_feature=conformation_aware
343 | )
344 | )
345 | if "plm_embedding" in model_input:
346 | model_input["plm_embedding"] = model_input["plm_embedding"].unsqueeze(0)
347 | model_input["single"] = model_input["single"].unsqueeze(0)
348 | model_input["pair"] = model_input["pair"].unsqueeze(0)
349 | model_input = {k: v.to(device) for k, v in model_input.items()}
350 | return model_input
351 |
352 |
353 | def collate_fn(batch: dict):
354 | """A collate function so the ABDataset can be used in a torch dataloader.
355 |
356 | Args:
357 | batch (dict): A list of datapoints from ABDataset
358 |
359 | Returns:
360 | dict: A dictionary where the keys are the same as batch but map to a batched tensor where the batch is on the leading dimension.
361 | """
362 | # Flatten the batch structure (needed for contrastive examples, where some of the batch elements are tuples of objects)
363 | flattened_batch = []
364 | for item in batch:
365 | if isinstance(item, tuple):
366 | flattened_batch.extend(item)
367 | else:
368 | flattened_batch.append(item)
369 | batch = {key: [d[key] for d in flattened_batch] for key in flattened_batch[0]}
370 | output = {}
371 | for key in batch:
372 | if key in pad_first_dim_keys:
373 | output[key] = pad_sequence(batch[key], batch_first=True)
374 | elif key in pad_first_two_dim_keys:
375 | output[key] = pad_square_tensors(batch[key])
376 | elif key == "resolution":
377 | output[key] = torch.Tensor(batch["resolution"])
378 | elif key == "is_matched":
379 | output[key] = torch.stack(batch[key])
380 | return output
381 |
382 |
383 | class SequenceDataset(Dataset):
384 | def __init__(self, fv_heavy_list, fv_light_list, apo_list=None, plm_model=None, batch_size=2, num_workers=0):
385 | self.fv_heavy_list = fv_heavy_list
386 | self.fv_light_list = fv_light_list
387 | self.apo_list = apo_list
388 | if plm_model is not None:
389 | with torch.no_grad():
390 | self.fv_heavy_embedding = plm_model.embed_dataset(
391 | fv_heavy_list,
392 | batch_size=batch_size,
393 | max_len=max([len(x) for x in fv_heavy_list]),
394 | full_embeddings=True,
395 | full_precision=False,
396 | pooling_type="mean",
397 | num_workers=num_workers,
398 | sql=False
399 | )
400 | self.fv_light_embedding = plm_model.embed_dataset(
401 | [light for light in fv_light_list if light is not None and light!=''],
402 | batch_size=batch_size,
403 | max_len=max([len(x) for x in fv_light_list]),
404 | full_embeddings=True,
405 | full_precision=False,
406 | pooling_type="mean",
407 | num_workers=num_workers,
408 | sql=False
409 | ) if fv_light_list is not None else None
410 | else:
411 | self.fv_heavy_embedding = None
412 | self.fv_light_embedding = None
413 |
414 | def __len__(self):
415 | return len(self.fv_heavy_list)
416 |
417 | def __getitem__(self, idx):
418 | heavy = self.fv_heavy_list[idx]
419 | light = self.fv_light_list[idx] if self.fv_light_list is not None else None
420 | if self.fv_heavy_embedding is not None and self.fv_light_embedding is not None:
421 | if light is None:
422 | embedding = self.fv_heavy_embedding[heavy]
423 | else:
424 | embedding = torch.concat([self.fv_heavy_embedding[heavy], self.fv_light_embedding[light]])
425 | else:
426 | embedding = None
427 | if self.apo_list is not None:
428 | apo = self.apo_list[idx]
429 | return string_to_input(heavy, light, apo=apo, conformation_aware=True, embedding=embedding)
430 | return string_to_input(heavy, light, embedding=embedding)
431 |
432 | def sequence_collate_fn(batch):
433 | # Separating each component of the batch
434 | batch_dict = {key: [d[key] if key not in ["single","pair","plm_embedding"] else d[key].squeeze(0) for d in batch] for key in batch[0]}
435 | # Prepare output dictionary
436 | output = {}
437 | # Pad the "single" features
438 | output["single"] = pad_sequence(batch_dict["single"], batch_first=True)
439 | # Pad the "pair" features
440 | output["pair"] = pad_square_tensors(batch_dict["pair"])
441 | # Copy other keys directly
442 | for key in ["aatype", "residue_index", "is_heavy"]:
443 | output[key] = pad_sequence(batch_dict[key], batch_first=True)
444 | if "is_apo" in batch_dict:
445 | output["is_apo"] = pad_sequence(batch_dict["is_apo"], batch_first=True)
446 | if "plm_embedding" in batch_dict:
447 | output["plm_embedding"] = pad_sequence(batch_dict["plm_embedding"], batch_first=True)
448 | # Create mask based on the lengths of "aatype"
449 | mask = [torch.ones(len(aatype)) for aatype in batch_dict["aatype"]]
450 | output["mask"] = pad_sequence(mask, batch_first=True)
451 | return output
452 |
453 | if __name__ == "__main__":
454 | from torch.utils.data import DataLoader
455 |
456 |
457 | dataset = ABDataset("test", split_file="/data/dreyerf1/ibex/split.csv", data_dir="/data/dreyerf1/ibex/structures", edge_chain_feature=True)
458 | dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
459 | for batch in dataloader:
460 | break
461 | for key in batch:
462 | print(key, batch[key].shape)
463 |
--------------------------------------------------------------------------------
/src/ibex/loss/aligned_rmsd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 | CDR_RANGES_AHO = {
18 | "L1": (23,42),
19 | "L2": (56,72),
20 | "L3": (106,138),
21 | "H1": (23,42),
22 | "H2": (56,69),
23 | "H3": (106,138),
24 | }
25 |
26 | region_mapping = {
27 | "cdrh1": 0,
28 | "cdrh2": 1,
29 | "cdrh3": 2,
30 | "cdrl1": 3,
31 | "cdrl2": 4,
32 | "cdrl3": 5,
33 | "fwh1": 6,
34 | "fwh2": 7,
35 | "fwh3": 8,
36 | "fwh4": 9,
37 | "fwl1": 10,
38 | "fwl2": 11,
39 | "fwl3": 12,
40 | "fwl4": 13,
41 | }
42 |
43 | heavy_chain_regions = {0, 1, 2, 6, 7, 8, 9}
44 | light_chain_regions = {3, 4, 5, 10, 11, 12, 13}
45 |
46 | heavy_framework_regions = {6, 7, 8, 9}
47 | light_framework_regions = {10, 11, 12, 13}
48 |
49 | heavy_cdr_regions = {0, 1, 2}
50 | light_cdr_regions = {3, 4, 5}
51 |
52 |
53 | def apply_transformation(coords, R, t):
54 | # Apply inverse rotation and translation
55 | coords_transformed = torch.bmm(R.transpose(-1, -2), (coords - t.unsqueeze(1)).transpose(-1, -2)).transpose(-1, -2)
56 | return coords_transformed
57 |
58 |
59 | def coordinates_to_dihedral(input: torch.Tensor) -> torch.Tensor:
60 | """Compute dihedral angle from a set of four points.
61 |
62 | Given an input tensor with shape (*, 4, 3) representing points (p1, p2, p3, p4)
63 | compute the dihedral angle between the plane defined by (p1, p2, p3) and (p2, p3, p4).
64 |
65 | Parameters
66 | ----------
67 | input: torch.Tensor
68 | Shape (*, 4, 3)
69 |
70 | Returns
71 | -------
72 | torsion: torch.Tensor
73 | Shape (*,)
74 | """
75 | assert input.ndim >= 3
76 | assert input.shape[-2] == 4
77 | assert input.shape[-1] == 3
78 |
79 | # difference vectors: [a = p2 - p1, b = p3 - p2, c = p4 - p3]
80 | delta = input[..., 1:, :] - input[..., :-1, :]
81 | a, b, c = torch.unbind(delta, dim=-2)
82 |
83 | # torsion angle is angle from axb to bxc counterclockwise around b
84 | # see https://www.math.fsu.edu/~quine/MB_10/6_torsion.pdf
85 |
86 | axb = torch.cross(a, b, dim=-1)
87 | bxc = torch.cross(b, c, dim=-1)
88 |
89 | # orthogonal basis in plane perpendicular to b
90 | # NOTE v1 and v2 are not unit but have the same magnitude
91 | v1 = axb
92 | v2 = torch.cross(torch.nn.functional.normalize(b, dim=-1), axb, dim=-1)
93 |
94 | x = torch.sum(bxc * v1, dim=-1)
95 | y = torch.sum(bxc * v2, dim=-1)
96 | phi = torch.atan2(y, x)
97 |
98 | return phi
99 |
100 |
101 | def _positions_to_phi(
102 | positions: torch.Tensor,
103 | mask: torch.Tensor,
104 | residue_index: torch.Tensor,
105 | chain_index: torch.Tensor,
106 | ):
107 | chain_boundary_mask = torch.cat(
108 | [torch.zeros_like(chain_index[..., :1], dtype=torch.bool), torch.diff(chain_index, n=1, dim=-1) == 0], dim=-1
109 | )
110 |
111 | chain_break_mask = torch.cat(
112 | [torch.zeros_like(residue_index[..., :1], dtype=torch.bool), torch.diff(residue_index, n=1, dim=-1) == 1],
113 | dim=-1,
114 | )
115 |
116 | (input, mask) = (
117 | torch.stack(
118 | [
119 | x[..., :-1, 2, :], # C(i-1)
120 | x[..., 1:, 0, :], # N(i)
121 | x[..., 1:, 1, :], # CA(i)
122 | x[..., 1:, 2, :], # C(i)
123 | ],
124 | dim=-2,
125 | )
126 | for x in (positions, mask)
127 | ) # [..., L, 4, 3]
128 |
129 | mask = mask.all(dim=-1).all(dim=-1) # [..., L]
130 |
131 | angles = coordinates_to_dihedral(input)
132 | nan_tensor = torch.full_like(angles[..., :1], float("nan"))
133 | false_tensor = torch.zeros_like(mask[..., :1])
134 |
135 | angles = torch.cat([nan_tensor, angles], dim=-1)
136 | mask = torch.cat([false_tensor, mask], dim=-1)
137 | mask = mask & chain_boundary_mask & chain_break_mask
138 |
139 | return angles, mask
140 |
141 |
142 | def _positions_to_psi(
143 | positions: torch.Tensor,
144 | mask: torch.Tensor,
145 | residue_index: torch.Tensor,
146 | chain_index: torch.Tensor,
147 | ):
148 | chain_boundary_mask = torch.cat(
149 | [torch.diff(chain_index, n=1, dim=-1) == 0, torch.zeros_like(chain_index[..., :1], dtype=torch.bool)], dim=-1
150 | )
151 |
152 | chain_break_mask = torch.cat(
153 | [torch.diff(residue_index, n=1, dim=-1) == 1, torch.zeros_like(residue_index[..., :1], dtype=torch.bool)],
154 | dim=-1,
155 | )
156 |
157 | (input, mask) = (
158 | torch.stack(
159 | [
160 | x[..., :-1, 0, :], # N(i)
161 | x[..., :-1, 1, :], # CA(i)
162 | x[..., :-1, 2, :], # C(i)
163 | x[..., 1:, 0, :], # N(i+1)
164 | ],
165 | dim=-2,
166 | )
167 | for x in (positions, mask)
168 | ) # [..., L, 4, 3]
169 |
170 | mask = mask.all(dim=-1).all(dim=-1) # [..., L]
171 |
172 | angles = coordinates_to_dihedral(input)
173 | nan_tensor = torch.full_like(angles[..., :1], float("nan"))
174 | false_tensor = torch.zeros_like(mask[..., :1])
175 |
176 | angles = torch.cat([angles, nan_tensor], dim=-1)
177 | mask = torch.cat([mask, false_tensor], dim=-1)
178 | mask = mask & chain_boundary_mask & chain_break_mask
179 |
180 | return angles, mask
181 |
182 |
183 | def _positions_to_omega(
184 | positions: torch.Tensor,
185 | mask: torch.Tensor,
186 | residue_index: torch.Tensor,
187 | chain_index: torch.Tensor,
188 | ):
189 | chain_boundary_mask = torch.cat(
190 | [torch.diff(chain_index, n=1, dim=-1) == 0, torch.zeros_like(chain_index[..., :1], dtype=torch.bool)], dim=-1
191 | )
192 |
193 | chain_break_mask = torch.cat(
194 | [torch.diff(residue_index, n=1, dim=-1) == 1, torch.zeros_like(residue_index[..., :1], dtype=torch.bool)],
195 | dim=-1,
196 | )
197 |
198 | (input, mask) = (
199 | torch.stack(
200 | [
201 | x[..., :-1, 1, :], # CA(i)
202 | x[..., :-1, 2, :], # C(i)
203 | x[..., 1:, 0, :], # N(i+1)
204 | x[..., 1:, 1, :], # CA(i+1)
205 | ],
206 | dim=-2,
207 | )
208 | for x in (positions, mask)
209 | ) # [..., L, 4, 3]
210 |
211 | mask = mask.all(dim=-1).all(dim=-1) # [..., L]
212 |
213 | angles = coordinates_to_dihedral(input)
214 | nan_tensor = torch.full_like(angles[..., :1], float("nan"))
215 | false_tensor = torch.zeros_like(mask[..., :1])
216 |
217 | angles = torch.cat([angles, nan_tensor], dim=-1)
218 | mask = torch.cat([mask, false_tensor], dim=-1)
219 | mask = mask & chain_boundary_mask & chain_break_mask
220 |
221 | return angles, mask
222 |
223 |
224 | def positions_to_backbone_dihedrals(
225 | positions: torch.Tensor, mask: torch.Tensor, residue_index: torch.Tensor | None = None, chain_index: torch.Tensor | None = None
226 | ) -> tuple[torch.Tensor, torch.Tensor]:
227 | """Compute Backbone dihedral angles (phi, psi, omega) from the atom-wise coordinates.
228 |
229 | Parameters
230 | ----------
231 | positions: Tensor
232 | Shape (..., L, 37, 3) tensor of the atom-wise coordinates
233 | mask: BoolTensor
234 | Shape (..., L, 37) boolean tensor indicating which atoms are present
235 | residue_index: Tensor | None = None
236 | Optional shape (..., L) tensor specifying the index of each residue along its chain.
237 | If supplied this is used to mask dihedrals that cross a chain break.
238 | chain_index: Tensor | None = None
239 | Optional shape (..., L) tensor specifying the chain index of each residue.
240 | If supplied this is used to mask dihedrals that cross a chain boundary.
241 |
242 | where `NAW` is the number of atoms in the atom wide representation.
243 |
244 | Returns
245 | -------
246 | dihedrals: Tensor
247 | Shape (..., L, 3) tensor of the dihedral angles (phi, psi, omega)
248 | dihedrals_mask: BoolTensor
249 | Shape (..., L, 3) boolean tensor indicating which dihedrals are present
250 | """
251 | assert positions.ndim >= 3
252 | L = positions.shape[-3]
253 | device = positions.device
254 |
255 | if residue_index is None:
256 | residue_index = torch.arange(L).expand(*positions.shape[:-3], -1) # [..., L]
257 | residue_index = residue_index.to(device)
258 |
259 | if chain_index is None:
260 | chain_index = torch.zeros_like(positions[..., :, 0, 0], dtype=torch.int64) # [..., L]
261 | chain_index = chain_index.to(device)
262 |
263 | mask = mask.unsqueeze(-1).expand(*mask.shape,3)
264 | phi, phi_mask = _positions_to_phi(positions, mask, residue_index=residue_index, chain_index=chain_index)
265 | psi, psi_mask = _positions_to_psi(positions, mask, residue_index=residue_index, chain_index=chain_index)
266 | omega, omega_mask = _positions_to_omega(positions, mask, residue_index=residue_index, chain_index=chain_index)
267 |
268 | dihedrals = torch.stack([phi, psi, omega], dim=-1)
269 | dihedrals_mask = torch.stack([phi_mask, psi_mask, omega_mask], dim=-1)
270 |
271 | return dihedrals, dihedrals_mask
272 |
273 |
274 | def rmsd_summary_calculation(
275 | coords_truth: torch.Tensor,
276 | coords_prediction: torch.Tensor,
277 | sequence_mask: torch.Tensor,
278 | region_mask: torch.Tensor,
279 | chain_mask: torch.Tensor,
280 | batch_average: bool = True,
281 | ) -> dict[str, torch.Tensor]:
282 | """Computes RMSD summary for different regions and chains.
283 |
284 | Args:
285 | coords_truth (torch.Tensor): (B, n, 14/37, 3) ground truth coordinates
286 | coords_prediction (torch.Tensor): (B, n, 14/37, 3) predicted coordinates
287 | sequence_mask (torch.Tensor): (B, n) where [i, j] = 1 if a coordinate for sequence i at residue j exists.
288 | region_mask (torch.Tensor): (B, n) region mask indicating the region of each residue
289 | chain_mask (torch.Tensor): (B, n) chain mask indicating the chain of each residue (0 for light chain, 1 for heavy chain)
290 | batch_average (bool): if True, average along the batch dimensions
291 |
292 | Returns:
293 | dict[str, torch.Tensor]: RMSD values for each region and chain
294 | """
295 | results = {}
296 |
297 | # Align and compute RMSD for heavy chain regions
298 | heavy_chain_mask = chain_mask == 1
299 |
300 | heavy_chain_backbone_truth = extract_backbone_coordinates(
301 | coords_truth * heavy_chain_mask.unsqueeze(-1).unsqueeze(-1)
302 | )
303 | heavy_chain_sequence_mask = extract_backbone_mask(sequence_mask * heavy_chain_mask)
304 |
305 | # Mask for framework regions only
306 | heavy_framework_mask = (region_mask.unsqueeze(-1) == torch.tensor(list(heavy_framework_regions), device=region_mask.device)).any(-1) * heavy_chain_mask
307 | heavy_framework_backbone_truth = extract_backbone_coordinates(
308 | coords_truth * heavy_framework_mask.unsqueeze(-1).unsqueeze(-1)
309 | )
310 | heavy_framework_backbone_prediction = extract_backbone_coordinates(
311 | coords_prediction * heavy_framework_mask.unsqueeze(-1).unsqueeze(-1)
312 | )
313 | heavy_framework_sequence_mask = extract_backbone_mask(sequence_mask * heavy_framework_mask)
314 |
315 | # Align framework regions
316 | heavy_framework_backbone_truth, R, t = batch_align(
317 | heavy_framework_backbone_truth, heavy_framework_backbone_prediction, heavy_framework_sequence_mask, return_transform=True
318 | )
319 |
320 | # Compute RMSD for heavy chain framework as a whole
321 | square_distance = (
322 | torch.linalg.norm(
323 | heavy_framework_backbone_prediction - heavy_framework_backbone_truth, dim=-1
324 | )
325 | ** 2
326 | )
327 | square_distance = square_distance * heavy_framework_sequence_mask
328 |
329 | heavy_framework_msd = torch.sum(square_distance, dim=-1) / heavy_framework_sequence_mask.sum(dim=-1)
330 | heavy_framework_rmsd = torch.sqrt(heavy_framework_msd)
331 |
332 | if batch_average:
333 | heavy_framework_rmsd = heavy_framework_rmsd.mean()
334 |
335 | results["fwh_rmsd"] = heavy_framework_rmsd
336 |
337 | # Apply the same transformation to the CDR regions
338 | heavy_cdr_mask = (region_mask.unsqueeze(-1) == torch.tensor(list(heavy_cdr_regions), device=region_mask.device)).any(-1) * heavy_chain_mask
339 | heavy_cdr_backbone_prediction = extract_backbone_coordinates(
340 | coords_prediction * heavy_cdr_mask.unsqueeze(-1).unsqueeze(-1)
341 | )
342 | heavy_cdr_backbone_prediction_aligned = apply_transformation(heavy_cdr_backbone_prediction, R, t)
343 |
344 | for region_name, region_idx in region_mapping.items():
345 | if region_idx in heavy_cdr_regions:
346 | region_mask_region = region_mask == region_idx
347 | region_mask_backbone = extract_backbone_mask(region_mask_region)
348 |
349 | heavy_chain_region_mask = region_mask_backbone * heavy_chain_sequence_mask
350 | square_distance = (
351 | torch.linalg.norm(
352 | heavy_cdr_backbone_prediction_aligned - heavy_chain_backbone_truth, dim=-1
353 | )
354 | ** 2
355 | )
356 | square_distance = square_distance * heavy_chain_region_mask
357 |
358 | region_msd = torch.sum(square_distance, dim=-1) / heavy_chain_region_mask.sum(dim=-1)
359 | region_rmsd = torch.sqrt(region_msd)
360 |
361 | if batch_average:
362 | region_rmsd = region_rmsd.mean()
363 |
364 | results[f"{region_name}_rmsd"] = region_rmsd
365 |
366 | # Align and compute RMSD for light chain regions
367 | light_chain_mask = chain_mask == 0
368 |
369 | light_chain_backbone_truth = extract_backbone_coordinates(
370 | coords_truth * light_chain_mask.unsqueeze(-1).unsqueeze(-1)
371 | )
372 | light_chain_sequence_mask = extract_backbone_mask(sequence_mask * light_chain_mask)
373 |
374 | # Mask for framework regions only
375 | light_framework_mask = (region_mask.unsqueeze(-1) == torch.tensor(list(light_framework_regions), device=region_mask.device)).any(-1) * light_chain_mask
376 | light_framework_backbone_truth = extract_backbone_coordinates(
377 | coords_truth * light_framework_mask.unsqueeze(-1).unsqueeze(-1)
378 | )
379 | light_framework_backbone_prediction = extract_backbone_coordinates(
380 | coords_prediction * light_framework_mask.unsqueeze(-1).unsqueeze(-1)
381 | )
382 | light_framework_sequence_mask = extract_backbone_mask(sequence_mask * light_framework_mask)
383 |
384 | # Align framework regions
385 | light_framework_backbone_truth, R, t = batch_align(
386 | light_framework_backbone_truth, light_framework_backbone_prediction, light_framework_sequence_mask, return_transform=True
387 | )
388 |
389 | # Compute RMSD for light chain framework as a whole
390 | square_distance = (
391 | torch.linalg.norm(
392 | light_framework_backbone_prediction - light_framework_backbone_truth, dim=-1
393 | )
394 | ** 2
395 | )
396 | square_distance = square_distance * light_framework_sequence_mask
397 |
398 | light_framework_msd = torch.sum(square_distance, dim=-1) / light_framework_sequence_mask.sum(dim=-1)
399 | light_framework_rmsd = torch.sqrt(light_framework_msd)
400 |
401 | if batch_average:
402 | light_framework_rmsd = light_framework_rmsd.mean()
403 |
404 | results["fwl_rmsd"] = light_framework_rmsd
405 |
406 | # Apply the same transformation to the CDR regions
407 | light_cdr_mask = (region_mask.unsqueeze(-1) == torch.tensor(list(light_cdr_regions), device=region_mask.device)).any(-1) * light_chain_mask
408 | light_cdr_backbone_prediction = extract_backbone_coordinates(
409 | coords_prediction * light_cdr_mask.unsqueeze(-1).unsqueeze(-1)
410 | )
411 | light_cdr_backbone_prediction_aligned = apply_transformation(light_cdr_backbone_prediction, R, t)
412 |
413 | for region_name, region_idx in region_mapping.items():
414 | if region_idx in light_cdr_regions:
415 | region_mask_region = region_mask == region_idx
416 | region_mask_backbone = extract_backbone_mask(region_mask_region)
417 |
418 | light_chain_region_mask = region_mask_backbone * light_chain_sequence_mask
419 | square_distance = (
420 | torch.linalg.norm(
421 | light_cdr_backbone_prediction_aligned - light_chain_backbone_truth, dim=-1
422 | )
423 | ** 2
424 | )
425 | square_distance = square_distance * light_chain_region_mask
426 |
427 | region_msd = torch.sum(square_distance, dim=-1) / light_chain_region_mask.sum(dim=-1)
428 | region_rmsd = torch.sqrt(region_msd)
429 |
430 | if batch_average:
431 | region_rmsd = region_rmsd.mean()
432 |
433 | results[f"{region_name}_rmsd"] = region_rmsd
434 |
435 | return results
436 |
437 |
438 | def aligned_fv_and_cdrh3_rmsd(
439 | coords_truth: torch.Tensor,
440 | coords_prediction: torch.Tensor,
441 | sequence_mask: torch.Tensor,
442 | cdrh3_mask: torch.Tensor,
443 | batch_average: bool = True,
444 | ) -> dict[str, torch.Tensor]:
445 | """Aligns positions_truth to positions_prediction in a batched way.
446 |
447 | Args:
448 | positions_truth (torch.Tensor): (B, n, 14/37, 3) ground truth coordinates
449 | positions_prediction (torch.Tensor): (B, n, 14/37, 3) predicted coordinates
450 | sequence_mask (torch.Tensor): (B, n) where [i, j] = 1 if a coordinate for sequence i at residue j exists.
451 | cdrh3_mask (torch.Tensor): (B, n) where [i, j] = 1 if a coordinate for sequence i at residue j is part of the cdrh3 loop.
452 | batch_average (bool): if True, average along the batch dimensions
453 |
454 | Returns:
455 | A dictionary[str, torch.Tensor] containing
456 | seq_rmsd: the RMSD of the backbone after backbone alignment
457 | cdrh3_rmsd: the RMSD of the CDRH3 backbone after backbone alignment
458 | """
459 |
460 | # extract backbones and mask and put in 3d point cloud shape
461 | backbone_truth = extract_backbone_coordinates(coords_truth)
462 | backbone_prediction = extract_backbone_coordinates(coords_prediction)
463 | backbone_sequence_mask = extract_backbone_mask(sequence_mask)
464 |
465 | # align backbones
466 | backbone_truth = batch_align(
467 | backbone_truth, backbone_prediction, backbone_sequence_mask
468 | )
469 |
470 | square_distance = (
471 | torch.linalg.norm(backbone_prediction - backbone_truth, dim=-1) ** 2
472 | )
473 | square_distance = square_distance * backbone_sequence_mask
474 |
475 | seq_msd = square_distance.sum(dim=-1) / backbone_sequence_mask.sum(dim=-1)
476 | seq_rmsd = torch.sqrt(seq_msd)
477 |
478 | backbone_cdrh3_mask = extract_backbone_mask(cdrh3_mask)
479 | square_distance = square_distance * (backbone_cdrh3_mask * backbone_sequence_mask)
480 | cdrh3_msd = torch.sum(square_distance, dim=-1) / backbone_cdrh3_mask.sum(dim=-1)
481 | cdrh3_rmsd = torch.sqrt(cdrh3_msd)
482 |
483 | if batch_average:
484 | seq_rmsd = seq_rmsd.mean()
485 | cdrh3_rmsd = cdrh3_rmsd.mean()
486 |
487 | return {"seq_rmsd": seq_rmsd, "cdrh3_rmsd": cdrh3_rmsd}
488 |
489 |
490 | def extract_backbone_coordinates(positions: torch.Tensor) -> torch.Tensor:
491 | """(B, n, 14/37, 3) -> (B, n * 4, 3)"""
492 | batch_size = positions.size(0)
493 | backbone_positions = positions[:, :, :4, :] # (B, n, 4, 3)
494 | backbone_positions_flat = backbone_positions.reshape(
495 | batch_size, -1, 3
496 | ) # (B, n * 4, 3)
497 | return backbone_positions_flat
498 |
499 |
500 | def extract_backbone_mask(sequence_mask: torch.Tensor) -> torch.Tensor:
501 | """(B, n) -> (B, n * 4)"""
502 | batch_size = sequence_mask.size(0)
503 | return sequence_mask.unsqueeze(-1).repeat(1, 1, 4).view(batch_size, -1)
504 |
505 |
506 | def batch_align(x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor, return_transform=False):
507 | """Aligns 3-dimensional point clouds. Based on section 4 of https://igl.ethz.ch/projects/ARAP/svd_rot.pdf.
508 |
509 | Args:
510 | x (torch.Tensor): A tensor shape (B, n, 3)
511 | y (torch.Tensor): A tensor shape (B, n, 3)
512 | mask (torch.Tensor): A mask of shape (B, n) were mask[i, j]=1 indicates the presence of a point in sample i at location j of both sequences.
513 | return_transform (bool): If True, return rotation and translation matrices.
514 |
515 | Returns:
516 | torch.Tensor: a rototranslated x aligned to y.
517 | torch.Tensor: rotation matrix used for alignment (if return_transform is True).
518 | torch.Tensor: translation matrix used for alignment (if return_transform is True).
519 | """
520 |
521 | # check inputs
522 | if x.ndim != 3:
523 | raise ValueError(f"Expected x.ndim=3. Instead got {x.ndim=}")
524 | if y.ndim != 3:
525 | raise ValueError(f"Expected y.ndim=3. Instead got {x.ndim=}")
526 | if mask.ndim != 2:
527 | raise ValueError(f"Expected mask.ndim=2. Instead got {mask.ndim=}")
528 | if x.size(-1) != 3:
529 | raise ValueError(f"Expected last dim of x to be 3. Instead got {x.size(-1)=}")
530 | if y.size(-1) != 3:
531 | raise ValueError(f"Expected last dim of y to be 3. Instead got {y.size(-1)=}")
532 |
533 | # (B, n) -> (B, n, 1)
534 | mask = mask.unsqueeze(-1)
535 |
536 | # zero masked coordinates (the below centroids computation relies on it).
537 | x = x * mask
538 | y = y * mask
539 |
540 | # centroids (B, 3)
541 | p_bar = x.sum(dim=1) / mask.sum(dim=1)
542 | q_bar = y.sum(dim=1) / mask.sum(dim=1)
543 |
544 | # centered points (B, n, 3)
545 | x_centered = x - p_bar.unsqueeze(1)
546 | y_centered = y - q_bar.unsqueeze(1)
547 |
548 | # compute covariance matrices (B, 3, 3)
549 | num_valid_points = mask.sum(dim=1, keepdim=True).sum(dim=2, keepdim=True)
550 | S = torch.bmm(x_centered.transpose(-1, -2), y_centered * mask) / num_valid_points
551 | S = S + 10e-6 * torch.eye(S.size(-1)).unsqueeze(0).to(S.device)
552 |
553 | # Compute U, V from SVD (B, 3, 3)
554 | U, _, Vh = torch.linalg.svd(S)
555 | V = Vh.transpose(-1, -2)
556 | Uh = U.transpose(-1, -2)
557 |
558 | # correction that accounts for reflection (B, 3, 3)
559 | correction = torch.eye(x.size(-1)).unsqueeze(0).repeat(x.size(0), 1, 1).to(x.device)
560 | correction[:, -1, -1] = torch.det(torch.bmm(V, Uh).float())
561 |
562 | # rotation (B, 3, 3)
563 | R = V.bmm(correction).bmm(Uh)
564 |
565 | # translation (B, 3)
566 | t = q_bar - R.bmm(p_bar.unsqueeze(-1)).squeeze()
567 |
568 | # translate x to align with y
569 | x_rotated = torch.bmm(R, x.transpose(-1, -2)).transpose(-1, -2)
570 | x_aligned = x_rotated + t.unsqueeze(1)
571 |
572 | if return_transform:
573 | return x_aligned, R, t
574 | else:
575 | return x_aligned
576 |
577 | if __name__ == "__main__":
578 | from torch.utils.data import DataLoader
579 | from ibex.dataloader import ABDataset, collate_fn
580 | dataset = ABDataset("test", split_file="/data/dreyerf1/ibex/split.csv", data_dir="/data/dreyerf1/ibex/structures", edge_chain_feature=True)
581 | dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn, shuffle=False)
582 | for batch in dataloader:
583 | coords = batch["atom14_gt_positions"]
584 | preds = coords + 10
585 | mask = batch["seq_mask"]
586 | cdrh3_mask = batch["region_numeric"] == 2
587 | print(aligned_fv_and_cdrh3_rmsd(coords, preds, mask, cdrh3_mask))
588 | break
589 |
590 |
--------------------------------------------------------------------------------
/src/ibex/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Genentech
2 | # Copyright 2024 Exscientia
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import lightning.pytorch as pl
17 | import torch
18 | from torch.optim import AdamW
19 | from torch.optim.lr_scheduler import LambdaLR, LinearLR
20 | from torch.utils.data import DataLoader, WeightedRandomSampler
21 | from torch.utils.data import ConcatDataset
22 | from esm.models.esmc import ESMC
23 | from esm.sdk.api import ESMProtein, LogitsConfig
24 |
25 | from ibex.dataloader import ABDataset, collate_fn, string_to_input, sequence_collate_fn
26 | from ibex.loss import IbexLoss
27 | from ibex.utils import output_to_pdb, add_atom37_to_output, output_to_protein, ENSEMBLE_MODELS, checkpoint_path
28 | from ibex.openfold.model import StructureModule
29 |
30 | class IbexDataModule(pl.LightningDataModule):
31 | def __init__(
32 | self,
33 | batch_size: int,
34 | data_dir_sabdab: str,
35 | split_file_sabdab: str,
36 | data_dir_predicted: str,
37 | split_file_predicted: str,
38 | fraction_predicted: float,
39 | data_dir_ig: str,
40 | split_file_ig: str,
41 | fraction_ig: float,
42 | fraction_matched: float,
43 | use_vhh: bool = False,
44 | use_tcr: bool = False,
45 | public_only_data: bool = False,
46 | rel_pos_dim: int = 64,
47 | edge_chain_feature: bool = False,
48 | num_workers: int = 0,
49 | pin_memory: bool = False,
50 | use_plm_embeddings: bool = False,
51 | use_weighted_sampler: bool = False,
52 | cluster_column: str = "cluster",
53 | weight_temperature: float = 1.0,
54 | weight_clip_quantile: float = 0.99,
55 | vhh_weight: float = 1.0,
56 | tcr_weight: float = 1.0,
57 | matched_weight: float = 1.0,
58 | conformation_node_feature: bool = False,
59 | use_contrastive: bool = False,
60 | use_boltz_only: bool = False,
61 | ):
62 | super().__init__()
63 | self.data_dir_sabdab = data_dir_sabdab
64 | self.data_dir_predicted = data_dir_predicted
65 | self.data_dir_ig = data_dir_ig
66 | self.split_file_sabdab = split_file_sabdab
67 | self.split_file_predicted = split_file_predicted
68 | self.split_file_ig = split_file_ig
69 | self.fraction_predicted = fraction_predicted
70 | self.fraction_ig = fraction_ig
71 | self.fraction_matched = fraction_matched if use_contrastive else 0.0
72 | self.batch_size = batch_size
73 | self.use_vhh = use_vhh
74 | self.use_tcr = use_tcr
75 | self.public_only_data = public_only_data
76 | self.rel_pos_dim = rel_pos_dim
77 | self.edge_chain_feature = edge_chain_feature
78 | self.num_workers = num_workers
79 | self.pin_memory = pin_memory
80 | self.use_plm_embeddings = use_plm_embeddings
81 | self.use_weighted_sampler = use_weighted_sampler
82 | self.cluster_column = cluster_column
83 | self.weight_temperature = weight_temperature
84 | self.weight_clip_quantile = weight_clip_quantile
85 | self.vhh_weight = vhh_weight
86 | self.tcr_weight = tcr_weight
87 | self.matched_weight = matched_weight
88 | self.conformation_node_feature = conformation_node_feature
89 | self.use_contrastive = use_contrastive
90 | self.use_boltz_only = use_boltz_only
91 |
92 | def setup(self, stage: str):
93 | sabdab_dataset = ABDataset(
94 | "train",
95 | split_file=self.split_file_sabdab,
96 | data_dir=self.data_dir_sabdab,
97 | use_vhh=self.use_vhh,
98 | use_tcr=self.use_tcr,
99 | rel_pos_dim=self.rel_pos_dim,
100 | edge_chain_feature=self.edge_chain_feature,
101 | use_plm_embeddings=self.use_plm_embeddings,
102 | use_public_only=self.public_only_data,
103 | cluster_column=self.cluster_column,
104 | weight_temperature=self.weight_temperature,
105 | weight_clip_quantile=self.weight_clip_quantile,
106 | vhh_weight=self.vhh_weight,
107 | tcr_weight=self.tcr_weight,
108 | matched_weight=self.matched_weight,
109 | conformation_node_feature=self.conformation_node_feature,
110 | use_contrastive=self.use_contrastive,
111 | is_matched=False,
112 | use_weights=self.use_weighted_sampler,
113 | )
114 | sabdab_dataset_matched = ABDataset(
115 | "train",
116 | split_file=self.split_file_sabdab,
117 | data_dir=self.data_dir_sabdab,
118 | use_vhh=self.use_vhh,
119 | use_tcr=self.use_tcr,
120 | rel_pos_dim=self.rel_pos_dim,
121 | edge_chain_feature=self.edge_chain_feature,
122 | use_plm_embeddings=self.use_plm_embeddings,
123 | use_public_only=self.public_only_data,
124 | cluster_column=self.cluster_column,
125 | weight_temperature=self.weight_temperature,
126 | weight_clip_quantile=self.weight_clip_quantile,
127 | vhh_weight=self.vhh_weight,
128 | tcr_weight=self.tcr_weight,
129 | conformation_node_feature=self.conformation_node_feature,
130 | use_contrastive=self.use_contrastive,
131 | is_matched=True,
132 | use_weights=self.use_weighted_sampler,
133 | ) if self.use_contrastive else None
134 | predicted_dataset = ABDataset(
135 | "all",
136 | split_file=self.split_file_predicted,
137 | data_dir=self.data_dir_predicted,
138 | use_vhh=False,
139 | use_tcr=False,
140 | rel_pos_dim=self.rel_pos_dim,
141 | edge_chain_feature=self.edge_chain_feature,
142 | use_plm_embeddings=self.use_plm_embeddings,
143 | use_public_only=self.public_only_data,
144 | cluster_column=self.cluster_column,
145 | weight_temperature=self.weight_temperature,
146 | weight_clip_quantile=self.weight_clip_quantile,
147 | vhh_weight=self.vhh_weight,
148 | tcr_weight=self.tcr_weight,
149 | conformation_node_feature=self.conformation_node_feature,
150 | use_boltz_only=self.use_boltz_only,
151 | use_weights=self.use_weighted_sampler,
152 | ) if self.fraction_predicted > 0.0 else None
153 | ig_dataset = ABDataset(
154 | "all",
155 | split_file=self.split_file_ig,
156 | data_dir=self.data_dir_ig,
157 | use_vhh=True,
158 | use_tcr=False,
159 | rel_pos_dim=self.rel_pos_dim,
160 | edge_chain_feature=self.edge_chain_feature,
161 | use_plm_embeddings=self.use_plm_embeddings,
162 | use_public_only=self.public_only_data,
163 | cluster_column=self.cluster_column,
164 | weight_temperature=self.weight_temperature,
165 | weight_clip_quantile=self.weight_clip_quantile,
166 | vhh_weight=self.vhh_weight,
167 | tcr_weight=self.tcr_weight,
168 | conformation_node_feature=self.conformation_node_feature,
169 | use_weights=self.use_weighted_sampler,
170 | ) if self.fraction_ig > 0.0 else None
171 | self.train_dataset = ConcatDataset([d for d in [sabdab_dataset, predicted_dataset, ig_dataset, sabdab_dataset_matched] if d is not None])
172 | self.len_train_dataset = len(sabdab_dataset)
173 | sabdab_fraction = 1.0 - self.fraction_predicted - self.fraction_ig - self.fraction_matched
174 | self.train_dataset_weights = torch.cat(
175 | [d.weights * frac for d, frac in [(sabdab_dataset, sabdab_fraction), (predicted_dataset, self.fraction_predicted), (ig_dataset, self.fraction_ig), (sabdab_dataset_matched, self.fraction_matched)] if d is not None]
176 | )
177 | self.train_dataset_weights = self.train_dataset_weights / self.train_dataset_weights.sum()
178 | self.valid_dataset = ABDataset(
179 | "valid",
180 | split_file=self.split_file_sabdab,
181 | data_dir=self.data_dir_sabdab,
182 | use_vhh=False,
183 | use_tcr=False,
184 | rel_pos_dim=self.rel_pos_dim,
185 | edge_chain_feature=self.edge_chain_feature,
186 | use_plm_embeddings=self.use_plm_embeddings,
187 | use_public_only=self.public_only_data,
188 | conformation_node_feature=self.conformation_node_feature,
189 | )
190 | self.test_dataset = ABDataset(
191 | "test",
192 | split_file=self.split_file_sabdab,
193 | data_dir=self.data_dir_sabdab,
194 | use_vhh=False,
195 | use_tcr=False,
196 | rel_pos_dim=self.rel_pos_dim,
197 | edge_chain_feature=self.edge_chain_feature,
198 | use_plm_embeddings=self.use_plm_embeddings,
199 | use_public_only=self.public_only_data,
200 | conformation_node_feature=self.conformation_node_feature,
201 | )
202 | if stage == "test":
203 | self.test_public_dataset = ABDataset(
204 | "test",
205 | split_file=self.split_file_sabdab,
206 | data_dir=self.data_dir_sabdab,
207 | use_vhh=False,
208 | use_tcr=False,
209 | rel_pos_dim=self.rel_pos_dim,
210 | edge_chain_feature=self.edge_chain_feature,
211 | use_plm_embeddings=self.use_plm_embeddings,
212 | use_public_only=True,
213 | conformation_node_feature=self.conformation_node_feature,
214 | )
215 | self.test_public_vhh_dataset = ABDataset(
216 | "test_vhh",
217 | split_file=self.split_file_sabdab,
218 | data_dir=self.data_dir_sabdab,
219 | use_vhh=True,
220 | use_tcr=False,
221 | rel_pos_dim=self.rel_pos_dim,
222 | edge_chain_feature=self.edge_chain_feature,
223 | use_plm_embeddings=self.use_plm_embeddings,
224 | use_public_only=True,
225 | conformation_node_feature=self.conformation_node_feature,
226 | )
227 | self.test_public_tcr_dataset = ABDataset(
228 | "test_tcr",
229 | split_file=self.split_file_sabdab,
230 | data_dir=self.data_dir_sabdab,
231 | use_vhh=False,
232 | use_tcr=True,
233 | rel_pos_dim=self.rel_pos_dim,
234 | edge_chain_feature=self.edge_chain_feature,
235 | use_plm_embeddings=self.use_plm_embeddings,
236 | use_public_only=True,
237 | conformation_node_feature=self.conformation_node_feature,
238 | )
239 |
240 | def train_dataloader(self):
241 | return DataLoader(
242 | self.train_dataset,
243 | collate_fn=collate_fn,
244 | sampler=WeightedRandomSampler(self.train_dataset_weights, num_samples=self.len_train_dataset // 2, replacement=self.use_weighted_sampler),
245 | batch_size=self.batch_size // 2, # Each element in our dataset actually contains 2 samples
246 | num_workers=self.num_workers,
247 | pin_memory=self.pin_memory,
248 | )
249 |
250 | def val_dataloader(self):
251 | return DataLoader(
252 | self.valid_dataset,
253 | collate_fn=collate_fn,
254 | batch_size=self.batch_size,
255 | shuffle=False,
256 | num_workers=self.num_workers,
257 | pin_memory=self.pin_memory,
258 | )
259 |
260 | def test_dataloader(self):
261 | return DataLoader(
262 | self.test_dataset,
263 | collate_fn=collate_fn,
264 | batch_size=self.batch_size,
265 | shuffle=False,
266 | num_workers=self.num_workers,
267 | pin_memory=self.pin_memory,
268 | )
269 |
270 | def test_public_dataloader(self):
271 | return DataLoader(
272 | self.test_public_dataset,
273 | collate_fn=collate_fn,
274 | batch_size=self.batch_size,
275 | shuffle=False,
276 | num_workers=self.num_workers,
277 | pin_memory=self.pin_memory,
278 | )
279 |
280 | def test_public_vhh_dataloader(self):
281 | return DataLoader(
282 | self.test_public_vhh_dataset,
283 | collate_fn=collate_fn,
284 | batch_size=self.batch_size,
285 | shuffle=False,
286 | num_workers=self.num_workers,
287 | pin_memory=self.pin_memory,
288 | )
289 |
290 | def test_public_tcr_dataloader(self):
291 | return DataLoader(
292 | self.test_public_tcr_dataset,
293 | collate_fn=collate_fn,
294 | batch_size=self.batch_size,
295 | shuffle=False,
296 | num_workers=self.num_workers,
297 | pin_memory=self.pin_memory,
298 | )
299 |
300 | class Ibex(pl.LightningModule):
301 | def __init__(self, model_config, loss_config, optim_config, conformation_aware=False, use_plm=False, ensemble=False, models=None, stage=None, init_plm=False):
302 | super().__init__()
303 | model_config["use_plddt"] = loss_config.plddt.weight > 0
304 | self.save_hyperparameters()
305 | self.model_config = model_config
306 | self.loss_config = loss_config
307 | self.optim_config = optim_config
308 | self.loss = IbexLoss(loss_config)
309 | self.finetune = False
310 | self.contrastive = False
311 | self.ensemble = ensemble
312 | self.conformation_aware = conformation_aware
313 | self.stage = stage
314 |
315 | if self.ensemble:
316 | if models is None:
317 | raise ValueError("Models must be provided for ensemble mode.")
318 | # Use the provided models to create the EnsembleStructureModule
319 | self.model = EnsembleStructureModule(models)
320 | else:
321 | self.model = StructureModule(**model_config)
322 |
323 | if use_plm:
324 | # # self.plm_model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
325 | # self.plm_model = ESMC.from_pretrained("esmc_300m")
326 | # self.plm_model.eval()
327 | # for param in self.plm_model.parameters():
328 | # param.requires_grad = False
329 | self.plm_model = True
330 | if init_plm:
331 | self.set_plm()
332 | # self.plm_model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
333 | else:
334 | self.plm_model = None
335 |
336 | def set_plm(self):
337 | if self.plm_model is not None:
338 | self.plm_model = ESMC.from_pretrained("esmc_300m")
339 | self.plm_model.eval()
340 | for param in self.plm_model.parameters():
341 | param.requires_grad = False
342 |
343 | @classmethod
344 | def load_from_ensemble_checkpoint(cls, checkpoint_path, map_location=None):
345 | checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
346 | state_dicts = checkpoint['state_dicts']
347 | model_config = checkpoint['model_config']
348 | loss_config = checkpoint['loss_config']
349 | optim_config = checkpoint['optim_config']
350 | if 'conformation_aware' in checkpoint:
351 | conformation_aware = checkpoint['conformation_aware']
352 | else:
353 | conformation_aware = False
354 | if 'language' in checkpoint:
355 | init_plm = checkpoint['language']
356 | else:
357 | init_plm = False
358 | # Initialize the models from state dicts
359 | models = [StructureModule(**model_config) for _ in state_dicts]
360 | for model, state_dict in zip(models, state_dicts):
361 | model.load_state_dict(state_dict)
362 |
363 | return cls(model_config, loss_config, optim_config, conformation_aware=conformation_aware, ensemble=True, models=models, use_plm=init_plm, init_plm=init_plm)
364 |
365 | @classmethod
366 | def from_pretrained(cls, model="ibex", map_location=None, cache_dir=None):
367 | ckpt = checkpoint_path(model, cache_dir=cache_dir)
368 |
369 | if model in ENSEMBLE_MODELS:
370 | ibex_model = Ibex.load_from_ensemble_checkpoint(ckpt, map_location=map_location)
371 | else:
372 | ibex_model = Ibex.load_from_checkpoint(ckpt, map_location=map_location)
373 | if ibex_model.plm_model is not None:
374 | ibex_model.set_plm()
375 | return ibex_model
376 |
377 | def training_step(self, batch, batch_idx):
378 | stage=f"_{self.stage}" if self.stage else ""
379 | loss = self._step(batch, "train"+stage)
380 | grad_norm = torch.nn.utils.clip_grad_norm_(
381 | self.parameters(), max_norm=self.optim_config.max_grad_norm
382 | )
383 | self.log(f"monitor{stage}/grad_norm", grad_norm)
384 | # Log custom epoch and learning rate
385 | if stage:
386 | self.log(
387 | f"monitor{stage}/epoch",
388 | self.current_epoch
389 | )
390 | self.log(
391 | f"monitor{stage}/learning_rate",
392 | self.trainer.optimizers[0].param_groups[0]['lr']
393 | )
394 | return loss
395 |
396 | def validation_step(self, batch, batch_idx):
397 | stage=f"_{self.stage}" if self.stage else ""
398 | self._step(batch, "valid"+stage)
399 | self.log(
400 | f"monitor{stage}/finetune",
401 | float(self.finetune),
402 | on_step=False,
403 | on_epoch=True,
404 | sync_dist=True,
405 | )
406 |
407 | def _step(self, batch, split):
408 | if self.plm_model is not None:
409 | output = self.model(
410 | {
411 | "single": batch["single"],
412 | "pair": batch["pair"],
413 | "plm_embedding": batch["plm_embedding"],
414 | },
415 | batch["aatype"],
416 | batch["seq_mask"],
417 | )
418 | else:
419 | output = self.model(
420 | {
421 | "single": batch["single"],
422 | "pair": batch["pair"],
423 | },
424 | batch["aatype"],
425 | batch["seq_mask"],
426 | )
427 | loss, loss_dict = self.loss(output, batch, self.finetune, (self.contrastive and "train" in split))
428 | for loss_name in loss_dict:
429 | self.log(
430 | f"{split}/{loss_name}",
431 | loss_dict[loss_name],
432 | prog_bar=loss_name == "loss",
433 | on_epoch=True,
434 | on_step=False,
435 | sync_dist=True,
436 | )
437 | return loss
438 |
439 | def configure_optimizers(self):
440 | if self.optim_config.optimizer == "AdamW":
441 | if self.stage=='first_stage':
442 | optimizer = AdamW(
443 | self.parameters(),
444 | lr=self.optim_config.lr,
445 | betas=(0.9, 0.99),
446 | eps=1e-6,
447 | )
448 | scheduler = LinearLR(optimizer, start_factor=1e-3, total_iters=self.optim_config.linear_iters)
449 | else:
450 | optimizer = AdamW(
451 | self.parameters(),
452 | lr=self.optim_config.lr,
453 | betas=(0.9, 0.99),
454 | eps=1e-6,
455 | )
456 | scheduler = LambdaLR(
457 | optimizer, lambda epoch: max(self.optim_config.lambda_lr ** (epoch-self.optim_config.lambda_iters), self.optim_config.lambda_min_factor)
458 | if epoch >= self.optim_config.lambda_iters else 1
459 | )
460 |
461 | lr_scheduler = {
462 | "scheduler": scheduler,
463 | "interval": "step",
464 | "frequency": 1,
465 | "name": "learning_rate",
466 | }
467 | return [optimizer], [lr_scheduler]
468 | else:
469 | raise ValueError(
470 | "Expected AdamW as optimizer. Instead got"
471 | f" {self.optim_config.optimizer=}."
472 | )
473 |
474 | def predict(self, fv_heavy, fv_light, device, ensemble=False, pdb_string=True, apo=False):
475 | self.model.eval()
476 | self.to(device)
477 | if self.plm_model is not None:
478 | self.plm_model.to(device)
479 | if fv_light is None or fv_light=="":
480 | # for nanobodies, embed the heavy chain
481 | prot_h = ESMProtein(sequence=fv_heavy)
482 | prot_tensor_h = self.plm_model.encode(prot_h)
483 | logits_output_h = self.plm_model.logits(
484 | prot_tensor_h, LogitsConfig(sequence=True, return_embeddings=True)
485 | )
486 | embedding = logits_output_h.embeddings[0,1:-1]
487 | # tokenized = self.plm_model.tokenizer([fv_heavy], padding=True, return_tensors='pt')
488 | # output = self.plm_model(**tokenized)
489 | # embedding = output.last_hidden_state[0][1:-1]
490 | else:
491 | # for antibodies, embed both chains
492 | prot_h = ESMProtein(sequence=fv_heavy)
493 | prot_tensor_h = self.plm_model.encode(prot_h)
494 | logits_output_h = self.plm_model.logits(
495 | prot_tensor_h, LogitsConfig(sequence=True, return_embeddings=True)
496 | )
497 | embed_vh = logits_output_h.embeddings[0,1:-1]
498 |
499 | prot_l = ESMProtein(sequence=fv_light)
500 | prot_tensor_l = self.plm_model.encode(prot_l)
501 | logits_output_l = self.plm_model.logits(
502 | prot_tensor_l, LogitsConfig(sequence=True, return_embeddings=True)
503 | )
504 | embed_vl = logits_output_l.embeddings[0,1:-1]
505 | # tokenized = self.plm_model.tokenizer([fv_heavy,fv_light], padding=True, return_tensors='pt')
506 | # output = self.plm_model(**tokenized)
507 | # embed_vh=output.last_hidden_state[0][1:len(fv_heavy)+1]
508 | # embed_vl=output.last_hidden_state[1][1:len(fv_light)+1]
509 | embedding = torch.concat([embed_vh, embed_vl])
510 | else:
511 | # if PLM is not used, set embedding to None and use one-hot encoding instead
512 | embedding = None
513 | ab_input = string_to_input(heavy=fv_heavy, light=fv_light, apo=apo, conformation_aware=self.conformation_aware, embedding=embedding, device=device)
514 | ab_input_batch = {
515 | key: (value.unsqueeze(0) if key not in ["single", "pair", "plm_embedding"] else value)
516 | for key, value in ab_input.items()
517 | }
518 | # Forward pass with model
519 | result = self.model(ab_input_batch, ab_input_batch["aatype"], return_all=ensemble)
520 | if ensemble:
521 | predictions = []
522 | for i, output in enumerate(result):
523 | result[i] = add_atom37_to_output(output, ab_input_batch["aatype"])
524 | if pdb_string:
525 | predictions.append(output_to_pdb(result[i], ab_input))
526 | else:
527 | predictions.append(output_to_protein(result[i], ab_input))
528 | return predictions
529 |
530 | result = add_atom37_to_output(result, ab_input["aatype"].to(device))
531 | # collate the results
532 | if pdb_string:
533 | return output_to_pdb(result, ab_input)
534 | return output_to_protein(result, ab_input)
535 |
536 |
537 | def predict_batch(self, fv_heavy_batch, fv_light_batch, device, ensemble=False, pdb_string=True, apo_list=None, num_workers=0):
538 | self.model.eval()
539 | self.to(device)
540 |
541 | # if self.plm_model is not None:
542 | # fv_heavy_embedding = self.plm_model.embed_dataset(
543 | # fv_heavy_batch,
544 | # batch_size=batch_size,
545 | # max_len=max([len(x) for x in fv_heavy_batch]),
546 | # full_embeddings=True,
547 | # full_precision=False,
548 | # pooling_type="mean",
549 | # num_workers=num_workers,
550 | # sql=False
551 | # )
552 | # fv_light_embedding = self.plm_model.embed_dataset(
553 | # [light for light in fv_light_batch if light is not None and light!=''],
554 | # batch_size=batch_size,
555 | # max_len=max([len(x) for x in fv_light_batch]),
556 | # full_embeddings=True,
557 | # full_precision=False,
558 | # pooling_type="mean",
559 | # num_workers=num_workers,
560 | # sql=False
561 | # ) if fv_light_batch is not None else None
562 | # else:
563 | # fv_heavy_embedding = None
564 | # fv_light_embedding = None
565 |
566 | batch = []
567 | for i, fv_heavy in enumerate(fv_heavy_batch):
568 | fv_light = fv_light_batch[i] if fv_light_batch is not None else None
569 | apo = apo_list[i] if apo_list is not None else False
570 | # if fv_light_embedding is None and fv_heavy_embedding is not None:
571 | # embedding = fv_heavy_embedding[fv_heavy][1:len(fv_heavy)+1]
572 | # elif fv_light_embedding is not None:
573 | # embedding = torch.concat([
574 | # fv_heavy_embedding[fv_heavy][1:len(fv_heavy)+1],
575 | # fv_light_embedding[fv_light][1:len(fv_light)+1]
576 | # ])
577 | # else:
578 | # embedding = None
579 | if self.plm_model is not None:
580 | prot_h = ESMProtein(sequence=fv_heavy)
581 | prot_tensor_h = self.plm_model.encode(prot_h)
582 | logits_output_h = self.plm_model.logits(
583 | prot_tensor_h, LogitsConfig(sequence=True, return_embeddings=True)
584 | )
585 | embed_vh = logits_output_h.embeddings[0,1:-1]
586 | if fv_light is None or fv_light=="":
587 | # for nanobodies, embed the heavy chain
588 | embedding = embed_vh
589 | else:
590 | # for antibodies, embed both chains
591 | prot_l = ESMProtein(sequence=fv_light)
592 | prot_tensor_l = self.plm_model.encode(prot_l)
593 | logits_output_l = self.plm_model.logits(
594 | prot_tensor_l, LogitsConfig(sequence=True, return_embeddings=True)
595 | )
596 | embed_vl = logits_output_l.embeddings[0,1:-1]
597 | embedding = torch.concat([embed_vh, embed_vl])
598 | else:
599 | embedding = None
600 | batch.append(string_to_input(heavy=fv_heavy, light=fv_light, apo=apo, conformation_aware=self.conformation_aware, embedding=embedding, device=device))
601 |
602 | ab_input_batch = sequence_collate_fn(batch)
603 | # Move inputs to the device
604 | ab_input_batch = {key: value.to(device) for key, value in ab_input_batch.items()}
605 |
606 | # Forward pass with mask
607 | results = self.model(ab_input_batch, ab_input_batch["aatype"], mask=ab_input_batch["mask"], return_all=ensemble)
608 |
609 | predictions = []
610 | batch_size = ab_input_batch["aatype"].size(0)
611 | if ensemble:
612 | for i in range(batch_size):
613 | ensemble_preds = []
614 | for result in results:
615 | masked_result = {
616 | "positions": result["positions"][-1,i][ab_input_batch["mask"][i]==1].unsqueeze(0).unsqueeze(0),
617 | }
618 | if "plddt" in result:
619 | masked_result["plddt"]=result["plddt"][i][ab_input_batch["mask"][i]==1].unsqueeze(0)
620 | masked_input = {
621 | "aatype": ab_input_batch["aatype"][i][ab_input_batch["mask"][i]==1],
622 | "is_heavy": ab_input_batch["is_heavy"][i][ab_input_batch["mask"][i]==1],
623 | }
624 | # Add atom37 coordinates to the output
625 | masked_result = add_atom37_to_output(masked_result, masked_input["aatype"].unsqueeze(0))
626 | if pdb_string:
627 | ensemble_preds.append(output_to_pdb(masked_result, masked_input))
628 | else:
629 | ensemble_preds.append(output_to_protein(masked_result, masked_input))
630 | predictions.append(ensemble_preds)
631 | return predictions
632 |
633 | # Iterate over each item in the batch
634 | for i in range(batch_size):
635 | masked_result = {
636 | "positions": results["positions"][-1,i][ab_input_batch["mask"][i]==1].unsqueeze(0).unsqueeze(0),
637 | }
638 | if "plddt" in results:
639 | masked_result["plddt"]=results["plddt"][i][ab_input_batch["mask"][i]==1].unsqueeze(0)
640 | masked_input = {
641 | "aatype": ab_input_batch["aatype"][i][ab_input_batch["mask"][i]==1],
642 | "is_heavy": ab_input_batch["is_heavy"][i][ab_input_batch["mask"][i]==1],
643 | }
644 |
645 | # Add atom37 coordinates to the output
646 | masked_result = add_atom37_to_output(masked_result, masked_input["aatype"].unsqueeze(0))
647 | if pdb_string:
648 | # Generate a PDB string for the single result
649 | predictions.append(output_to_pdb(masked_result, masked_input))
650 | else:
651 | # Generate a Protein object for the single result
652 | predictions.append(output_to_protein(masked_result, masked_input))
653 |
654 | return predictions
655 |
656 |
657 | class EnsembleStructureModule(torch.nn.Module):
658 | def __init__(self, models):
659 | super(EnsembleStructureModule, self).__init__()
660 | if not all(isinstance(model, StructureModule) for model in models):
661 | raise ValueError("All models must be instances of StructureModule")
662 | self.models = torch.nn.ModuleList(models)
663 |
664 | def forward(self, inputs, aatype, mask=None, return_all=False, plddt_select=False):
665 | outputs = []
666 | for model in self.models:
667 | output = model(inputs, aatype, mask)
668 | outputs.append(output)
669 |
670 | if return_all:
671 | # if all outputs are requested, no need for alignment, just return everything
672 | return outputs
673 |
674 | if aatype.shape[0] == 1:
675 | # if batch size is one, then we can use the original ABB2 implementation
676 | if plddt_select:
677 | from ibex.utils import compute_plddt
678 | plddts = torch.stack([compute_plddt(output["plddt"]).squeeze() for output in outputs])
679 | plddts = torch.stack([x for x in plddts]) # [E, N, 3]
680 | closest_index = torch.argmax(torch.quantile(plddts, 0.1, dim=1))
681 | print(torch.quantile(plddts, 0.1, dim=1), closest_index)
682 | return outputs[closest_index]
683 | else:
684 | # Stack positions along a new axis for batch processing
685 | positions = [output['positions'][-1].squeeze() for output in outputs]
686 | traces = torch.stack([x[:,0] for x in positions]) # [E, N, 3]
687 | # find the rotation and translation that aligns the traces
688 | R,t = find_alignment_transform(traces)
689 | aligned_traces = (traces-t) @ R
690 | # compute rmsd to the mean and return the prediction closest to the mean
691 | rmsd_values = (aligned_traces - aligned_traces.mean(0)).square().sum(-1).mean(-1)
692 | closest_index = torch.argmin(rmsd_values)
693 | return outputs[closest_index]
694 |
695 | # for batch size > 1 we need to perform a more sophisticated batching operation
696 | # and keep track of the mask
697 | positions = [output['positions'][-1] for output in outputs]
698 | traces = torch.stack([x[:,:,0] for x in positions]) # [E, B, N, 3]
699 | # Permute dimensions to get [B, E, N, 3]
700 | traces = traces.permute(1, 0, 2, 3) # [B, E, N, 3]
701 |
702 | if mask is not None:
703 | # Expand the mask to [B, E, N, 3] for element-wise operations
704 | mask = mask.unsqueeze(1).unsqueeze(-1).expand(-1, traces.size(1), -1, traces.size(-1))
705 | R,t = find_alignment_transform_batch(traces, mask)
706 | aligned_traces = (traces-t) @ R
707 |
708 | if mask is not None:
709 | # Compute the mean of aligned traces along the ensemble dimension, considering the mask
710 | masked_aligned_traces = aligned_traces * mask # Mask application
711 | mask_sum = mask.sum(1)
712 | mask_sum[mask_sum == 0] = 1
713 | mean_aligned_traces = masked_aligned_traces.sum(1) / mask_sum # [B, N, 3]
714 | # Compute RMSD, taking the mask into account -> [B, E, N]
715 | rmsd_values = ((aligned_traces - mean_aligned_traces.unsqueeze(1)) * mask).square().sum(-1)
716 | # Normalize by the number of valid elements per sequence
717 | rmsd_values = (rmsd_values * mask[:,:,:,0]).sum(-1) / mask[:,:,:,0].sum(-1)
718 | else:
719 | # rmsd_values = (aligned_traces - aligned_traces.mean(0)).square().sum(-1).mean(-1)
720 | rmsd_values = (aligned_traces - aligned_traces.mean(1).unsqueeze(1)).square().sum(-1).mean(-1)
721 |
722 | # Find the prediction with the minimum RMSD
723 | closest_index = torch.argmin(rmsd_values, dim=-1)
724 | result = outputs[-1]
725 | # now iterate over the batch and select the best prediction for each sequence
726 | for ibatch in range(len(closest_index)):
727 | for k in result:
728 | if k in ['single', 'plddt']:
729 | result[k][ibatch] = outputs[closest_index[ibatch]][k][ibatch]
730 | else:
731 | result[k][:, ibatch] = outputs[closest_index[ibatch]][k][:, ibatch]
732 | return result
733 |
734 |
735 | def find_alignment_transform(traces):
736 | centers = traces.mean(-2, keepdim=True)
737 | traces = traces - centers
738 |
739 | p1, p2 = traces[0], traces[1:]
740 | C = torch.einsum("i j k, j l -> i k l", p2, p1)
741 | V, _, W = torch.linalg.svd(C)
742 | U = torch.matmul(V, W)
743 | U = torch.matmul(torch.stack([torch.ones(len(p2), device=U.device),torch.ones(len(p2), device=U.device),torch.linalg.det(U)], dim=1)[:,:,None] * V, W)
744 |
745 | return torch.cat([torch.eye(3, device=U.device)[None], U]), centers
746 |
747 |
748 | def find_alignment_transform_batch(traces, mask=None):
749 | # traces: [B, E, N, 3], mask: [B, E, N, 3]
750 | if mask is not None:
751 | centers = (traces * mask).sum(dim=-2, keepdim=True) / mask.sum(dim=-2, keepdim=True)
752 | else:
753 | centers = traces.mean(dim=-2, keepdim=True)
754 |
755 | traces = traces - centers
756 |
757 | p1 = traces[:, 0, :, :] # [B, N, 3]
758 | p2 = traces[:, 1:, :, :] # [B, E-1, N, 3]
759 |
760 | if mask is not None:
761 | C = torch.einsum("...ni,...nj->...ij", p2 * mask[:, 1:], p1.unsqueeze(1) * mask[:, :1])
762 | else:
763 | C = torch.einsum("...ni,...nj->...ij", p2, p1.unsqueeze(1))
764 |
765 | V, _, W = torch.linalg.svd(C)
766 |
767 | # Compute the rotation matrix U
768 | U = torch.matmul(V, W)
769 |
770 | # Ensure U is a proper rotation matrix by checking its determinant
771 | det_U = torch.linalg.det(U)
772 | V[..., -1] *= torch.sign(det_U).unsqueeze(-1)
773 |
774 | U = torch.matmul(V, W)
775 |
776 | # Prepare identity matrices for the reference trace
777 | identity_matrices = torch.eye(3, device=U.device).expand(traces.size(0), 1, 3, 3) # [B, 1, 3, 3]
778 |
779 | # Concatenate identity matrix for reference with U for ensemble members
780 | all_transforms = torch.cat([identity_matrices, U], dim=1) # [B, E, 3, 3]
781 |
782 | return all_transforms, centers
783 |
--------------------------------------------------------------------------------