├── .gitignore ├── .python-version ├── LICENSE.txt ├── README.md ├── img └── scaling.jpg ├── pyproject.toml └── src └── flash_ipa ├── edge_embedder.py ├── factorizer.py ├── ipa.py ├── linear.py ├── model.py ├── rigid.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Python-generated files 2 | __pycache__/ 3 | *.py[oc] 4 | build/ 5 | dist/ 6 | wheels/ 7 | *.egg-info 8 | *.ipynb 9 | 10 | # Virtual environments 11 | .venv 12 | uv.lock -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 anonymous 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlashIPA 2 | 3 | Official implementation of FlashIPA, which enhances the efficiency of the IPA module. Our module **reduces training and inference time** and **memory requirements** of standard models. 4 | 5 | ![scalling](img/scaling.jpg) 6 | 7 | ## How to use FlashIPA? 8 | 9 | After following the setup guide, FlashIPA can be integrated into any model using the IPA module by replacing any original IPA layer with our implementation. The primary input difference from the standard IPA module is the **z_factor**, which represents a memory-efficient graph edge embedding. A complete example of an IPA model is provided in [model.py](src/flash_ipa/model.py), including the full computation of the **z_factor**. 10 | 11 | 12 | ### FlashIPA Model 13 | ```python 14 | from flash_ipa.ipa import IPAConfig 15 | from flash_ipa.model import Model, ModelConfig 16 | import torch 17 | 18 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | B, L, D = 1, 32, 256 21 | ipa_conf = IPAConfig(use_flash_attn=True, attn_dtype="bf16") 22 | model_conf = ModelConfig(mode="flash_1d_bias", ipa=ipa_conf) 23 | model = Model(model_conf) 24 | model.to(DEVICE) 25 | batch = { 26 | "node_embeddings": torch.rand(B, L, D).to(DEVICE), 27 | "translations": torch.rand(B, L, 3).to(DEVICE), 28 | "rotations": torch.rand(B, L, 3, 3).to(DEVICE), 29 | "res_mask": torch.ones(B, L).to(DEVICE), 30 | } 31 | output = model(batch) 32 | ``` 33 | 34 | 35 | ### FlashIPA 36 | ```python 37 | from flash_ipa.ipa import InvariantPointAttention, IPAConfig 38 | 39 | ipa_conf = IPAConfig(use_flash_attn=True, attn_dtype='bf16') 40 | ipa_layer = InvariantPointAttention(ipa_conf) 41 | ipa_embed = ipa_layer( 42 | node_embed, 43 | None, 44 | z_factor_1, 45 | z_factor_2, 46 | curr_rigids, 47 | mask=node_mask 48 | ) 49 | ``` 50 | 51 | ### Original IPA 52 | ```python 53 | from flash_ipa.ipa import InvariantPointAttention, IPAConfig 54 | from dataclasses import dataclass 55 | 56 | ipa_conf = IPAConfig(use_flash_attn=False, attn_dtype='fp32') 57 | ipa_layer = InvariantPointAttention(ipa_conf) 58 | ipa_embed = ipa_layer( 59 | node_embed, 60 | edge_embed, 61 | None, 62 | None, 63 | curr_rigids, 64 | mask=node_mask 65 | ) 66 | ``` 67 | 68 | ## Setup Guide 69 | 70 | To manage environments efficiently, we use [uv](https://docs.astral.sh/uv/getting-started/installation/#standalone-installer). It simplifies managing dependencies and executing scripts. 71 | 72 | ### As a python package in your uv environement 73 | ```bash 74 | uv add "flash_ipa @ git+https://github.com/anonymous/flash_ipa" 75 | ``` 76 | 77 | ### For developement 78 | ```bash 79 | git clone https://github.com/anonymous/flash_ipa 80 | cd flash_ipa 81 | uv sync 82 | ``` 83 | 84 | 85 | ## License 86 | 87 | This project is licensed under MIT License. See [LICENSE](LICENSE.txt) for more details. 88 | 89 | ## Citation 90 | 91 | ``` bash 92 | @article{liu2025flashipa, 93 | title={Flash Invariant Point Attention}, 94 | author={Liu, Andrew and Elaldi, Axel and Franklin, Nicholas T and Russell, Nathan and Atwal, Gurinder S and Ban, Yih-En A and Viessmann, Olivia}, 95 | journal={arXiv preprint arXiv:2505.11580}, 96 | year={2025}, 97 | url={https://arxiv.org/abs/2505.11580} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /img/scaling.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagshippioneering/flash_ipa/b3f757136e89381f777a82b91e3847cfd5988525/img/scaling.jpg -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "flash-ipa" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.11,<3.13" 7 | dependencies = [ 8 | "torch==2.4.1+cu121", 9 | "flash-attn>=2.7.4.post1", 10 | "torchvision>=0.19.1", 11 | "torchaudio>=2.4.1", 12 | "lightning>=2.5.1", 13 | "hydra-core>=1.3.2", 14 | "torch-geometric>=2.6.1", 15 | "torch-scatter>=2.1.2", 16 | "torch-cluster>=1.6.3", 17 | "mdanalysis>=2.9.0", 18 | "mdanalysistests>=2.9.0", 19 | "biopandas>=0.5.1", 20 | "biopython>=1.85", 21 | "rdkit>=2024.9.6", 22 | "mdtraj>=1.10.3", 23 | "graphein>=1.7.7", 24 | "hydra-colorlog>=1.2.0", 25 | "rootutils>=1.0.7", 26 | "rich>=14.0.0", 27 | "matplotlib>=3.10.1", 28 | "networkx>=3.4.2", 29 | "gputil>=1.4.0", 30 | "omegaconf>=2.3.0", 31 | "beartype>=0.20.2", 32 | "jaxtyping>=0.3.1", 33 | "dm-tree>=0.1.8", 34 | "tmtools>=0.2.0", 35 | "pot>=0.9.5", 36 | "iminuit>=2.31.1", 37 | "tmscoring>=0.4.post0", 38 | "biotite>=1.2.0", 39 | "einops>=0.8.1", 40 | "ml-collections>=1.0.0", 41 | "mlflow>=2.21.3", 42 | "hatchling>=1.27.0", 43 | "editables>=0.5", 44 | "setuptools>=78.1.0", 45 | "ipykernel>=6.29.5", 46 | ] 47 | 48 | [tool.uv] 49 | find-links = [ 50 | "https://data.pyg.org/whl/torch-2.4.0%2Bcu121.html" 51 | ] 52 | no-build-package = ["torch", "torchvision", "torchaudio", "torch-geometric", "torch-scatter", "torch-cluster", "mamba-ssm"] 53 | no-build-isolation-package = ["flash-attn"] 54 | 55 | [tool.uv.sources] 56 | torch = { index = "pytorch" } 57 | 58 | [[tool.uv.index]] 59 | name = "pytorch" 60 | url = "https://download.pytorch.org/whl/cu121" 61 | explicit = true 62 | 63 | [build-system] 64 | requires = ["hatchling"] 65 | build-backend = "hatchling.build" 66 | 67 | [tool.hatch.build.targets.wheel] 68 | packages = ["src/flash_ipa"] 69 | -------------------------------------------------------------------------------- /src/flash_ipa/edge_embedder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from 3 | https://github.com/microsoft/protein-frame-flow/blob/main/models/edge_embedder.py 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | from flash_ipa import utils 9 | from einops import rearrange 10 | from dataclasses import dataclass 11 | 12 | 13 | @dataclass 14 | class EdgeEmbedderConfig: 15 | z_factor_rank: int = 2 16 | single_bias_transition_n: int = 2 17 | c_s: int = 256 18 | c_p: int = 128 19 | relpos_k: int = 64 20 | use_rbf: bool = True 21 | num_rbf: int = 32 22 | feat_dim: int = 64 23 | num_bins: int = 22 24 | self_condition: bool = True 25 | k: int = 10 26 | 27 | 28 | class EdgeEmbedder(nn.Module): 29 | 30 | def __init__(self, module_cfg, mode): 31 | super(EdgeEmbedder, self).__init__() 32 | self._cfg = module_cfg 33 | 34 | self.c_s = self._cfg.c_s 35 | self.c_p = self._cfg.c_p 36 | self.feat_dim = self._cfg.feat_dim 37 | self.mode = mode 38 | 39 | total_edge_feats = self.feat_dim * 3 + self._cfg.num_bins * 2 40 | 41 | if self.mode == "1d": 42 | self.k = self._cfg.k 43 | self.z_factor_rank = self._cfg.z_factor_rank 44 | self.linear_s_p = nn.Linear(self.c_s, self.feat_dim * self.z_factor_rank * 2) 45 | self.linear_relpos = nn.Linear(self.feat_dim, self.feat_dim * self.z_factor_rank * 4) 46 | self.linear_t_distogram = nn.Linear( 47 | self.k * (self._cfg.num_bins + self.feat_dim), self._cfg.num_bins * self.z_factor_rank * 2 48 | ) 49 | self.linear_sc_distogram = nn.Linear( 50 | self.k * (self._cfg.num_bins + self.feat_dim), self._cfg.num_bins * self.z_factor_rank * 2 51 | ) 52 | elif self.mode == "2d": 53 | self.linear_s_p = nn.Linear(self.c_s, self.feat_dim) 54 | self.linear_relpos = nn.Linear(self.feat_dim, self.feat_dim) 55 | else: 56 | raise ValueError(f"Unknown edge embedder type: {self.mode}. Must be '1d' or '2d'.") 57 | 58 | self.edge_embedder = nn.Sequential( 59 | nn.Linear(total_edge_feats, self.c_p), 60 | nn.ReLU(), 61 | nn.Linear(self.c_p, self.c_p), 62 | nn.ReLU(), 63 | nn.Linear(self.c_p, self.c_p), 64 | nn.LayerNorm(self.c_p), 65 | ) 66 | 67 | def embed_relpos(self, pos): 68 | rel_pos = pos[:, :, None] - pos[:, None, :] 69 | pos_emb = utils.get_index_embedding(rel_pos, self._cfg.feat_dim, max_len=2056) 70 | return self.linear_relpos(pos_emb) 71 | 72 | def _cross_concat(self, feats_1d, num_batch, num_res): 73 | return ( 74 | torch.cat( 75 | [ 76 | torch.tile(feats_1d[:, :, None, :], (1, 1, num_res, 1)), 77 | torch.tile(feats_1d[:, None, :, :], (1, num_res, 1, 1)), 78 | ], 79 | dim=-1, 80 | ) 81 | .float() 82 | .reshape([num_batch, num_res, num_res, -1]) 83 | ) 84 | 85 | def forward(self, s, t, sc_t, p_mask): 86 | """ 87 | s: [B,L,D] 88 | t: [B,L,3] 89 | sc_t: [B,L,3] 90 | p_mask: [B,L,L] or [B,L] 91 | """ 92 | if self.mode == "2d": 93 | return self.fwd_2d(s, t, sc_t, p_mask) 94 | else: 95 | return self.fwd_1d(s, t, sc_t, p_mask) 96 | 97 | def fwd_1d(self, s, t, sc_t, p_mask): 98 | """ 99 | returns [B,L,R,D] 100 | """ 101 | 102 | num_batch, num_res, _ = s.shape 103 | p_i = self.linear_s_p(s) # B, L, 2 * R * D 104 | p_i = rearrange(p_i, "b l (n r d) -> b n l r d", r=self.z_factor_rank, n=2) # B, 2, L, R, D 105 | 106 | pos = torch.arange(num_res, device=s.device).unsqueeze(0).repeat(num_batch, 1) 107 | pos_emb = utils.get_index_embedding(pos, self._cfg.feat_dim, max_len=2056) 108 | pos_emb = self.linear_relpos(pos_emb) # B, L, 2* R * 2 * D 109 | pos_emb = rearrange(pos_emb, "b l (n r d) -> b n l r d", r=self.z_factor_rank, n=2) # B, 2, L, R, 2D 110 | 111 | t_distogram, t_indices = utils.calc_distogram_knn(t, k=self.k, min_bin=1e-3, max_bin=20.0, num_bins=self._cfg.num_bins) 112 | t_idx_emb = utils.get_index_embedding(rearrange(t_indices, "b l k -> (b l) k"), self._cfg.feat_dim, max_len=2056) 113 | t_idx_emb = rearrange(t_idx_emb, "(b l) k d -> b l k d", b=num_batch) 114 | t_cat_emb = torch.cat([t_distogram, t_idx_emb], dim=-1) 115 | t_embed = self.linear_t_distogram(rearrange(t_cat_emb, "b l k d -> b l (k d)")) # B, L, num_bins * R 116 | t_embed = rearrange(t_embed, "b l (n r d) -> b n l r d", r=self.z_factor_rank, n=2) # B, 2, L, R, num_bins 117 | 118 | sc_distogram, sc_indices = utils.calc_distogram_knn( 119 | sc_t, k=self.k, min_bin=1e-3, max_bin=20.0, num_bins=self._cfg.num_bins 120 | ) 121 | sc_idx_emb = utils.get_index_embedding(rearrange(sc_indices, "b l k -> (b l) k"), self._cfg.feat_dim, max_len=2056) 122 | sc_idx_emb = rearrange(sc_idx_emb, "(b l) k d -> b l k d", b=num_batch) 123 | sc_cat_emb = torch.cat([sc_distogram, sc_idx_emb], dim=-1) 124 | sc_embed = self.linear_sc_distogram(rearrange(sc_cat_emb, "b l k d -> b l (k d)")) # B, L, num_bins * R 125 | sc_embed = rearrange(sc_embed, "b l (n r d) -> b n l r d", r=self.z_factor_rank, n=2) # B, 2, L, R, num_bins 126 | 127 | all_edge_feats = torch.concat([p_i, pos_emb, t_embed, sc_embed], dim=-1) 128 | edge_feats = self.edge_embedder(all_edge_feats) 129 | # edge_feats *= p_mask[:, None, :, None, None] #Do we need this? 130 | z_factor_1 = edge_feats[:, 0, :, :, :] 131 | z_factor_2 = edge_feats[:, 1, :, :, :] 132 | 133 | return z_factor_1, z_factor_2 134 | 135 | def fwd_2d(self, s, t, sc_t, p_mask): 136 | """ 137 | s: [B,L,D] 138 | t: [B,L,3] 139 | sc_t: [B,L,3] 140 | p_mask: [B,L,L] 141 | """ 142 | 143 | # raise ValueError(s.shape, t.shape, sc_t.shape, p_mask.shape) 144 | num_batch, num_res, _ = s.shape 145 | p_i = self.linear_s_p(s) # B, L, D 146 | cross_node_feats = self._cross_concat( 147 | p_i, num_batch, num_res 148 | ) # B, L, L, 2*D. Cross concat is similar in spirit to outer product with self. 149 | 150 | pos = torch.arange(num_res, device=s.device).unsqueeze(0).repeat(num_batch, 1) # B, L. 1D positional encoding 151 | relpos_feats = self.embed_relpos(pos) # B, L, L, D. 2D positional encoding 152 | 153 | dist_feats = utils.calc_distogram(t, min_bin=1e-3, max_bin=20.0, num_bins=self._cfg.num_bins) # B, L, L, D 154 | sc_feats = utils.calc_distogram(sc_t, min_bin=1e-3, max_bin=20.0, num_bins=self._cfg.num_bins) # B, L, L, D 155 | 156 | all_edge_feats = torch.concat([cross_node_feats, relpos_feats, dist_feats, sc_feats], dim=-1) 157 | edge_feats = self.edge_embedder(all_edge_feats) 158 | edge_feats *= p_mask.unsqueeze(-1) 159 | return edge_feats 160 | -------------------------------------------------------------------------------- /src/flash_ipa/factorizer.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from einops import rearrange 3 | from typing import Optional, Tuple 4 | from torch import nn, Tensor 5 | import math 6 | 7 | 8 | class LinearFactorizer(nn.Module): 9 | 10 | def __init__(self, in_L, in_D, target_rank=4, target_inner_dim=8): 11 | super().__init__() 12 | self.target_rank = target_rank 13 | self.target_inner_dim = target_inner_dim 14 | # self.linear_col = nn.Linear(in_L * in_D, target_rank * target_inner_dim, bias=False) 15 | # self.linear_row = nn.Linear(in_L * in_D, target_rank * target_inner_dim, bias=False) 16 | self.length_compressor = nn.Linear(in_L, target_rank, bias=False) 17 | self.inner_compressor = nn.Linear(in_D, target_inner_dim, bias=False) 18 | self.in_L = in_L 19 | self.length_norm = nn.LayerNorm(target_rank) 20 | self.inner_norm = nn.LayerNorm(target_inner_dim) 21 | 22 | nn.init.kaiming_uniform_(self.length_compressor.weight, a=math.sqrt(5)) 23 | nn.init.kaiming_uniform_(self.inner_compressor.weight, a=math.sqrt(5)) 24 | 25 | def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: 26 | """ 27 | Input: 28 | x: (B, L, L, D) 29 | mask: (B, L), optional length mask 30 | Output: 31 | U: (B * target_inner_dim, L, target_rank) 32 | V: (B * target_inner_dim, L, target_rank) 33 | """ 34 | # Compress along length 35 | L_orig = x.shape[1] 36 | x = F.pad(x, (0, 0, 0, self.in_L - x.shape[2], 0, self.in_L - x.shape[1]), value=0.0) 37 | row_embed = self.length_compressor(rearrange(x, "B R C D -> B C D R"))[:, :L_orig, :, :] # (B, L, D, target_rank) 38 | col_embed = self.length_compressor(rearrange(x, "B R C D -> B R D C"))[:, :L_orig, :, :] # (B, L, D, target_rank) 39 | 40 | row_embed = self.length_norm(row_embed) 41 | col_embed = self.length_norm(col_embed) 42 | 43 | row_embed = self.inner_compressor(rearrange(row_embed, "B C D R -> B C R D")) # (B, L, target_rank, target_inner_dim) 44 | col_embed = self.inner_compressor(rearrange(col_embed, "B R D C -> B R C D")) # (B, L, target_rank, target_inner_dim) 45 | 46 | row_embed = self.inner_norm(row_embed) 47 | col_embed = self.inner_norm(col_embed) 48 | if mask is not None: 49 | # Apply mask to row_embed and col_embed 50 | row_embed = row_embed * mask[:, :, None, None] 51 | col_embed = col_embed * mask[:, :, None, None] 52 | 53 | row_embed = rearrange(row_embed, "B C R D -> (B D) C R")[:,] / math.sqrt(self.target_rank) # (B * D, L, target_rank) 54 | col_embed = rearrange(col_embed, "B R C D -> (B D) R C") / math.sqrt(self.target_rank) # (B * D, L, target_rank) 55 | 56 | # row_embed = row_embed / self.target_rank 57 | # col_embed = col_embed / self.target_rank 58 | 59 | return row_embed, col_embed 60 | -------------------------------------------------------------------------------- /src/flash_ipa/ipa.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Anonymous 2 | # Copyright 2021 AlQuraishi Laboratory 3 | # Copyright 2021 DeepMind Technologies Limited 4 | 5 | import torch 6 | import torch.nn as nn 7 | from typing import Optional, List, Sequence 8 | import math 9 | import torch.nn.functional as F 10 | from einops import rearrange 11 | from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_func 12 | from flash_ipa.linear import Linear 13 | from flash_ipa.rigid import Rigid 14 | from dataclasses import dataclass 15 | 16 | attn_dtype_dict = { 17 | "fp16": torch.float16, 18 | "bf16": torch.bfloat16, 19 | "fp32": torch.float32, 20 | } 21 | 22 | 23 | @dataclass 24 | class IPAConfig: 25 | use_flash_attn: bool = True 26 | attn_dtype: str = "bfp16" # "fp16", "bfp16", "fp32". For flash ipa, bfp16 or fp16. For original, fp32. 27 | use_packed: bool = True 28 | c_s: int = 256 29 | c_z: int = 128 30 | c_hidden: int = 128 31 | no_heads: int = 8 32 | z_factor_rank: int = 2 # 0 for no factorization 33 | no_qk_points: int = 8 34 | no_v_points: int = 12 35 | seq_tfmr_num_heads: int = 4 36 | seq_tfmr_num_layers: int = 2 37 | num_blocks: int = 6 38 | 39 | 40 | class InvariantPointAttention(nn.Module): 41 | """ 42 | Implements Algorithm 22, with flash IPA. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | ipa_conf: IPAConfig, 48 | inf: float = 1e5, 49 | eps: float = 1e-8, 50 | ): 51 | """ 52 | Args: 53 | c_s: 54 | Single representation channel dimension 55 | c_z: 56 | Pair representation channel dimension 57 | c_hidden: 58 | Hidden channel dimension 59 | no_heads: 60 | Number of attention heads 61 | no_qk_points: 62 | Number of query/key points to generate 63 | no_v_points: 64 | Number of value points to generate 65 | """ 66 | super(InvariantPointAttention, self).__init__() 67 | self._ipa_conf = ipa_conf 68 | 69 | self.use_flash_attn = ipa_conf.use_flash_attn 70 | self.attn_dtype = attn_dtype_dict[ipa_conf.attn_dtype] 71 | self.use_packed = ipa_conf.use_packed 72 | 73 | self.c_s = ipa_conf.c_s 74 | self.c_z = ipa_conf.c_z 75 | self.c_hidden = ipa_conf.c_hidden 76 | self.no_heads = ipa_conf.no_heads 77 | self.no_qk_points = ipa_conf.no_qk_points 78 | self.no_v_points = ipa_conf.no_v_points 79 | self.inf = inf 80 | self.eps = eps 81 | 82 | # These linear layers differ from their specifications in the 83 | # supplement. There, they lack bias and use Glorot initialization. 84 | # Here as in the official source, they have bias and use the default 85 | # Lecun initialization. 86 | hc = self.c_hidden * self.no_heads 87 | self.linear_q = Linear(self.c_s, hc) 88 | self.linear_kv = Linear(self.c_s, 2 * hc) 89 | 90 | hpq = self.no_heads * self.no_qk_points * 3 91 | self.linear_q_points = Linear(self.c_s, hpq) 92 | 93 | hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3 94 | self.linear_kv_points = Linear(self.c_s, hpkv) 95 | if self.c_z > 0: 96 | self.linear_b = Linear(self.c_z, self.no_heads) 97 | self.down_z = Linear(self.c_z, self.c_z // 4) 98 | 99 | self.head_weights = nn.Parameter(torch.zeros((ipa_conf.no_heads))) 100 | ipa_point_weights_init_(self.head_weights) 101 | 102 | concat_out_dim = self.c_z // 4 + self.c_hidden + self.no_v_points * 4 103 | self.linear_out = Linear(self.no_heads * concat_out_dim, self.c_s, init="final") 104 | 105 | self.softmax = nn.Softmax(dim=-1) 106 | self.softplus = nn.Softplus() 107 | 108 | self.headdim_eff = max( 109 | self._ipa_conf.c_hidden + 5 * self.no_qk_points + (self._ipa_conf.z_factor_rank * self.no_heads), 110 | self._ipa_conf.c_hidden + 3 * self.no_v_points + (self._ipa_conf.z_factor_rank * self.c_z // 4), 111 | ) 112 | 113 | if self.headdim_eff > 256: 114 | assert ( 115 | self.use_flash_attn is False or self.attn_dtype == torch.float16 116 | ), "For headdim_eff > 256, you must use either naive attention or FFPA, which requires fp16 dtype." 117 | 118 | def flash_ipa_fwd(self, q, k, v, q_pts, k_pts, v_pts, z_factor_1, z_factor_2, r, mask): 119 | """ 120 | Compute squared norm components (used for SE(3) invariance part) 121 | """ 122 | q_pts_norm_sq = torch.norm(q_pts, dim=-1) ** 2 123 | k_pts_norm_sq = torch.norm(k_pts, dim=-1) ** 2 124 | 125 | """ 126 | Compute non-zero padding (used for SE(3) invariance part) 127 | """ 128 | head_weights = self.softplus(self.head_weights) 129 | head_weights = head_weights * math.sqrt(1.0 / (3 * (self.no_qk_points * 9.0 / 2))) 130 | q_pad = torch.ones_like(q_pts_norm_sq) 131 | k_pad = torch.ones_like(k_pts_norm_sq) * (-0.5) * head_weights.view(1, 1, -1, 1) 132 | 133 | """ 134 | Compute pair bias factors 135 | """ 136 | if z_factor_1 is not None and z_factor_2 is not None: 137 | # z_factor_1 has shape [B, N_res, rank, C_z] 138 | z_comb = torch.cat([z_factor_1.unsqueeze(1), z_factor_2.unsqueeze(1)], dim=1) 139 | b = self.linear_b(z_comb) 140 | b1 = b[:, 0, :, :, :].permute(0, 1, 3, 2) # B, N_res, H, rank 141 | b2 = b[:, 1, :, :, :].permute(0, 1, 3, 2) # B, N_res, H, rank 142 | 143 | z_comb_down = self.down_z(z_comb) 144 | z_factor_1 = z_comb_down[:, 0, :, :, :] # B, N_res, rank, C_z//4 145 | z_factor_2 = z_comb_down[:, 1, :, :, :] # B, N_res, rank, C_z//4 146 | 147 | """ 148 | Compute q_aggregated 149 | """ 150 | if z_factor_1 is not None: 151 | q_aggregated = torch.cat( 152 | [q, q_pts.view(q_pts.shape[0], q_pts.shape[1], q_pts.shape[2], -1), q_pts_norm_sq, q_pad, b1], dim=-1 153 | ) 154 | else: 155 | q_aggregated = torch.cat( 156 | [q, q_pts.view(q_pts.shape[0], q_pts.shape[1], q_pts.shape[2], -1), q_pts_norm_sq, q_pad], dim=-1 157 | ) 158 | """ 159 | Compute k_aggregated 160 | """ 161 | k_scaled = k * math.sqrt(1.0 / (3 * self.c_hidden)) 162 | k_pts_scaled = k_pts.view(k_pts.shape[0], k_pts.shape[1], k_pts.shape[2], -1) * head_weights.view(1, 1, -1, 1) 163 | k_pts_norm_sq_scaled = k_pts_norm_sq * (-0.5) * head_weights.view(1, 1, -1, 1) 164 | if z_factor_2 is not None: 165 | k_aggregated = torch.cat([k_scaled, k_pts_scaled, k_pad, k_pts_norm_sq_scaled, b2], dim=-1) 166 | else: 167 | k_aggregated = torch.cat([k_scaled, k_pts_scaled, k_pad, k_pts_norm_sq_scaled], dim=-1) 168 | 169 | """ 170 | Compute v_aggregated 171 | """ 172 | if z_factor_2 is not None: 173 | v_aggregated = torch.cat( 174 | [ 175 | v, 176 | v_pts.view(*v_pts.shape[:3], -1), 177 | z_factor_2.view(*z_factor_2.shape[:2], 1, -1).expand(-1, -1, self.no_heads, -1), 178 | ], 179 | dim=-1, 180 | ) 181 | else: 182 | v_aggregated = torch.cat([v, v_pts.view(v_pts.shape[0], v_pts.shape[1], v_pts.shape[2], -1)], dim=-1) 183 | 184 | if mask is None: 185 | mask = torch.ones((q.shape[0], q.shape[1]), device=q.device, dtype=torch.bool) 186 | 187 | """ 188 | Pass through FA2 or FFPA depending on headdim size 189 | """ 190 | 191 | if self.headdim_eff <= 256: # use FA2 192 | # FA2 requires that QKV have same size for last dimension. So just choose the smallest possible size. 193 | max_dim_sz = max(q_aggregated.shape[-1], k_aggregated.shape[-1], v_aggregated.shape[-1]) 194 | q_aggregated = F.pad(q_aggregated, (0, max_dim_sz - q_aggregated.shape[-1]), value=0.0) 195 | k_aggregated = F.pad(k_aggregated, (0, max_dim_sz - k_aggregated.shape[-1]), value=0.0) 196 | v_aggregated = F.pad(v_aggregated, (0, max_dim_sz - v_aggregated.shape[-1]), value=0.0) 197 | if self.use_packed: 198 | qkv = torch.cat([q_aggregated.unsqueeze(2), k_aggregated.unsqueeze(2), v_aggregated.unsqueeze(2)], dim=2) 199 | ( 200 | qkv, 201 | indices, 202 | cu_seqlens, 203 | max_seqlen, 204 | _, 205 | ) = unpad_input(qkv, mask) 206 | 207 | if qkv.dtype != self.attn_dtype: 208 | qkv = qkv.to(self.attn_dtype) 209 | 210 | attn_res = flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, softmax_scale=1) 211 | 212 | else: 213 | q_aggregated, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q_aggregated, mask) 214 | k_aggregated, _, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k_aggregated, mask) 215 | v_aggregated, _, cu_seqlens_v, max_seqlen_v, _ = unpad_input(v_aggregated, mask) 216 | 217 | if ( 218 | q_aggregated.dtype != self.attn_dtype 219 | or k_aggregated.dtype != self.attn_dtype 220 | or v_aggregated.dtype != self.attn_dtype 221 | ): 222 | q_aggregated = q_aggregated.to(self.attn_dtype) 223 | k_aggregated = k_aggregated.to(self.attn_dtype) 224 | v_aggregated = v_aggregated.to(self.attn_dtype) 225 | 226 | attn_res = flash_attn_varlen_func( 227 | q_aggregated, 228 | k_aggregated, 229 | v_aggregated, 230 | cu_seqlens_q, 231 | cu_seqlens_k, 232 | max_seqlen_q, 233 | max_seqlen_k, 234 | softmax_scale=1, 235 | ) 236 | attn_res = pad_input( 237 | attn_res, 238 | indices=indices, 239 | batch=q.shape[0], 240 | seqlen=q.shape[1], 241 | ) 242 | else: 243 | raise ValueError(f"self.headdim_eff has to be <= 256 for FA2 to work: {self.headdim_eff}") 244 | 245 | if attn_res.dtype != torch.float32: 246 | attn_res = attn_res.float() 247 | 248 | if z_factor_2 is not None: 249 | attn_res = attn_res[ 250 | :, :, :, : self.c_hidden + 3 * self.no_v_points + self._ipa_conf.z_factor_rank * z_factor_2.shape[-1] 251 | ] 252 | else: 253 | attn_res = attn_res[:, :, :, : self.c_hidden + 3 * self.no_v_points] 254 | 255 | o = attn_res[:, :, :, : self.c_hidden] 256 | o = flatten_final_dims(o, 2) 257 | 258 | # B,L,H,D 259 | o_pt = attn_res[:, :, :, self.c_hidden : self.c_hidden + 3 * self.no_v_points] 260 | # [*, H, 3, N_res, P_v] 261 | o_pt = rearrange(o_pt, "B L H (P_v r) -> B H r L P_v", P_v=self.no_v_points) 262 | 263 | o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) 264 | o_pt = r[..., None, None].invert_apply(o_pt) 265 | 266 | # [*, N_res, H * P_v] 267 | o_pt_dists = torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.eps) 268 | o_pt_norm_feats = flatten_final_dims(o_pt_dists, 2) 269 | 270 | # [*, N_res, H * P_v, 3] 271 | o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) 272 | 273 | # calculate o_pair 274 | if z_factor_1 is not None and z_factor_2 is not None: 275 | o_pair = attn_res[:, :, :, self.c_hidden + 3 * self.no_v_points :].view( 276 | *attn_res.shape[:3], self._ipa_conf.z_factor_rank, -1 277 | ) # B, L, H, rank, C_z//4 278 | o_pair = torch.einsum("b n r d, b n h r d -> b n h d", z_factor_1, o_pair) 279 | o_pair = flatten_final_dims(o_pair, 2) 280 | o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats, o_pair] 281 | else: 282 | o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats] 283 | 284 | s = self.linear_out(torch.cat(o_feats, dim=-1)) 285 | 286 | return s 287 | 288 | def slow_ipa_fwd(self, q, k, v, q_pts, k_pts, v_pts, z, r, mask, _offload_inference=False): 289 | ########################## 290 | # Compute attention scores 291 | ########################## 292 | # [*, N_res, N_res, H] 293 | 294 | if not z is None: 295 | b = self.linear_b(z[0]) 296 | 297 | if _offload_inference: 298 | z[0] = z[0].cpu() 299 | 300 | # [*, H, N_res, N_res] 301 | a = torch.matmul( 302 | permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] 303 | permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] 304 | ) 305 | a *= math.sqrt(1.0 / (3 * self.c_hidden)) 306 | 307 | if not z is None: 308 | a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)) 309 | 310 | # [*, N_res, N_res, H, P_q, 3] 311 | pt_displacement = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) 312 | pt_att = pt_displacement**2 313 | 314 | # [*, N_res, N_res, H, P_q] 315 | pt_att = sum(torch.unbind(pt_att, dim=-1)) 316 | head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1))) 317 | head_weights = head_weights * math.sqrt(1.0 / (3 * (self.no_qk_points * 9.0 / 2))) 318 | pt_att = pt_att * head_weights 319 | 320 | # [*, N_res, N_res, H] 321 | pt_att = torch.sum(pt_att, dim=-1) * (-0.5) 322 | # [*, N_res, N_res] 323 | square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) 324 | square_mask = self.inf * (square_mask - 1) 325 | 326 | # [*, H, N_res, N_res] 327 | pt_att = permute_final_dims(pt_att, (2, 0, 1)) 328 | 329 | a = a + pt_att 330 | a = a + square_mask.unsqueeze(-3) 331 | a = self.softmax(a) 332 | 333 | ################ 334 | # Compute output 335 | ################ 336 | # [*, N_res, H, C_hidden] 337 | o = torch.matmul(a, v.transpose(-2, -3)).transpose(-2, -3) 338 | 339 | # [*, N_res, H * C_hidden] 340 | o = flatten_final_dims(o, 2) 341 | 342 | # [*, H, 3, N_res, P_v] 343 | o_pt = torch.sum( 344 | (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), 345 | dim=-2, 346 | ) 347 | 348 | # [*, N_res, H, P_v, 3] 349 | o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) 350 | o_pt = r[..., None, None].invert_apply(o_pt) 351 | 352 | # [*, N_res, H * P_v] 353 | o_pt_dists = torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.eps) 354 | o_pt_norm_feats = flatten_final_dims(o_pt_dists, 2) 355 | 356 | # [*, N_res, H * P_v, 3] 357 | o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) 358 | 359 | if not z is None: 360 | if _offload_inference: 361 | z[0] = z[0].to(o_pt.device) 362 | 363 | # [*, N_res, H, C_z // 4] 364 | pair_z = self.down_z(z[0]) 365 | o_pair = torch.matmul(a.transpose(-2, -3), pair_z) 366 | 367 | # [*, N_res, H * C_z // 4] 368 | o_pair = flatten_final_dims(o_pair, 2) 369 | 370 | o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats, o_pair] 371 | else: 372 | o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats] 373 | 374 | # [*, N_res, C_s] 375 | s = self.linear_out(torch.cat(o_feats, dim=-1)) 376 | return s 377 | 378 | def forward( 379 | self, 380 | s: torch.Tensor, 381 | z: Optional[torch.Tensor], 382 | z_factor_1: Optional[torch.Tensor], 383 | z_factor_2: Optional[torch.Tensor], 384 | r: Rigid, 385 | mask: torch.Tensor, 386 | _offload_inference: bool = False, 387 | _z_reference_list: Optional[Sequence[torch.Tensor]] = None, 388 | ) -> torch.Tensor: 389 | """ 390 | Args: 391 | s: 392 | [*, N_res, C_s] single representation 393 | z: 394 | [*, N_res, N_res, C_z] pair representation 395 | r: 396 | [*, N_res] transformation object 397 | mask: 398 | [*, N_res] mask 399 | Returns: 400 | [*, N_res, C_s] single representation update 401 | """ 402 | if not z is None: 403 | if _offload_inference: 404 | z = _z_reference_list 405 | else: 406 | z = [z] 407 | 408 | ####################################### 409 | # Generate scalar and point activations 410 | ####################################### 411 | # [*, N_res, H * C_hidden] 412 | q = self.linear_q(s) 413 | kv = self.linear_kv(s) 414 | 415 | # [*, N_res, H, C_hidden] 416 | q = q.view(q.shape[:-1] + (self.no_heads, -1)) 417 | 418 | # [*, N_res, H, 2 * C_hidden] 419 | kv = kv.view(kv.shape[:-1] + (self.no_heads, -1)) 420 | 421 | # [*, N_res, H, C_hidden] 422 | k, v = torch.split(kv, self.c_hidden, dim=-1) 423 | 424 | # [*, N_res, H * P_q * 3] 425 | q_pts = self.linear_q_points(s) 426 | 427 | # This is kind of clunky, but it's how the original does it 428 | # [*, N_res, H * P_q, 3] 429 | q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) 430 | q_pts = torch.stack(q_pts, dim=-1) 431 | q_pts = r[..., None].apply(q_pts) 432 | 433 | # [*, N_res, H, P_q, 3] 434 | q_pts = q_pts.view(q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)) 435 | 436 | # [*, N_res, H * (P_q + P_v) * 3] 437 | kv_pts = self.linear_kv_points(s) 438 | 439 | # [*, N_res, H * (P_q + P_v), 3] 440 | kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) 441 | kv_pts = torch.stack(kv_pts, dim=-1) 442 | kv_pts = r[..., None].apply(kv_pts) 443 | 444 | # [*, N_res, H, (P_q + P_v), 3] 445 | kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) 446 | 447 | # [*, N_res, H, P_q/P_v, 3] 448 | k_pts, v_pts = torch.split(kv_pts, [self.no_qk_points, self.no_v_points], dim=-2) 449 | 450 | if self.use_flash_attn: 451 | s = self.flash_ipa_fwd( 452 | q, 453 | k, 454 | v, 455 | q_pts, 456 | k_pts, 457 | v_pts, 458 | z_factor_1, 459 | z_factor_2, 460 | r, 461 | mask=mask, 462 | ) 463 | 464 | else: 465 | s = self.slow_ipa_fwd( 466 | q, 467 | k, 468 | v, 469 | q_pts, 470 | k_pts, 471 | v_pts, 472 | z, 473 | r, 474 | mask=mask, 475 | _offload_inference=_offload_inference, 476 | ) 477 | return s 478 | 479 | 480 | def permute_final_dims(tensor: torch.Tensor, inds: List[int]): 481 | zero_index = -1 * len(inds) 482 | first_inds = list(range(len(tensor.shape[:zero_index]))) 483 | return tensor.permute(first_inds + [zero_index + i for i in inds]) 484 | 485 | 486 | def flatten_final_dims(t: torch.Tensor, no_dims: int): 487 | return t.reshape(t.shape[:-no_dims] + (-1,)) 488 | 489 | 490 | def ipa_point_weights_init_(weights): 491 | with torch.no_grad(): 492 | softplus_inverse_1 = 0.541324854612918 493 | weights.fill_(softplus_inverse_1) 494 | 495 | 496 | """ 497 | Unpadding and padding operations for FlashAttention 498 | """ 499 | 500 | 501 | def unpad_input(hidden_states, attention_mask, unused_mask=None): 502 | """ 503 | Arguments: 504 | hidden_states: (batch, seqlen, ...) 505 | attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. 506 | unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. 507 | Return: 508 | hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. 509 | indices: (total_nnz), the indices of masked tokens from the flattened input sequence. 510 | cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. 511 | max_seqlen_in_batch: int 512 | seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. 513 | """ 514 | all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask 515 | seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) 516 | used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 517 | indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() 518 | max_seqlen_in_batch = seqlens_in_batch.max().item() 519 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 520 | return ( 521 | rearrange(hidden_states, "b s ... -> (b s) ...")[indices], 522 | indices, 523 | cu_seqlens, 524 | max_seqlen_in_batch, 525 | used_seqlens_in_batch, 526 | ) 527 | 528 | 529 | def pad_input(hidden_states, indices, batch, seqlen): 530 | """ 531 | Arguments: 532 | hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. 533 | indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. 534 | batch: int, batch size for the padded sequence. 535 | seqlen: int, maximum sequence length for the padded sequence. 536 | Return: 537 | hidden_states: (batch, seqlen, ...) 538 | """ 539 | dim = hidden_states.shape[1:] 540 | output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) 541 | output[indices] = hidden_states 542 | return rearrange(output, "(b s) ... -> b s ...", b=batch) 543 | 544 | 545 | class StructureModuleTransition(nn.Module): 546 | def __init__(self, c): 547 | super(StructureModuleTransition, self).__init__() 548 | 549 | self.c = c 550 | 551 | self.linear_1 = Linear(self.c, self.c, init="relu") 552 | self.linear_2 = Linear(self.c, self.c, init="relu") 553 | self.linear_3 = Linear(self.c, self.c, init="final") 554 | self.relu = nn.ReLU() 555 | self.ln = nn.LayerNorm(self.c) 556 | 557 | def forward(self, s): 558 | s_initial = s 559 | s = self.linear_1(s) 560 | s = self.relu(s) 561 | s = self.linear_2(s) 562 | s = self.relu(s) 563 | s = self.linear_3(s) 564 | s = s + s_initial 565 | s = self.ln(s) 566 | 567 | return s 568 | 569 | 570 | class EdgeTransition(nn.Module): 571 | def __init__( 572 | self, 573 | *, 574 | mode, 575 | node_embed_size, 576 | edge_embed_in, 577 | edge_embed_out, 578 | z_factor_rank=0, 579 | num_layers=2, 580 | node_dilation=2, 581 | ): 582 | super(EdgeTransition, self).__init__() 583 | 584 | self.mode = mode 585 | self.z_factor_rank = z_factor_rank 586 | assert mode in ["1d", "2d"], f"Invalid mode: {mode}. Must be '1d' or '2d'." 587 | bias_embed_size = node_embed_size // node_dilation 588 | 589 | self.initial_embed = Linear(node_embed_size, bias_embed_size, init="relu") 590 | if mode == "1d": 591 | self.edge_bias_linear = Linear(bias_embed_size, 4 * self.z_factor_rank * bias_embed_size, init="final") 592 | hidden_size = bias_embed_size * 2 + edge_embed_in 593 | trunk_layers = [] 594 | for _ in range(num_layers): 595 | trunk_layers.append(Linear(hidden_size, hidden_size, init="relu")) 596 | trunk_layers.append(nn.ReLU()) 597 | self.trunk = nn.Sequential(*trunk_layers) 598 | self.final_layer = Linear(hidden_size, edge_embed_out, init="final") 599 | self.layer_norm = nn.LayerNorm(edge_embed_out) 600 | 601 | def forward(self, node_embed, edge_embed, z_factor_1=None, z_factor_2=None): 602 | if edge_embed is not None: 603 | return self.fwd_2d(node_embed, edge_embed) 604 | elif z_factor_1 is not None and z_factor_2 is not None: 605 | return self.fwd_1d(node_embed, z_factor_1, z_factor_2) 606 | 607 | def fwd_1d(self, node_embed, z_factor_1, z_factor_2): 608 | node_embed = self.initial_embed(node_embed) # B,L,D 609 | 610 | batch_size, num_res, _ = node_embed.shape 611 | rank = z_factor_1.shape[2] 612 | 613 | edge_bias = self.edge_bias_linear(node_embed) 614 | edge_bias = rearrange(edge_bias, "b l (n r d) -> b n l r d", r=self.z_factor_rank, n=2) # B,2,L,R,2*D 615 | 616 | z_agg = torch.cat( 617 | [z_factor_1[:, None, :, :, :], z_factor_2[:, None, :, :, :]], 618 | axis=1, 619 | ) 620 | 621 | edge_embed = torch.cat([z_agg, edge_bias], axis=-1) / math.sqrt(2) 622 | 623 | edge_embed = self.final_layer(self.trunk(edge_embed) + edge_embed) 624 | edge_embed = self.layer_norm(edge_embed) 625 | z_factor_1 = edge_embed[:, 0, :, :, :] 626 | z_factor_2 = edge_embed[:, 1, :, :, :] 627 | return z_factor_1, z_factor_2 628 | 629 | def fwd_2d(self, node_embed, edge_embed): 630 | node_embed = self.initial_embed(node_embed) 631 | batch_size, num_res, _ = node_embed.shape 632 | edge_bias = torch.cat( 633 | [ 634 | torch.tile(node_embed[:, :, None, :], (1, 1, num_res, 1)), 635 | torch.tile(node_embed[:, None, :, :], (1, num_res, 1, 1)), 636 | ], 637 | axis=-1, 638 | ) 639 | edge_embed = torch.cat([edge_embed, edge_bias], axis=-1).reshape(batch_size * num_res**2, -1) # B*L*L,D 640 | edge_embed = self.final_layer(self.trunk(edge_embed) + edge_embed) 641 | edge_embed = self.layer_norm(edge_embed) 642 | edge_embed = edge_embed.reshape(batch_size, num_res, num_res, -1) 643 | return edge_embed 644 | 645 | 646 | class BackboneUpdate(nn.Module): 647 | """ 648 | Implements part of Algorithm 23. 649 | """ 650 | 651 | def __init__(self, c_s, use_rot_updates): 652 | """ 653 | Args: 654 | c_s: 655 | Single representation channel dimension 656 | """ 657 | super(BackboneUpdate, self).__init__() 658 | 659 | self.c_s = c_s 660 | self._use_rot_updates = use_rot_updates 661 | update_dim = 6 if use_rot_updates else 3 662 | self.linear = Linear(self.c_s, update_dim, init="final") 663 | 664 | def forward(self, s: torch.Tensor): 665 | """ 666 | Args: 667 | [*, N_res, C_s] single representation 668 | Returns: 669 | [*, N_res, 6] update vector 670 | """ 671 | # [*, 6] 672 | update = self.linear(s) 673 | 674 | return update 675 | -------------------------------------------------------------------------------- /src/flash_ipa/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Optional, Callable 4 | import math 5 | import numpy as np 6 | from scipy.stats import truncnorm 7 | 8 | 9 | def _prod(nums): 10 | out = 1 11 | for n in nums: 12 | out = out * n 13 | return out 14 | 15 | 16 | def _calculate_fan(linear_weight_shape, fan="fan_in"): 17 | fan_out, fan_in = linear_weight_shape 18 | 19 | if fan == "fan_in": 20 | f = fan_in 21 | elif fan == "fan_out": 22 | f = fan_out 23 | elif fan == "fan_avg": 24 | f = (fan_in + fan_out) / 2 25 | else: 26 | raise ValueError("Invalid fan option") 27 | 28 | return f 29 | 30 | 31 | def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): 32 | shape = weights.shape 33 | f = _calculate_fan(shape, fan) 34 | scale = scale / max(1, f) 35 | a = -2 36 | b = 2 37 | std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) 38 | size = _prod(shape) 39 | samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) 40 | samples = np.reshape(samples, shape) 41 | with torch.no_grad(): 42 | weights.copy_(torch.tensor(samples, device=weights.device)) 43 | 44 | 45 | def lecun_normal_init_(weights): 46 | trunc_normal_init_(weights, scale=1.0) 47 | 48 | 49 | def he_normal_init_(weights): 50 | trunc_normal_init_(weights, scale=2.0) 51 | 52 | 53 | def glorot_uniform_init_(weights): 54 | nn.init.xavier_uniform_(weights, gain=1) 55 | 56 | 57 | def final_init_(weights): 58 | with torch.no_grad(): 59 | weights.fill_(0.0) 60 | 61 | 62 | def gating_init_(weights): 63 | with torch.no_grad(): 64 | weights.fill_(0.0) 65 | 66 | 67 | def normal_init_(weights): 68 | torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") 69 | 70 | 71 | class Linear(nn.Linear): 72 | """ 73 | A Linear layer with built-in nonstandard initializations. Called just 74 | like torch.nn.Linear. 75 | 76 | Implements the initializers in 1.11.4, plus some additional ones found 77 | in the code. 78 | """ 79 | 80 | def __init__( 81 | self, 82 | in_dim: int, 83 | out_dim: int, 84 | bias: bool = True, 85 | init: str = "default", 86 | init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, 87 | ): 88 | """ 89 | Args: 90 | in_dim: 91 | The final dimension of inputs to the layer 92 | out_dim: 93 | The final dimension of layer outputs 94 | bias: 95 | Whether to learn an additive bias. True by default 96 | init: 97 | The initializer to use. Choose from: 98 | 99 | "default": LeCun fan-in truncated normal initialization 100 | "relu": He initialization w/ truncated normal distribution 101 | "glorot": Fan-average Glorot uniform initialization 102 | "gating": Weights=0, Bias=1 103 | "normal": Normal initialization with std=1/sqrt(fan_in) 104 | "final": Weights=0, Bias=0 105 | 106 | Overridden by init_fn if the latter is not None. 107 | init_fn: 108 | A custom initializer taking weight and bias as inputs. 109 | Overrides init if not None. 110 | """ 111 | super(Linear, self).__init__(in_dim, out_dim, bias=bias) 112 | 113 | if bias: 114 | with torch.no_grad(): 115 | self.bias.fill_(0) 116 | 117 | if init_fn is not None: 118 | init_fn(self.weight, self.bias) 119 | else: 120 | if init == "default": 121 | lecun_normal_init_(self.weight) 122 | elif init == "relu": 123 | he_normal_init_(self.weight) 124 | elif init == "glorot": 125 | glorot_uniform_init_(self.weight) 126 | elif init == "gating": 127 | gating_init_(self.weight) 128 | if bias: 129 | with torch.no_grad(): 130 | self.bias.fill_(1.0) 131 | elif init == "normal": 132 | normal_init_(self.weight) 133 | elif init == "final": 134 | final_init_(self.weight) 135 | else: 136 | raise ValueError("Invalid init string.") 137 | -------------------------------------------------------------------------------- /src/flash_ipa/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of neural network architecture. 3 | 4 | Code adapted from 5 | https://github.com/microsoft/protein-frame-flow/blob/main/models/flow_model.py 6 | """ 7 | 8 | import torch 9 | from torch import nn 10 | from einops import rearrange 11 | 12 | from flash_ipa.edge_embedder import EdgeEmbedder, EdgeEmbedderConfig 13 | from flash_ipa.ipa import StructureModuleTransition, BackboneUpdate, EdgeTransition, InvariantPointAttention, IPAConfig 14 | from flash_ipa.utils import ANG_TO_NM_SCALE, NM_TO_ANG_SCALE 15 | from flash_ipa.rigid import create_rigid 16 | from flash_ipa.factorizer import LinearFactorizer 17 | from flash_ipa.linear import Linear 18 | from dataclasses import dataclass, field 19 | 20 | 21 | @dataclass 22 | class ModelConfig: 23 | # 5 options: "orig_no_bias", "orig_2d_bias", "flash_no_bias", "flash_1d_bias", "flash_2d_factorize_bias" 24 | mode: str = "flash_1d_bias" 25 | max_len: int = 256 26 | node_embed_size: int = 256 27 | edge_embed_size: int = 128 28 | ipa: IPAConfig = field(default_factory=IPAConfig) 29 | edge_features: EdgeEmbedderConfig = field(default_factory=EdgeEmbedderConfig) 30 | 31 | 32 | class Model(nn.Module): 33 | 34 | def __init__(self, model_conf): 35 | super(Model, self).__init__() 36 | self._model_conf = model_conf 37 | self.mode = model_conf.mode 38 | self._ipa_conf = model_conf.ipa 39 | self.rigids_ang_to_nm = lambda x: x.apply_trans_fn(lambda x: x * ANG_TO_NM_SCALE) 40 | self.rigids_nm_to_ang = lambda x: x.apply_trans_fn(lambda x: x * NM_TO_ANG_SCALE) 41 | self.edge_embedder = EdgeEmbedder(model_conf.edge_features, mode="1d" if self.mode == "flash_1d_bias" else "2d") 42 | 43 | """ 44 | Check variables are consistent for experiment. 45 | """ 46 | if self.mode == "orig_no_bias": 47 | assert ( 48 | self._ipa_conf.c_z == 0 and self._ipa_conf.z_factor_rank == 0 and self._ipa_conf.use_flash_attn == False 49 | ), "Expecting self._ipa_conf.c_z == 0 and self._ipa_conf.z_factor_rank == 0 and self._ipa_conf.use_flash_attn == False, but got {self._ipa_conf.c_z}, {self._ipa_conf.z_factor_rank}, {self._ipa_conf.use_flash_attn}." 50 | elif self.mode == "orig_2d_bias": 51 | assert ( 52 | self._ipa_conf.c_z > 0 and self._ipa_conf.z_factor_rank == 0 and self._ipa_conf.use_flash_attn == False 53 | ), "Expecting self._ipa_conf.c_z > 0 and self._ipa_conf.z_factor_rank == 0 and self._ipa_conf.use_flash_attn == False, but got {self._ipa_conf.c_z}, {self._ipa_conf.z_factor_rank}, {self._ipa_conf.use_flash_attn}." 54 | elif self.mode == "flash_no_bias": 55 | assert ( 56 | self._ipa_conf.c_z == 0 and self._ipa_conf.z_factor_rank == 0 and self._ipa_conf.use_flash_attn == True 57 | ), "Expecting self._ipa_conf.c_z == 0 and self._ipa_conf.z_factor_rank == 0 and self._ipa_conf.use_flash_attn == True, but got {self._ipa_conf.c_z}, {self._ipa_conf.z_factor_rank}, {self._ipa_conf.use_flash_attn}." 58 | elif self.mode == "flash_1d_bias": 59 | assert ( 60 | self._ipa_conf.c_z > 0 and self._ipa_conf.z_factor_rank > 0 and self._ipa_conf.use_flash_attn == True 61 | ), "Expecting self._ipa_conf.c_z > 0 and self._ipa_conf.z_factor_rank > 0 and self._ipa_conf.use_flash_attn == True, but got {self._ipa_conf.c_z}, {self._ipa_conf.z_factor_rank}, {self._ipa_conf.use_flash_attn}." 62 | elif self.mode == "flash_2d_factorize_bias": 63 | assert ( 64 | self._ipa_conf.c_z > 0 and self._ipa_conf.z_factor_rank > 0 and self._ipa_conf.use_flash_attn == True 65 | ), "Expecting self._ipa_conf.c_z > 0 and self._ipa_conf.z_factor_rank > 0 and self._ipa_conf.use_flash_attn == True, but got {self._ipa_conf.c_z}, {self._ipa_conf.z_factor_rank}, {self._ipa_conf.use_flash_attn}." 66 | else: 67 | raise ValueError( 68 | f"Invalid mode: {self.mode}. Must be one of ['orig_no_bias', 'orig_2d_bias', 'flash_no_bias', 'flash_1d_bias', 'flash_2d_factorize_bias']." 69 | ) 70 | 71 | if self.mode == "flash_2d_factorize_bias": 72 | self.factorizer = LinearFactorizer( 73 | in_L=model_conf.max_len, 74 | in_D=self._ipa_conf.c_z, 75 | target_rank=self._ipa_conf.z_factor_rank, 76 | target_inner_dim=self._ipa_conf.c_z, 77 | ) 78 | 79 | # Attention trunk 80 | self.trunk = nn.ModuleDict() 81 | for b in range(self._ipa_conf.num_blocks): 82 | self.trunk[f"ipa_{b}"] = InvariantPointAttention(self._ipa_conf) 83 | self.trunk[f"ipa_ln_{b}"] = nn.LayerNorm(self._ipa_conf.c_s) 84 | tfmr_in = self._ipa_conf.c_s 85 | tfmr_layer = torch.nn.TransformerEncoderLayer( 86 | d_model=tfmr_in, 87 | nhead=self._ipa_conf.seq_tfmr_num_heads, 88 | dim_feedforward=tfmr_in, 89 | batch_first=True, 90 | dropout=0.0, 91 | norm_first=False, 92 | ) 93 | self.trunk[f"seq_tfmr_{b}"] = torch.nn.TransformerEncoder( 94 | tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False 95 | ) 96 | self.trunk[f"post_tfmr_{b}"] = Linear(tfmr_in, self._ipa_conf.c_s, init="final") 97 | self.trunk[f"node_transition_{b}"] = StructureModuleTransition(c=self._ipa_conf.c_s) 98 | self.trunk[f"bb_update_{b}"] = BackboneUpdate(self._ipa_conf.c_s, use_rot_updates=True) 99 | 100 | if b < self._ipa_conf.num_blocks - 1: 101 | # No edge update on the last block. 102 | edge_in = self._model_conf.edge_embed_size 103 | self.trunk[f"edge_transition_{b}"] = EdgeTransition( 104 | mode="2d" if self.mode == "orig_2d_bias" else "1d", 105 | node_embed_size=self._ipa_conf.c_s, 106 | edge_embed_in=edge_in, 107 | edge_embed_out=self._model_conf.edge_embed_size, 108 | z_factor_rank=self._ipa_conf.z_factor_rank, 109 | ) 110 | 111 | def forward(self, input_feats): 112 | """ 113 | Assuming frames are already computed. 114 | input_feats: 115 | node_embeddings: (B, L, D) 116 | translations: (B, L, 3) 117 | rotations: (B, L, 3, 3) 118 | res_mask: (B, L) 119 | 120 | """ 121 | # Masks 122 | node_mask = input_feats["res_mask"] 123 | if self.mode in ["orig_2d_bias", "flash_2d_factorize_bias"]: 124 | # Edge mask exist only if we use 2d bias 125 | edge_mask = node_mask[:, None] * node_mask[:, :, None] 126 | else: 127 | edge_mask = None 128 | 129 | # Inputs 130 | init_node_embed = input_feats["node_embeddings"] 131 | translations = input_feats["translations"] 132 | rotations = input_feats["rotations"] 133 | 134 | if "trans_sc" not in input_feats: 135 | trans_sc = torch.zeros_like(translations) 136 | else: 137 | trans_sc = input_feats["trans_sc"] 138 | 139 | # Initialize edge embeddings depending on the mode 140 | if self.mode == "orig_no_bias" or self.mode == "flash_no_bias": 141 | init_edge_embed, z_factor_1, z_factor_2 = None, None, None 142 | elif self.mode == "orig_2d_bias": 143 | init_edge_embed = self.edge_embedder(init_node_embed, translations, trans_sc, edge_mask) # 2d mode 144 | elif self.mode == "flash_1d_bias": 145 | z_factor_1, z_factor_2 = self.edge_embedder(init_node_embed, translations, trans_sc, node_mask) # 1d mode 146 | elif self.mode == "flash_2d_factorize_bias": 147 | init_edge_embed = self.edge_embedder(init_node_embed, translations, trans_sc, edge_mask) # 2d mode 148 | z_factor_1, z_factor_2 = self.factorizer(init_edge_embed) 149 | z_factor_1 = rearrange(z_factor_1, "(b d) n r -> b n r d", b=init_node_embed.shape[0]) 150 | z_factor_2 = rearrange(z_factor_2, "(b d) n r -> b n r d", b=init_node_embed.shape[0]) 151 | 152 | # Apply masks 153 | init_node_embed = init_node_embed * node_mask[..., None] 154 | node_embed = init_node_embed * node_mask[..., None] 155 | if self.mode == "orig_2d_bias": 156 | # The edge_embed is used for slow IPA. Otherwise, use z_factor or node_embed. 157 | edge_embed = init_edge_embed * edge_mask[..., None] 158 | else: 159 | edge_embed = None 160 | 161 | # Initial rigids 162 | curr_rigids = create_rigid( 163 | rotations, 164 | translations, 165 | ) 166 | curr_rigids = self.rigids_ang_to_nm(curr_rigids) 167 | 168 | # Main trunk 169 | for b in range(self._ipa_conf.num_blocks): 170 | if self._ipa_conf.use_flash_attn: 171 | # The FlashAttention case uses pseudo-factors of the pair bias. 172 | ipa_embed = self.trunk[f"ipa_{b}"]( 173 | node_embed, 174 | None, 175 | z_factor_1, 176 | z_factor_2, 177 | curr_rigids, 178 | mask=node_mask, 179 | ) 180 | else: 181 | # The non-FlashAttention case uses the full pair bias (edge_embed). 182 | ipa_embed = self.trunk[f"ipa_{b}"](node_embed, edge_embed, None, None, curr_rigids, node_mask) 183 | 184 | # Update embedings and frame 185 | ipa_embed *= node_mask[..., None] 186 | node_embed = self.trunk[f"ipa_ln_{b}"](node_embed + ipa_embed) 187 | seq_tfmr_out = self.trunk[f"seq_tfmr_{b}"](node_embed, src_key_padding_mask=(1 - node_mask).bool()) 188 | node_embed = node_embed + self.trunk[f"post_tfmr_{b}"](seq_tfmr_out) 189 | node_embed = self.trunk[f"node_transition_{b}"](node_embed) 190 | node_embed = node_embed * node_mask[..., None] 191 | rigid_update = self.trunk[f"bb_update_{b}"](node_embed * node_mask[..., None]) 192 | curr_rigids = curr_rigids.compose_q_update_vec(rigid_update, node_mask[..., None]) 193 | 194 | if b < self._ipa_conf.num_blocks - 1: 195 | if self.mode == "orig_2d_bias": 196 | edge_embed = self.trunk[f"edge_transition_{b}"](node_embed, edge_embed) # edge_embed is B,L,L,D 197 | edge_embed *= edge_mask[..., None] 198 | elif self.mode == "flash_1d_bias" or self.mode == "flash_2d_factorize_bias": 199 | z_factor_1, z_factor_2 = self.trunk[f"edge_transition_{b}"](node_embed, None, z_factor_1, z_factor_2) 200 | z_factor_1 *= node_mask[:, :, None, None] 201 | z_factor_2 *= node_mask[:, :, None, None] 202 | else: 203 | # no bias 204 | continue 205 | 206 | curr_rigids = self.rigids_nm_to_ang(curr_rigids) 207 | pred_trans = curr_rigids.get_trans() 208 | pred_rotmats = curr_rigids.get_rots().get_rot_mats() 209 | 210 | return { 211 | "pred_trans": pred_trans, 212 | "pred_rotmats": pred_rotmats, 213 | } 214 | -------------------------------------------------------------------------------- /src/flash_ipa/rigid.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code adapted from se3_diffusion (https://github.com/jasonkyuyim/se3_diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | """Versions of OpenFold's vector update functions patched to support masking.""" 5 | 6 | import numpy as np 7 | import torch 8 | from beartype.typing import Any, Callable, List, Optional, Tuple, Union 9 | from jaxtyping import Float 10 | 11 | NODE_MASK_TENSOR_TYPE = Float[torch.Tensor, "... num_nodes"] 12 | UPDATE_NODE_MASK_TENSOR_TYPE = Float[torch.Tensor, "... num_nodes 1"] 13 | QUATERNION_TENSOR_TYPE = Float[torch.Tensor, "... num_nodes 4"] 14 | ROTATION_TENSOR_TYPE = Float[torch.Tensor, "... 3 3"] 15 | COORDINATES_TENSOR_TYPE = Float[torch.Tensor, "... num_nodes 3"] 16 | 17 | 18 | def rot_matmul(a: ROTATION_TENSOR_TYPE, b: ROTATION_TENSOR_TYPE) -> ROTATION_TENSOR_TYPE: 19 | """Performs matrix multiplication of two rotation matrix tensors. Written out by hand to avoid 20 | AMP downcasting. 21 | 22 | Args: 23 | a: [*, 3, 3] left multiplicand 24 | b: [*, 3, 3] right multiplicand 25 | Returns: 26 | The product ab 27 | """ 28 | row_1 = torch.stack( 29 | [ 30 | a[..., 0, 0] * b[..., 0, 0] + a[..., 0, 1] * b[..., 1, 0] + a[..., 0, 2] * b[..., 2, 0], 31 | a[..., 0, 0] * b[..., 0, 1] + a[..., 0, 1] * b[..., 1, 1] + a[..., 0, 2] * b[..., 2, 1], 32 | a[..., 0, 0] * b[..., 0, 2] + a[..., 0, 1] * b[..., 1, 2] + a[..., 0, 2] * b[..., 2, 2], 33 | ], 34 | dim=-1, 35 | ) 36 | row_2 = torch.stack( 37 | [ 38 | a[..., 1, 0] * b[..., 0, 0] + a[..., 1, 1] * b[..., 1, 0] + a[..., 1, 2] * b[..., 2, 0], 39 | a[..., 1, 0] * b[..., 0, 1] + a[..., 1, 1] * b[..., 1, 1] + a[..., 1, 2] * b[..., 2, 1], 40 | a[..., 1, 0] * b[..., 0, 2] + a[..., 1, 1] * b[..., 1, 2] + a[..., 1, 2] * b[..., 2, 2], 41 | ], 42 | dim=-1, 43 | ) 44 | row_3 = torch.stack( 45 | [ 46 | a[..., 2, 0] * b[..., 0, 0] + a[..., 2, 1] * b[..., 1, 0] + a[..., 2, 2] * b[..., 2, 0], 47 | a[..., 2, 0] * b[..., 0, 1] + a[..., 2, 1] * b[..., 1, 1] + a[..., 2, 2] * b[..., 2, 1], 48 | a[..., 2, 0] * b[..., 0, 2] + a[..., 2, 1] * b[..., 1, 2] + a[..., 2, 2] * b[..., 2, 2], 49 | ], 50 | dim=-1, 51 | ) 52 | 53 | return torch.stack([row_1, row_2, row_3], dim=-2) 54 | 55 | 56 | def rot_vec_mul(r: ROTATION_TENSOR_TYPE, t: COORDINATES_TENSOR_TYPE) -> COORDINATES_TENSOR_TYPE: 57 | """Applies a rotation to a vector. Written out by hand to avoid transfer to avoid AMP 58 | downcasting. 59 | 60 | Args: 61 | r: [*, 3, 3] rotation matrices 62 | t: [*, 3] coordinate tensors 63 | Returns: 64 | [*, 3] rotated coordinates 65 | """ 66 | x = t[..., 0] 67 | y = t[..., 1] 68 | z = t[..., 2] 69 | return torch.stack( 70 | [ 71 | r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z, 72 | r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z, 73 | r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z, 74 | ], 75 | dim=-1, 76 | ) 77 | 78 | 79 | def identity_rot_mats( 80 | batch_dims: Union[Union[Tuple[int], Tuple[np.int64]], torch.Size], 81 | dtype: Optional[torch.dtype] = None, 82 | device: Optional[torch.device] = None, 83 | requires_grad: bool = True, 84 | ) -> ROTATION_TENSOR_TYPE: 85 | rots = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad) 86 | rots = rots.view(*((1,) * len(batch_dims)), 3, 3) 87 | rots = rots.expand(*batch_dims, -1, -1) 88 | 89 | return rots 90 | 91 | 92 | def identity_trans( 93 | batch_dims: Union[Union[Tuple[int], Tuple[np.int64]], torch.Size], 94 | dtype: Optional[torch.dtype] = None, 95 | device: Optional[torch.device] = None, 96 | requires_grad: bool = True, 97 | ) -> COORDINATES_TENSOR_TYPE: 98 | trans = torch.zeros((*batch_dims, 3), dtype=dtype, device=device, requires_grad=requires_grad) 99 | return trans 100 | 101 | 102 | def identity_quats( 103 | batch_dims: Union[Union[Tuple[int], Tuple[np.int64]], torch.Size], 104 | dtype: Optional[torch.dtype] = None, 105 | device: Optional[torch.device] = None, 106 | requires_grad: bool = True, 107 | ) -> QUATERNION_TENSOR_TYPE: 108 | quat = torch.zeros((*batch_dims, 4), dtype=dtype, device=device, requires_grad=requires_grad) 109 | 110 | with torch.no_grad(): 111 | quat[..., 0] = 1 112 | 113 | return quat 114 | 115 | 116 | _quat_elements = ["a", "b", "c", "d"] 117 | _qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] 118 | _qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)} 119 | 120 | 121 | def _to_mat(pairs: List[Tuple[str, int]]) -> np.ndarray: 122 | mat = np.zeros((4, 4)) 123 | for pair in pairs: 124 | key, value = pair 125 | ind = _qtr_ind_dict[key] 126 | mat[ind // 4][ind % 4] = value 127 | 128 | return mat 129 | 130 | 131 | _QTR_MAT = np.zeros((4, 4, 3, 3)) 132 | _QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)]) 133 | _QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)]) 134 | _QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)]) 135 | _QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)]) 136 | _QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)]) 137 | _QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)]) 138 | _QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)]) 139 | _QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)]) 140 | _QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)]) 141 | 142 | 143 | def quat_to_rot(quat: QUATERNION_TENSOR_TYPE) -> ROTATION_TENSOR_TYPE: 144 | """Converts a quaternion to a rotation matrix. 145 | 146 | Args: 147 | quat: [*, 4] quaternions 148 | Returns: 149 | [*, 3, 3] rotation matrices 150 | """ 151 | # [*, 4, 4] 152 | quat = quat[..., None] * quat[..., None, :] 153 | 154 | # [4, 4, 3, 3] 155 | mat = quat.new_tensor(_QTR_MAT, requires_grad=False) 156 | 157 | # [*, 4, 4, 3, 3] 158 | shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape) 159 | quat = quat[..., None, None] * shaped_qtr_mat 160 | 161 | # [*, 3, 3] 162 | return torch.sum(quat, dim=(-3, -4)) 163 | 164 | 165 | def rot_to_quat(rot: ROTATION_TENSOR_TYPE) -> QUATERNION_TENSOR_TYPE: 166 | if rot.shape[-2:] != (3, 3): 167 | raise ValueError("Input rotation is incorrectly shaped") 168 | 169 | rot = [[rot[..., i, j] for j in range(3)] for i in range(3)] 170 | [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot 171 | 172 | k = [ 173 | [ 174 | xx + yy + zz, 175 | zy - yz, 176 | xz - zx, 177 | yx - xy, 178 | ], 179 | [ 180 | zy - yz, 181 | xx - yy - zz, 182 | xy + yx, 183 | xz + zx, 184 | ], 185 | [ 186 | xz - zx, 187 | xy + yx, 188 | yy - xx - zz, 189 | yz + zy, 190 | ], 191 | [ 192 | yx - xy, 193 | xz + zx, 194 | yz + zy, 195 | zz - xx - yy, 196 | ], 197 | ] 198 | 199 | k = (1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2) 200 | 201 | _, vectors = torch.linalg.eigh(k) 202 | return vectors[..., -1] 203 | 204 | 205 | _QUAT_MULTIPLY = np.zeros((4, 4, 4)) 206 | _QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, -1]] 207 | 208 | _QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]] 209 | 210 | _QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, 1, 0, 0]] 211 | 212 | _QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0, 0]] 213 | 214 | _QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] 215 | 216 | 217 | def quat_multiply(quat1: QUATERNION_TENSOR_TYPE, quat2: QUATERNION_TENSOR_TYPE) -> QUATERNION_TENSOR_TYPE: 218 | """Multiply a quaternion by another quaternion.""" 219 | mat = quat1.new_tensor(_QUAT_MULTIPLY) 220 | reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape) 221 | return torch.sum(reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], dim=(-3, -2)) 222 | 223 | 224 | def quat_multiply_by_vec(quat: QUATERNION_TENSOR_TYPE, vec: COORDINATES_TENSOR_TYPE) -> QUATERNION_TENSOR_TYPE: 225 | """Multiply a quaternion by a pure-vector quaternion.""" 226 | mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC) 227 | reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape) 228 | return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2)) 229 | 230 | 231 | def invert_rot_mat(rot_mat: ROTATION_TENSOR_TYPE) -> ROTATION_TENSOR_TYPE: 232 | return rot_mat.transpose(-1, -2) 233 | 234 | 235 | def invert_quat(quat: QUATERNION_TENSOR_TYPE, mask: Optional[NODE_MASK_TENSOR_TYPE] = None) -> QUATERNION_TENSOR_TYPE: 236 | quat_prime = quat.clone() 237 | quat_prime[..., 1:] *= -1 238 | if mask is not None: 239 | # avoid creating NaNs with masked nodes' "missing" values via division by zero 240 | inv, quat_mask = quat_prime, mask.bool() 241 | inv[quat_mask] = inv[quat_mask] / torch.sum(quat[quat_mask] ** 2, dim=-1, keepdim=True) 242 | else: 243 | inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True) 244 | return inv 245 | 246 | 247 | class Rotation: 248 | """A 3D rotation. 249 | 250 | Depending on how the object is initialized, the rotation is represented by either a rotation 251 | matrix or a quaternion, though both formats are made available by helper functions. To simplify 252 | gradient computation, the underlying format of the rotation cannot be changed in-place. Like 253 | Rigid, the class is designed to mimic the behavior of a torch Tensor, almost as if each 254 | Rotation object were a tensor of rotations, in one format or another. 255 | """ 256 | 257 | def __init__( 258 | self, 259 | rot_mats: Optional[ROTATION_TENSOR_TYPE] = None, 260 | quats: Optional[QUATERNION_TENSOR_TYPE] = None, 261 | quats_mask: Optional[NODE_MASK_TENSOR_TYPE] = None, 262 | normalize_quats: bool = True, 263 | ): 264 | """ 265 | Args: 266 | rot_mats: 267 | A [*, 3, 3] rotation matrix tensor. Mutually exclusive with 268 | quats 269 | quats: 270 | A [*, 4] quaternion. Mutually exclusive with rot_mats. If 271 | normalize_quats is not True, must be a unit quaternion 272 | quats_mask: 273 | A [*] quaternion mask. If quats is specified and normalize_quats 274 | is True, this will be used to subset the elements of quats 275 | being normalized. 276 | normalize_quats: 277 | If quats is specified, whether to normalize quats 278 | """ 279 | if (rot_mats is None and quats is None) or (rot_mats is not None and quats is not None): 280 | raise ValueError("Exactly one input argument must be specified") 281 | 282 | if (rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or (quats is not None and quats.shape[-1] != 4): 283 | raise ValueError("Incorrectly shaped rotation matrix or quaternion") 284 | 285 | # Force full-precision 286 | if quats is not None: 287 | quats = quats.type(torch.float32) 288 | if rot_mats is not None: 289 | rot_mats = rot_mats.type(torch.float32) 290 | 291 | # Parse mask 292 | if quats is not None and quats_mask is not None: 293 | quats_mask = quats_mask.type(torch.bool) 294 | 295 | if quats is not None and normalize_quats: 296 | if quats_mask is not None: 297 | quats[quats_mask] = quats[quats_mask] / torch.linalg.norm(quats[quats_mask], dim=-1, keepdim=True) 298 | else: 299 | quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) 300 | 301 | self._rot_mats = rot_mats 302 | self._quats = quats 303 | 304 | @staticmethod 305 | def identity( 306 | shape: Tuple[Union[int, np.int64]], 307 | dtype: Optional[torch.dtype] = None, 308 | device: Optional[torch.device] = None, 309 | requires_grad: bool = True, 310 | fmt: str = "quat", 311 | ): 312 | """Returns an identity Rotation. 313 | 314 | Args: 315 | shape: 316 | The "shape" of the resulting Rotation object. See documentation 317 | for the shape property 318 | dtype: 319 | The torch dtype for the rotation 320 | device: 321 | The torch device for the new rotation 322 | requires_grad: 323 | Whether the underlying tensors in the new rotation object 324 | should require gradient computation 325 | fmt: 326 | One of "quat" or "rot_mat". Determines the underlying format 327 | of the new object's rotation 328 | Returns: 329 | A new identity rotation 330 | """ 331 | if fmt == "rot_mat": 332 | rot_mats = identity_rot_mats( 333 | shape, 334 | dtype, 335 | device, 336 | requires_grad, 337 | ) 338 | return Rotation(rot_mats=rot_mats, quats=None) 339 | elif fmt == "quat": 340 | quats = identity_quats(shape, dtype, device, requires_grad) 341 | return Rotation(rot_mats=None, quats=quats, normalize_quats=False) 342 | else: 343 | raise ValueError(f"Invalid format: f{fmt}") 344 | 345 | # Magic methods 346 | 347 | def __getitem__(self, index: Any): 348 | """Allows torch-style indexing over the virtual shape of the rotation object. See 349 | documentation for the shape property. 350 | 351 | Args: 352 | index: 353 | A torch index. E.g. (1, 3, 2), or (slice(None,)) 354 | Returns: 355 | The indexed rotation 356 | """ 357 | if type(index) != tuple: 358 | index = (index,) 359 | 360 | if self._rot_mats is not None: 361 | rot_mats = self._rot_mats[index + (slice(None), slice(None))] 362 | return Rotation(rot_mats=rot_mats) 363 | elif self._quats is not None: 364 | quats = self._quats[index + (slice(None),)] 365 | return Rotation(quats=quats, normalize_quats=False) 366 | else: 367 | raise ValueError("Both rotations are None") 368 | 369 | def __mul__(self, right: torch.Tensor) -> "Rotation": 370 | """Pointwise left multiplication of the rotation with a tensor. Can be used to e.g., mask 371 | the Rotation. 372 | 373 | Args: 374 | right: 375 | The tensor multiplicand 376 | Returns: 377 | The product 378 | """ 379 | if not (isinstance(right, torch.Tensor)): 380 | raise TypeError("The other multiplicand must be a Tensor") 381 | 382 | if self._rot_mats is not None: 383 | rot_mats = self._rot_mats * right[..., None, None] 384 | return Rotation(rot_mats=rot_mats, quats=None) 385 | elif self._quats is not None: 386 | quats = self._quats * right[..., None] 387 | return Rotation(rot_mats=None, quats=quats, normalize_quats=False) 388 | else: 389 | raise ValueError("Both rotations are None") 390 | 391 | def __rmul__(self, left: torch.Tensor) -> "Rotation": 392 | """Reverse pointwise multiplication of the rotation with a tensor. 393 | 394 | Args: 395 | left: 396 | The left multiplicand 397 | Returns: 398 | The product 399 | """ 400 | return self.__mul__(left) 401 | 402 | # Properties 403 | 404 | @property 405 | def shape(self) -> torch.Size: 406 | """Returns the virtual shape of the rotation object. This shape is defined as the batch 407 | dimensions of the underlying rotation matrix or quaternion. If the Rotation was initialized 408 | with a [10, 3, 3] rotation matrix tensor, for example, the resulting shape would be [10]. 409 | 410 | Returns: 411 | The virtual shape of the rotation object 412 | """ 413 | s = None 414 | if self._quats is not None: 415 | s = self._quats.shape[:-1] 416 | else: 417 | s = self._rot_mats.shape[:-2] 418 | 419 | return s 420 | 421 | @property 422 | def dtype(self) -> torch.dtype: 423 | """Returns the dtype of the underlying rotation. 424 | 425 | Returns: 426 | The dtype of the underlying rotation 427 | """ 428 | if self._rot_mats is not None: 429 | return self._rot_mats.dtype 430 | elif self._quats is not None: 431 | return self._quats.dtype 432 | else: 433 | raise ValueError("Both rotations are None") 434 | 435 | @property 436 | def device(self) -> torch.device: 437 | """The device of the underlying rotation. 438 | 439 | Returns: 440 | The device of the underlying rotation 441 | """ 442 | if self._rot_mats is not None: 443 | return self._rot_mats.device 444 | elif self._quats is not None: 445 | return self._quats.device 446 | else: 447 | raise ValueError("Both rotations are None") 448 | 449 | @property 450 | def requires_grad(self) -> bool: 451 | """Returns the requires_grad property of the underlying rotation. 452 | 453 | Returns: 454 | The requires_grad property of the underlying tensor 455 | """ 456 | if self._rot_mats is not None: 457 | return self._rot_mats.requires_grad 458 | elif self._quats is not None: 459 | return self._quats.requires_grad 460 | else: 461 | raise ValueError("Both rotations are None") 462 | 463 | def reshape( 464 | self, 465 | new_rots_shape: Optional[torch.Size] = None, 466 | ) -> "Rotation": 467 | """Returns the corresponding reshaped rotation. 468 | 469 | Returns: 470 | The reshaped rotation 471 | """ 472 | if self._quats is not None: 473 | new_rots = self._quats.reshape(new_rots_shape) if new_rots_shape else self._quats 474 | new_rot = Rotation(quats=new_rots, normalize_quats=False) 475 | else: 476 | new_rots = self._rot_mats.reshape(new_rots_shape) if new_rots_shape else self._rot_mats 477 | new_rot = Rotation(rot_mats=new_rots, normalize_quats=False) 478 | 479 | return new_rot 480 | 481 | def get_rot_mats(self) -> ROTATION_TENSOR_TYPE: 482 | """Returns the underlying rotation as a rotation matrix tensor. 483 | 484 | Returns: 485 | The rotation as a rotation matrix tensor 486 | """ 487 | rot_mats = self._rot_mats 488 | if rot_mats is None: 489 | if self._quats is None: 490 | raise ValueError("Both rotations are None") 491 | else: 492 | rot_mats = quat_to_rot(self._quats) 493 | 494 | return rot_mats 495 | 496 | def get_quats(self) -> QUATERNION_TENSOR_TYPE: 497 | """Returns the underlying rotation as a quaternion tensor. 498 | 499 | Depending on whether the Rotation was initialized with a 500 | quaternion, this function may call torch.linalg.eigh. 501 | 502 | Returns: 503 | The rotation as a quaternion tensor. 504 | """ 505 | quats = self._quats 506 | if quats is None: 507 | if self._rot_mats is None: 508 | raise ValueError("Both rotations are None") 509 | else: 510 | quats = rot_to_quat(self._rot_mats) 511 | 512 | return quats 513 | 514 | def get_cur_rot(self) -> Union[QUATERNION_TENSOR_TYPE, ROTATION_TENSOR_TYPE]: 515 | """Return the underlying rotation in its current form. 516 | 517 | Returns: 518 | The stored rotation 519 | """ 520 | if self._rot_mats is not None: 521 | return self._rot_mats 522 | elif self._quats is not None: 523 | return self._quats 524 | else: 525 | raise ValueError("Both rotations are None") 526 | 527 | def get_rotvec(self, eps: float = 1e-6) -> torch.Tensor: 528 | """Return the underlying axis-angle rotation vector. 529 | 530 | Follow's scipy's implementation: 531 | https://github.com/scipy/scipy/blob/HEAD/scipy/spatial/transform/_rotation.pyx#L1385-L1402 532 | 533 | Returns: 534 | The stored rotation as a axis-angle vector. 535 | """ 536 | quat = self.get_quats() 537 | # w > 0 to ensure 0 <= angle <= pi 538 | flip = (quat[..., :1] < 0).float() 539 | quat = (-1 * quat) * flip + (1 - flip) * quat 540 | 541 | angle = 2 * torch.atan2(torch.linalg.norm(quat[..., 1:], dim=-1), quat[..., 0]) 542 | 543 | angle2 = angle * angle 544 | small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880 545 | large_angle_scales = angle / torch.sin(angle / 2 + eps) 546 | 547 | small_angles = (angle <= 1e-3).float() 548 | rot_vec_scale = small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales 549 | rot_vec = rot_vec_scale[..., None] * quat[..., 1:] 550 | return rot_vec 551 | 552 | # Rotation functions 553 | 554 | def compose_q_update_vec( 555 | self, 556 | q_update_vec: torch.Tensor, 557 | normalize_quats: bool = True, 558 | update_mask: Optional[UPDATE_NODE_MASK_TENSOR_TYPE] = None, 559 | ) -> "Rotation": 560 | """Returns a new quaternion Rotation after updating the current object's underlying 561 | rotation with a quaternion update, formatted as a [*, 3] tensor whose final three columns 562 | represent x, y, z such that (1, x, y, z) is the desired (not necessarily unit) quaternion 563 | update. 564 | 565 | Args: 566 | q_update_vec: 567 | A [*, 3] quaternion update tensor 568 | normalize_quats: 569 | Whether to normalize the output quaternion 570 | update_mask: 571 | An optional [*, 1] node mask indicating whether to update a node's geometry. 572 | Returns: 573 | An updated Rotation 574 | """ 575 | quats = self.get_quats() 576 | quat_update = quat_multiply_by_vec(quats, q_update_vec) 577 | if update_mask is not None: 578 | quat_update = quat_update * update_mask 579 | new_quats = quats + quat_update 580 | return Rotation( 581 | rot_mats=None, 582 | quats=new_quats, 583 | quats_mask=update_mask.squeeze(-1), 584 | normalize_quats=normalize_quats, 585 | ) 586 | 587 | def compose_r(self, r: "Rotation") -> "Rotation": 588 | """Compose the rotation matrices of the current Rotation object with those of another. 589 | 590 | Args: 591 | r: 592 | An update rotation object 593 | Returns: 594 | An updated rotation object 595 | """ 596 | r1 = self.get_rot_mats() 597 | r2 = r.get_rot_mats() 598 | new_rot_mats = rot_matmul(r1, r2) 599 | return Rotation(rot_mats=new_rot_mats, quats=None) 600 | 601 | def compose_q(self, r: "Rotation", normalize_quats: bool = True) -> "Rotation": 602 | """Compose the quaternions of the current Rotation object with those of another. 603 | 604 | Depending on whether either Rotation was initialized with 605 | quaternions, this function may call torch.linalg.eigh. 606 | 607 | Args: 608 | r: 609 | An update rotation object 610 | Returns: 611 | An updated rotation object 612 | """ 613 | q1 = self.get_quats() 614 | q2 = r.get_quats() 615 | new_quats = quat_multiply(q1, q2) 616 | return Rotation(rot_mats=None, quats=new_quats, normalize_quats=normalize_quats) 617 | 618 | def apply(self, pts: COORDINATES_TENSOR_TYPE) -> COORDINATES_TENSOR_TYPE: 619 | """Apply the current Rotation as a rotation matrix to a set of 3D coordinates. 620 | 621 | Args: 622 | pts: 623 | A [*, 3] set of points 624 | Returns: 625 | [*, 3] rotated points 626 | """ 627 | rot_mats = self.get_rot_mats() 628 | return rot_vec_mul(rot_mats, pts) 629 | 630 | def invert_apply(self, pts: COORDINATES_TENSOR_TYPE) -> COORDINATES_TENSOR_TYPE: 631 | """The inverse of the apply() method. 632 | 633 | Args: 634 | pts: 635 | A [*, 3] set of points 636 | Returns: 637 | [*, 3] inverse-rotated points 638 | """ 639 | rot_mats = self.get_rot_mats() 640 | inv_rot_mats = invert_rot_mat(rot_mats) 641 | return rot_vec_mul(inv_rot_mats, pts) 642 | 643 | def invert(self, mask: Optional[NODE_MASK_TENSOR_TYPE] = None) -> "Rotation": 644 | """Returns the inverse of the current Rotation. 645 | 646 | Args: 647 | mask: 648 | An optional node mask indicating whether to invert a node's geometry. 649 | Returns: 650 | The inverse of the current Rotation 651 | """ 652 | if self._rot_mats is not None: 653 | return Rotation(rot_mats=invert_rot_mat(self._rot_mats), quats=None) 654 | elif self._quats is not None: 655 | return Rotation( 656 | rot_mats=None, 657 | quats=invert_quat(self._quats, mask=mask), 658 | normalize_quats=False, 659 | quats_mask=mask, 660 | ) 661 | else: 662 | raise ValueError("Both rotations are None") 663 | 664 | # "Tensor" stuff 665 | 666 | def unsqueeze(self, dim: int) -> "Rotation": 667 | """Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation 668 | object. 669 | 670 | Args: 671 | dim: A positive or negative dimension index. 672 | Returns: 673 | The unsqueezed Rotation. 674 | """ 675 | if dim >= len(self.shape): 676 | raise ValueError("Invalid dimension") 677 | 678 | if self._rot_mats is not None: 679 | rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2) 680 | return Rotation(rot_mats=rot_mats, quats=None) 681 | elif self._quats is not None: 682 | quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1) 683 | return Rotation(rot_mats=None, quats=quats, normalize_quats=False) 684 | else: 685 | raise ValueError("Both rotations are None") 686 | 687 | @staticmethod 688 | def cat(rs, dim: int) -> "Rotation": 689 | """Concatenates rotations along one of the batch dimensions. Analogous to torch.cat(). 690 | 691 | Note that the output of this operation is always a rotation matrix, 692 | regardless of the format of input rotations. 693 | 694 | Args: 695 | rs: 696 | A list of rotation objects 697 | dim: 698 | The dimension along which the rotations should be 699 | concatenated 700 | Returns: 701 | A concatenated Rotation object in rotation matrix format 702 | """ 703 | rot_mats = [r.get_rot_mats() for r in rs] 704 | rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) 705 | 706 | return Rotation(rot_mats=rot_mats, quats=None) 707 | 708 | def map_tensor_fn(self, fn: Callable) -> "Rotation": 709 | """Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the 710 | rotation dimension(s). Can be used e.g. to sum out a one-hot batch dimension. 711 | 712 | Args: 713 | fn: 714 | A Tensor -> Tensor function to be mapped over the Rotation 715 | Returns: 716 | The transformed Rotation object 717 | """ 718 | if self._rot_mats is not None: 719 | rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,)) 720 | rot_mats = torch.stack(list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1) 721 | rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3)) 722 | return Rotation(rot_mats=rot_mats, quats=None) 723 | elif self._quats is not None: 724 | quats = torch.stack(list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1) 725 | return Rotation(rot_mats=None, quats=quats, normalize_quats=False) 726 | else: 727 | raise ValueError("Both rotations are None") 728 | 729 | def cuda(self) -> "Rotation": 730 | """Analogous to the cuda() method of torch Tensors. 731 | 732 | Returns: 733 | A copy of the Rotation in CUDA memory 734 | """ 735 | if self._rot_mats is not None: 736 | return Rotation(rot_mats=self._rot_mats.cuda(), quats=None) 737 | elif self._quats is not None: 738 | return Rotation(rot_mats=None, quats=self._quats.cuda(), normalize_quats=False) 739 | else: 740 | raise ValueError("Both rotations are None") 741 | 742 | def to(self, device: Optional[torch.device], dtype: Optional[torch.dtype]) -> "Rotation": 743 | """Analogous to the to() method of torch Tensors. 744 | 745 | Args: 746 | device: 747 | A torch device 748 | dtype: 749 | A torch dtype 750 | Returns: 751 | A copy of the Rotation using the new device and dtype 752 | """ 753 | if self._rot_mats is not None: 754 | return Rotation( 755 | rot_mats=self._rot_mats.to(device=device, dtype=dtype), 756 | quats=None, 757 | ) 758 | elif self._quats is not None: 759 | return Rotation( 760 | rot_mats=None, 761 | quats=self._quats.to(device=device, dtype=dtype), 762 | normalize_quats=False, 763 | ) 764 | else: 765 | raise ValueError("Both rotations are None") 766 | 767 | def detach(self) -> "Rotation": 768 | """Returns a copy of the Rotation whose underlying Tensor has been detached from its torch 769 | graph. 770 | 771 | Returns: 772 | A copy of the Rotation whose underlying Tensor has been detached 773 | from its torch graph 774 | """ 775 | if self._rot_mats is not None: 776 | return Rotation(rot_mats=self._rot_mats.detach(), quats=None) 777 | elif self._quats is not None: 778 | return Rotation( 779 | rot_mats=None, 780 | quats=self._quats.detach(), 781 | normalize_quats=False, 782 | ) 783 | else: 784 | raise ValueError("Both rotations are None") 785 | 786 | 787 | class Rigid: 788 | """A class representing a rigid transformation. 789 | 790 | Little more than a wrapper around two objects: a Rotation object and a [*, 3] translation 791 | Designed to behave approximately like a single torch tensor with the shape of the shared batch 792 | dimensions of its component parts. 793 | """ 794 | 795 | def __init__( 796 | self, 797 | rots: Optional[Rotation], 798 | trans: Optional[COORDINATES_TENSOR_TYPE], 799 | ): 800 | """ 801 | Args: 802 | rots: A [*, 3, 3] rotation tensor 803 | trans: A corresponding [*, 3] translation tensor 804 | """ 805 | # (we need device, dtype, etc. from at least one input) 806 | 807 | batch_dims, dtype, device, requires_grad = None, None, None, None 808 | if trans is not None: 809 | batch_dims = trans.shape[:-1] 810 | dtype = trans.dtype 811 | device = trans.device 812 | requires_grad = trans.requires_grad 813 | elif rots is not None: 814 | batch_dims = rots.shape 815 | dtype = rots.dtype 816 | device = rots.device 817 | requires_grad = rots.requires_grad 818 | else: 819 | raise ValueError("At least one input argument must be specified") 820 | 821 | if rots is None: 822 | rots = Rotation.identity( 823 | batch_dims, 824 | dtype, 825 | device, 826 | requires_grad, 827 | ) 828 | elif trans is None: 829 | trans = identity_trans( 830 | batch_dims, 831 | dtype, 832 | device, 833 | requires_grad, 834 | ) 835 | 836 | if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device): 837 | raise ValueError("Rots and trans incompatible") 838 | 839 | # Force full precision. Happens to the rotations automatically. 840 | trans = trans.type(torch.float32) 841 | 842 | self._rots = rots 843 | self._trans = trans 844 | 845 | @staticmethod 846 | def identity( 847 | shape: Tuple[Union[int, np.int64]], 848 | dtype: Optional[torch.dtype] = None, 849 | device: Optional[torch.device] = None, 850 | requires_grad: bool = True, 851 | fmt: str = "quat", 852 | ) -> "Rigid": 853 | """Constructs an identity transformation. 854 | 855 | Args: 856 | shape: 857 | The desired shape 858 | dtype: 859 | The dtype of both internal tensors 860 | device: 861 | The device of both internal tensors 862 | requires_grad: 863 | Whether grad should be enabled for the internal tensors 864 | Returns: 865 | The identity transformation 866 | """ 867 | return Rigid( 868 | Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt), 869 | identity_trans(shape, dtype, device, requires_grad), 870 | ) 871 | 872 | def __getitem__(self, index: Any) -> "Rigid": 873 | """Indexes the affine transformation with PyTorch-style indices. The index is applied to 874 | the shared dimensions of both the rotation and the translation. 875 | 876 | E.g.:: 877 | 878 | r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None) 879 | t = Rigid(r, torch.rand(10, 10, 3)) 880 | indexed = t[3, 4:6] 881 | assert(indexed.shape == (2,)) 882 | assert(indexed.get_rots().shape == (2,)) 883 | assert(indexed.get_trans().shape == (2, 3)) 884 | 885 | Args: 886 | index: A standard torch tensor index. E.g. 8, (10, None, 3), 887 | or (3, slice(0, 1, None)) 888 | Returns: 889 | The indexed tensor 890 | """ 891 | if type(index) != tuple: 892 | index = (index,) 893 | 894 | return Rigid( 895 | self._rots[index], 896 | self._trans[index + (slice(None),)], 897 | ) 898 | 899 | def __mul__(self, right: torch.Tensor) -> "Rigid": 900 | """Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. 901 | mask the Rigid. 902 | 903 | Args: 904 | right: 905 | The tensor multiplicand 906 | Returns: 907 | The product 908 | """ 909 | if not (isinstance(right, torch.Tensor)): 910 | raise TypeError("The other multiplicand must be a Tensor") 911 | 912 | new_rots = self._rots * right 913 | new_trans = self._trans * right[..., None] 914 | 915 | return Rigid(new_rots, new_trans) 916 | 917 | def __rmul__(self, left: torch.Tensor) -> "Rigid": 918 | """Reverse pointwise multiplication of the transformation with a tensor. 919 | 920 | Args: 921 | left: 922 | The left multiplicand 923 | Returns: 924 | The product 925 | """ 926 | return self.__mul__(left) 927 | 928 | @property 929 | def shape(self) -> torch.Size: 930 | """Returns the shape of the shared dimensions of the rotation and the translation. 931 | 932 | Returns: 933 | The shape of the transformation 934 | """ 935 | s = self._trans.shape[:-1] 936 | return s 937 | 938 | @property 939 | def device(self) -> torch.device: 940 | """Returns the device on which the Rigid's tensors are located. 941 | 942 | Returns: 943 | The device on which the Rigid's tensors are located 944 | """ 945 | return self._trans.device 946 | 947 | def reshape( 948 | self, 949 | new_rots_shape: Optional[torch.Size] = None, 950 | new_trans_shape: Optional[torch.Size] = None, 951 | ) -> "Rigid": 952 | """Returns the corresponding reshaped rotation and reshaped translation. 953 | 954 | Returns: 955 | The reshaped transformation 956 | """ 957 | new_rots = self._rots.reshape(new_rots_shape=new_rots_shape) if new_rots_shape else self._rots 958 | new_trans = self._trans.reshape(new_trans_shape) if new_trans_shape else self._trans 959 | 960 | return Rigid(new_rots, new_trans) 961 | 962 | def get_rots(self) -> Rotation: 963 | """Getter for the rotation. 964 | 965 | Returns: 966 | The rotation object 967 | """ 968 | return self._rots 969 | 970 | def get_trans(self) -> COORDINATES_TENSOR_TYPE: 971 | """Getter for the translation. 972 | 973 | Returns: 974 | The stored translation 975 | """ 976 | return self._trans 977 | 978 | def compose_q_update_vec( 979 | self, 980 | q_update_vec: Float[torch.Tensor, "... num_nodes 6"], # noqa: F722 981 | update_mask: Optional[UPDATE_NODE_MASK_TENSOR_TYPE] = None, 982 | ) -> "Rigid": 983 | """Composes the transformation with a quaternion update vector of shape [*, 6], where the 984 | final 6 columns represent the x, y, and z values of a quaternion of form (1, x, y, z) 985 | followed by a 3D translation. 986 | 987 | Args: 988 | q_update_vec: 989 | The quaternion update vector. 990 | update_mask: 991 | An optional [*, 1] node mask indicating whether to update a node's geometry. 992 | Returns: 993 | The composed transformation. 994 | """ 995 | q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:] 996 | new_rots = self._rots.compose_q_update_vec(q_vec, update_mask=update_mask) 997 | 998 | trans_update = self._rots.apply(t_vec) 999 | if update_mask is not None: 1000 | trans_update = trans_update * update_mask 1001 | new_translation = self._trans + trans_update 1002 | 1003 | return Rigid(new_rots, new_translation) 1004 | 1005 | def compose(self, r: "Rigid") -> "Rigid": 1006 | """Composes the current rigid object with another. 1007 | 1008 | Args: 1009 | r: 1010 | Another Rigid object 1011 | Returns: 1012 | The composition of the two transformations 1013 | """ 1014 | new_rot = self._rots.compose_r(r._rots) 1015 | new_trans = self._rots.apply(r._trans) + self._trans 1016 | return Rigid(new_rot, new_trans) 1017 | 1018 | def compose_r(self, rot: "Rigid", order: str = "right") -> "Rigid": 1019 | """Composes the current rigid object with another. 1020 | 1021 | Args: 1022 | r: 1023 | Another Rigid object 1024 | order: 1025 | Order in which to perform rotation multiplication. 1026 | Returns: 1027 | The composition of the two transformations 1028 | """ 1029 | if order == "right": 1030 | new_rot = self._rots.compose_r(rot) 1031 | elif order == "left": 1032 | new_rot = rot.compose_r(self._rots) 1033 | else: 1034 | raise ValueError(f"Unrecognized multiplication order: {order}") 1035 | return Rigid(new_rot, self._trans) 1036 | 1037 | def apply(self, pts: COORDINATES_TENSOR_TYPE) -> COORDINATES_TENSOR_TYPE: 1038 | """Applies the transformation to a coordinate tensor. 1039 | 1040 | Args: 1041 | pts: A [*, 3] coordinate tensor. 1042 | Returns: 1043 | The transformed points. 1044 | """ 1045 | rotated = self._rots.apply(pts) 1046 | return rotated + self._trans 1047 | 1048 | def invert_apply(self, pts: COORDINATES_TENSOR_TYPE) -> COORDINATES_TENSOR_TYPE: 1049 | """Applies the inverse of the transformation to a coordinate tensor. 1050 | 1051 | Args: 1052 | pts: A [*, 3] coordinate tensor 1053 | Returns: 1054 | The transformed points. 1055 | """ 1056 | pts = pts - self._trans 1057 | return self._rots.invert_apply(pts) 1058 | 1059 | def invert(self) -> "Rigid": 1060 | """Inverts the transformation. 1061 | 1062 | Returns: 1063 | The inverse transformation. 1064 | """ 1065 | rot_inv = self._rots.invert() 1066 | trn_inv = rot_inv.apply(self._trans) 1067 | 1068 | return Rigid(rot_inv, -1 * trn_inv) 1069 | 1070 | def map_tensor_fn(self, fn: Callable) -> "Rigid": 1071 | """Apply a Tensor -> Tensor function to underlying translation and rotation tensors, 1072 | mapping over the translation/rotation dimensions respectively. 1073 | 1074 | Args: 1075 | fn: 1076 | A Tensor -> Tensor function to be mapped over the Rigid 1077 | Returns: 1078 | The transformed Rigid object 1079 | """ 1080 | new_rots = self._rots.map_tensor_fn(fn) 1081 | new_trans = torch.stack(list(map(fn, torch.unbind(self._trans, dim=-1))), dim=-1) 1082 | 1083 | return Rigid(new_rots, new_trans) 1084 | 1085 | def to_tensor_4x4(self) -> Float[torch.Tensor, "... num_nodes 4 4"]: # noqa: F722 1086 | """Converts a transformation to a homogeneous transformation tensor. 1087 | 1088 | Returns: 1089 | A [*, 4, 4] homogeneous transformation tensor 1090 | """ 1091 | tensor = self._trans.new_zeros((*self.shape, 4, 4)) 1092 | tensor[..., :3, :3] = self._rots.get_rot_mats() 1093 | tensor[..., :3, 3] = self._trans 1094 | tensor[..., 3, 3] = 1 1095 | return tensor 1096 | 1097 | @staticmethod 1098 | def from_tensor_4x4(t: Float[torch.Tensor, "... num_nodes 4 4"]) -> "Rigid": # noqa: F722 1099 | """Constructs a transformation from a homogeneous transformation tensor. 1100 | 1101 | Args: 1102 | t: [*, 4, 4] homogeneous transformation tensor 1103 | Returns: 1104 | T object with shape [*] 1105 | """ 1106 | if t.shape[-2:] != (4, 4): 1107 | raise ValueError("Incorrectly shaped input tensor") 1108 | 1109 | rots = Rotation(rot_mats=t[..., :3, :3], quats=None) 1110 | trans = t[..., :3, 3] 1111 | 1112 | return Rigid(rots, trans) 1113 | 1114 | def to_tensor_7(self) -> Float[torch.Tensor, "... num_nodes 7"]: # noqa: F722 1115 | """Converts a transformation to a tensor with 7 final columns, four for the quaternion 1116 | followed by three for the translation. 1117 | 1118 | Returns: 1119 | A [*, 7] tensor representation of the transformation 1120 | """ 1121 | tensor = self._trans.new_zeros((*self.shape, 7)) 1122 | tensor[..., :4] = self._rots.get_quats() 1123 | tensor[..., 4:] = self._trans 1124 | 1125 | return tensor 1126 | 1127 | @staticmethod 1128 | def from_tensor_7(t: Float[torch.Tensor, "... num_nodes 7"], normalize_quats: bool = False) -> "Rigid": # noqa: F722 1129 | if t.shape[-1] != 7: 1130 | raise ValueError("Incorrectly shaped input tensor") 1131 | 1132 | quats, trans = t[..., :4], t[..., 4:] 1133 | 1134 | rots = Rotation(rot_mats=None, quats=quats, normalize_quats=normalize_quats) 1135 | 1136 | return Rigid(rots, trans) 1137 | 1138 | @staticmethod 1139 | def from_3_points( 1140 | p_neg_x_axis: COORDINATES_TENSOR_TYPE, 1141 | origin: COORDINATES_TENSOR_TYPE, 1142 | p_xy_plane: COORDINATES_TENSOR_TYPE, 1143 | eps: float = 1e-8, 1144 | ) -> "Rigid": 1145 | """Implements algorithm 21. Constructs transformations from sets of 3 points using the 1146 | Gram-Schmidt algorithm. 1147 | 1148 | Args: 1149 | p_neg_x_axis: [*, 3] coordinates 1150 | origin: [*, 3] coordinates used as frame origins 1151 | p_xy_plane: [*, 3] coordinates 1152 | eps: Small epsilon value 1153 | Returns: 1154 | A transformation object of shape [*] 1155 | """ 1156 | p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) 1157 | origin = torch.unbind(origin, dim=-1) 1158 | p_xy_plane = torch.unbind(p_xy_plane, dim=-1) 1159 | 1160 | e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] 1161 | e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] 1162 | 1163 | denom = torch.sqrt(sum(c * c for c in e0) + eps) 1164 | e0 = [c / denom for c in e0] 1165 | dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) 1166 | e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] 1167 | denom = torch.sqrt(sum(c * c for c in e1) + eps) 1168 | e1 = [c / denom for c in e1] 1169 | e2 = [ 1170 | e0[1] * e1[2] - e0[2] * e1[1], 1171 | e0[2] * e1[0] - e0[0] * e1[2], 1172 | e0[0] * e1[1] - e0[1] * e1[0], 1173 | ] 1174 | 1175 | rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) 1176 | rots = rots.reshape(rots.shape[:-1] + (3, 3)) 1177 | 1178 | rot_obj = Rotation(rot_mats=rots, quats=None) 1179 | 1180 | return Rigid(rot_obj, torch.stack(origin, dim=-1)) 1181 | 1182 | def unsqueeze(self, dim: int) -> "Rigid": 1183 | """Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the 1184 | rotation/translation. 1185 | 1186 | Args: 1187 | dim: A positive or negative dimension index. 1188 | Returns: 1189 | The unsqueezed transformation. 1190 | """ 1191 | if dim >= len(self.shape): 1192 | raise ValueError("Invalid dimension") 1193 | rots = self._rots.unsqueeze(dim) 1194 | trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1) 1195 | 1196 | return Rigid(rots, trans) 1197 | 1198 | @staticmethod 1199 | def cat(ts: List["Rigid"], dim: int) -> "Rigid": 1200 | """Concatenates transformations along a new dimension. 1201 | 1202 | Args: 1203 | ts: 1204 | A list of T objects 1205 | dim: 1206 | The dimension along which the transformations should be 1207 | concatenated 1208 | Returns: 1209 | A concatenated transformation object 1210 | """ 1211 | rots = Rotation.cat([t._rots for t in ts], dim) 1212 | trans = torch.cat([t._trans for t in ts], dim=dim if dim >= 0 else dim - 1) 1213 | 1214 | return Rigid(rots, trans) 1215 | 1216 | def apply_rot_fn(self, fn: Callable) -> "Rigid": 1217 | """Applies a Rotation -> Rotation function to the stored rotation object. 1218 | 1219 | Args: 1220 | fn: A function of type Rotation -> Rotation 1221 | Returns: 1222 | A transformation object with a transformed rotation. 1223 | """ 1224 | return Rigid(fn(self._rots), self._trans) 1225 | 1226 | def apply_trans_fn(self, fn: Callable) -> "Rigid": 1227 | """Applies a Tensor -> Tensor function to the stored translation. 1228 | 1229 | Args: 1230 | fn: 1231 | A function of type Tensor -> Tensor to be applied to the 1232 | translation 1233 | Returns: 1234 | A transformation object with a transformed translation. 1235 | """ 1236 | return Rigid(self._rots, fn(self._trans)) 1237 | 1238 | def scale_translation(self, trans_scale_factor: float) -> "Rigid": 1239 | """Scales the translation by a constant factor. 1240 | 1241 | Args: 1242 | trans_scale_factor: 1243 | The constant factor 1244 | Returns: 1245 | A transformation object with a scaled translation. 1246 | """ 1247 | return self.apply_trans_fn(lambda t: t * trans_scale_factor) 1248 | 1249 | def stop_rot_gradient(self) -> "Rigid": 1250 | """Detaches the underlying rotation object. 1251 | 1252 | Returns: 1253 | A transformation object with detached rotations 1254 | """ 1255 | return self.apply_rot_fn(lambda r: r.detach()) 1256 | 1257 | @staticmethod 1258 | def make_transform_from_reference( 1259 | n_xyz: COORDINATES_TENSOR_TYPE, 1260 | ca_xyz: COORDINATES_TENSOR_TYPE, 1261 | c_xyz: COORDINATES_TENSOR_TYPE, 1262 | eps: float = 1e-20, 1263 | ) -> "Rigid": 1264 | """Returns a transformation object from reference coordinates. 1265 | 1266 | Note that this method does not take care of symmetries. If you 1267 | provide the atom positions in the non-standard way, the N atom will 1268 | end up not at [-0.527250, 1.359329, 0.0] but instead at 1269 | [-0.527250, -1.359329, 0.0]. You need to take care of such cases in 1270 | your code. 1271 | 1272 | Args: 1273 | n_xyz: A [*, 3] tensor of nitrogen xyz coordinates. 1274 | ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates. 1275 | c_xyz: A [*, 3] tensor of carbon xyz coordinates. 1276 | Returns: 1277 | A transformation object. After applying the translation and 1278 | rotation to the reference backbone, the coordinates will 1279 | approximately equal to the input coordinates. 1280 | """ 1281 | translation = -1 * ca_xyz 1282 | n_xyz = n_xyz + translation 1283 | c_xyz = c_xyz + translation 1284 | 1285 | c_x, c_y, c_z = (c_xyz[..., i] for i in range(3)) 1286 | norm = torch.sqrt(eps + c_x**2 + c_y**2) 1287 | sin_c1 = -c_y / norm 1288 | cos_c1 = c_x / norm 1289 | zeros = sin_c1.new_zeros(sin_c1.shape) 1290 | ones = sin_c1.new_ones(sin_c1.shape) 1291 | 1292 | c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) 1293 | c1_rots[..., 0, 0] = cos_c1 1294 | c1_rots[..., 0, 1] = -1 * sin_c1 1295 | c1_rots[..., 1, 0] = sin_c1 1296 | c1_rots[..., 1, 1] = cos_c1 1297 | c1_rots[..., 2, 2] = 1 1298 | 1299 | norm = torch.sqrt(eps + c_x**2 + c_y**2 + c_z**2) 1300 | sin_c2 = c_z / norm 1301 | cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm 1302 | 1303 | c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) 1304 | c2_rots[..., 0, 0] = cos_c2 1305 | c2_rots[..., 0, 2] = sin_c2 1306 | c2_rots[..., 1, 1] = 1 1307 | c1_rots[..., 2, 0] = -1 * sin_c2 1308 | c1_rots[..., 2, 2] = cos_c2 1309 | 1310 | c_rots = rot_matmul(c2_rots, c1_rots) 1311 | n_xyz = rot_vec_mul(c_rots, n_xyz) 1312 | 1313 | _, n_y, n_z = (n_xyz[..., i] for i in range(3)) 1314 | norm = torch.sqrt(eps + n_y**2 + n_z**2) 1315 | sin_n = -n_z / norm 1316 | cos_n = n_y / norm 1317 | 1318 | n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) 1319 | n_rots[..., 0, 0] = 1 1320 | n_rots[..., 1, 1] = cos_n 1321 | n_rots[..., 1, 2] = -1 * sin_n 1322 | n_rots[..., 2, 1] = sin_n 1323 | n_rots[..., 2, 2] = cos_n 1324 | 1325 | rots = rot_matmul(n_rots, c_rots) 1326 | 1327 | rots = rots.transpose(-1, -2) 1328 | translation = -1 * translation 1329 | 1330 | rot_obj = Rotation(rot_mats=rots, quats=None) 1331 | 1332 | return Rigid(rot_obj, translation) 1333 | 1334 | def cuda(self) -> "Rigid": 1335 | """Moves the transformation object to GPU memory. 1336 | 1337 | Returns: 1338 | A version of the transformation on GPU 1339 | """ 1340 | return Rigid(self._rots.cuda(), self._trans.cuda()) 1341 | 1342 | 1343 | def create_rigid(rots, trans): 1344 | rots = Rotation(rot_mats=rots) 1345 | return Rigid(rots=rots, trans=trans) 1346 | -------------------------------------------------------------------------------- /src/flash_ipa/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from 3 | https://github.com/microsoft/protein-frame-flow/blob/main/models/utils.py 4 | """ 5 | 6 | import math 7 | import torch 8 | 9 | NM_TO_ANG_SCALE = 10.0 10 | ANG_TO_NM_SCALE = 1 / NM_TO_ANG_SCALE 11 | 12 | 13 | def calc_distogram(pos, min_bin, max_bin, num_bins): 14 | dists_2d = torch.linalg.norm(pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None] 15 | lower = torch.linspace(min_bin, max_bin, num_bins, device=pos.device) 16 | upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1) 17 | dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype) 18 | return dgram 19 | 20 | 21 | def knn_indices(pos, k): 22 | 23 | # pos: [B, L, 3] 24 | B, L, _ = pos.shape 25 | x = pos.unsqueeze(2) # [B, L, 1, 3] 26 | y = pos.unsqueeze(1) # [B, 1, L, 3] 27 | dist = torch.norm(x - y, dim=-1) # [B, L, L] 28 | 29 | # Prevent self-matching 30 | dist += torch.eye(L, device=pos.device).unsqueeze(0) * 1e6 31 | dist_k, idx_k = torch.topk(dist, k=k, dim=-1, largest=False) # [B, L, k] 32 | return dist_k, idx_k 33 | 34 | 35 | def calc_distogram_knn(pos, k, min_bin, max_bin, num_bins): 36 | dist_k, idx_k = knn_indices(pos, k) # [B, L, k] 37 | dists = dist_k.unsqueeze(-1) 38 | 39 | # Bin edges 40 | lower = torch.linspace(min_bin, max_bin, num_bins, device=pos.device) # [num_bins] 41 | upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=0) # [num_bins] 42 | 43 | # Create one-hot bin assignment 44 | dgram = ((dists > lower) & (dists < upper)).type(pos.dtype) # [B, L, k, num_bins] 45 | 46 | return dgram, idx_k 47 | 48 | 49 | def get_index_embedding(indices, embed_size, max_len=2056): 50 | """Creates sine / cosine positional embeddings from a prespecified indices. 51 | 52 | Args: 53 | indices: offsets of size [..., N_edges] of type integer 54 | max_len: maximum length. 55 | embed_size: dimension of the embeddings to create 56 | 57 | Returns: 58 | positional embedding of shape [N, embed_size] 59 | """ 60 | K = torch.arange(embed_size // 2, device=indices.device) 61 | pos_embedding_sin = torch.sin(indices[..., None] * math.pi / (max_len ** (2 * K[None] / embed_size))).to(indices.device) 62 | pos_embedding_cos = torch.cos(indices[..., None] * math.pi / (max_len ** (2 * K[None] / embed_size))).to(indices.device) 63 | pos_embedding = torch.cat([pos_embedding_sin, pos_embedding_cos], axis=-1) 64 | return pos_embedding 65 | --------------------------------------------------------------------------------