├── .gitignore ├── LICENSE ├── README.md ├── contact_pred ├── IPCA.py ├── __init__.py ├── building_blocks.py ├── data_utils.py ├── linear_mem_attn.py ├── linear_mem_attn_utils.py ├── linear_mem_point_attn.py ├── models.py ├── modules.py ├── residue_constants.py ├── structure.py ├── training_utils.py └── utils.py ├── environment.yml ├── eval_notebooks └── eval_structure.ipynb ├── examples ├── 4mds_23H_ligand.sdf ├── 4mds_CA.pdb ├── 4mds_protein_noHET.pdb ├── 7s3s_860_ligand.sdf ├── 7s3s_protein.pdb └── ligand_pred_7s3s.sdf ├── train.py ├── train ├── .gitignore ├── setup_summit.sh └── structure_summit.lsf ├── unit_test ├── ATTNvsLMATTN.py ├── IPCAvsLMIPCA.py ├── README.md └── test_kabsch_rmsd.py └── utils └── token_coords.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.so 3 | slurm*out 4 | slurm*err 5 | structure.o* 6 | structure.e* 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018-2022, UT-Battelle 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Protein-ligand structure prediction 2 | 3 | This repository contains models and training scripts to predict the structure of protein-ligand complexes. 4 | 5 | - `train.py`: Fine-tune sequence embeddings on structure predictions 6 | 7 | -------------------------------------------------------------------------------- /contact_pred/IPCA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | from .utils import get_extended_attention_mask 6 | 7 | from .linear_mem_point_attn import point_attention 8 | 9 | class InvariantPointCrossAttention(nn.Module): 10 | def __init__(self, config, other_config, is_cross_attention=True): 11 | super().__init__() 12 | if config.bert_config['hidden_size'] % config.num_ipa_heads != 0 and not hasattr(config, "embedding_size"): 13 | raise ValueError( 14 | f"The hidden size ({config.bert_config['hidden_size']}) is not a multiple of the number of IPA attention " 15 | f"heads ({config.num_ipa_heads})" 16 | ) 17 | 18 | self.num_ipa_heads = config.num_ipa_heads 19 | self.attention_head_size = int(config.bert_config['hidden_size'] / self.num_ipa_heads) 20 | self.all_head_size = self.num_ipa_heads * self.attention_head_size 21 | 22 | self.query = nn.Linear(config.bert_config['hidden_size'], self.all_head_size, bias=False) 23 | self.key = nn.Linear(other_config.bert_config['hidden_size'], self.all_head_size, bias=False) 24 | self.value = nn.Linear(other_config.bert_config['hidden_size'], self.all_head_size, bias=False) 25 | 26 | self.num_query_points = config.num_points 27 | self.num_value_points = other_config.num_points 28 | 29 | # points in R3 30 | self.query_point = nn.Linear(config.bert_config['hidden_size'], 31 | self.num_ipa_heads*self.num_query_points*3, bias=False) 32 | self.key_point = nn.Linear(other_config.bert_config['hidden_size'], 33 | self.num_ipa_heads*self.num_query_points*3, bias=False) 34 | self.value_point = nn.Linear(other_config.bert_config['hidden_size'], 35 | self.num_ipa_heads*self.num_value_points*3, bias=False) 36 | 37 | self.head_weight = torch.nn.Parameter(torch.zeros(config.num_ipa_heads)) 38 | torch.nn.init.normal_(self.head_weight) 39 | 40 | # scalar self attention weights 41 | self.n_pair_channels = config.bert_config['num_attention_heads'] 42 | if is_cross_attention: 43 | self.n_pair_channels += other_config.bert_config['num_attention_heads'] 44 | 45 | self.pair_attention = nn.Linear(self.n_pair_channels, self.num_ipa_heads, bias=False) 46 | 47 | self.output_layer = nn.Linear(self.num_ipa_heads * (self.n_pair_channels + 48 | self.attention_head_size + self.num_value_points*(3+1)), config.bert_config['hidden_size']) 49 | 50 | def transpose_for_scores(self, x): 51 | new_x_shape = x.size()[:-1] + (self.num_ipa_heads, self.attention_head_size) 52 | x = x.view(new_x_shape) 53 | return x.permute(0, 2, 1, 3) 54 | 55 | def transpose_points_for_scores(self, x, num_points): 56 | new_x_shape = x.size()[:-1] + (self.num_ipa_heads, num_points, 3) 57 | x = x.view(new_x_shape) 58 | return x.permute(0, 2, 1, 3, 4) 59 | 60 | def transpose_pair_representation(self, x): 61 | return x.permute(0, 2, 3, 1) 62 | 63 | def transpose_pair_attention(self, x): 64 | return x.permute(0, 3, 1, 2) 65 | 66 | def forward( 67 | self, 68 | hidden_states, 69 | attention_mask=None, 70 | encoder_hidden_states=None, 71 | encoder_attention_mask=None, 72 | pair_representation=None, 73 | rigid_rotations=None, 74 | rigid_translations=None, 75 | encoder_rigid_rotations=None, 76 | encoder_rigid_translations=None, 77 | **kwargs 78 | ): 79 | is_cross_attention = encoder_hidden_states is not None 80 | 81 | if not is_cross_attention: 82 | encoder_hidden_states = hidden_states 83 | encoder_rigid_translations = rigid_translations 84 | encoder_rigid_rotations = rigid_rotations 85 | encoder_attention_mask = attention_mask 86 | 87 | inv_attention_mask = None 88 | if encoder_attention_mask is not None: 89 | if encoder_attention_mask.dtype != torch.float32 and encoder_attention_mask.dtype != torch.float16: 90 | inv_attention_mask = get_extended_attention_mask( 91 | encoder_attention_mask, 92 | encoder_hidden_states.shape[:-1], 93 | encoder_hidden_states.device, 94 | encoder_hidden_states.dtype 95 | ) 96 | 97 | mixed_query_layer = self.query(hidden_states) 98 | query_layer = self.transpose_for_scores(mixed_query_layer) 99 | 100 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 101 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 102 | 103 | # Take the dot product between "query" and "key" to get the raw attention scores. 104 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-2, -1)) 105 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 106 | 107 | # pair representation contribution 108 | attention_scores = attention_scores + self.transpose_pair_attention(self.pair_attention(self.transpose_pair_representation(pair_representation))) 109 | 110 | # point contribution 111 | query_points = self.transpose_points_for_scores(self.query_point(hidden_states), self.num_query_points) 112 | key_points = self.transpose_points_for_scores(self.key_point(encoder_hidden_states), self.num_query_points) 113 | value_points = self.transpose_points_for_scores(self.value_point(encoder_hidden_states), self.num_value_points) 114 | 115 | # rigid update 116 | a = torch.einsum('bnij,bhnpj->bhnpi', rigid_rotations, query_points) 117 | b = torch.einsum('bnij,bhnpj->bhnpi', encoder_rigid_rotations, key_points) 118 | a = a + rigid_translations[:,None,:,None] 119 | b = b + encoder_rigid_translations[:,None,:,None] 120 | 121 | weight_points = math.sqrt(2/(9*self.num_query_points)) 122 | gamma = torch.nn.functional.softplus(self.head_weight) 123 | 124 | # invariant = torch.sum((a[:,:,:,None,:,:] - b[:,:,None,:,:,:])**2,dim=[-2,-1]) 125 | a_sq = torch.sum(a**2,dim=[-2,-1]) 126 | b_sq = torch.sum(b**2,dim=[-2,-1]) 127 | invariant = a_sq[:,:,:,None] + b_sq[:,:,None,:] - 2*torch.einsum('bhnpi,bhmpi->bhnm',a,b) 128 | 129 | attention_scores = attention_scores - 0.5*weight_points*gamma[None,:,None,None] * invariant 130 | 131 | # overall scaling 132 | attention_scores = attention_scores / math.sqrt(3) 133 | 134 | if inv_attention_mask is not None: 135 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 136 | attention_scores = attention_scores + inv_attention_mask 137 | 138 | # Normalize the attention scores to probabilities. 139 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 140 | 141 | pair_output = torch.einsum('bhnm,bcnm->bnhc', attention_probs, pair_representation) 142 | 143 | # rigid update on output [Eq. 10] 144 | c = torch.einsum('bnij,bhnpj->bhnpi', encoder_rigid_rotations, value_points) 145 | c = c + encoder_rigid_translations[:,None,:,None] 146 | point_output = torch.einsum('bhij,bhjpk->bhipk', attention_probs, c) 147 | 148 | # transpose of an orthogonal matrix == inverse 149 | point_output = point_output - rigid_translations[:,None,:,None] 150 | point_output = torch.einsum('bnji,bhnpj->bhnpi', rigid_rotations, point_output) 151 | 152 | context_layer = torch.matmul(attention_probs, value_layer) 153 | 154 | new_pair_shape = pair_output.size()[:2] + (self.n_pair_channels*self.num_ipa_heads,) 155 | pair_output = pair_output.view(new_pair_shape) 156 | 157 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 158 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 159 | context_layer = context_layer.view(new_context_layer_shape) 160 | 161 | point_output = point_output.permute(0, 2, 1, 3, 4).contiguous() 162 | # the vector norm is a separate channel [in Eq. 11] 163 | point_output_sq = torch.sqrt(torch.sum(point_output*point_output,-1)) 164 | 165 | new_point_shape = point_output.size()[:-3] + (3*self.num_value_points*self.num_ipa_heads,) 166 | point_output = point_output.view(new_point_shape) 167 | 168 | new_point_sq_shape = point_output_sq.size()[:-2] + (self.num_value_points*self.num_ipa_heads,) 169 | point_output_sq = point_output_sq.view(new_point_sq_shape) 170 | 171 | return self.output_layer(torch.cat([pair_output, context_layer, point_output, point_output_sq], dim=-1)) 172 | 173 | class LinearMemInvariantPointCrossAttention(nn.Module): 174 | def __init__(self, config, other_config, is_cross_attention=True): 175 | super().__init__() 176 | if config.bert_config['hidden_size'] % config.num_ipa_heads != 0 and not hasattr(config, "embedding_size"): 177 | raise ValueError( 178 | f"The hidden size ({config.bert_config['hidden_size']}) is not a multiple of the number of IPA attention " 179 | f"heads ({config.num_ipa_heads})" 180 | ) 181 | 182 | self.num_ipa_heads = config.num_ipa_heads 183 | self.attention_head_size = int(config.bert_config['hidden_size'] / self.num_ipa_heads) 184 | self.all_head_size = self.num_ipa_heads * self.attention_head_size 185 | 186 | self.query = nn.Linear(config.bert_config['hidden_size'], self.all_head_size, bias=False) 187 | self.key = nn.Linear(other_config.bert_config['hidden_size'], self.all_head_size, bias=False) 188 | self.value = nn.Linear(other_config.bert_config['hidden_size'], self.all_head_size, bias=False) 189 | 190 | self.num_query_points = config.num_points 191 | self.num_value_points = other_config.num_points 192 | 193 | # points in R3 194 | self.query_point = nn.Linear(config.bert_config['hidden_size'], 195 | self.num_ipa_heads*self.num_query_points*3, bias=False) 196 | self.key_point = nn.Linear(other_config.bert_config['hidden_size'], 197 | self.num_ipa_heads*self.num_query_points*3, bias=False) 198 | self.value_point = nn.Linear(other_config.bert_config['hidden_size'], 199 | self.num_ipa_heads*self.num_value_points*3, bias=False) 200 | 201 | self.head_weight = torch.nn.Parameter(torch.zeros(config.num_ipa_heads)) 202 | torch.nn.init.normal_(self.head_weight) 203 | 204 | # scalar self attention weights 205 | self.n_pair_channels = config.bert_config['num_attention_heads'] 206 | if is_cross_attention: 207 | self.n_pair_channels += other_config.bert_config['num_attention_heads'] 208 | 209 | self.pair_attention = nn.Linear(self.n_pair_channels, self.num_ipa_heads, bias=False) 210 | 211 | self.output_layer = nn.Linear(self.num_ipa_heads * (self.n_pair_channels + 212 | self.attention_head_size + self.num_value_points*(3+1)), config.bert_config['hidden_size']) 213 | 214 | def view_for_scores(self, x): 215 | new_x_shape = x.size()[:-1] + (self.num_ipa_heads, self.attention_head_size) 216 | return x.view(new_x_shape) 217 | 218 | def view_points_for_scores(self, x, num_points): 219 | new_x_shape = x.size()[:-1] + (self.num_ipa_heads, num_points, 3) 220 | return x.view(new_x_shape) 221 | 222 | def transpose_pair_representation(self, x): 223 | return x.permute(0, 2, 3, 1) 224 | 225 | def transpose_pair_representation_for_value(self, x): 226 | return x.permute(0, 2, 1, 3) 227 | 228 | def transpose_pair_attention(self, x): 229 | return x.permute(0, 1, 3, 2) 230 | 231 | def forward( 232 | self, 233 | hidden_states, 234 | attention_mask=None, 235 | encoder_hidden_states=None, 236 | encoder_attention_mask=None, 237 | pair_representation=None, 238 | rigid_rotations=None, 239 | rigid_translations=None, 240 | encoder_rigid_rotations=None, 241 | encoder_rigid_translations=None, 242 | query_chunk_size = 1024, 243 | key_chunk_size = 4096, 244 | ): 245 | is_cross_attention = encoder_hidden_states is not None 246 | 247 | if not is_cross_attention: 248 | encoder_hidden_states = hidden_states 249 | encoder_rigid_translations = rigid_translations 250 | encoder_rigid_rotations = rigid_rotations 251 | encoder_attention_mask = attention_mask 252 | 253 | key_layer = self.view_for_scores(self.key(encoder_hidden_states)) 254 | value_layer = self.view_for_scores(self.value(encoder_hidden_states)) 255 | key_points = self.view_points_for_scores(self.key_point(encoder_hidden_states), self.num_query_points) 256 | value_points = self.view_points_for_scores(self.value_point(encoder_hidden_states), self.num_value_points) 257 | value_layer = self.view_for_scores(self.value(encoder_hidden_states)) 258 | 259 | mixed_query_layer = self.query(hidden_states) 260 | query_layer = self.view_for_scores(mixed_query_layer) 261 | 262 | # pair representation 263 | pair = self.transpose_pair_attention(self.pair_attention(self.transpose_pair_representation(pair_representation))) 264 | pair_value = self.transpose_pair_representation_for_value(pair_representation) 265 | 266 | # points 267 | query_points = self.view_points_for_scores(self.query_point(hidden_states), self.num_query_points) 268 | 269 | weight_kv = 1 / math.sqrt(self.attention_head_size) 270 | weight_points = math.sqrt(2/(9*self.num_query_points)) 271 | gamma = torch.nn.functional.softplus(self.head_weight) 272 | point_attention_output = point_attention( 273 | query=query_layer, 274 | key=key_layer, 275 | value=value_layer, 276 | pair=pair, 277 | pair_value=pair_value, 278 | rotations=rigid_rotations, 279 | translations=rigid_translations, 280 | encoder_rotations=encoder_rigid_rotations, 281 | encoder_translations=encoder_rigid_translations, 282 | points_query=query_points, 283 | points_key=key_points, 284 | points_value=value_points, 285 | mask=encoder_attention_mask, 286 | weight_kv=weight_kv, 287 | weight_points=weight_points, 288 | gamma=gamma, 289 | query_chunk_size=query_chunk_size, 290 | key_chunk_size=key_chunk_size, 291 | ) 292 | 293 | context, pair_output, points_output = point_attention_output 294 | 295 | new_context_layer_shape = context.size()[:-2] + (self.all_head_size,) 296 | context_layer = context.view(new_context_layer_shape) 297 | 298 | new_pair_shape = pair_output.size()[:2] + (self.n_pair_channels*self.num_ipa_heads,) 299 | pair_output = pair_output.view(new_pair_shape) 300 | 301 | # the vector norm is a separate channel [in Eq. 11] 302 | points_output_sq = torch.sqrt(torch.sum(points_output*points_output,-1)) 303 | new_points_sq_shape = points_output_sq.size()[:-2] + (self.num_value_points*self.num_ipa_heads,) 304 | points_output_sq = points_output_sq.view(new_points_sq_shape) 305 | 306 | new_point_shape = points_output.size()[:-3] + (3*self.num_value_points*self.num_ipa_heads,) 307 | points_output = points_output.view(new_point_shape) 308 | 309 | return self.output_layer(torch.cat([pair_output, context_layer, points_output, points_output_sq], dim=-1)) 310 | -------------------------------------------------------------------------------- /contact_pred/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ORNL/TwoFold_DL/d90bab7d95dec5f7820a7fc81b2e6121300a60a6/contact_pred/__init__.py -------------------------------------------------------------------------------- /contact_pred/building_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MLP(torch.nn.Module): 5 | ''' 6 | Multilayer Perceptron. 7 | ''' 8 | def __init__(self, n, nlayers): 9 | super().__init__() 10 | hidden_layers = [(torch.nn.Linear(n, n),torch.nn.GELU()) 11 | for _ in range(nlayers-1)] 12 | self.layers = torch.nn.Sequential( 13 | *[item for layer_pair in hidden_layers for item in layer_pair], 14 | torch.nn.Linear(n,n), 15 | ) 16 | 17 | def forward(self, x): 18 | '''Forward pass''' 19 | return self.layers(x) 20 | -------------------------------------------------------------------------------- /contact_pred/data_utils.py: -------------------------------------------------------------------------------- 1 | from transformers import DataCollatorWithPadding, default_data_collator, PreTrainedTokenizerBase 2 | from transformers.data.data_collator import default_data_collator, DataCollatorMixin, DataCollatorForLanguageModeling 3 | 4 | from transformers import Pipeline 5 | 6 | from torch.utils.data import DataLoader 7 | 8 | from typing import List, Dict, Any 9 | 10 | import os 11 | import collections 12 | from collections.abc import Iterable 13 | 14 | import logging 15 | logger = logging.getLogger() 16 | 17 | class EnsembleDataCollatorWithPadding: 18 | def __init__(self, 19 | smiles_tokenizer, 20 | seq_tokenizer, 21 | smiles_padding=True, 22 | smiles_max_length=None, 23 | seq_padding=True, 24 | seq_max_length=None): 25 | 26 | self.smiles_collator = DataCollatorWithPadding(smiles_tokenizer, smiles_padding, smiles_max_length) 27 | self.seq_collator = DataCollatorWithPadding(seq_tokenizer, seq_padding, seq_max_length) 28 | 29 | def __call__(self, features): 30 | # individually collate protein and ligand sequences into batches 31 | batch_1 = self.seq_collator([{'input_ids': b['input_ids_1'], 'attention_mask': b['attention_mask_1']} for b in features]) 32 | batch_2 = self.smiles_collator([{'input_ids': b['input_ids_2'], 'attention_mask': b['attention_mask_2']} for b in features]) 33 | 34 | batch_merged = default_data_collator([{k: v for k,v in f.items() 35 | if k not in ('input_ids_1','attention_mask_1','input_ids_2','attention_mask_2')} 36 | for f in features]) 37 | batch_merged['input_ids_1'] = batch_1['input_ids'] 38 | batch_merged['attention_mask_1'] = batch_1['attention_mask'] 39 | batch_merged['input_ids_2'] = batch_2['input_ids'] 40 | batch_merged['attention_mask_2'] = batch_2['attention_mask'] 41 | return batch_merged 42 | 43 | class EnsembleTokenizer: 44 | def __init__(self, 45 | smiles_tokenizer, 46 | seq_tokenizer, 47 | ): 48 | self.smiles_tokenizer = smiles_tokenizer 49 | self.seq_tokenizer = seq_tokenizer 50 | 51 | def __call__(self, features, **kwargs): 52 | item = dict(features) 53 | 54 | is_batched = isinstance(features, Iterable) and not isinstance(features, dict) 55 | 56 | seq_args = {} 57 | smiles_args = {} 58 | if 'seq_padding' in kwargs: 59 | seq_args['padding'] = kwargs['seq_padding'] 60 | if 'smiles_padding' in kwargs: 61 | smiles_args['padding'] = kwargs['smiles_padding'] 62 | if 'seq_max_length' in kwargs: 63 | seq_args['max_length'] = kwargs['seq_max_length'] 64 | if 'smiles_max_length' in kwargs: 65 | smiles_args['max_length'] = kwargs['smiles_max_length'] 66 | if 'seq_truncation' in kwargs: 67 | seq_args['truncation'] = kwargs['seq_truncation'] 68 | if 'smiles_truncation' in kwargs: 69 | smiles_args['truncation'] = kwargs['smiles_truncation'] 70 | 71 | if is_batched: 72 | seq_encodings = self.seq_tokenizer([f['protein'] for f in features], **seq_args) 73 | else: 74 | seq_encodings = self.seq_tokenizer(features['protein'], **seq_args) 75 | 76 | item.pop('protein') 77 | item['input_ids_1'] = seq_encodings['input_ids'] 78 | item['attention_mask_1'] = seq_encodings['attention_mask'] 79 | 80 | if is_batched: 81 | smiles_encodings = self.smiles_tokenizer([f['ligand'] for f in features], **smiles_args) 82 | else: 83 | smiles_encodings = self.smiles_tokenizer(features['ligand'], **smiles_args) 84 | 85 | item.pop('ligand') 86 | item['input_ids_2'] = smiles_encodings['input_ids'] 87 | item['attention_mask_2'] = smiles_encodings['attention_mask'] 88 | 89 | return item 90 | 91 | class StructurePredictionPipeline(Pipeline): 92 | def _sanitize_parameters(self, **kwargs): 93 | preprocess_kwargs = {} 94 | 95 | if 'seq_padding' in kwargs: 96 | preprocess_kwargs['seq_padding'] = kwargs['seq_padding'] 97 | else: 98 | preprocess_kwargs['seq_padding'] = True 99 | 100 | if 'smiles_padding' in kwargs: 101 | preprocess_kwargs['smiles_padding'] = kwargs['smiles_padding'] 102 | else: 103 | preprocess_kwargs['smiles_padding'] = True 104 | 105 | if 'seq_truncation' in kwargs: 106 | preprocess_kwargs['seq_truncation'] = kwargs['seq_truncation'] 107 | else: 108 | preprocess_kwargs['seq_truncation'] = True 109 | 110 | if 'seq_max_length' in kwargs: 111 | preprocess_kwargs['seq_max_length'] = kwargs['seq_max_length'] 112 | else: 113 | preprocess_kwargs['seq_max_length'] = None 114 | 115 | if 'smiles_truncation' in kwargs: 116 | preprocess_kwargs['smiles_truncation'] = kwargs['smiles_truncation'] 117 | else: 118 | preprocess_kwargs['smiles_truncation'] = True 119 | 120 | if 'smiles_max_length' in kwargs: 121 | preprocess_kwargs['smiles_max_length'] = kwargs['smiles_max_length'] 122 | else: 123 | preprocess_kwargs['smiles_max_length'] = None 124 | 125 | return preprocess_kwargs, {}, {} 126 | 127 | def __init__(self, 128 | model, 129 | seq_tokenizer, 130 | smiles_tokenizer, 131 | output_prediction_scores=False, 132 | **kwargs 133 | ): 134 | self.seq_tokenizer = seq_tokenizer 135 | self.smiles_tokenizer = smiles_tokenizer 136 | self.data_collator = EnsembleDataCollatorWithPadding(self.smiles_tokenizer, 137 | self.seq_tokenizer) 138 | self.output_prediction_scores = output_prediction_scores 139 | model.eval() 140 | super().__init__(model=model, 141 | tokenizer=EnsembleTokenizer(self.smiles_tokenizer, 142 | self.seq_tokenizer), 143 | **kwargs) 144 | 145 | def preprocess(self, inputs, **kwargs): 146 | tokenized_input = self.tokenizer(inputs, **kwargs) 147 | return tokenized_input 148 | 149 | def _forward(self, model_inputs): 150 | outputs = self.model(**model_inputs, 151 | return_dict=True) 152 | return outputs 153 | 154 | def postprocess(self, model_outputs): 155 | if isinstance(model_outputs, dict): 156 | return {k: v.numpy() for k,v in model_outputs.items()} 157 | else: 158 | return model_outputs.numpy() 159 | 160 | def get_iterator( 161 | self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params 162 | ): 163 | from transformers.pipelines.pt_utils import PipelineDataset, PipelineIterator 164 | if isinstance(inputs, collections.abc.Sized): 165 | dataset = PipelineDataset(inputs, self.preprocess, preprocess_params) 166 | else: 167 | if num_workers > 1: 168 | logger.warning( 169 | "For iterable dataset using num_workers>1 is likely to result" 170 | " in errors since everything is iterable, setting `num_workers=1`" 171 | " to guarantee correctness." 172 | ) 173 | num_workers = 1 174 | dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) 175 | if "TOKENIZERS_PARALLELISM" not in os.environ: 176 | logger.info("Disabling tokenizer parallelism, we're using DataLoader multithreading already") 177 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 178 | 179 | collate_fn = self.data_collator 180 | dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn) 181 | model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) 182 | final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) 183 | return final_iterator 184 | 185 | -------------------------------------------------------------------------------- /contact_pred/linear_mem_attn.py: -------------------------------------------------------------------------------- 1 | # adapted to torch from: https://arxiv.org/abs/2112.05682 2 | import math 3 | import torch 4 | from torch.utils import checkpoint 5 | from typing import Tuple, Optional 6 | 7 | from .linear_mem_attn_utils import dynamic_length_slice, dynamic_slice, torch_map, torch_scan 8 | 9 | def query_chunk_attention( 10 | query: torch.Tensor, 11 | key: torch.Tensor, 12 | value: torch.Tensor, 13 | mask: Optional[torch.Tensor] = None, 14 | key_chunk_size: int = 4096, 15 | ) -> torch.Tensor: 16 | """Multi-head dot product attention with a limited number of queries.""" 17 | device, dtype = query.device, query.dtype 18 | batch, num_kv, num_heads, k_features = key.shape 19 | v_features = value.shape[-1] 20 | query_chunk = query.shape[1] # b n h d 21 | key_chunk_size = min(key_chunk_size, num_kv) 22 | query = query / k_features ** 0.5 23 | 24 | # @functools.partial(jax.checkpoint, prevent_cse=False) 25 | def summarize_chunk( 26 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor 27 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 28 | 29 | attn_scores = torch.einsum("bqhd,bkhd->bqhk", query, key) 30 | attn_weights = attn_scores.clone() 31 | 32 | if mask is not None: 33 | mask = mask.unsqueeze(1).unsqueeze(2) 34 | max_neg = -torch.finfo(attn_weights.dtype).max 35 | mask = mask.bool() 36 | attn_weights.masked_fill_(~mask, max_neg) 37 | 38 | max_score = torch.amax(attn_weights, dim=-1, keepdim=True).detach() 39 | exp_weights = torch.exp(attn_weights - max_score) 40 | 41 | exp_values = torch.einsum("bvhf,bqhv->bqhf", value, exp_weights) 42 | # (b q h f), (b q h), (b q h 1) 43 | 44 | attn_scores_pad = torch.nn.functional.pad(attn_scores, (0, key_chunk_size-key.shape[1])) 45 | 46 | return exp_values, exp_weights.sum(dim=-1), max_score.squeeze(dim=-1), attn_scores_pad.permute(3,0,1,2) 47 | 48 | def chunk_scanner( 49 | chunk_idx: int, 50 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 51 | 52 | key_chunk = dynamic_length_slice(key, chunk_idx, key_chunk_size) 53 | value_chunk = dynamic_length_slice(value, chunk_idx, key_chunk_size) 54 | 55 | mask_chunk = None 56 | if mask is not None: 57 | mask_chunk = dynamic_length_slice(mask, chunk_idx, key_chunk_size) 58 | 59 | return checkpoint.checkpoint( 60 | summarize_chunk, query, key_chunk, value_chunk, mask_chunk 61 | ) 62 | 63 | num_chunks = math.ceil(num_kv / key_chunk_size) 64 | chunk_values = torch.zeros( 65 | num_chunks, 66 | batch, 67 | query_chunk, 68 | num_heads, 69 | v_features, 70 | dtype=dtype, 71 | device=device, 72 | ) 73 | chunk_exp_weights = torch.zeros( 74 | num_chunks, 75 | batch, 76 | query_chunk, 77 | num_heads, 78 | dtype=dtype, 79 | device=device, 80 | ) 81 | chunk_max = torch.zeros( 82 | num_chunks, 83 | batch, 84 | query_chunk, 85 | num_heads, 86 | dtype=dtype, 87 | device=device, 88 | ) 89 | chunk_weights = torch.zeros( 90 | num_chunks, 91 | key_chunk_size, 92 | batch, 93 | query_chunk, 94 | num_heads, 95 | dtype=dtype, 96 | device=device, 97 | ) 98 | 99 | for i in range(num_chunks): 100 | chunk_values[i], chunk_exp_weights[i], chunk_max[i], chunk_weights[i] = chunk_scanner( 101 | i * key_chunk_size 102 | ) 103 | 104 | max_diffs = torch.exp(chunk_max - chunk_max.amax(dim=0)) 105 | 106 | all_values = (max_diffs.unsqueeze(dim=-1) * chunk_values).sum(dim=0) 107 | all_exp_weights = (max_diffs * chunk_exp_weights).sum(dim=0).unsqueeze(dim=-1) 108 | 109 | all_scores = chunk_weights.view(-1,batch,query_chunk,num_heads).permute(1,2,0,3) 110 | all_scores = all_scores[:,:,:num_kv] 111 | 112 | return all_values / all_exp_weights, all_scores 113 | 114 | 115 | def attention( 116 | query: torch.Tensor, 117 | key: torch.Tensor, 118 | value: torch.Tensor, 119 | mask: Optional[torch.Tensor] = None, 120 | query_chunk_size: int = 1024, 121 | key_chunk_size: int = 4096, 122 | ) -> torch.Tensor: 123 | """Memory-efficient multi-head dot product attention. 124 | Inputs: 125 | * q, k, v: (b n h d) torch tensors 126 | * mask: (b n) 127 | * query_chunk_size: int. 128 | * key_chunk_size: int. 129 | Outputs: (b n h d) torch tensor (qk-weighted sum of v), 130 | and (b n m h) attention scores 131 | """ 132 | batch, num_q, num_heads, q_features = query.shape 133 | 134 | def chunk_scanner(chunk_idx: int, _): 135 | query_chunk = dynamic_length_slice(query, chunk_idx, query_chunk_size) 136 | 137 | return ( 138 | chunk_idx + query_chunk_size, 139 | query_chunk_attention( 140 | query_chunk, key, value, mask, key_chunk_size=key_chunk_size 141 | ), 142 | ) 143 | 144 | _, res = torch_scan( 145 | chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) 146 | ) 147 | 148 | return res 149 | -------------------------------------------------------------------------------- /contact_pred/linear_mem_attn_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Tuple, Any, List 3 | from types import FunctionType 4 | 5 | @torch.jit.script 6 | def dynamic_length_slice( 7 | x: torch.Tensor, start: int = 0, size: int = 1024 8 | ) -> torch.Tensor: 9 | """Slices a tensor along the second axis. 10 | Ex: (b n h d) -> (b n[start:start+size] h d) 11 | """ 12 | # avoid slicing overhead if not needed 13 | if start == 0 and start + size >= x.shape[1]: 14 | return x 15 | else: 16 | return x[:, start : start + size] 17 | 18 | 19 | @torch.jit.script 20 | def dynamic_slice( 21 | x: torch.Tensor, 22 | start: Tuple[int, int, int], 23 | slice_sizes: Tuple[int, int, int], 24 | ) -> torch.Tensor: 25 | """approx like jax.lax.dynamic_slice. 26 | * NOTE: assumes we dont work on first dim 27 | Ex: 28 | dynamic_slice( 29 | x, 30 | slices=(0, 0, 0), 31 | slice_sizes=(16, 64, 64) 32 | ) 33 | """ 34 | return x[ 35 | :, 36 | start[0] : start[0] + slice_sizes[0], 37 | start[1] : start[1] + slice_sizes[1], 38 | start[2] : start[2] + slice_sizes[2], 39 | ] 40 | 41 | 42 | def torch_map(fn, xs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 43 | """approx like jax.lax.map""" 44 | return 45 | 46 | 47 | def torch_scan( 48 | f: FunctionType, init: int = 0, xs: Optional[List] = None, length: int = 0 49 | ): 50 | if xs is None: 51 | xs = [None] * length 52 | carry = init 53 | ys = [] 54 | for x in xs: 55 | carry, y = f(carry, x) 56 | ys.append(y) 57 | 58 | if len(ys) > 0 and isinstance(ys[0], tuple): 59 | return carry, tuple((torch.cat([y[i] for y in ys], dim=1) for i in range(len(ys[0])))) 60 | 61 | return carry, torch.cat(ys, dim=1) 62 | -------------------------------------------------------------------------------- /contact_pred/linear_mem_point_attn.py: -------------------------------------------------------------------------------- 1 | # adapted to torch from: https://arxiv.org/abs/2112.05682 2 | # inspired by https://github.com/CHARM-Tx/linear_mem_attention_pytorch 3 | 4 | import math 5 | import torch 6 | from torch.utils import checkpoint 7 | from typing import Tuple, Optional 8 | 9 | from .linear_mem_attn_utils import dynamic_length_slice, dynamic_slice, torch_map, torch_scan 10 | 11 | def query_chunk_attention( 12 | query: torch.Tensor, 13 | key: torch.Tensor, 14 | value: torch.Tensor, 15 | pair: torch.Tensor, 16 | pair_value: torch.Tensor, 17 | rotations: torch.Tensor, 18 | translations: torch.Tensor, 19 | encoder_rotations: torch.Tensor, 20 | encoder_translations: torch.Tensor, 21 | points_query: torch.Tensor, 22 | points_key: torch.Tensor, 23 | points_value: torch.Tensor, 24 | mask: Optional[torch.Tensor] = None, 25 | key_chunk_size: int = 4096, 26 | weight_kv = None, 27 | weight_points = None, 28 | gamma: torch.Tensor = None, 29 | ): 30 | """Multi-head dot product attention with a limited number of queries.""" 31 | device, dtype = query.device, query.dtype 32 | batch, num_kv, num_heads, k_features = key.shape 33 | v_features = value.shape[-1] 34 | query_chunk = query.shape[1] # b n h d 35 | key_chunk_size = min(key_chunk_size, num_kv) 36 | 37 | gamma = gamma.repeat(batch,num_kv,1) 38 | 39 | # @functools.partial(jax.checkpoint, prevent_cse=False) 40 | def summarize_chunk( 41 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, 42 | pair: torch.Tensor, pair_value: torch.Tensor, 43 | mask: torch.Tensor, 44 | translations: torch.Tensor, encoder_translations: torch.Tensor, 45 | rotations: torch.Tensor, encoder_rotations: torch.Tensor, 46 | points_query: torch.Tensor, points_key: torch.Tensor, points_value: torch.Tensor, 47 | gamma: torch.Tensor 48 | ): 49 | 50 | # sequences 51 | attn_weights = torch.einsum("bqhd,bkhd->bqhk", query, key) 52 | attn_weights = attn_weights * weight_kv 53 | 54 | # pair representation 55 | pair = pair.transpose(3,1) 56 | attn_weights = attn_weights + pair 57 | 58 | # point invariant 59 | a = torch.einsum('bnij,bnhpj->bnhpi', rotations, points_query) 60 | b = torch.einsum('bnij,bnhpj->bnhpi', encoder_rotations, points_key) 61 | a = a + translations[:,:,None,None] 62 | b = b + encoder_translations[:,:,None,None] 63 | 64 | a_sq = torch.sum(a**2,dim=[-2,-1]) 65 | b_sq = torch.sum(b**2,dim=[-2,-1]) 66 | b_sq = b_sq.transpose(2,1) 67 | invariant = a_sq[:,:,:,None] + b_sq[:,None] - 2*torch.einsum('bnhpi,bmhpi->bnhm',a,b) 68 | 69 | gamma = gamma.transpose(-2,-1) 70 | attn_weights = attn_weights - 0.5*weight_points*gamma[:,None,:,:] * invariant 71 | 72 | # overall scaling 73 | attn_weights = attn_weights / math.sqrt(3) 74 | 75 | if mask is not None: 76 | mask = mask.unsqueeze(1).unsqueeze(2) 77 | max_neg = -torch.finfo(attn_weights.dtype).max 78 | mask = mask.bool() 79 | attn_weights.masked_fill_(~mask, max_neg) 80 | 81 | max_score = torch.amax(attn_weights, dim=-1, keepdim=True).detach() 82 | exp_weights = torch.exp(attn_weights - max_score) 83 | 84 | # context 85 | exp_values = torch.einsum("bvhf,bqhv->bqhf", value, exp_weights) 86 | 87 | # pair output 88 | pair_value = pair_value.transpose(3,1) 89 | exp_pair = torch.einsum('bqcv,bqhv->bqhc', pair_value, exp_weights) 90 | 91 | # point output 92 | c = torch.einsum('bnij,bnhpj->bnhpi', encoder_rotations, points_value) 93 | c = c + encoder_translations[:,:,None,None] 94 | exp_value_points = torch.einsum("bvhpi,bqhv->bqhpi", c, exp_weights) 95 | 96 | # ((b q h f), (b q h c), (b q h p i)), (b q h k), (b q h 1) 97 | return exp_values, exp_pair, exp_value_points, exp_weights.sum(dim=-1), max_score.squeeze(dim=-1) 98 | 99 | def chunk_scanner( 100 | chunk_idx: int, 101 | ): 102 | key_chunk = dynamic_length_slice(key, chunk_idx, key_chunk_size) 103 | value_chunk = dynamic_length_slice(value, chunk_idx, key_chunk_size) 104 | 105 | pair_chunk = dynamic_length_slice(pair.transpose(3,1), chunk_idx, key_chunk_size) 106 | pair_value_chunk = dynamic_length_slice(pair_value.transpose(3,1), chunk_idx, key_chunk_size) 107 | 108 | encoder_rotations_chunk = dynamic_length_slice(encoder_rotations, chunk_idx, key_chunk_size) 109 | encoder_translations_chunk = dynamic_length_slice(encoder_translations, chunk_idx, key_chunk_size) 110 | 111 | points_key_chunk = dynamic_length_slice(points_key, chunk_idx, key_chunk_size) 112 | points_value_chunk = dynamic_length_slice(points_value, chunk_idx, key_chunk_size) 113 | 114 | mask_chunk = None 115 | if mask is not None: 116 | mask_chunk = dynamic_length_slice(mask, chunk_idx, key_chunk_size) 117 | 118 | gamma_chunk = dynamic_length_slice(gamma, chunk_idx, key_chunk_size) 119 | 120 | return checkpoint.checkpoint( 121 | summarize_chunk, query, key_chunk, value_chunk, 122 | pair_chunk, pair_value_chunk, 123 | mask_chunk, 124 | translations, encoder_translations_chunk, 125 | rotations, encoder_rotations_chunk, 126 | points_query, points_key_chunk, points_value_chunk, 127 | gamma_chunk 128 | ) 129 | 130 | num_chunks = int(math.ceil(num_kv / key_chunk_size)) 131 | chunk_values = torch.zeros( 132 | num_chunks, 133 | batch, 134 | query_chunk, 135 | num_heads, 136 | v_features, 137 | dtype=dtype, 138 | device=device, 139 | ) 140 | 141 | pair_channels = pair_value.shape[2] 142 | chunk_pair = torch.zeros( 143 | num_chunks, 144 | batch, 145 | query_chunk, 146 | num_heads, 147 | pair_channels, 148 | dtype=dtype, 149 | device=device, 150 | ) 151 | 152 | num_value_points = points_value.shape[-2] 153 | point_dim = points_value.shape[-1] 154 | chunk_points = torch.zeros( 155 | num_chunks, 156 | batch, 157 | query_chunk, 158 | num_heads, 159 | num_value_points, 160 | point_dim, 161 | dtype=dtype, 162 | device=device, 163 | ) 164 | 165 | chunk_weights = torch.zeros( 166 | num_chunks, 167 | batch, 168 | query_chunk, 169 | num_heads, 170 | dtype=dtype, 171 | device=device, 172 | ) 173 | chunk_max = torch.zeros( 174 | num_chunks, 175 | batch, 176 | query_chunk, 177 | num_heads, 178 | dtype=dtype, 179 | device=device, 180 | ) 181 | 182 | for i in range(num_chunks): 183 | chunk_values[i], chunk_pair[i], chunk_points[i], chunk_weights[i], chunk_max[i] = chunk_scanner( 184 | i * key_chunk_size 185 | ) 186 | 187 | max_diffs = torch.exp(chunk_max - chunk_max.amax(dim=0)) 188 | 189 | all_weights = (max_diffs * chunk_weights).sum(dim=0).unsqueeze(dim=-1) 190 | 191 | all_values = (max_diffs.unsqueeze(dim=-1) * chunk_values).sum(dim=0) / all_weights 192 | all_pair = (max_diffs.unsqueeze(dim=-1) * chunk_pair).sum(dim=0) / all_weights 193 | all_points = (max_diffs.unsqueeze(dim=-1).unsqueeze(dim=-1) * chunk_points).sum(dim=0) / all_weights.unsqueeze(dim=-1) 194 | 195 | # transpose of an orthogonal matrix == inverse 196 | all_points = all_points - translations[:,:,None,None] 197 | all_points = torch.einsum('bnji,bnhpj->bnhpi', rotations, all_points) 198 | 199 | return all_values, all_pair, all_points 200 | 201 | def point_attention( 202 | query: torch.Tensor, 203 | key: torch.Tensor, 204 | value: torch.Tensor, 205 | pair: torch.Tensor, 206 | pair_value: torch.Tensor, 207 | rotations: torch.Tensor, 208 | translations: torch.Tensor, 209 | encoder_rotations: torch.Tensor, 210 | encoder_translations: torch.Tensor, 211 | points_query: torch.Tensor, 212 | points_key: torch.Tensor, 213 | points_value: torch.Tensor, 214 | mask: Optional[torch.Tensor] = None, 215 | weight_kv = None, 216 | weight_points = None, 217 | gamma: torch.Tensor = None, 218 | query_chunk_size: int = 1024, 219 | key_chunk_size: int = 4096, 220 | ): 221 | """Memory-efficient multi-head dot product attention. 222 | Inputs: 223 | * q, k, v: (b n h d) torch tensors 224 | * pair, pair_value: (b n h m), (b n c m) 225 | * rotations, translations, encoder_rotations, encoder_translations: 226 | * points_query, points_key, points_value: 227 | * mask: (b n) 228 | * query_chunk_size: int. 229 | * key_chunk_size: int. 230 | Outputs: (b n h d) torch tensor (qk-weighted sum of v) 231 | """ 232 | batch, num_q, num_heads, q_features = query.shape 233 | 234 | def chunk_scanner(chunk_idx: int, _): 235 | query_chunk = dynamic_length_slice(query, chunk_idx, query_chunk_size) 236 | pair_chunk = dynamic_length_slice(pair, chunk_idx, query_chunk_size) 237 | pair_value_chunk = dynamic_length_slice(pair_value, chunk_idx, query_chunk_size) 238 | translations_chunk = dynamic_length_slice(translations, chunk_idx, query_chunk_size) 239 | rotations_chunk = dynamic_length_slice(rotations, chunk_idx, query_chunk_size) 240 | points_chunk = dynamic_length_slice(points_query, chunk_idx, query_chunk_size) 241 | 242 | return ( 243 | chunk_idx + query_chunk_size, 244 | query_chunk_attention( 245 | query_chunk, 246 | key, 247 | value, 248 | pair_chunk, 249 | pair_value_chunk, 250 | rotations_chunk, 251 | translations_chunk, 252 | encoder_rotations, 253 | encoder_translations, 254 | points_chunk, 255 | points_key, 256 | points_value, 257 | mask, 258 | key_chunk_size, 259 | weight_kv, 260 | weight_points, 261 | gamma, 262 | ), 263 | ) 264 | 265 | _, res = torch_scan( 266 | chunk_scanner, init=0, xs=None, length=int(math.ceil(num_q / query_chunk_size)) 267 | ) 268 | 269 | (context, pair_output, points_output) = res 270 | context = context.reshape(batch, num_q, num_heads, value.shape[-1]) 271 | pair_output = pair_output.reshape(batch, num_q, num_heads, pair_value.shape[2]) 272 | points_output = points_output.reshape(batch, num_q, num_heads, points_value.shape[-2], points_value.shape[-1]) 273 | 274 | return context, pair_output, points_output 275 | -------------------------------------------------------------------------------- /contact_pred/models.py: -------------------------------------------------------------------------------- 1 | from transformers import BertModel, BertConfig 2 | from transformers import PreTrainedModel, PretrainedConfig 3 | 4 | from .modules import PairRepresentation, CrossPairRepresentation 5 | from .utils import get_extended_attention_mask 6 | from .structure import Structure, CrossStructure, compute_weighted_FAPE, IPAConfig 7 | from .structure import compute_kabsch_RMSD, compute_weighted_RMSD, compute_residue_dist, compute_residue_CN_dist, compute_residue_CNC_angle 8 | from .structure import computeAllAtomCoordinates 9 | from .residue_constants import ( 10 | restype_rigid_group_default_frame, 11 | restype_atom14_to_rigid_group, 12 | restype_atom14_mask, 13 | restype_atom14_rigid_group_positions, 14 | restype_order_with_x 15 | ) 16 | 17 | import torch 18 | import torch.nn as nn 19 | from torch.utils.checkpoint import checkpoint 20 | import math 21 | 22 | import pickle, numpy as np # for dumping items 23 | 24 | class ProteinLigandConfig(PretrainedConfig): 25 | model_type = 'bert' # this is required for tokenizer selection 26 | 27 | def __init__( 28 | self, 29 | seq_config=BertConfig(), 30 | smiles_config=BertConfig(), 31 | n_cross_attention=3, 32 | linear_mem_attn=True, 33 | query_chunk_size_seq=512, 34 | key_chunk_size_seq=1024, 35 | query_chunk_size_smiles=512, 36 | key_chunk_size_smiles=1024, 37 | **kwargs 38 | ): 39 | 40 | self.smiles_config = smiles_config 41 | if isinstance(smiles_config, BertConfig): 42 | self.smiles_config = self.smiles_config.to_dict() 43 | 44 | self.seq_config = seq_config 45 | if isinstance(seq_config, BertConfig): 46 | self.seq_config = self.seq_config.to_dict() 47 | 48 | self.n_cross_attention = n_cross_attention 49 | 50 | # to estimate memory usage with deepspeed ZERO stage3, the larger of the two hidden dimensions 51 | self.hidden_size = self.seq_config['hidden_size'] 52 | 53 | self.linear_mem_attn = linear_mem_attn 54 | 55 | self.query_chunk_size_seq = query_chunk_size_seq 56 | self.key_chunk_size_seq = key_chunk_size_seq 57 | 58 | self.query_chunk_size_smiles = query_chunk_size_smiles 59 | self.key_chunk_size_smiles = key_chunk_size_smiles 60 | 61 | super().__init__(**kwargs) 62 | 63 | class ProteinLigandConfigStructure(ProteinLigandConfig): 64 | def __init__( 65 | self, 66 | seq_ipa_config = IPAConfig(), 67 | smiles_ipa_config = IPAConfig(), 68 | num_ipa_layers=8, 69 | num_rigid_groups=8, 70 | width_resnet=128, 71 | depth_resnet=5, 72 | num_embeddings=30, # >= number of amino acids 73 | num_atoms=14, # max number of heavy atoms in a residue 74 | enable_cross=True, 75 | seq_vocab=None, 76 | **kwargs 77 | ): 78 | 79 | self.seq_ipa_config = seq_ipa_config 80 | if isinstance(seq_ipa_config, IPAConfig): 81 | self.seq_ipa_config = self.seq_ipa_config.to_dict() 82 | 83 | self.smiles_ipa_config = smiles_ipa_config 84 | if isinstance(smiles_ipa_config, IPAConfig): 85 | self.smiles_ipa_config = self.smiles_ipa_config.to_dict() 86 | 87 | self.num_ipa_layers = num_ipa_layers 88 | self.num_rigid_groups = num_rigid_groups 89 | self.width_resnet = width_resnet 90 | self.depth_resnet = depth_resnet 91 | self.num_embeddings = num_embeddings 92 | self.num_atoms = num_atoms 93 | 94 | self.enable_cross = enable_cross 95 | self.seq_vocab = seq_vocab 96 | 97 | super().__init__(**kwargs) 98 | 99 | class StructurePrediction(PreTrainedModel): 100 | config_class = ProteinLigandConfigStructure 101 | supports_gradient_checkpointing = True 102 | main_input_name = "input_ids_1" # estimate FLOPs from the protein sequence, which typically has more tokens 103 | 104 | def __init__(self, config): 105 | super().__init__(config) 106 | 107 | self.pair_representation = PairRepresentation(config) 108 | self.structure = Structure(config) 109 | 110 | if config.enable_cross: 111 | self.cross_pair_representation = CrossPairRepresentation(config) 112 | self.cross_structure = CrossStructure(config) 113 | 114 | self.gradient_checkpointing = False 115 | self.enable_cross = config.enable_cross 116 | 117 | self.default_frame = None 118 | self.group_idx = None 119 | self.atom_mask = None 120 | self.lit_positions = None 121 | self.input_ids_to_aatype = None 122 | 123 | def forward( 124 | self, 125 | input_ids_1=None, 126 | inputs_embeds_1=None, 127 | attention_mask_1=None, 128 | input_ids_2=None, 129 | inputs_embeds_2=None, 130 | attention_mask_2=None, 131 | labels_receptor_frames_xyz=None, 132 | labels_receptor_frames_rot=None, 133 | labels_receptor_xyz=None, 134 | labels_ligand_frames_xyz=None, 135 | labels_ligand_frames_rot=None, 136 | labels_ligand_xyz=None, 137 | labels_ligand_token_mask=None, 138 | labels_receptor_token_mask=None, 139 | return_coordinates=True, 140 | return_dict=False, 141 | **kwargs, 142 | ): 143 | pair_representation_output = self.pair_representation( 144 | input_ids_1=input_ids_1, 145 | inputs_embeds_1=inputs_embeds_1, 146 | attention_mask_1=attention_mask_1, 147 | input_ids_2=input_ids_2, 148 | inputs_embeds_2=inputs_embeds_2, 149 | attention_mask_2=attention_mask_2, 150 | ) 151 | 152 | pair_representation_seq, pair_representation_smiles = pair_representation_output[:2] 153 | hidden_seq, hidden_smiles = pair_representation_output[2:4] 154 | 155 | hidden_smiles, hidden_seq, xyz_ligand, rot_ligand, xyz_receptor, rot_receptor, rotation_angles = self.structure( 156 | hidden_seq, 157 | hidden_smiles, 158 | attention_mask_1, 159 | attention_mask_2, 160 | pair_representation_seq, 161 | pair_representation_smiles, 162 | ) 163 | 164 | if labels_receptor_frames_xyz is not None or labels_ligand_xyz is not None: 165 | if labels_ligand_xyz is None or labels_receptor_frames_xyz is None or labels_ligand_token_mask is None or labels_receptor_token_mask is None: 166 | raise ValueError("Need both ligand and receptor coordinates.") 167 | 168 | # mask non-atom coordinates 169 | mask_receptor = attention_mask_1*labels_receptor_token_mask 170 | mask_ligand = attention_mask_2*labels_ligand_token_mask 171 | 172 | # auxiliary loss on Calpha + ligand 173 | loss_receptor = compute_kabsch_RMSD(labels_receptor_frames_xyz, xyz_receptor, mask_receptor, dclamp=None) 174 | loss_ligand = compute_kabsch_RMSD(labels_ligand_frames_xyz, xyz_ligand, mask_ligand, dclamp=None) 175 | #loss_receptor_dist = compute_residue_dist(xyz_receptor, mask_receptor, dclamp=None) 176 | #print(f'loss_receptor_dist 1: {loss_receptor_dist:.3f}') 177 | #loss = 0.5*(loss_receptor+loss_ligand) + loss_receptor_dist * 0.5 178 | #loss = (loss_receptor + loss_ligand + loss_receptor_dist) / 3 179 | loss = (loss_receptor + loss_ligand) / 2 180 | 181 | if self.enable_cross: 182 | pair_representation_output = self.cross_pair_representation( 183 | hidden_states_1=hidden_seq, 184 | hidden_states_2=hidden_smiles, 185 | attention_mask_1=attention_mask_1, 186 | attention_mask_2=attention_mask_2, 187 | ) 188 | pair_representation_cross, hidden_seq, hidden_smiles = pair_representation_output 189 | 190 | hidden_smiles, hidden_seq, xyz_ligand, rot_ligand, xyz_receptor, rot_receptor, rotation_angles = self.cross_structure( 191 | hidden_seq, 192 | hidden_smiles, 193 | attention_mask_1, 194 | attention_mask_2, 195 | rot_receptor, 196 | xyz_receptor, 197 | rot_ligand, 198 | xyz_ligand, 199 | pair_representation_cross, 200 | rotation_angles, 201 | ) 202 | 203 | # generate sidechain coordinates 204 | #rotation_angles[:] = 1 205 | #print(rotation_angles[0, 1:11]) 206 | receptor_feat = self.computeAllAtomCoordinates(input_ids_1, 207 | xyz_receptor, 208 | rot_receptor, 209 | rotation_angles, 210 | ) 211 | 212 | #print(receptor_feat[0, 2]) 213 | 214 | 215 | outputs = dict() 216 | if return_coordinates: 217 | outputs['ligand_frames_xyz'] = xyz_ligand 218 | outputs['ligand_frames_rot'] = rot_ligand 219 | outputs['receptor_frames_xyz'] = xyz_receptor 220 | outputs['receptor_frames_rot'] = rot_receptor 221 | 222 | if self.enable_cross: 223 | outputs['receptor_xyz'] = receptor_feat 224 | 225 | if not return_dict: 226 | outputs = tuple(outputs.values()) 227 | 228 | if len(outputs) == 1: 229 | outputs = outputs[0] 230 | 231 | if labels_receptor_frames_xyz is not None or labels_ligand_xyz is not None: 232 | if labels_ligand_xyz is None or labels_receptor_frames_xyz is None or labels_ligand_token_mask is None or labels_receptor_token_mask is None: 233 | raise ValueError("Need both ligand and receptor coordinates.") 234 | 235 | if self.enable_cross: 236 | # ligand frames with a single atom 237 | ligand_feat = xyz_ligand.unsqueeze(2) 238 | ligand_feat = torch.cat([ligand_feat, 239 | torch.zeros(*(ligand_feat.shape[:2] + (self.config.num_atoms-1,) + ligand_feat.shape[3:]), 240 | device=ligand_feat.device, dtype=ligand_feat.dtype)], 2) 241 | 242 | weight = torch.cat([mask_receptor,mask_ligand],1) 243 | labels_feat = torch.cat([labels_receptor_xyz, labels_ligand_xyz], 1) 244 | feat = torch.cat([receptor_feat, ligand_feat], 1) 245 | 246 | use_fape = False 247 | if use_fape: 248 | labels_frames_xyz = torch.cat([labels_receptor_frames_xyz, labels_ligand_frames_xyz], 1) 249 | labels_frames_rot = torch.cat([labels_receptor_frames_rot, labels_ligand_frames_rot], 1) 250 | frames_xyz = torch.cat([xyz_receptor, xyz_ligand], 1) 251 | frames_rot = torch.cat([rot_receptor, rot_ligand], 1) 252 | loss = (2*loss + compute_weighted_FAPE(labels_feat, labels_frames_xyz, labels_frames_rot, feat, frames_xyz, frames_rot, weight))/3 253 | else: 254 | # flatten atom coordinates 255 | non_nan = (~torch.any(torch.isnan(labels_feat),dim=-1)).type(torch.int64) 256 | weight = weight[:,:,None]*non_nan 257 | 258 | # normalize so that both molecules are weighted equally 259 | seq_len = mask_receptor.shape[1] 260 | norm_seq = torch.sum(weight[:,:seq_len], [-1,-2], keepdim=True) 261 | norm_smiles = torch.sum(weight[:,seq_len:], [-1,-2], keepdim=True) 262 | weight_seq = weight[:,:seq_len].type(feat.dtype) 263 | weight_smiles = weight[:,seq_len:].type(feat.dtype) 264 | weight_seq = torch.where(norm_seq > 0, weight_seq/norm_seq, weight_seq) 265 | weight_smiles = torch.where(norm_smiles > 0, weight_smiles/norm_smiles, weight_smiles) 266 | weight = torch.cat([weight_seq, weight_smiles], 1) 267 | 268 | labels_feat = torch.nan_to_num(labels_feat) 269 | labels_feat = labels_feat.reshape(*labels_feat.shape[:1], -1, *labels_feat.shape[-1:]) 270 | feat = feat.reshape(*feat.shape[:1], -1, *feat.shape[-1:]) 271 | weight = weight.reshape(*weight.shape[:1], -1) 272 | #weight_smiles = weight_smiles.reshape(*weight_smiles.shape[:1], -1) 273 | #weight_seq = weight_seq.reshape(*weight_seq.shape[:1], -1) 274 | 275 | loss_kabsch = compute_kabsch_RMSD(labels_feat, feat, weight, dclamp=None) 276 | loss_receptor_CN_dist = compute_residue_CN_dist(receptor_feat, mask_receptor, dclamp=None, losstype='bottom') 277 | loss_receptor_CNC_angle = compute_residue_CNC_angle(receptor_feat, mask_receptor) 278 | #print(f'loss_receptor_dist 2: {loss_receptor_dist:.3f}') 279 | loss = (2 * loss + 3 * loss_kabsch + loss_receptor_CN_dist) / 6 280 | # 1/6 ligand self, 1/6 protein self, 1/6 CN dist, 1/2 overall kabsch 281 | 282 | # dumper = {'input_ids_1': input_ids_1.cpu().detach().numpy(), 283 | # 'receptor_feat': receptor_feat.cpu().detach().numpy(), 284 | # 'labels_receptor_xyz': labels_receptor_xyz.cpu().detach().numpy(), 285 | # 'loss': loss.cpu().detach().numpy(), 286 | # 'loss_this': compute_kabsch_RMSD(labels_feat, feat, weight).cpu().detach().numpy(), 287 | # 'loss_receptor_CA': compute_kabsch_RMSD(labels_receptor_frames_xyz, xyz_receptor, mask_receptor).cpu().detach().numpy(), 288 | # 'loss_ligand': compute_kabsch_RMSD(labels_ligand_frames_xyz, xyz_ligand, mask_ligand).cpu().detach().numpy(),} 289 | # #'loss_receptor_all': compute_kabsch_RMSD(labels_receptor_xyz, receptor_feat, weight_seq).cpu().detach().numpy()} 290 | # pickle.dump(dumper, open(f'random_dump/dump_{np.random.randint(1000)}.pkl', 'wb')) 291 | # # Basically take input_id_1, receptor_feat, and loss 292 | 293 | return (loss, outputs) 294 | else: 295 | return outputs 296 | 297 | def computeAllAtomCoordinates(self, 298 | input_ids_1, 299 | xyz_receptor, 300 | rot_receptor, 301 | rotation_angles, 302 | ): 303 | self._init_residue_constants(rotation_angles.dtype, rotation_angles.device) 304 | aatypes = torch.tensor(self.input_ids_to_aatype[input_ids_1], device=input_ids_1.device)#, requires_grad=False) 305 | return computeAllAtomCoordinates( 306 | aatypes, 307 | xyz_receptor, 308 | rot_receptor, 309 | rotation_angles, 310 | self.default_frame, 311 | self.group_idx, 312 | self.atom_mask, 313 | self.lit_positions) 314 | 315 | def _init_residue_constants(self, float_dtype, device): 316 | if self.default_frame is None: 317 | self.default_frame = torch.tensor( 318 | restype_rigid_group_default_frame, 319 | dtype=float_dtype, 320 | device=device, 321 | requires_grad=False, 322 | ) 323 | if self.group_idx is None: 324 | self.group_idx = torch.tensor( 325 | restype_atom14_to_rigid_group, 326 | device=device, 327 | requires_grad=False, 328 | ) 329 | if self.atom_mask is None: 330 | self.atom_mask = torch.tensor( 331 | restype_atom14_mask, 332 | dtype=float_dtype, 333 | device=device, 334 | requires_grad=False, 335 | ) 336 | if self.lit_positions is None: 337 | self.lit_positions = torch.tensor( 338 | restype_atom14_rigid_group_positions, 339 | dtype=float_dtype, 340 | device=device, 341 | requires_grad=False, 342 | ) 343 | 344 | if self.input_ids_to_aatype is None: 345 | input_ids_to_aatype = torch.zeros(len(self.config.seq_vocab.keys())) 346 | for k in self.config.seq_vocab.keys(): 347 | if k in restype_order_with_x.keys(): 348 | input_ids_to_aatype[self.config.seq_vocab[k]] = restype_order_with_x[k] 349 | else: 350 | input_ids_to_aatype[self.config.seq_vocab[k]] = 20 351 | #print(f'{k} ({self.config.seq_vocab[k]}) -> {input_ids_to_aatype[self.config.seq_vocab[k]]}') 352 | 353 | self.input_ids_to_aatype = torch.tensor(input_ids_to_aatype, dtype=torch.long, device=device, requires_grad=False) 354 | #print(self.input_ids_to_aatype) 355 | 356 | def gradient_checkpointing_enable(self): 357 | self.gradient_checkpointing = True 358 | self.structure.gradient_checkpointing_enable() 359 | self.pair_representation.gradient_checkpointing_enable() 360 | 361 | def gradient_checkpointing_disable(self): 362 | self.gradient_checkpointing = False 363 | self.structure.gradient_checkpointing_disable() 364 | self.pair_representation.gradient_checkpointing_disable() 365 | 366 | def freeze_protein(self): 367 | self.pair_representation.freeze_protein() 368 | self.structure.freeze_protein() 369 | 370 | def freeze_ligand(self): 371 | self.pair_representation.freeze_ligand() 372 | self.structure.freeze_ligand() 373 | -------------------------------------------------------------------------------- /contact_pred/modules.py: -------------------------------------------------------------------------------- 1 | from transformers import BertModel, BertConfig 2 | from transformers.models.bert.modeling_bert import BertIntermediate, BertSelfOutput, BertOutput 3 | from transformers.modeling_utils import apply_chunking_to_forward 4 | 5 | from .structure import Structure, compute_weighted_RMSD, IPAConfig 6 | 7 | from .utils import get_extended_attention_mask 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import functional as F 12 | from torch.utils.checkpoint import checkpoint 13 | 14 | import math 15 | 16 | from .linear_mem_attn import attention 17 | 18 | class AttentionWithScoreOutput(nn.Module): 19 | def __init__(self, config, other_config=None): 20 | super().__init__() 21 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 22 | raise ValueError( 23 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 24 | f"heads ({config.num_attention_heads})" 25 | ) 26 | 27 | self.num_attention_heads = config.num_attention_heads 28 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 29 | self.all_head_size = self.num_attention_heads * self.attention_head_size 30 | 31 | if other_config is None: 32 | other_config = config 33 | 34 | self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False) 35 | self.key = nn.Linear(other_config.hidden_size, self.all_head_size, bias=False) 36 | self.value = nn.Linear(other_config.hidden_size, self.all_head_size, bias=False) 37 | 38 | def transpose_for_scores(self, x): 39 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 40 | x = x.view(new_x_shape) 41 | return x.permute(0, 2, 1, 3) 42 | 43 | def forward( 44 | self, 45 | hidden_states, 46 | attention_mask=None, 47 | encoder_hidden_states=None, 48 | encoder_attention_mask=None, 49 | **kwargs 50 | ): 51 | mixed_query_layer = self.query(hidden_states) 52 | 53 | # If this is instantiated as a cross-attention module, the keys 54 | # and values come from an encoder; the attention mask needs to be 55 | # such that the encoder's padding tokens are not attended to. 56 | is_cross_attention = encoder_hidden_states is not None 57 | 58 | if is_cross_attention: 59 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 60 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 61 | 62 | attention_mask = encoder_attention_mask 63 | else: 64 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 65 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 66 | 67 | query_layer = self.transpose_for_scores(mixed_query_layer) 68 | 69 | # Take the dot product between "query" and "key" to get the raw attention scores. 70 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 71 | 72 | # apply mask 73 | if attention_mask is not None: 74 | if attention_mask.dtype != torch.float32 and attention_mask.dtype != torch.float16: 75 | inv_attention_mask = get_extended_attention_mask( 76 | attention_mask, 77 | hidden_states.shape[:-1], 78 | hidden_states.device, 79 | hidden_states.dtype 80 | ) 81 | 82 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 83 | 84 | # Normalize the attention scores to probabilities. 85 | attention_probs = nn.functional.softmax(attention_scores + inv_attention_mask, dim=-1) 86 | 87 | context_layer = torch.matmul(attention_probs, value_layer) 88 | 89 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 90 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 91 | context_layer = context_layer.view(new_context_layer_shape) 92 | 93 | return context_layer, attention_scores 94 | 95 | class LinearMemAttentionWithScoreOutput(nn.Module): 96 | def __init__(self, config, other_config=None): 97 | super().__init__() 98 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 99 | raise ValueError( 100 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 101 | f"heads ({config.num_attention_heads})" 102 | ) 103 | 104 | self.num_attention_heads = config.num_attention_heads 105 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 106 | self.all_head_size = self.num_attention_heads * self.attention_head_size 107 | 108 | if other_config is None: 109 | other_config = config 110 | 111 | self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False) 112 | self.key = nn.Linear(other_config.hidden_size, self.all_head_size, bias=False) 113 | self.value = nn.Linear(other_config.hidden_size, self.all_head_size, bias=False) 114 | 115 | def view_for_scores(self, x): 116 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 117 | return x.view(new_x_shape) 118 | 119 | def forward( 120 | self, 121 | hidden_states, 122 | attention_mask=None, 123 | encoder_hidden_states=None, 124 | encoder_attention_mask=None, 125 | query_chunk_size=1024, 126 | key_chunk_size=4096, 127 | ): 128 | mixed_query_layer = self.query(hidden_states) 129 | 130 | # If this is instantiated as a cross-attention module, the keys 131 | # and values come from an encoder; the attention mask needs to be 132 | # such that the encoder's padding tokens are not attended to. 133 | is_cross_attention = encoder_hidden_states is not None 134 | 135 | if is_cross_attention: 136 | key_layer = self.view_for_scores(self.key(encoder_hidden_states)) 137 | value_layer = self.view_for_scores(self.value(encoder_hidden_states)) 138 | 139 | attention_mask = encoder_attention_mask 140 | else: 141 | key_layer = self.view_for_scores(self.key(hidden_states)) 142 | value_layer = self.view_for_scores(self.value(hidden_states)) 143 | 144 | query_layer = self.view_for_scores(mixed_query_layer) 145 | 146 | context, attention_scores = attention( 147 | query_layer, 148 | key_layer, 149 | value_layer, 150 | attention_mask, 151 | query_chunk_size, 152 | key_chunk_size, 153 | ) 154 | 155 | context_layer_shape = context.size()[:-2] + (self.all_head_size,) 156 | context = context.view(context_layer_shape) 157 | 158 | attention_scores = attention_scores.permute(0,3,1,2) 159 | 160 | return context, attention_scores 161 | 162 | class AttentionBlock(nn.Module): 163 | def __init__( 164 | self, 165 | config, 166 | other_config=None, 167 | linear_mem=False, 168 | ): 169 | super().__init__() 170 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 171 | self.seq_len_dim = 1 172 | 173 | self.is_cross_attention = other_config is not None 174 | 175 | if linear_mem: 176 | if not self.is_cross_attention: 177 | self.attention = LinearMemAttentionWithScoreOutput(config) 178 | else: 179 | self.crossattention = LinearMemAttentionWithScoreOutput(config, other_config) 180 | else: 181 | if not self.is_cross_attention: 182 | self.attention = AttentionWithScoreOutput(config) 183 | else: 184 | self.crossattention = AttentionWithScoreOutput(config,other_config) 185 | 186 | if not self.is_cross_attention: 187 | self.self_output = BertSelfOutput(config) 188 | else: 189 | self.cross_output = BertSelfOutput(config) 190 | 191 | self.intermediate = BertIntermediate(config) 192 | self.output = BertOutput(config) 193 | 194 | def forward( 195 | self, 196 | hidden_states, 197 | attention_mask, 198 | query_chunk_size=1024, 199 | key_chunk_size=4096, 200 | encoder_hidden_states=None, 201 | encoder_attention_mask=None, 202 | ): 203 | if not self.is_cross_attention: 204 | attention_outputs = self.attention( 205 | hidden_states=hidden_states, 206 | attention_mask=attention_mask, 207 | query_chunk_size=query_chunk_size, 208 | key_chunk_size=key_chunk_size, 209 | ) 210 | attention_output = attention_outputs[0] 211 | score_outputs = attention_outputs[1:] # add cross attentions if we output attention weights 212 | 213 | hidden_states = self.self_output(attention_output, hidden_states) 214 | else: 215 | cross_attention_outputs = self.crossattention( 216 | hidden_states=hidden_states, 217 | encoder_hidden_states=encoder_hidden_states, 218 | encoder_attention_mask=encoder_attention_mask, 219 | query_chunk_size=query_chunk_size, 220 | key_chunk_size=key_chunk_size, 221 | ) 222 | attention_output = cross_attention_outputs[0] 223 | score_outputs = cross_attention_outputs[1:] # add cross attentions if we output attention weights 224 | 225 | hidden_states = self.cross_output(attention_output, hidden_states) 226 | 227 | # add cross-attn cache to positions 3,4 of present_key_value tuple 228 | layer_output = apply_chunking_to_forward( 229 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, hidden_states 230 | ) 231 | outputs = (layer_output,) + score_outputs 232 | 233 | return outputs 234 | 235 | def feed_forward_chunk(self, attention_output): 236 | intermediate_output = self.intermediate(attention_output) 237 | layer_output = self.output(intermediate_output, attention_output) 238 | return layer_output 239 | 240 | class EnsembleEmbedding(torch.nn.Module): 241 | def __init__(self, config, add_pooling_layer=False): 242 | super().__init__() 243 | 244 | self.config = config 245 | 246 | self.gradient_checkpointing = False 247 | 248 | self.seq_model = BertModel( 249 | BertConfig.from_dict(config.seq_config), 250 | add_pooling_layer=add_pooling_layer, 251 | ) 252 | 253 | self.smiles_model = BertModel( 254 | BertConfig.from_dict(config.smiles_config), 255 | add_pooling_layer=add_pooling_layer, 256 | ) 257 | 258 | # use the configuration of the model with the larger hidden dimensions 259 | self.hidden_size = self.seq_model.config.hidden_size + self.smiles_model.config.hidden_size 260 | 261 | def gradient_checkpointing_enable(self): 262 | self.gradient_checkpointing = True 263 | 264 | # checkpoint gradients for models that are not frozen 265 | if any([p.requires_grad for p in self.seq_model.parameters()]): 266 | self.seq_model.gradient_checkpointing_enable() 267 | 268 | if any([p.requires_grad for p in self.smiles_model.parameters()]): 269 | self.smiles_model.gradient_checkpointing_enable() 270 | 271 | def gradient_checkpointing_disable(self): 272 | self.gradient_checkpointing = False 273 | self.seq_model.gradient_checkpointing_disable() 274 | self.smiles_model.gradient_checkpointing_disable() 275 | 276 | def load_pretrained(self, seq_model_name, smiles_model_name, add_pooling_layer=False): 277 | self.seq_model = BertModel.from_pretrained(seq_model_name, 278 | add_pooling_layer=add_pooling_layer, 279 | config=BertConfig.from_dict(self.config.seq_config), 280 | ) 281 | self.smiles_model = BertModel.from_pretrained(smiles_model_name, 282 | add_pooling_layer=add_pooling_layer, 283 | config=BertConfig.from_dict(self.config.smiles_config), 284 | ) 285 | 286 | def freeze_protein(self): 287 | for param in self.seq_model.parameters(): 288 | param.requires_grad = False 289 | 290 | def freeze_ligand(self): 291 | for param in self.smiles_model.parameters(): 292 | param.requires_grad = False 293 | 294 | def forward( 295 | self, 296 | input_ids_1=None, 297 | inputs_embeds_1=None, 298 | attention_mask_1=None, 299 | input_ids_2=None, 300 | inputs_embeds_2=None, 301 | attention_mask_2=None, 302 | ): 303 | # embed amino acids, sharing the same model 304 | encoder_outputs = self.seq_model( 305 | input_ids=input_ids_1, 306 | inputs_embeds=inputs_embeds_1, 307 | attention_mask=attention_mask_1, 308 | ) 309 | hidden_seq = encoder_outputs.last_hidden_state 310 | 311 | # encode SMILES 312 | encoder_outputs = self.smiles_model( 313 | input_ids=input_ids_2, 314 | inputs_embeds=inputs_embeds_2, 315 | attention_mask=attention_mask_2, 316 | ) 317 | hidden_smiles = encoder_outputs.last_hidden_state 318 | 319 | # concatenate the outputs 320 | return hidden_seq, hidden_smiles 321 | 322 | class PairRepresentation(torch.nn.Module): 323 | def __init__(self, config): 324 | super().__init__() 325 | 326 | self.embedding = EnsembleEmbedding(config, add_pooling_layer=False) 327 | 328 | self.attn_seq = torch.nn.ModuleList([ 329 | AttentionBlock(self.embedding.seq_model.config, linear_mem=config.linear_mem_attn) 330 | for i in range(config.n_cross_attention)]) 331 | 332 | self.attn_smiles = torch.nn.ModuleList([ 333 | AttentionBlock(self.embedding.smiles_model.config, linear_mem=config.linear_mem_attn) 334 | for i in range(config.n_cross_attention)]) 335 | 336 | self.query_chunk_size_seq = config.query_chunk_size_seq 337 | self.key_chunk_size_seq = config.key_chunk_size_seq 338 | 339 | self.query_chunk_size_smiles = config.query_chunk_size_smiles 340 | self.key_chunk_size_smiles = config.key_chunk_size_smiles 341 | 342 | self.initial_norm_seq = nn.LayerNorm(self.embedding.seq_model.config.hidden_size) 343 | self.initial_norm_smiles = nn.LayerNorm(self.embedding.smiles_model.config.hidden_size) 344 | 345 | self.gradient_checkpointing=False 346 | 347 | def forward( 348 | self, 349 | input_ids_1=None, 350 | inputs_embeds_1=None, 351 | attention_mask_1=None, 352 | input_ids_2=None, 353 | inputs_embeds_2=None, 354 | attention_mask_2=None, 355 | return_dict=False, 356 | **kwargs, 357 | ): 358 | embedding = self.embedding( 359 | input_ids_1=input_ids_1, 360 | inputs_embeds_1=inputs_embeds_1, 361 | attention_mask_1=attention_mask_1, 362 | input_ids_2=input_ids_2, 363 | inputs_embeds_2=inputs_embeds_2, 364 | attention_mask_2=attention_mask_2, 365 | ) 366 | 367 | hidden_seq, hidden_smiles = embedding 368 | 369 | hidden_seq = self.initial_norm_seq(hidden_seq) 370 | hidden_smiles = self.initial_norm_smiles(hidden_smiles) 371 | 372 | for attn_1, attn_2 in zip(self.attn_seq, self.attn_smiles): 373 | # receptor 374 | if self.gradient_checkpointing: 375 | hidden_seq, attention_score_1 = checkpoint(attn_1, 376 | hidden_seq, 377 | attention_mask_1, 378 | self.query_chunk_size_seq, 379 | self.key_chunk_size_seq, 380 | ) 381 | else: 382 | hidden_seq, attention_score_1 = attn_1( 383 | hidden_states=hidden_seq, 384 | attention_mask=attention_mask_1, 385 | query_chunk_size=self.query_chunk_size_seq, 386 | key_chunk_size=self.key_chunk_size_seq, 387 | ) 388 | 389 | # ligand 390 | if self.gradient_checkpointing: 391 | hidden_smiles, attention_score_2 = checkpoint(attn_2, 392 | hidden_smiles, 393 | attention_mask_2, 394 | self.query_chunk_size_smiles, 395 | self.key_chunk_size_smiles, 396 | ) 397 | else: 398 | hidden_smiles, attention_score_2 = attn_2( 399 | hidden_states = hidden_smiles, 400 | attention_mask = attention_mask_2, 401 | query_chunk_size = self.query_chunk_size_smiles, 402 | key_chunk_size = self.key_chunk_size_smiles, 403 | ) 404 | 405 | pair_representation_seq = attention_score_1 406 | pair_representation_smiles = attention_score_2 407 | 408 | outputs = dict() 409 | 410 | outputs['pair_representation_seq'] = pair_representation_seq 411 | outputs['pair_representation_smiles'] = pair_representation_smiles 412 | 413 | outputs['hidden_seq'] = hidden_seq 414 | outputs['hidden_smiles'] = hidden_smiles 415 | 416 | if not return_dict: 417 | outputs = tuple(outputs.values()) 418 | 419 | return outputs 420 | 421 | def gradient_checkpointing_enable(self): 422 | self.gradient_checkpointing = True 423 | self.embedding.gradient_checkpointing_enable() 424 | 425 | def gradient_checkpointing_disable(self): 426 | self.gradient_checkpointing = False 427 | self.embedding.gradient_checkpointing_disable() 428 | 429 | def freeze_protein(self): 430 | self.embedding.freeze_protein() 431 | for param in self.attn_seq.parameters(): 432 | param.requires_grad = False 433 | for param in self.initial_norm_seq.parameters(): 434 | param.requires_grad = False 435 | 436 | def freeze_ligand(self): 437 | self.embedding.freeze_ligand() 438 | for param in self.attn_smiles.parameters(): 439 | param.requires_grad = False 440 | for param in self.initial_norm_smiles.parameters(): 441 | param.requires_grad = False 442 | 443 | class CrossPairRepresentation(torch.nn.Module): 444 | def __init__(self, config): 445 | super().__init__() 446 | 447 | self.attn_seq = torch.nn.ModuleList([ 448 | AttentionBlock(BertConfig.from_dict(config.seq_config), 449 | BertConfig.from_dict(config.smiles_config), 450 | linear_mem=config.linear_mem_attn) 451 | for i in range(config.n_cross_attention)]) 452 | 453 | self.attn_smiles = torch.nn.ModuleList([ 454 | AttentionBlock(BertConfig.from_dict(config.smiles_config), 455 | BertConfig.from_dict(config.seq_config), 456 | linear_mem=config.linear_mem_attn) 457 | for i in range(config.n_cross_attention)]) 458 | 459 | self.query_chunk_size_seq = config.query_chunk_size_seq 460 | self.key_chunk_size_seq = config.key_chunk_size_seq 461 | 462 | self.query_chunk_size_smiles = config.query_chunk_size_smiles 463 | self.key_chunk_size_smiles = config.key_chunk_size_smiles 464 | 465 | self.initial_norm_seq = nn.LayerNorm(config.seq_config['hidden_size']) 466 | self.initial_norm_smiles = nn.LayerNorm(config.smiles_config['hidden_size']) 467 | 468 | self.gradient_checkpointing=False 469 | 470 | def forward( 471 | self, 472 | hidden_states_1, 473 | hidden_states_2, 474 | attention_mask_1, 475 | attention_mask_2, 476 | return_dict=False, 477 | **kwargs, 478 | ): 479 | hidden_seq = self.initial_norm_seq(hidden_states_1) 480 | hidden_smiles = self.initial_norm_smiles(hidden_states_2) 481 | 482 | def cross(attn_1, attn_2, hidden_states_1, hidden_states_2, attention_mask_1, attention_mask_2, 483 | query_chunk_size_1, key_chunk_size_1, 484 | query_chunk_size_2, key_chunk_size_2 485 | ): 486 | attention_output_1 = attn_1( 487 | hidden_states = hidden_states_1, 488 | attention_mask = attention_mask_1, 489 | query_chunk_size = query_chunk_size_1, 490 | key_chunk_size = key_chunk_size_1, 491 | encoder_hidden_states = hidden_states_2, 492 | encoder_attention_mask = attention_mask_2, 493 | ) 494 | attention_output_2 = attn_2( 495 | hidden_states = hidden_states_2, 496 | attention_mask = attention_mask_2, 497 | query_chunk_size = query_chunk_size_2, 498 | key_chunk_size = key_chunk_size_2, 499 | encoder_hidden_states = hidden_states_1, 500 | encoder_attention_mask = attention_mask_1, 501 | ) 502 | 503 | # torch.utils.checkpoint does not support nested structures, concatenate the outputs 504 | output = attention_output_1 + attention_output_2 505 | 506 | return output 507 | 508 | for attn_1, attn_2 in zip(self.attn_seq, self.attn_smiles): 509 | if self.gradient_checkpointing: 510 | (hidden_seq, 511 | attention_score_1, 512 | hidden_smiles, 513 | attention_score_2) = checkpoint( 514 | cross, 515 | attn_1, 516 | attn_2, 517 | hidden_seq, 518 | hidden_smiles, 519 | attention_mask_1, 520 | attention_mask_2, 521 | self.query_chunk_size_seq, 522 | self.key_chunk_size_seq, 523 | self.query_chunk_size_smiles, 524 | self.key_chunk_size_smiles, 525 | ) 526 | else: 527 | (hidden_seq, 528 | attention_score_1, 529 | hidden_smiles, 530 | attention_score_2) = cross( 531 | attn_1, 532 | attn_2, 533 | hidden_seq, 534 | hidden_smiles, 535 | attention_mask_1, 536 | attention_mask_2, 537 | self.query_chunk_size_seq, 538 | self.key_chunk_size_seq, 539 | self.query_chunk_size_smiles, 540 | self.key_chunk_size_smiles, 541 | ) 542 | 543 | # concatenate attention heads 544 | pair_representation_cross = torch.cat((attention_score_1, torch.transpose(attention_score_2, 2, 3)), dim=1) 545 | 546 | outputs = dict() 547 | outputs['pair_representation_cross'] = pair_representation_cross 548 | outputs['hidden_seq'] = hidden_seq 549 | outputs['hidden_smiles'] = hidden_smiles 550 | 551 | if not return_dict: 552 | outputs = tuple(outputs.values()) 553 | 554 | return outputs 555 | 556 | def gradient_checkpointing_enable(self): 557 | self.gradient_checkpointing = True 558 | self.embedding.gradient_checkpointing_enable() 559 | 560 | def gradient_checkpointing_disable(self): 561 | self.gradient_checkpointing = False 562 | self.embedding.gradient_checkpointing_disable() 563 | -------------------------------------------------------------------------------- /contact_pred/residue_constants.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | def _make_rigid_transformation_4x4(ex, ey, translation): 5 | """Create a rigid 4x4 transformation matrix from two axes and transl.""" 6 | # Normalize ex. 7 | ex_normalized = ex / np.linalg.norm(ex) 8 | 9 | # make ey perpendicular to ex 10 | ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized 11 | ey_normalized /= np.linalg.norm(ey_normalized) 12 | 13 | # compute ez as cross product 14 | eznorm = np.cross(ex_normalized, ey_normalized) 15 | m = np.stack( 16 | [ex_normalized, ey_normalized, eznorm, translation] 17 | ).transpose() 18 | m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0) 19 | return m 20 | 21 | restypes = [ 22 | "A", "R", "N", "D", 23 | "C", "Q", "E", "G", 24 | "H", "I", "L", "K", 25 | "M", "F", "P", "S", 26 | "T", "W", "Y", "V", 27 | ] 28 | 29 | restype_order = {restype: i for i, restype in enumerate(restypes)} 30 | restype_num = len(restypes) # := 20. 31 | unk_restype_index = restype_num # Catch-all index for unknown restypes. 32 | 33 | restypes_with_x = restypes + ["X"] 34 | restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} 35 | 36 | 37 | 38 | restype_1to3 = { 39 | "A": "ALA", 40 | "R": "ARG", 41 | "N": "ASN", 42 | "D": "ASP", 43 | "C": "CYS", 44 | "Q": "GLN", 45 | "E": "GLU", 46 | "G": "GLY", 47 | "H": "HIS", 48 | "I": "ILE", 49 | "L": "LEU", 50 | "K": "LYS", 51 | "M": "MET", 52 | "F": "PHE", 53 | "P": "PRO", 54 | "S": "SER", 55 | "T": "THR", 56 | "W": "TRP", 57 | "Y": "TYR", 58 | "V": "VAL", 59 | } 60 | 61 | rigid_group_atom_positions = { 62 | "ALA": [ 63 | ["N", 0, (-0.525, 1.363, 0.000)], 64 | ["CA", 0, (0.000, 0.000, 0.000)], 65 | ["C", 0, (1.526, -0.000, -0.000)], 66 | ["CB", 0, (-0.529, -0.774, -1.205)], 67 | ["O", 3, (0.627, 1.062, 0.000)], 68 | ], 69 | "ARG": [ 70 | ["N", 0, (-0.524, 1.362, -0.000)], 71 | ["CA", 0, (0.000, 0.000, 0.000)], 72 | ["C", 0, (1.525, -0.000, -0.000)], 73 | ["CB", 0, (-0.524, -0.778, -1.209)], 74 | ["O", 3, (0.626, 1.062, 0.000)], 75 | ["CG", 4, (0.616, 1.390, -0.000)], 76 | ["CD", 5, (0.564, 1.414, 0.000)], 77 | ["NE", 6, (0.539, 1.357, -0.000)], 78 | ["NH1", 7, (0.206, 2.301, 0.000)], 79 | ["NH2", 7, (2.078, 0.978, -0.000)], 80 | ["CZ", 7, (0.758, 1.093, -0.000)], 81 | ], 82 | "ASN": [ 83 | ["N", 0, (-0.536, 1.357, 0.000)], 84 | ["CA", 0, (0.000, 0.000, 0.000)], 85 | ["C", 0, (1.526, -0.000, -0.000)], 86 | ["CB", 0, (-0.531, -0.787, -1.200)], 87 | ["O", 3, (0.625, 1.062, 0.000)], 88 | ["CG", 4, (0.584, 1.399, 0.000)], 89 | ["ND2", 5, (0.593, -1.188, 0.001)], 90 | ["OD1", 5, (0.633, 1.059, 0.000)], 91 | ], 92 | "ASP": [ 93 | ["N", 0, (-0.525, 1.362, -0.000)], 94 | ["CA", 0, (0.000, 0.000, 0.000)], 95 | ["C", 0, (1.527, 0.000, -0.000)], 96 | ["CB", 0, (-0.526, -0.778, -1.208)], 97 | ["O", 3, (0.626, 1.062, -0.000)], 98 | ["CG", 4, (0.593, 1.398, -0.000)], 99 | ["OD1", 5, (0.610, 1.091, 0.000)], 100 | ["OD2", 5, (0.592, -1.101, -0.003)], 101 | ], 102 | "CYS": [ 103 | ["N", 0, (-0.522, 1.362, -0.000)], 104 | ["CA", 0, (0.000, 0.000, 0.000)], 105 | ["C", 0, (1.524, 0.000, 0.000)], 106 | ["CB", 0, (-0.519, -0.773, -1.212)], 107 | ["O", 3, (0.625, 1.062, -0.000)], 108 | ["SG", 4, (0.728, 1.653, 0.000)], 109 | ], 110 | "GLN": [ 111 | ["N", 0, (-0.526, 1.361, -0.000)], 112 | ["CA", 0, (0.000, 0.000, 0.000)], 113 | ["C", 0, (1.526, 0.000, 0.000)], 114 | ["CB", 0, (-0.525, -0.779, -1.207)], 115 | ["O", 3, (0.626, 1.062, -0.000)], 116 | ["CG", 4, (0.615, 1.393, 0.000)], 117 | ["CD", 5, (0.587, 1.399, -0.000)], 118 | ["NE2", 6, (0.593, -1.189, -0.001)], 119 | ["OE1", 6, (0.634, 1.060, 0.000)], 120 | ], 121 | "GLU": [ 122 | ["N", 0, (-0.528, 1.361, 0.000)], 123 | ["CA", 0, (0.000, 0.000, 0.000)], 124 | ["C", 0, (1.526, -0.000, -0.000)], 125 | ["CB", 0, (-0.526, -0.781, -1.207)], 126 | ["O", 3, (0.626, 1.062, 0.000)], 127 | ["CG", 4, (0.615, 1.392, 0.000)], 128 | ["CD", 5, (0.600, 1.397, 0.000)], 129 | ["OE1", 6, (0.607, 1.095, -0.000)], 130 | ["OE2", 6, (0.589, -1.104, -0.001)], 131 | ], 132 | "GLY": [ 133 | ["N", 0, (-0.572, 1.337, 0.000)], 134 | ["CA", 0, (0.000, 0.000, 0.000)], 135 | ["C", 0, (1.517, -0.000, -0.000)], 136 | ["O", 3, (0.626, 1.062, -0.000)], 137 | ], 138 | "HIS": [ 139 | ["N", 0, (-0.527, 1.360, 0.000)], 140 | ["CA", 0, (0.000, 0.000, 0.000)], 141 | ["C", 0, (1.525, 0.000, 0.000)], 142 | ["CB", 0, (-0.525, -0.778, -1.208)], 143 | ["O", 3, (0.625, 1.063, 0.000)], 144 | ["CG", 4, (0.600, 1.370, -0.000)], 145 | ["CD2", 5, (0.889, -1.021, 0.003)], 146 | ["ND1", 5, (0.744, 1.160, -0.000)], 147 | ["CE1", 5, (2.030, 0.851, 0.002)], 148 | ["NE2", 5, (2.145, -0.466, 0.004)], 149 | ], 150 | "ILE": [ 151 | ["N", 0, (-0.493, 1.373, -0.000)], 152 | ["CA", 0, (0.000, 0.000, 0.000)], 153 | ["C", 0, (1.527, -0.000, -0.000)], 154 | ["CB", 0, (-0.536, -0.793, -1.213)], 155 | ["O", 3, (0.627, 1.062, -0.000)], 156 | ["CG1", 4, (0.534, 1.437, -0.000)], 157 | ["CG2", 4, (0.540, -0.785, -1.199)], 158 | ["CD1", 5, (0.619, 1.391, 0.000)], 159 | ], 160 | "LEU": [ 161 | ["N", 0, (-0.520, 1.363, 0.000)], 162 | ["CA", 0, (0.000, 0.000, 0.000)], 163 | ["C", 0, (1.525, -0.000, -0.000)], 164 | ["CB", 0, (-0.522, -0.773, -1.214)], 165 | ["O", 3, (0.625, 1.063, -0.000)], 166 | ["CG", 4, (0.678, 1.371, 0.000)], 167 | ["CD1", 5, (0.530, 1.430, -0.000)], 168 | ["CD2", 5, (0.535, -0.774, 1.200)], 169 | ], 170 | "LYS": [ 171 | ["N", 0, (-0.526, 1.362, -0.000)], 172 | ["CA", 0, (0.000, 0.000, 0.000)], 173 | ["C", 0, (1.526, 0.000, 0.000)], 174 | ["CB", 0, (-0.524, -0.778, -1.208)], 175 | ["O", 3, (0.626, 1.062, -0.000)], 176 | ["CG", 4, (0.619, 1.390, 0.000)], 177 | ["CD", 5, (0.559, 1.417, 0.000)], 178 | ["CE", 6, (0.560, 1.416, 0.000)], 179 | ["NZ", 7, (0.554, 1.387, 0.000)], 180 | ], 181 | "MET": [ 182 | ["N", 0, (-0.521, 1.364, -0.000)], 183 | ["CA", 0, (0.000, 0.000, 0.000)], 184 | ["C", 0, (1.525, 0.000, 0.000)], 185 | ["CB", 0, (-0.523, -0.776, -1.210)], 186 | ["O", 3, (0.625, 1.062, -0.000)], 187 | ["CG", 4, (0.613, 1.391, -0.000)], 188 | ["SD", 5, (0.703, 1.695, 0.000)], 189 | ["CE", 6, (0.320, 1.786, -0.000)], 190 | ], 191 | "PHE": [ 192 | ["N", 0, (-0.518, 1.363, 0.000)], 193 | ["CA", 0, (0.000, 0.000, 0.000)], 194 | ["C", 0, (1.524, 0.000, -0.000)], 195 | ["CB", 0, (-0.525, -0.776, -1.212)], 196 | ["O", 3, (0.626, 1.062, -0.000)], 197 | ["CG", 4, (0.607, 1.377, 0.000)], 198 | ["CD1", 5, (0.709, 1.195, -0.000)], 199 | ["CD2", 5, (0.706, -1.196, 0.000)], 200 | ["CE1", 5, (2.102, 1.198, -0.000)], 201 | ["CE2", 5, (2.098, -1.201, -0.000)], 202 | ["CZ", 5, (2.794, -0.003, -0.001)], 203 | ], 204 | "PRO": [ 205 | ["N", 0, (-0.566, 1.351, -0.000)], 206 | ["CA", 0, (0.000, 0.000, 0.000)], 207 | ["C", 0, (1.527, -0.000, 0.000)], 208 | ["CB", 0, (-0.546, -0.611, -1.293)], 209 | ["O", 3, (0.621, 1.066, 0.000)], 210 | ["CG", 4, (0.382, 1.445, 0.0)], 211 | # ['CD', 5, (0.427, 1.440, 0.0)], 212 | ["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger 213 | ], 214 | "SER": [ 215 | ["N", 0, (-0.529, 1.360, -0.000)], 216 | ["CA", 0, (0.000, 0.000, 0.000)], 217 | ["C", 0, (1.525, -0.000, -0.000)], 218 | ["CB", 0, (-0.518, -0.777, -1.211)], 219 | ["O", 3, (0.626, 1.062, -0.000)], 220 | ["OG", 4, (0.503, 1.325, 0.000)], 221 | ], 222 | "THR": [ 223 | ["N", 0, (-0.517, 1.364, 0.000)], 224 | ["CA", 0, (0.000, 0.000, 0.000)], 225 | ["C", 0, (1.526, 0.000, -0.000)], 226 | ["CB", 0, (-0.516, -0.793, -1.215)], 227 | ["O", 3, (0.626, 1.062, 0.000)], 228 | ["CG2", 4, (0.550, -0.718, -1.228)], 229 | ["OG1", 4, (0.472, 1.353, 0.000)], 230 | ], 231 | "TRP": [ 232 | ["N", 0, (-0.521, 1.363, 0.000)], 233 | ["CA", 0, (0.000, 0.000, 0.000)], 234 | ["C", 0, (1.525, -0.000, 0.000)], 235 | ["CB", 0, (-0.523, -0.776, -1.212)], 236 | ["O", 3, (0.627, 1.062, 0.000)], 237 | ["CG", 4, (0.609, 1.370, -0.000)], 238 | ["CD1", 5, (0.824, 1.091, 0.000)], 239 | ["CD2", 5, (0.854, -1.148, -0.005)], 240 | ["CE2", 5, (2.186, -0.678, -0.007)], 241 | ["CE3", 5, (0.622, -2.530, -0.007)], 242 | ["NE1", 5, (2.140, 0.690, -0.004)], 243 | ["CH2", 5, (3.028, -2.890, -0.013)], 244 | ["CZ2", 5, (3.283, -1.543, -0.011)], 245 | ["CZ3", 5, (1.715, -3.389, -0.011)], 246 | ], 247 | "TYR": [ 248 | ["N", 0, (-0.522, 1.362, 0.000)], 249 | ["CA", 0, (0.000, 0.000, 0.000)], 250 | ["C", 0, (1.524, -0.000, -0.000)], 251 | ["CB", 0, (-0.522, -0.776, -1.213)], 252 | ["O", 3, (0.627, 1.062, -0.000)], 253 | ["CG", 4, (0.607, 1.382, -0.000)], 254 | ["CD1", 5, (0.716, 1.195, -0.000)], 255 | ["CD2", 5, (0.713, -1.194, -0.001)], 256 | ["CE1", 5, (2.107, 1.200, -0.002)], 257 | ["CE2", 5, (2.104, -1.201, -0.003)], 258 | ["OH", 5, (4.168, -0.002, -0.005)], 259 | ["CZ", 5, (2.791, -0.001, -0.003)], 260 | ], 261 | "VAL": [ 262 | ["N", 0, (-0.494, 1.373, -0.000)], 263 | ["CA", 0, (0.000, 0.000, 0.000)], 264 | ["C", 0, (1.527, -0.000, -0.000)], 265 | ["CB", 0, (-0.533, -0.795, -1.213)], 266 | ["O", 3, (0.627, 1.062, -0.000)], 267 | ["CG1", 4, (0.540, 1.429, -0.000)], 268 | ["CG2", 4, (0.533, -0.776, 1.203)], 269 | ], 270 | } 271 | 272 | # Format: The list for each AA type contains chi1, chi2, chi3, chi4 in 273 | # this order (or a relevant subset from chi1 onwards). ALA and GLY don't have 274 | # chi angles so their chi angle lists are empty. 275 | chi_angles_atoms = { 276 | "ALA": [], 277 | # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. 278 | "ARG": [ 279 | ["N", "CA", "CB", "CG"], 280 | ["CA", "CB", "CG", "CD"], 281 | ["CB", "CG", "CD", "NE"], 282 | ["CG", "CD", "NE", "CZ"], 283 | ], 284 | "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], 285 | "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], 286 | "CYS": [["N", "CA", "CB", "SG"]], 287 | "GLN": [ 288 | ["N", "CA", "CB", "CG"], 289 | ["CA", "CB", "CG", "CD"], 290 | ["CB", "CG", "CD", "OE1"], 291 | ], 292 | "GLU": [ 293 | ["N", "CA", "CB", "CG"], 294 | ["CA", "CB", "CG", "CD"], 295 | ["CB", "CG", "CD", "OE1"], 296 | ], 297 | "GLY": [], 298 | "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]], 299 | "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]], 300 | "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], 301 | "LYS": [ 302 | ["N", "CA", "CB", "CG"], 303 | ["CA", "CB", "CG", "CD"], 304 | ["CB", "CG", "CD", "CE"], 305 | ["CG", "CD", "CE", "NZ"], 306 | ], 307 | "MET": [ 308 | ["N", "CA", "CB", "CG"], 309 | ["CA", "CB", "CG", "SD"], 310 | ["CB", "CG", "SD", "CE"], 311 | ], 312 | "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], 313 | "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]], 314 | "SER": [["N", "CA", "CB", "OG"]], 315 | "THR": [["N", "CA", "CB", "OG1"]], 316 | "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], 317 | "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], 318 | "VAL": [["N", "CA", "CB", "CG1"]], 319 | } 320 | 321 | chi_angles_mask = [ 322 | [0.0, 0.0, 0.0, 0.0], # ALA 323 | [1.0, 1.0, 1.0, 1.0], # ARG 324 | [1.0, 1.0, 0.0, 0.0], # ASN 325 | [1.0, 1.0, 0.0, 0.0], # ASP 326 | [1.0, 0.0, 0.0, 0.0], # CYS 327 | [1.0, 1.0, 1.0, 0.0], # GLN 328 | [1.0, 1.0, 1.0, 0.0], # GLU 329 | [0.0, 0.0, 0.0, 0.0], # GLY 330 | [1.0, 1.0, 0.0, 0.0], # HIS 331 | [1.0, 1.0, 0.0, 0.0], # ILE 332 | [1.0, 1.0, 0.0, 0.0], # LEU 333 | [1.0, 1.0, 1.0, 1.0], # LYS 334 | [1.0, 1.0, 1.0, 0.0], # MET 335 | [1.0, 1.0, 0.0, 0.0], # PHE 336 | [1.0, 1.0, 0.0, 0.0], # PRO 337 | [1.0, 0.0, 0.0, 0.0], # SER 338 | [1.0, 0.0, 0.0, 0.0], # THR 339 | [1.0, 1.0, 0.0, 0.0], # TRP 340 | [1.0, 1.0, 0.0, 0.0], # TYR 341 | [1.0, 0.0, 0.0, 0.0], # VAL 342 | ] 343 | 344 | 345 | # This mapping is used when we need to store atom data in a format that requires 346 | # fixed atom data size for every residue (e.g. a numpy array). 347 | atom_types = [ 348 | "N", "CA", "C", "CB", "O", 349 | "CG", "CG1", "CG2", "OG", "OG1", 350 | "SG", "CD", "CD1", "CD2", "ND1", 351 | "ND2", "OD1", "OD2", "SD", "CE", 352 | "CE1", "CE2", "CE3", "NE", "NE1", 353 | "NE2", "OE1", "OE2", "CH2", "NH1", 354 | "NH2", "OH", "CZ", "CZ2", "CZ3", 355 | "NZ", "OXT", 356 | ] 357 | atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} 358 | 359 | # A compact atom encoding with 14 columns 360 | restype_name_to_atom14_names = { 361 | "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""], 362 | "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2", "", "", "",], 363 | "ASN": [ "N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", "",], 364 | "ASP": [ "N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", "",], 365 | "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""], 366 | "GLN": [ "N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", "",], 367 | "GLU": [ "N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", "",], 368 | "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""], 369 | "HIS": [ "N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2", "", "", "", "",], 370 | "ILE": [ "N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", "",], 371 | "LEU": [ "N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", "",], 372 | "LYS": [ "N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", "",], 373 | "MET": [ "N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", "",], 374 | "PHE": [ "N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "", "", "",], 375 | "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""], 376 | "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""], 377 | "THR": [ "N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", "",], 378 | "TRP": [ "N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2",], 379 | "TYR": [ "N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH", "", "",], 380 | "VAL": [ "N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", "",], 381 | "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""], 382 | } 383 | 384 | restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int) 385 | restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) 386 | restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) 387 | restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) 388 | 389 | 390 | def _make_rigid_group_constants(): 391 | """Fill the arrays above.""" 392 | 393 | for restype, restype_letter in enumerate(restypes): 394 | resname = restype_1to3[restype_letter] 395 | for atomname, group_idx, atom_position in rigid_group_atom_positions[ 396 | resname 397 | ]: 398 | atomtype = atom_order[atomname] 399 | #restype_atom37_to_rigid_group[restype, atomtype] = group_idx 400 | #restype_atom37_mask[restype, atomtype] = 1 401 | #restype_atom37_rigid_group_positions[ 402 | # restype, atomtype, : 403 | #] = atom_position 404 | 405 | atom14idx = restype_name_to_atom14_names[resname].index(atomname) 406 | restype_atom14_to_rigid_group[restype, atom14idx] = group_idx 407 | restype_atom14_mask[restype, atom14idx] = 1 408 | restype_atom14_rigid_group_positions[ 409 | restype, atom14idx, : 410 | ] = atom_position 411 | 412 | 413 | for restype, restype_letter in enumerate(restypes): 414 | resname = restype_1to3[restype_letter] 415 | atom_positions = { 416 | name: np.array(pos) 417 | for name, _, pos in rigid_group_atom_positions[resname] 418 | } 419 | 420 | # backbone to backbone is the identity transform 421 | restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) 422 | 423 | # pre-omega-frame to backbone (currently dummy identity matrix) 424 | restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) 425 | 426 | # phi-frame to backbone 427 | mat = _make_rigid_transformation_4x4( 428 | ex=atom_positions["N"] - atom_positions["CA"], 429 | ey=np.array([1.0, 0.0, 0.0]), 430 | translation=atom_positions["N"], 431 | ) 432 | restype_rigid_group_default_frame[restype, 2, :, :] = mat 433 | 434 | # psi-frame to backbone 435 | mat = _make_rigid_transformation_4x4( 436 | ex=atom_positions["C"] - atom_positions["CA"], 437 | ey=atom_positions["CA"] - atom_positions["N"], 438 | translation=atom_positions["C"], 439 | ) 440 | restype_rigid_group_default_frame[restype, 3, :, :] = mat 441 | 442 | # chi1-frame to backbone 443 | if chi_angles_mask[restype][0]: 444 | base_atom_names = chi_angles_atoms[resname][0] 445 | base_atom_positions = [ 446 | atom_positions[name] for name in base_atom_names 447 | ] 448 | mat = _make_rigid_transformation_4x4( 449 | ex=base_atom_positions[2] - base_atom_positions[1], 450 | ey=base_atom_positions[0] - base_atom_positions[1], 451 | translation=base_atom_positions[2], 452 | ) 453 | restype_rigid_group_default_frame[restype, 4, :, :] = mat 454 | 455 | # chi2-frame to chi1-frame 456 | # chi3-frame to chi2-frame 457 | # chi4-frame to chi3-frame 458 | # luckily all rotation axes for the next frame start at (0,0,0) of the 459 | # previous frame 460 | for chi_idx in range(1, 4): 461 | if chi_angles_mask[restype][chi_idx]: 462 | axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] 463 | axis_end_atom_position = atom_positions[axis_end_atom_name] 464 | mat = _make_rigid_transformation_4x4( 465 | ex=axis_end_atom_position, 466 | ey=np.array([-1.0, 0.0, 0.0]), 467 | translation=axis_end_atom_position, 468 | ) 469 | restype_rigid_group_default_frame[ 470 | restype, 4 + chi_idx, :, : 471 | ] = mat 472 | 473 | 474 | _make_rigid_group_constants() 475 | 476 | #print('restype_atom14_to_rigid_group') 477 | #print(restype_atom14_to_rigid_group) 478 | #print('restype_atom14_mask') 479 | #print(restype_atom14_mask) 480 | #print('restype_atom14_rigid_group_positions') 481 | #print(restype_atom14_rigid_group_positions) 482 | #print('restype_rigid_group_default_frame') 483 | #print(restype_rigid_group_default_frame) 484 | -------------------------------------------------------------------------------- /contact_pred/structure.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from transformers import PretrainedConfig, BertConfig 5 | 6 | from .IPCA import InvariantPointCrossAttention, LinearMemInvariantPointCrossAttention 7 | from .building_blocks import MLP 8 | 9 | from torch.utils.checkpoint import checkpoint 10 | from packaging import version 11 | from .residue_constants import ( 12 | restype_rigid_group_default_frame, 13 | restype_atom14_to_rigid_group, 14 | restype_atom14_mask, 15 | restype_atom14_rigid_group_positions, 16 | ) 17 | 18 | 19 | 20 | class IPAConfig(PretrainedConfig): 21 | def __init__( 22 | self, 23 | bert_config = BertConfig(), 24 | num_ipa_heads=12, 25 | num_points=8, 26 | num_intermediate_layers=3, 27 | query_chunk_size=512, 28 | key_chunk_size=1024, 29 | **kwargs, 30 | ): 31 | 32 | self.bert_config = bert_config 33 | if isinstance(bert_config, BertConfig): 34 | self.bert_config = self.bert_config.to_dict() 35 | 36 | self.num_ipa_heads = num_ipa_heads 37 | self.num_points = num_points 38 | self.num_intermediate_layers = 3 39 | 40 | self.query_chunk_size=query_chunk_size 41 | self.key_chunk_size=key_chunk_size 42 | 43 | super().__init__(**kwargs) 44 | 45 | def ijk_to_R(ijk): 46 | # make rotation matrix from vector component of unnormalized quaternion 47 | one = torch.ones(*(ijk.shape[:-1] + (1,)), dtype=ijk.dtype, device=ijk.device) 48 | rijk = torch.cat([one, ijk], dim=-1) 49 | 50 | r, i, j, k = torch.unbind(rijk, -1) 51 | two_s = 2.0 / (rijk * rijk).sum(-1) 52 | 53 | # convert to rotation matrix 54 | o = torch.stack( 55 | ( 56 | 1 - two_s * (j * j + k * k), 57 | two_s * (i * j - k * r), 58 | two_s * (i * k + j * r), 59 | two_s * (i * j + k * r), 60 | 1 - two_s * (i * i + k * k), 61 | two_s * (j * k - i * r), 62 | two_s * (i * k - j * r), 63 | two_s * (j * k + i * r), 64 | 1 - two_s * (i * i + j * j), 65 | ), 66 | -1, 67 | ) 68 | 69 | return o.reshape(rijk.shape[:-1] + (3, 3)) 70 | 71 | class FrameTranslation(nn.Module): 72 | # Similar to Alg 23 currently without the quaternion part 73 | def __init__(self, config): 74 | super().__init__() 75 | self.linear = nn.Linear(config.bert_config['hidden_size'], 3) # x,y,z 76 | 77 | def forward(self, hidden_states): 78 | translation = self.linear(hidden_states) 79 | 80 | return translation 81 | 82 | class FrameRotation(nn.Module): 83 | # Similar to Alg 23 currently without the quaternion part 84 | def __init__(self, config): 85 | super().__init__() 86 | self.linear = nn.Linear(config.bert_config['hidden_size'], 3) # x,y,z 87 | 88 | def forward(self, hidden_states): 89 | ijk = self.linear(hidden_states) 90 | 91 | return ijk_to_R(ijk) 92 | 93 | def compute_weighted_RMSD(T_true, T_pred, weight, unit=10.0, dclamp=10.0, eps = 1e-4, reduction='mean'): 94 | dist = torch.sqrt(((T_pred - T_true)**2).sum(-1) + eps) 95 | weight = weight.type(dist.dtype) 96 | if dclamp is not None: 97 | dist = torch.clamp(dist, max=dclamp) 98 | dist = weight*dist 99 | 100 | loss = torch.sum(dist, dim=-1) / weight.sum(dim=-1) / unit 101 | 102 | # mini-batch reduction 103 | if reduction == 'mean': 104 | loss = torch.mean(loss) 105 | elif reduction == 'gmean': 106 | loss = torch.exp(torch.mean(torch.log(loss))) 107 | else: 108 | raise ValueError 109 | 110 | return loss 111 | 112 | def compute_weighted_FAPE(x_true, t_true, R_true, x_pred, t_pred, R_pred, weight, unit=10.0, dclamp=10.0, eps = 1e-6, reduction='mean'): 113 | non_nan = (~torch.any(torch.isnan(x_true),dim=-1)).type(torch.int64)[:,:,None] 114 | x_true = torch.nan_to_num(x_true) 115 | 116 | outer_true = torch.einsum('bijmk,bjkl->bijml',x_true[:,:,None] - t_true[:,None,:,None,:], R_true) 117 | outer_pred = torch.einsum('bijmk,bjkl->bijml',x_pred[:,:,None] - t_pred[:,None,:,None,:], R_pred) 118 | 119 | dist = torch.sqrt(torch.sum((outer_pred - outer_true)**2, dim=-1) + eps) 120 | 121 | # ignore non-occupied coordinates in ground truth (they are set to nan) 122 | dist = torch.where(non_nan==0, torch.zeros_like(dist), dist) 123 | weight = weight[:,:,None,None]*weight[:,None,:,None]*non_nan 124 | 125 | weight = weight.type(dist.dtype) 126 | if dclamp is not None: 127 | dist = torch.clamp(dist, max=dclamp) 128 | dist = weight*dist 129 | 130 | loss = dist.sum([-3,-2,-1]) / weight.sum([-3,-2,-1]) / unit 131 | 132 | # mini-batch reduction 133 | if reduction == 'mean': 134 | loss = torch.mean(loss) 135 | elif reduction == 'gmean': 136 | loss = torch.exp(torch.mean(torch.log(loss))) 137 | else: 138 | raise ValueError 139 | 140 | return loss 141 | 142 | def compute_kabsch_RMSD(T_true, T_pred, weight, unit=10.0, dclamp=10.0, eps = 1e-4, reduction='mean'): 143 | # T_true, T_pred is (N_batch, N_max_len_of_batch, 3) 144 | # weight is the mask of size (N_batch, N_max_len_of_batch) 145 | 146 | # First center with mean coordinate of meaningful tokens, which involves scaling up the mean coordinates w.r.t mean of weight. 147 | # And then we zero the non-meaningful ones - zero coordinates in SVD will not affect the rotation matrix 148 | 149 | weight = weight.type(T_true.dtype) 150 | 151 | T_true_cm = (T_true * weight[:, :, None]).mean(axis=1)[:,None,:] / (weight.mean(1)[:, None, None]) 152 | T_pred_cm = (T_pred * weight[:, :, None]).mean(axis=1)[:,None,:] / (weight.mean(1)[:, None, None]) 153 | 154 | T_true_cen = T_true - T_true_cm 155 | T_pred_cen = T_pred - T_pred_cm 156 | 157 | # Kabsch method 158 | C = torch.matmul(torch.transpose(T_pred_cen * weight[:, :, None], -2, -1), T_true_cen * weight[:, :, None]) 159 | with torch.autocast('cuda',enabled=False): 160 | # svd doesn't support fp16/bf16 on AMD 161 | V, S, W = torch.linalg.svd(C.float()) 162 | V_prime = V * torch.tensor([1, 1, -1], device=V.device, dtype=V.dtype)[None, :, None] # Alternate version of V if d < 0 163 | d = (torch.linalg.det(V) * torch.linalg.det(W)) < 0.0 164 | 165 | V = torch.where(d[:, None, None], V_prime, V) # We have to broadcast d to select the right V's 166 | # Create Rotation matrix U 167 | U = torch.matmul(V, W) 168 | T_pred_align = torch.matmul(T_pred_cen, U) # Rotate P 169 | diff = T_pred_align - T_true_cen 170 | dist = torch.sqrt((diff**2).sum(-1) + eps) # (N_batch, N_max_len_of_batch) 171 | 172 | if dclamp is not None: 173 | dist = torch.clamp(dist, max=dclamp) 174 | dist = weight*dist 175 | 176 | loss = torch.sum(dist, dim=-1) / weight.sum(dim=-1) / unit 177 | 178 | # mini-batch reduction 179 | if reduction == 'mean': 180 | loss = torch.mean(loss) 181 | elif reduction == 'gmean': 182 | loss = torch.exp(torch.mean(torch.log(loss))) 183 | else: 184 | raise ValueError 185 | 186 | return loss 187 | 188 | def compute_residue_dist(T_pred, weight, unit=10.0, dclamp=10.0, eps = 1e-4, reduction='mean'): 189 | # T_pred is predicted xyz of frames(N_batch, N_max_len_of_batch, 3) 190 | # weight is the mask of size (N_batch, N_max_len_of_batch) 191 | 192 | ca_ca = 3.80209737096 # This is a constant, taken from OpenFold 193 | 194 | weight = weight.type(T_pred.dtype) 195 | weight = weight[:, 1:] * weight[:, :-1] 196 | 197 | diff = torch.square(torch.sqrt(((T_pred[:, 1:, :] - T_pred[:, :-1, :])**2).sum(-1)) - ca_ca) 198 | 199 | if dclamp is not None: 200 | dist = torch.clamp(dist, max=dclamp) 201 | dist = diff * weight 202 | 203 | loss = torch.sum(dist, dim=-1) / weight.sum(dim=-1) / unit 204 | 205 | # mini-batch reduction 206 | if reduction == 'mean': 207 | loss = torch.mean(loss) 208 | elif reduction == 'gmean': 209 | loss = torch.exp(torch.mean(torch.log(loss))) 210 | else: 211 | raise ValueError 212 | 213 | return loss 214 | 215 | def compute_residue_CN_dist(T_pred, weight, unit=10.0, dclamp=10.0, eps = 1e-4, reduction='mean', losstype='bottom'): 216 | # T_pred is predicted features (atom coordinates) of shape (N_batch, N_max_len_of_batch, 14, 3) 217 | # weight is the mask of size (N_batch, N_max_len_of_batch) 218 | # Loss type 'bottom' is copied from AlphaFold 2, where loss is flat bottom L1 loss max(|T - c_n - 12 * c_n_s|, 0) 219 | # Loss type 'square': loss is L2 loss square((T - c_n)) 220 | 221 | c_n = 1.3296 # This is a constant, taken from OpenFold (19/20 of 1.329 and 1/20 of 1.341 [for proline] 222 | c_n_s = 0.0141 # Weighted average of the sigma of bond length 223 | 224 | weight = weight.type(T_pred.dtype) 225 | weight = weight[:, 1:] * weight[:, :-1] 226 | 227 | # Calculate dist between C of res N and N of res N+1 228 | if losstype == 'square': 229 | diff = torch.square(torch.sqrt(((T_pred[:, 1:, 0, :] - T_pred[:, :-1, 2, :])**2).sum(-1) + eps) - c_n) 230 | elif losstype == 'bottom': 231 | diff = torch.maximum(torch.abs(torch.sqrt(((T_pred[:, 1:, 0, :] - T_pred[:, :-1, 2, :])**2).sum(-1) + eps) - c_n) - 12 * c_n_s, 232 | torch.zeros_like(weight)) 233 | 234 | if dclamp is not None: 235 | dist = torch.clamp(dist, max=dclamp) 236 | dist = diff * weight 237 | 238 | loss = torch.sum(dist, dim=-1) / weight.sum(dim=-1) / unit 239 | 240 | # mini-batch reduction 241 | if reduction == 'mean': 242 | loss = torch.mean(loss) 243 | elif reduction == 'gmean': 244 | loss = torch.exp(torch.mean(torch.log(loss))) 245 | else: 246 | raise ValueError 247 | 248 | return loss 249 | 250 | def compute_residue_CNC_angle(T_pred, weight, eps = 1e-4, reduction='mean'): 251 | # T_pred is predicted features (atom coordinates) of shape (N_batch, N_max_len_of_batch, 14, 3) 252 | # weight is the mask of size (N_batch, N_max_len_of_batch) 253 | # Loss type 'bottom' is copied from AlphaFold 2, where loss is flat bottom L1 loss max(|T - c_n - 12 * c_n_s|, 0) 254 | # Loss type 'square': loss is L2 loss square((T - c_n)) 255 | 256 | cos_c_n_ca = -0.51665 # This is a constant, taken from OpenFold (19/20 of -0.5203 and 1/20 of -0.0353 [for proline]) 257 | cos_c_n_ca_s = 0.03509 # Weighted average of the sigma of bond angle 258 | 259 | weight = weight.type(T_pred.dtype) 260 | weight = weight[:, 1:] * weight[:, :-1] 261 | 262 | # Calculate dist between C of res N, N of res N+1, and CA of res N+1 263 | diff = torch.maximum(torch.abs(angle_3point(T_pred[:, :-1, 2, :], T_pred[:, 1:, 0, :], T_pred[:, 1:, 1, :]) - cos_c_n_ca) - 12 * cos_c_n_ca_s, 264 | torch.zeros_like(weight)) 265 | 266 | dist = diff * weight 267 | 268 | loss = torch.sum(dist, dim=-1) / weight.sum(dim=-1) # this in radian 269 | 270 | # mini-batch reduction 271 | if reduction == 'mean': 272 | loss = torch.mean(loss) 273 | elif reduction == 'gmean': 274 | loss = torch.exp(torch.mean(torch.log(loss))) 275 | else: 276 | raise ValueError 277 | 278 | return loss 279 | 280 | class PlacementIteration(nn.Module): 281 | # Similar to AF2 FoldIteration (Alg 20, lines 6 - 10), 282 | # only that we're placing the ligand with a rigid protein frame 283 | def __init__(self, config, other_config=None, aniso=True, linear_mem=False, **kwargs): 284 | super().__init__() 285 | 286 | self.is_cross_attention = other_config is not None 287 | 288 | if not self.is_cross_attention: 289 | # self attention 290 | if linear_mem: 291 | self.ipa = LinearMemInvariantPointCrossAttention(config, config, **kwargs, is_cross_attention=False) 292 | else: 293 | self.ipa = InvariantPointCrossAttention(config, config, **kwargs, is_cross_attention=False) 294 | 295 | self.dropout_self = nn.Dropout(config.bert_config['hidden_dropout_prob']) 296 | self.norm_self = nn.LayerNorm(config.bert_config['hidden_size']) 297 | else: 298 | # cross attention 299 | if linear_mem: 300 | self.ipca = LinearMemInvariantPointCrossAttention(config, other_config, **kwargs) 301 | else: 302 | self.ipca = InvariantPointCrossAttention(config, other_config, **kwargs) 303 | 304 | self.dropout_cross = nn.Dropout(config.bert_config['hidden_dropout_prob']) 305 | self.norm_cross = nn.LayerNorm(config.bert_config['hidden_size']) 306 | 307 | self.intermediate = MLP(config.bert_config['hidden_size'], config.num_intermediate_layers) 308 | self.norm_intermediate = nn.LayerNorm(config.bert_config['hidden_size']) 309 | self.dropout_intermediate = nn.Dropout(config.bert_config['hidden_dropout_prob']) 310 | 311 | self.frame_translation = FrameTranslation(config) 312 | self.frame_rotation = FrameRotation(config) 313 | 314 | self.aniso = aniso 315 | 316 | def forward( 317 | self, 318 | hidden_states, 319 | attention_mask, 320 | pair_representation, 321 | rigid_rotations, 322 | rigid_translations, 323 | query_chunk_size = 1024, 324 | key_chunk_size = 4096, 325 | encoder_hidden_states=None, 326 | encoder_attention_mask=None, 327 | encoder_rigid_rotations=None, 328 | encoder_rigid_translations=None, 329 | ): 330 | if not self.is_cross_attention: 331 | ipa_output = self.ipa(hidden_states, 332 | attention_mask=attention_mask, 333 | pair_representation=pair_representation, 334 | rigid_rotations=rigid_rotations, 335 | rigid_translations=rigid_translations, 336 | query_chunk_size=query_chunk_size, 337 | key_chunk_size=key_chunk_size) 338 | 339 | hidden_states = self.norm_self(hidden_states + self.dropout_self(ipa_output)) 340 | else: 341 | ipca_output = self.ipca(hidden_states, 342 | encoder_hidden_states=encoder_hidden_states, 343 | encoder_attention_mask=encoder_attention_mask, 344 | pair_representation=pair_representation, 345 | rigid_rotations=rigid_rotations, 346 | rigid_translations=rigid_translations, 347 | encoder_rigid_rotations=encoder_rigid_rotations, 348 | encoder_rigid_translations=encoder_rigid_translations, 349 | query_chunk_size=query_chunk_size, 350 | key_chunk_size=key_chunk_size) 351 | 352 | hidden_states = self.norm_cross(hidden_states + self.dropout_cross(ipca_output)) 353 | 354 | # transition 355 | hidden_states = hidden_states + self.dropout_intermediate(self.intermediate(hidden_states)) 356 | hidden_states = self.norm_intermediate(hidden_states) 357 | 358 | # update spatial transforms 359 | if self.aniso: 360 | R = self.frame_rotation(hidden_states) 361 | else: 362 | # enforce rotation=const. 363 | R = self.frame_rotation(torch.mean(hidden_states,dim=1,keepdim=True)) 364 | 365 | if version.parse(torch.__version__) < version.parse('1.8.0'): 366 | # WAR for missing broadcast in einsum 367 | R = R.repeat(1,hidden_states.shape[1],1,1) 368 | 369 | rigid_translations = torch.einsum('bnij,bnj->bni', R, rigid_translations) 370 | rigid_rotations = torch.einsum('bnij,bnjk->bnik', R, rigid_rotations) 371 | 372 | rigid_translations = rigid_translations + self.frame_translation(hidden_states) 373 | 374 | return hidden_states, rigid_translations, rigid_rotations 375 | 376 | class SidechainRotation(nn.Module): 377 | # ResNet (Algorithm 20, lines 11-14) 378 | def __init__(self, config): 379 | super().__init__() 380 | 381 | self.initial_linear = nn.Linear(config.seq_config['hidden_size'], config.width_resnet) 382 | self.initial_relu = nn.ReLU() 383 | self.linear = nn.ModuleList([nn.Linear(config.width_resnet, config.width_resnet) for _ in range(config.depth_resnet-1)]) 384 | self.relu = nn.ModuleList([nn.ReLU() for _ in range(config.depth_resnet)]) 385 | self.final_linear = nn.Linear(config.width_resnet, config.num_rigid_groups-1) 386 | 387 | def forward(self, 388 | hidden_states, 389 | rotation_angles): 390 | act = self.initial_relu(self.initial_linear(hidden_states)) 391 | for linear, relu in zip(self.linear, self.relu): 392 | act = act + relu(linear(act)) 393 | rotation_angles = rotation_angles + self.final_linear(act) 394 | 395 | return rotation_angles 396 | 397 | def make_rot_X(alpha): 398 | # convert to rotation matrix 399 | o = torch.stack( 400 | ( 401 | torch.ones_like(alpha), torch.zeros_like(alpha), torch.zeros_like(alpha), 402 | torch.zeros_like(alpha), torch.cos(alpha), -torch.sin(alpha), 403 | torch.zeros_like(alpha), torch.sin(alpha), torch.cos(alpha), 404 | ), 405 | -1, 406 | ) 407 | return o.reshape(alpha.shape + (3, 3)) 408 | 409 | def angle_3point(T1, T2, T3, eps = 1e-4): 410 | # T1 is C, T2 is N, and T3 is CA 411 | # Each is (N_batch, N_max_len_of_batch-1, 1, 3) 412 | NtoC = T1 - T2 413 | CAtoN = T3 - T2 414 | angles = torch.einsum('bni, bnj -> bn', NtoC, CAtoN) / ((NtoC**2).sum(-1) + eps) / ((CAtoN**2).sum(-1) + eps) 415 | return angles # This in (N_batch, N_max_len_of_batch-1) 416 | 417 | def computeAllAtomCoordinates( 418 | # This is merging the torsion_angles_to_frames and 419 | # frames_and_literature_positions_to_atom14_pos functions in OpenFold. 420 | 421 | input_ids, # a form of aatypes, [*, N] 422 | frame_xyz, # [*, N, 3] 423 | frame_rot, # [*, N, 3, 3] 424 | rotation_angles, # [*, N, 7] 425 | default_frames, # intra-residue frames, from library, 4x4 matrices [21, 8, 4, 4] 426 | group_idx, # [21, 14] 427 | atom_mask, # [21, 14] 428 | lit_positions # relative locations based on intra-residue frames, [21, 14, 3] 429 | ): 430 | 431 | 432 | # Algorithm 24, but we use polar coordinates [0, 2*pi] 433 | # TODO maybe add a term for polar coordinates beyond this range 434 | 435 | default_4x4 = default_frames[input_ids, ...] # [*, N, 8, 4, 4] 436 | default_r = default_4x4[..., :3, :3] # [*, N, 8, 3, 3] 437 | default_t = default_4x4[..., :3, 3] # [*, N, 8, 3] 438 | # in OpenFold default_r is a Rigid class which includes both the default_r and default_t here 439 | 440 | # [1, 1, 1], zeros is correct as we want bb_rot to be zero radian 441 | bb_rot = rotation_angles.new_zeros((((1,) * len(rotation_angles.shape)))) 442 | 443 | # [*, N, 8] 444 | rotation_angles = torch.cat([bb_rot.expand(*rotation_angles.shape[:-1], -1), rotation_angles], dim=-1) 445 | 446 | #print(rotation_angles) 447 | 448 | # TODO Generate local rotation matrix through make_rot_X and alpha 449 | all_rots = make_rot_X(rotation_angles) # [*, N, 8, 3, 3] 450 | 451 | #print(all_rots) 452 | 453 | # TODO Rotate default rotation with local rotation matrix 454 | all_frames_r = torch.einsum('bnaij, bnajk -> bnaik', default_r, all_rots) 455 | all_frames_t = default_t 456 | #all_frames_r = torch.einsum('bnaij, bnajk -> bnaik', all_rots, default_r) 457 | #all_frames_t = torch.einsum('bnaij, bnaj -> bnai', all_rots, default_t) 458 | 459 | #print(all_frames_r[:, 1:-1], all_frames_t[:, 1:-1]) 460 | 461 | #all_frames_to_bb_r = all_frames_r 462 | #all_frames_to_bb_t = all_frames_t 463 | 464 | # TODO Calculate all frames to back bone 465 | # change frame-to-frame to frame-to-backbone starting from chi2, kind of tedious but will do 466 | chi2_frame_to_frame_r = all_frames_r[:, :, 5] 467 | chi2_frame_to_frame_t = all_frames_t[:, :, 5] 468 | chi3_frame_to_frame_r = all_frames_r[:, :, 6] 469 | chi3_frame_to_frame_t = all_frames_t[:, :, 6] 470 | chi4_frame_to_frame_r = all_frames_r[:, :, 7] 471 | chi4_frame_to_frame_t = all_frames_t[:, :, 7] 472 | 473 | chi1_frame_to_bb_r = all_frames_r[:, :, 4] # Gives [*, N, 3, 3] 474 | chi1_frame_to_bb_t = all_frames_t[:, :, 4] # Gives [*, N, 3] 475 | 476 | #chi2_frame_to_bb_r = torch.einsum('bnij, bnjk -> bnik', chi1_frame_to_bb_r, chi2_frame_to_frame_r) 477 | #chi2_frame_to_bb_t = torch.einsum('bnj, bnij -> bni', chi1_frame_to_bb_t, chi2_frame_to_frame_r) + chi2_frame_to_frame_t 478 | #chi3_frame_to_bb_r = torch.einsum('bnij, bnjk -> bnik', chi2_frame_to_bb_r, chi3_frame_to_frame_r) 479 | #chi3_frame_to_bb_t = torch.einsum('bnj, bnij -> bni', chi2_frame_to_bb_t, chi3_frame_to_frame_r) + chi3_frame_to_frame_t 480 | #chi4_frame_to_bb_r = torch.einsum('bnij, bnjk -> bnik', chi3_frame_to_bb_r, chi4_frame_to_frame_r) 481 | #chi4_frame_to_bb_t = torch.einsum('bnj, bnij -> bni', chi3_frame_to_bb_t, chi4_frame_to_frame_r) + chi4_frame_to_frame_t 482 | #chi2_frame_to_bb_r = torch.einsum('bnij, bnjk -> bnik', chi2_frame_to_frame_r, chi1_frame_to_bb_r) 483 | #chi2_frame_to_bb_t = torch.einsum('bnij, bnj -> bni', chi2_frame_to_frame_r, chi1_frame_to_bb_t) + chi2_frame_to_frame_t 484 | #chi3_frame_to_bb_r = torch.einsum('bnij, bnjk -> bnik', chi3_frame_to_frame_r, chi2_frame_to_bb_r) 485 | #chi3_frame_to_bb_t = torch.einsum('bnij, bnj -> bni', chi3_frame_to_frame_r, chi2_frame_to_bb_t) + chi3_frame_to_frame_t 486 | #chi4_frame_to_bb_r = torch.einsum('bnij, bnjk -> bnik', chi4_frame_to_frame_r, chi3_frame_to_bb_r) 487 | #chi4_frame_to_bb_t = torch.einsum('bnij, bnj -> bni', chi4_frame_to_frame_r, chi3_frame_to_bb_t) + chi4_frame_to_frame_t 488 | chi2_frame_to_bb_r = torch.einsum('bnij, bnjk -> bnik', chi1_frame_to_bb_r, chi2_frame_to_frame_r) 489 | chi2_frame_to_bb_t = torch.einsum('bnij, bnj -> bni', chi1_frame_to_bb_r, chi2_frame_to_frame_t) + chi1_frame_to_bb_t 490 | chi3_frame_to_bb_r = torch.einsum('bnij, bnjk -> bnik', chi2_frame_to_bb_r, chi3_frame_to_frame_r) 491 | chi3_frame_to_bb_t = torch.einsum('bnij, bnj -> bni', chi2_frame_to_bb_r, chi3_frame_to_frame_t) + chi2_frame_to_bb_t 492 | chi4_frame_to_bb_r = torch.einsum('bnij, bnjk -> bnik', chi3_frame_to_bb_r, chi4_frame_to_frame_r) 493 | chi4_frame_to_bb_t = torch.einsum('bnij, bnj -> bni', chi3_frame_to_bb_r, chi4_frame_to_frame_t) + chi3_frame_to_bb_t 494 | 495 | 496 | all_frames_to_bb_r = torch.cat( 497 | [ 498 | all_frames_r[:, :, :5], 499 | chi2_frame_to_bb_r.unsqueeze(2), 500 | chi3_frame_to_bb_r.unsqueeze(2), 501 | chi4_frame_to_bb_r.unsqueeze(2), 502 | ], 503 | dim=-3 504 | ) 505 | all_frames_to_bb_t = torch.cat( 506 | [ 507 | all_frames_t[:, :, :5], 508 | chi2_frame_to_bb_t.unsqueeze(2), 509 | chi3_frame_to_bb_t.unsqueeze(2), 510 | chi4_frame_to_bb_t.unsqueeze(2), 511 | ], 512 | dim=-2 513 | ) 514 | 515 | # TODO Calculate all frames to global 516 | all_frames_to_global_r = torch.einsum('bnij, bnajk -> bnaik', frame_rot, all_frames_to_bb_r) 517 | all_frames_to_global_t = torch.einsum('bnij, bnaj -> bnai', frame_rot, all_frames_to_bb_t) + frame_xyz[:,:,None,:] 518 | #all_frames_to_global_r = torch.einsum('bnaij, bnjk -> bnaik', all_frames_to_bb_r, frame_rot) 519 | #all_frames_to_global_t = torch.einsum('bnaij, bnj -> bnai', all_frames_to_bb_r, frame_xyz) + all_frames_to_bb_t 520 | 521 | # TODO Get masks and calculate all side chain atom locations 522 | # (this is a part of frames_and_literature_positions_to_atom14_pos) 523 | 524 | # [*, N, 14] 525 | group_mask = group_idx[input_ids, ...] 526 | 527 | # [*, N, 14, 8] 528 | group_mask = nn.functional.one_hot( 529 | group_mask, 530 | num_classes=default_frames.shape[-3] 531 | ) 532 | 533 | #print(all_frames_to_bb_r[:, 1:-1]) 534 | #print(all_frames_to_bb_t[:, 1:-1]) 535 | #print(group_mask[:, 1:-1]) 536 | 537 | #print(input_ids) 538 | 539 | # [*, N, 14, 8, 3, 3] 540 | t_atoms_to_global_r = all_frames_to_global_r[..., None, :, :, :] * group_mask[..., None, None] 541 | #t_atoms_to_global_r = all_frames_to_bb_r[..., None, :, :, :] * group_mask[..., None, None] 542 | # [*, N, 14, 8, 3] 543 | t_atoms_to_global_t = all_frames_to_global_t[..., None, :, :] * group_mask[..., None] 544 | #t_atoms_to_global_t = all_frames_to_bb_t[..., None, :, :] * group_mask[..., None] 545 | 546 | # [*, N, 14] 547 | atom_mask = atom_mask[input_ids, ...]#.unsqueeze(-1) 548 | 549 | # [*, N, 14, 3] 550 | lit_positions = lit_positions[input_ids, ...] 551 | # [*, N, 14, 8, 3] 552 | #pred_positions = torch.einsum('bnaj, bnakij -> bnaki', lit_positions, t_atoms_to_global_r) + t_atoms_to_global_t 553 | pred_positions = torch.einsum('bnakij, bnaj -> bnaki', t_atoms_to_global_r, lit_positions) + t_atoms_to_global_t 554 | pred_positions = pred_positions * atom_mask[..., None, None] 555 | # [*, N, 14, 3] 556 | atom_feat = pred_positions.sum(-2) 557 | #print(atom_feat.shape) 558 | 559 | return atom_feat 560 | 561 | class Structure(nn.Module): 562 | # Alg 20 563 | def __init__(self, config): 564 | super().__init__() 565 | 566 | 567 | self.num_layers = config.num_ipa_layers 568 | self.num_rigid_groups = config.num_rigid_groups 569 | 570 | seq_config = IPAConfig.from_dict(config.seq_ipa_config) 571 | smiles_config = IPAConfig.from_dict(config.smiles_ipa_config) 572 | 573 | self.initial_norm_seq = nn.LayerNorm(seq_config.bert_config['hidden_size']) 574 | self.initial_norm_smiles = nn.LayerNorm(smiles_config.bert_config['hidden_size']) 575 | 576 | self.receptor_self = torch.nn.ModuleList([PlacementIteration(seq_config, aniso=True, linear_mem=config.linear_mem_attn) 577 | for _ in range(self.num_layers)]) 578 | self.ligand_self = torch.nn.ModuleList([PlacementIteration(smiles_config, aniso=False, linear_mem=config.linear_mem_attn) 579 | for _ in range(self.num_layers)]) 580 | 581 | self.sidechain_self = torch.nn.ModuleList([SidechainRotation(config) for _ in range(self.num_layers)]) 582 | 583 | self.gradient_checkpointing = False 584 | 585 | # for linear mem point attention 586 | self.query_chunk_size_receptor = seq_config.query_chunk_size 587 | self.key_chunk_size_receptor = seq_config.key_chunk_size 588 | 589 | self.query_chunk_size_ligand = smiles_config.query_chunk_size 590 | self.key_chunk_size_ligand = smiles_config.key_chunk_size 591 | 592 | def freeze_protein(self): 593 | for param in self.initial_norm_seq.parameters(): 594 | param.requires_grad = False 595 | for param in self.receptor_self.parameters(): 596 | param.requires_grad = False 597 | for param in self.sidechain_rotation.parameters(): 598 | param.requires_grad = False 599 | 600 | def freeze_ligand(self): 601 | for param in self.initial_norm_smiles.parameters(): 602 | param.requires_grad = False 603 | for param in self.ligand_self.parameters(): 604 | param.requires_grad = False 605 | 606 | def forward( 607 | self, 608 | hidden_states_1, 609 | hidden_states_2, 610 | attention_mask_1, 611 | attention_mask_2, 612 | pair_representation_seq, 613 | pair_representation_smiles 614 | ): 615 | 616 | hidden_seq = self.initial_norm_seq(hidden_states_1) 617 | hidden_smiles = self.initial_norm_smiles(hidden_states_2) 618 | 619 | # "black-hole" initialization 620 | translations_receptor = torch.zeros(hidden_seq.size()[:2]+(3, ), 621 | device=hidden_seq.device, dtype=hidden_seq.dtype) 622 | rotations_receptor = torch.eye(3, 623 | device=hidden_seq.device, 624 | dtype=hidden_seq.dtype).repeat(hidden_seq.size()[0], hidden_seq.size()[1], 1, 1) 625 | 626 | translations_ligand = torch.zeros(hidden_smiles.size()[:2]+(3, ), 627 | device=hidden_smiles.device,dtype=hidden_smiles.dtype) 628 | rotations_ligand = torch.eye(3, 629 | device=hidden_smiles.device, 630 | dtype=hidden_smiles.dtype).repeat(hidden_smiles.size()[0], hidden_smiles.size()[1], 1, 1) 631 | 632 | # side chain rotations 633 | rotation_angles = torch.zeros(hidden_seq.size()[:2]+(self.num_rigid_groups-1,), device=hidden_seq.device, dtype=hidden_seq.dtype) 634 | 635 | # self interactions 636 | for (ligand_update, receptor_update, sidechain_rotation) in zip(self.ligand_self, self.receptor_self, self.sidechain_self): 637 | # receptor 638 | if self.gradient_checkpointing: 639 | hidden_seq, translations_receptor, rotations_receptor = checkpoint(receptor_update, 640 | hidden_seq, 641 | attention_mask_1, 642 | pair_representation_seq, 643 | rotations_receptor, 644 | translations_receptor, 645 | self.query_chunk_size_receptor, 646 | self.key_chunk_size_receptor, 647 | ) 648 | else: 649 | hidden_seq, translations_receptor, rotations_receptor = receptor_update( 650 | hidden_states=hidden_seq, 651 | attention_mask=attention_mask_1, 652 | pair_representation=pair_representation_seq, 653 | rigid_rotations=rotations_receptor, 654 | rigid_translations=translations_receptor, 655 | query_chunk_size=self.query_chunk_size_receptor, 656 | key_chunk_size=self.key_chunk_size_receptor, 657 | ) 658 | 659 | # update internal coordinates for protein 660 | rotation_angles = sidechain_rotation(hidden_seq, rotation_angles) 661 | # ligand 662 | if self.gradient_checkpointing: 663 | hidden_smiles, translations_ligand, rotations_ligand = checkpoint(ligand_update, 664 | hidden_smiles, 665 | attention_mask_2, 666 | pair_representation_smiles, 667 | rotations_ligand, 668 | translations_ligand, 669 | self.query_chunk_size_ligand, 670 | self.key_chunk_size_ligand, 671 | ) 672 | else: 673 | hidden_smiles, translations_ligand, rotations_ligand = ligand_update( 674 | hidden_states=hidden_smiles, 675 | attention_mask=attention_mask_2, 676 | pair_representation=pair_representation_smiles, 677 | rigid_rotations=rotations_ligand, 678 | rigid_translations=translations_ligand, 679 | query_chunk_size=self.query_chunk_size_ligand, 680 | key_chunk_size=self.key_chunk_size_ligand, 681 | ) 682 | 683 | return hidden_smiles, hidden_seq, translations_ligand, rotations_ligand, translations_receptor, rotations_receptor, rotation_angles 684 | 685 | def gradient_checkpointing_enable(self): 686 | self.gradient_checkpointing = True 687 | 688 | def gradient_checkpointing_disable(self): 689 | self.gradient_checkpointing = False 690 | 691 | class CrossStructure(nn.Module): 692 | # Alg 20 693 | def __init__(self, config): 694 | super().__init__() 695 | self.num_layers = config.num_ipa_layers 696 | self.num_rigid_groups = config.num_rigid_groups 697 | 698 | seq_config = IPAConfig.from_dict(config.seq_ipa_config) 699 | smiles_config = IPAConfig.from_dict(config.smiles_ipa_config) 700 | 701 | self.initial_norm_seq = nn.LayerNorm(seq_config.bert_config['hidden_size']) 702 | self.initial_norm_smiles = nn.LayerNorm(smiles_config.bert_config['hidden_size']) 703 | 704 | self.receptor_cross = torch.nn.ModuleList([PlacementIteration(seq_config, smiles_config, aniso=True, linear_mem=config.linear_mem_attn) 705 | for _ in range(self.num_layers)]) 706 | self.ligand_cross = torch.nn.ModuleList([PlacementIteration(smiles_config, seq_config, aniso=False, linear_mem=config.linear_mem_attn) 707 | for _ in range(self.num_layers)]) 708 | 709 | self.sidechain_cross = torch.nn.ModuleList([SidechainRotation(config) for _ in range(self.num_layers)]) 710 | 711 | self.gradient_checkpointing = False 712 | 713 | # for linear mem point attention 714 | self.query_chunk_size_receptor = seq_config.query_chunk_size 715 | self.key_chunk_size_receptor = seq_config.key_chunk_size 716 | 717 | self.query_chunk_size_ligand = smiles_config.query_chunk_size 718 | self.key_chunk_size_ligand = smiles_config.key_chunk_size 719 | 720 | def forward( 721 | self, 722 | hidden_states_1, 723 | hidden_states_2, 724 | attention_mask_1, 725 | attention_mask_2, 726 | rotations_receptor, 727 | translations_receptor, 728 | rotations_ligand, 729 | translations_ligand, 730 | pair_representation_cross, 731 | rotation_angles, 732 | ): 733 | 734 | hidden_seq = self.initial_norm_seq(hidden_states_1) 735 | hidden_smiles = self.initial_norm_smiles(hidden_states_2) 736 | 737 | def cross(update_1, 738 | update_2, 739 | hidden_states_1, 740 | hidden_states_2, 741 | attention_mask_1, 742 | attention_mask_2, 743 | pair_representation_cross, 744 | rigid_rotations_1, 745 | rigid_translations_1, 746 | rigid_rotations_2, 747 | rigid_translations_2, 748 | query_chunk_size_1, 749 | key_chunk_size_1, 750 | query_chunk_size_2, 751 | key_chunk_size_2, 752 | sidechain_rotation, 753 | rotation_angles): 754 | output_1 = update_1( 755 | hidden_states=hidden_states_1, 756 | attention_mask=attention_mask_1, 757 | pair_representation=pair_representation_cross, 758 | rigid_rotations=rigid_rotations_1, 759 | rigid_translations=rigid_translations_1, 760 | query_chunk_size=query_chunk_size_1, 761 | key_chunk_size=key_chunk_size_1, 762 | encoder_hidden_states=hidden_states_2, 763 | encoder_attention_mask=attention_mask_2, 764 | encoder_rigid_rotations=rigid_rotations_2, 765 | encoder_rigid_translations=rigid_translations_2, 766 | ) 767 | 768 | # update internal coordinates for protein 769 | rotation_angles = sidechain_rotation(output_1[0], rotation_angles) 770 | 771 | output_2 = update_2( 772 | hidden_states=hidden_states_2, 773 | attention_mask=attention_mask_2, 774 | pair_representation=pair_representation_cross.transpose(-2,-1), 775 | rigid_rotations=rigid_rotations_2, 776 | rigid_translations=rigid_translations_2, 777 | query_chunk_size=query_chunk_size_2, 778 | key_chunk_size=key_chunk_size_2, 779 | encoder_hidden_states=hidden_states_1, 780 | encoder_attention_mask=attention_mask_1, 781 | encoder_rigid_rotations=rigid_rotations_1, 782 | encoder_rigid_translations=rigid_translations_1, 783 | ) 784 | return output_1 + output_2 + (rotation_angles, ) 785 | 786 | for (ligand_update, receptor_update, sidechain_rotation) in zip(self.ligand_cross, self.receptor_cross, self.sidechain_cross): 787 | if self.gradient_checkpointing: 788 | cross_output = checkpoint( 789 | cross, 790 | receptor_update, 791 | ligand_update, 792 | hidden_seq, 793 | hidden_smiles, 794 | attention_mask_1, 795 | attention_mask_2, 796 | pair_representation_cross, 797 | rotations_receptor, 798 | translations_receptor, 799 | rotations_ligand, 800 | translations_ligand, 801 | self.query_chunk_size_receptor, 802 | self.key_chunk_size_receptor, 803 | self.query_chunk_size_ligand, 804 | self.key_chunk_size_ligand, 805 | sidechain_rotation, 806 | rotation_angles 807 | ) 808 | else: 809 | cross_output = cross( 810 | receptor_update, 811 | ligand_update, 812 | hidden_seq, 813 | hidden_smiles, 814 | attention_mask_1, 815 | attention_mask_2, 816 | pair_representation_cross, 817 | rotations_receptor, 818 | translations_receptor, 819 | rotations_ligand, 820 | translations_ligand, 821 | self.query_chunk_size_receptor, 822 | self.key_chunk_size_receptor, 823 | self.query_chunk_size_ligand, 824 | self.key_chunk_size_ligand, 825 | sidechain_rotation, 826 | rotation_angles 827 | ) 828 | hidden_seq, translations_receptor, rotations_receptor, \ 829 | hidden_smiles, translations_ligand, rotations_ligand, \ 830 | rotation_angles = cross_output 831 | 832 | return hidden_smiles, hidden_seq, translations_ligand, rotations_ligand, translations_receptor, rotations_receptor, rotation_angles 833 | 834 | def gradient_checkpointing_enable(self): 835 | self.gradient_checkpointing = True 836 | 837 | def gradient_checkpointing_disable(self): 838 | self.gradient_checkpointing = False 839 | -------------------------------------------------------------------------------- /contact_pred/training_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright 2020 The Microsoft DeepSpeed Team 3 | ''' 4 | 5 | import torch 6 | import logging 7 | import os 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | def mpi_discovery(distributed_port=29005, verbose=True): 12 | """ 13 | Discovery MPI environment via mpi4py and map to relevant torch.distributed state 14 | """ 15 | from mpi4py import MPI 16 | import subprocess 17 | comm = MPI.COMM_WORLD 18 | rank = comm.Get_rank() 19 | world_size = comm.Get_size() 20 | 21 | master_addr = None 22 | if rank == 0: 23 | hostname_cmd = ["hostname -I"] 24 | result = subprocess.check_output(hostname_cmd, shell=True) 25 | master_addr = result.decode('utf-8').split()[-1] 26 | master_addr = comm.bcast(master_addr, root=0) 27 | 28 | # Determine local rank by assuming hostnames are unique 29 | proc_name = MPI.Get_processor_name() 30 | all_procs = comm.allgather(proc_name) 31 | local_rank = sum([i == proc_name for i in all_procs[:rank]]) 32 | 33 | os.environ['RANK'] = str(rank) 34 | os.environ['WORLD_SIZE'] = str(world_size) 35 | os.environ['LOCAL_RANK'] = str(local_rank) 36 | os.environ['MASTER_ADDR'] = master_addr 37 | os.environ['MASTER_PORT'] = str(distributed_port) 38 | 39 | if verbose: 40 | logger.info( 41 | "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" 42 | .format(os.environ['RANK'], 43 | os.environ['LOCAL_RANK'], 44 | os.environ['WORLD_SIZE'], 45 | os.environ['MASTER_ADDR'], 46 | os.environ['MASTER_PORT'])) 47 | 48 | if torch.distributed.is_initialized(): 49 | assert torch.distributed.get_rank() == rank, "MPI rank {} does not match torch rank {}".format( 50 | rank, torch.distributed.get_rank()) 51 | assert torch.distributed.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( 52 | world_size, torch.distributed.get_world_size()) 53 | -------------------------------------------------------------------------------- /contact_pred/utils.py: -------------------------------------------------------------------------------- 1 | def get_extended_attention_mask(attention_mask, input_shape, device, dtype): 2 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 3 | # ourselves in which case we just need to make it broadcastable to all heads. 4 | if attention_mask.dim() == 3: 5 | extended_attention_mask = attention_mask[:, None, :, :] 6 | elif attention_mask.dim() == 2: 7 | extended_attention_mask = attention_mask[:, None, None, :] 8 | else: 9 | raise ValueError( 10 | f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" 11 | ) 12 | 13 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 14 | # masked positions, this operation will create a tensor which is 0.0 for 15 | # positions we want to attend and -10000.0 for masked positions. 16 | # Since we are adding it to the raw scores before the softmax, this is 17 | # effectively the same as removing these entirely. 18 | extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility 19 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 20 | return extended_attention_mask 21 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: /ccs/proj/bif136/crusher-env 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=4.5=1_gnu 7 | - ca-certificates=2022.4.26=h06a4308_0 8 | - certifi=2021.10.8=py38h06a4308_2 9 | - ld_impl_linux-64=2.35.1=h7274673_9 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=9.3.0=h5101ec6_17 12 | - libgomp=9.3.0=h5101ec6_17 13 | - libstdcxx-ng=9.3.0=hd4cf53a_17 14 | - ncurses=6.3=h7f8727e_2 15 | - openssl=1.1.1o=h7f8727e_0 16 | - pip=21.2.4=py38h06a4308_0 17 | - python=3.8.13=h12debd9_0 18 | - readline=8.1.2=h7f8727e_1 19 | - setuptools=61.2.0=py38h06a4308_0 20 | - sqlite=3.38.3=hc218d9a_0 21 | - tk=8.6.11=h1ccaba5_1 22 | - wheel=0.37.1=pyhd3eb1b0_0 23 | - xz=5.2.5=h7f8727e_1 24 | - zlib=1.2.12=h7f8727e_2 25 | - pip: 26 | - absl-py==1.0.0 27 | - aiohttp==3.8.1 28 | - aiosignal==1.2.0 29 | - async-timeout==4.0.2 30 | - attrs==21.4.0 31 | - cachetools==5.0.0 32 | - charset-normalizer==2.0.12 33 | - click==8.1.3 34 | - datasets==2.2.0 35 | - dill==0.3.4 36 | - filelock==3.6.0 37 | - frozenlist==1.3.0 38 | - fsspec==2022.3.0 39 | - google-auth==2.6.6 40 | - google-auth-oauthlib==0.4.6 41 | - grpcio==1.46.0 42 | - huggingface-hub==0.5.1 43 | - idna==3.3 44 | - importlib-metadata==4.11.3 45 | - joblib==1.1.0 46 | - markdown==3.3.7 47 | - multidict==6.0.2 48 | - multiprocess==0.70.12.2 49 | - numpy==1.22.3 50 | - oauthlib==3.2.0 51 | - packaging==21.3 52 | - pandas==1.4.2 53 | - protobuf==3.20.1 54 | - pyarrow==8.0.0 55 | - pyasn1==0.4.8 56 | - pyasn1-modules==0.2.8 57 | - pyparsing==3.0.9 58 | - python-dateutil==2.8.2 59 | - pytz==2022.1 60 | - pyyaml==6.0 61 | - regex==2022.4.24 62 | - requests==2.27.1 63 | - requests-oauthlib==1.3.1 64 | - responses==0.18.0 65 | - rsa==4.8 66 | - sacremoses==0.0.53 67 | - scikit-learn==1.0.2 68 | - scipy==1.8.0 69 | - six==1.16.0 70 | - tensorboard==2.9.0 71 | - tensorboard-data-server==0.6.1 72 | - tensorboard-plugin-wit==1.8.1 73 | - tensorboardx==2.5 74 | - threadpoolctl==3.1.0 75 | - tokenizers==0.12.1 76 | - torch==1.12.0a0+gitb30c027 77 | - tqdm==4.64.0 78 | - transformers==4.18.0 79 | - typing-extensions==4.2.0 80 | - urllib3==1.26.9 81 | - werkzeug==2.1.2 82 | - xxhash==3.0.0 83 | - yarl==1.7.2 84 | - zipp==3.8.0 85 | prefix: /ccs/proj/bif136/crusher-env 86 | -------------------------------------------------------------------------------- /eval_notebooks/eval_structure.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f112b673-987f-41f1-a602-7d6a25871132", 6 | "metadata": {}, 7 | "source": [ 8 | "**TwoFold_DL - inference**" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "91169a46-2387-4c69-82d3-3d8de1c4b863", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "fatal: destination path 'TwoFold_DL' already exists and is not an empty directory.\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "! git clone https://github.com/ORNL/TwoFold_DL" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "id": "871ca9b1-f28e-45ed-b433-6bfa0b26c30c", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import sys\n", 37 | "sys.path.append('TwoFold_DL/')" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "id": "d50ef6cb-5611-48c7-8451-636b88e1a5a8", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "import logging\n", 48 | "logging.disable(logging.INFO)\n", 49 | "logging.disable(logging.WARNING)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 5, 55 | "id": "19a1eafd-8a94-44b7-8562-2e132c5223be", 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "Requirement already satisfied: pip in /opt/conda/lib/python3.11/site-packages (23.2.1)\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "! pip install --upgrade pip" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "702b34ac-5af6-432e-8c0b-7113db540610", 74 | "metadata": { 75 | "scrolled": true 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "! conda install -q -c conda-forge -y Rust\n", 80 | "! pip install -q datasets\n", 81 | "! pip install -q transformers==4.18.0\n", 82 | "! pip install -q huggingface_hub\n", 83 | "! pip install -q rdkit\n", 84 | "! pip install -q biopython" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 11, 90 | "id": "922e2a06-ac1d-4674-b590-db25f95e3dc9", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "import torch\n", 95 | "from torch.utils.data import Dataset\n", 96 | "from huggingface_hub import hf_hub_download" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 12, 102 | "id": "aad9fba7-a015-403d-b722-2e70e0afa76b", 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "name": "stderr", 107 | "output_type": "stream", 108 | "text": [ 109 | "2023-09-26 16:54:47.911545: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", 110 | "2023-09-26 16:54:48.236835: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", 111 | "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 112 | "2023-09-26 16:54:49.208838: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" 113 | ] 114 | } 115 | ], 116 | "source": [ 117 | "import datasets\n", 118 | "from datasets import load_dataset\n", 119 | "from transformers import AutoTokenizer, AutoConfig, Trainer\n", 120 | "from transformers import EvalPrediction\n", 121 | "from transformers import TrainingArguments\n", 122 | "\n", 123 | "from tokenizers import Regex\n", 124 | "from tokenizers import pre_tokenizers\n", 125 | "from tokenizers import normalizers\n", 126 | "from tokenizers.normalizers import Replace\n", 127 | "\n", 128 | "from tokenizers.pre_tokenizers import BertPreTokenizer\n", 129 | "from tokenizers.pre_tokenizers import Digits\n", 130 | "from tokenizers.pre_tokenizers import Sequence\n", 131 | "from tokenizers.pre_tokenizers import WhitespaceSplit\n", 132 | "from tokenizers.pre_tokenizers import Split\n", 133 | "\n", 134 | "from sklearn.metrics import mean_squared_error, mean_absolute_error\n", 135 | "from sklearn.preprocessing import StandardScaler\n", 136 | " \n", 137 | "from sklearn.metrics import precision_recall_curve, roc_curve\n", 138 | "import pandas as pd\n", 139 | "import numpy as np\n", 140 | "import json\n", 141 | "import re\n", 142 | "import tqdm\n", 143 | "import os\n", 144 | "import rdkit" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 13, 150 | "id": "27513c30-3ba6-45f2-b2f0-7ab283006246", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "from contact_pred.models import StructurePrediction, ProteinLigandConfigStructure\n", 155 | "from contact_pred.structure import IPAConfig\n", 156 | "from contact_pred.data_utils import StructurePredictionPipeline" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 14, 162 | "id": "564ae875-4ba7-4743-9cdc-07da169be7b9", 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "assert torch.cuda.is_available()" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 16, 172 | "id": "dd2ea5cd-790d-4ec6-80d7-fab0cbed9dda", 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "data": { 177 | "application/vnd.jupyter.widget-view+json": { 178 | "model_id": "87613baf34604c68b9cea26fa935d4d2", 179 | "version_major": 2, 180 | "version_minor": 0 181 | }, 182 | "text/plain": [ 183 | "Downloading: 0%| | 0.00/361 [00:00>?|\\*|\\$|\\%[0-9]{2}|[0-9])\"\"\"), behavior='isolated')])" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 20, 334 | "id": "5c027266-bef0-43a4-b057-6f044bc01c64", 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "checkpoint_path = hf_hub_download('djh992/TwoFold_DL_GB2022', 'pytorch_model.bin')\n", 339 | "config_json = hf_hub_download('djh992/TwoFold_DL_GB2022', 'config.json')\n", 340 | "config = ProteinLigandConfigStructure(json.load(open(config_json, 'r')))\n", 341 | "config.seq_config = seq_config.to_dict()\n", 342 | "config.smiles_config = smiles_config.to_dict()\n", 343 | "config.seq_vocab = seq_tokenizer.get_vocab()\n", 344 | "seq_ipa_config = IPAConfig(bert_config=seq_config.to_dict(),\n", 345 | " num_ipa_heads=seq_config.num_attention_heads)\n", 346 | "smiles_ipa_config = IPAConfig(bert_config=smiles_config.to_dict(),\n", 347 | " num_ipa_heads=smiles_config.num_attention_heads)\n", 348 | "config.seq_ipa_config = seq_ipa_config.to_dict()\n", 349 | "config.smiles_ipa_config = smiles_ipa_config.to_dict()\n", 350 | "model = StructurePrediction(config=config)\n", 351 | "checkpoint = torch.load(checkpoint_path)\n", 352 | "model.load_state_dict(checkpoint,strict=True)\n", 353 | "\n", 354 | "del checkpoint\n", 355 | "\n", 356 | "pipeline = StructurePredictionPipeline(\n", 357 | " model,\n", 358 | " seq_tokenizer=seq_tokenizer,\n", 359 | " smiles_tokenizer=smiles_tokenizer,\n", 360 | " device=0,\n", 361 | " batch_size=1)" 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "id": "951f4189-07cf-494c-8d15-9437ecb69665", 367 | "metadata": {}, 368 | "source": [ 369 | "### Inference" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 22, 375 | "id": "8aef0732-5989-4cc5-b794-d7b106c38e42", 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [ 379 | "class ProteinLigandDataset(Dataset):\n", 380 | " def __init__(self, dataset, smiles_name='smiles', seq_name='seq'):\n", 381 | " self.dataset = dataset\n", 382 | " self.seq_name = seq_name\n", 383 | " self.smiles_name = smiles_name\n", 384 | "\n", 385 | " def __getitem__(self, idx):\n", 386 | " try:\n", 387 | " item = self.dataset[idx]\n", 388 | " except:\n", 389 | " item = self.dataset.iloc[idx]\n", 390 | " \n", 391 | " try:\n", 392 | " # make canonical\n", 393 | " smiles_canonical = str(Chem.MolToSmiles(Chem.MolFromSmiles(item[self.smiles_name])))\n", 394 | " except:\n", 395 | " smiles_canonical = str(item[self.smiles_name])\n", 396 | " \n", 397 | " result = {'ligand': smiles_canonical, \n", 398 | "# result = {'ligand': '', \n", 399 | " 'protein': item[self.seq_name]}\n", 400 | " \n", 401 | " return result\n", 402 | "\n", 403 | " def __len__(self):\n", 404 | " return len(self.dataset)" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 23, 410 | "id": "e02ff342-9912-4160-9e86-6da0344bda80", 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "from rdkit import Chem\n", 415 | "smi_4mds = Chem.MolToSmiles(Chem.MolFromMolFile('TwoFold_DL/examples/4mds_23H_ligand.sdf'))" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 50, 421 | "id": "07f45375-bc48-4c01-9743-9d937e5f4927", 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "# Andrii's Mpro1-199\n", 426 | "df = pd.DataFrame({'seq': [\n", 427 | " #'SGFRKMAFPSGKVEGCMVQVTCGTTTLNGLWLDDVVYCPRHVICTSEDMLNPNYEDLLIRKSNHNFLVQAGNVQLRVIGHSMQNCVLKLKVDTANPKTPKYKFVRIQPGQTFSVLACYNGSPSGVYQCAMRPNFTIKGSFLNGSCGSVGFNIDYDCVSFCYMHHMELPTGVHAGTDLEGNFYGPFVDRQTAQAAGTDTT' # Mpro1-199\n", 428 | "# 'SGFRKMAFPSGKVEGCMVQVTCGTTTLNGLWLDDVVYCPRHVICTSEDMLNPNYEDLLIRKSNHNFLVQAGNVQLRVIGHSMQNCVLKLKVDTANPKTPKYKFVRIQPGQTFSVLACYNGSPSGVYQCAMRPNFTIKGSFLNGSCGSVGFNIDYDCVSFCYMHHMELPTGVHAGTDLEGNFYGPFVDRQTAQAAGTDTTITVNVLAWLYAAVINGDRWFLNRFTTTLNDFNLVAMKYNYEPLTQDHVDILGPLSAQTGIAVLDMCASLKELLQNGMNGRTILGSALLEDEFTPFDVVRQCSGVTFQ' #Full Mpro\n", 429 | " 'SGFRKMAFPSGKVEGCMVQVTCGTTTLNGLWLDDTVYCPRHVICTAEDMLNPNYEDLLIRKSNHSFLVQAGNVQLRVIGHSMQNCLLRLKVDTSNPKTPKYKFVRIQPGQTFSVLACYNGSPSGVYQCAMRPNHTIKGSFLNGSCGSVGFNIDYDCVSFCYMHHMELPTGVHAGTDLEGKFYGPFVDRQTAQAAGTDTTITLNVLAWLYAAVINGDRWFLNRFTTTLNDFNLVAMKYNYEPLTQDHVDILGPLSAQTGIAVLDMCAALKELLQNGMNGRTILGSTILEDEFTPFDVVRQCSGASGFRKMAFPSGKVEGCMVQVTCGTTTLNGLWLDDTVYCPRHVICTAEDMLNPNYEDLLIRKSNHSFLVQAGNVQLRVIGHSMQNCLLRLKVDTSNPKTPKYKFVRIQPGQTFSVLACYNGSPSGVYQCAMRPNHTIKGSFLNGSCGSVGFNIDYDCVSFCYMHHMELPTGVHAGTDLEGKFYGPFVDRQTAQAAGTDTTITLNVLAWLYAAVINGDRWFLNRFTTTLNDFNLVAMKYNYEPLTQDHVDILGPLSAQTGIAVLDMCAALKELLQNGMNGRTILGSTILEDEFTPFDVVRQCSGA' # 4mds' # 4mds\n", 430 | " #'VNVLAWLYAAVINGDRWFLNRFTTTLNDFNLVAMKYNYEPLTQDHVDILGPLSAQTGIAVLDMCASLKELLQNGMNGRTILGSALLEDEFTPFDVVRQCSGVTFQ' # Mpro200-306\n", 431 | "], \n", 432 | " #'smiles': ['']})\n", 433 | " 'smiles': smi_4mds})" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 51, 439 | "id": "c47934d6-0431-427f-ac03-6b7503e991bf", 440 | "metadata": {}, 441 | "outputs": [ 442 | { 443 | "data": { 444 | "text/plain": [ 445 | "0 606\n", 446 | "Name: seq, dtype: int64" 447 | ] 448 | }, 449 | "execution_count": 51, 450 | "metadata": {}, 451 | "output_type": "execute_result" 452 | } 453 | ], 454 | "source": [ 455 | "df['seq'].str.len()" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 52, 461 | "id": "5c57d1c3-8f7b-48a0-9ec2-117bb0494fcb", 462 | "metadata": {}, 463 | "outputs": [], 464 | "source": [ 465 | "dataset = ProteinLigandDataset(df)" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 53, 471 | "id": "21bf2097-2757-4200-abee-c07bd15b9db4", 472 | "metadata": {}, 473 | "outputs": [], 474 | "source": [ 475 | "pipeline.model.enable_cross = True" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 54, 481 | "id": "c9635f28-9fe8-4cd3-bf83-901424c6d2a6", 482 | "metadata": {}, 483 | "outputs": [ 484 | { 485 | "name": "stderr", 486 | "output_type": "stream", 487 | "text": [ 488 | "/opt/conda/lib/python3.11/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n", 489 | " warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n", 490 | "/gpfs/alpine/stf006/proj-shared/ngoav/TwoFold_DL/TwoFold_DL/contact_pred/models.py:304: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", 491 | " aatypes = torch.tensor(self.input_ids_to_aatype[input_ids_1], device=input_ids_1.device)#, requires_grad=False)\n" 492 | ] 493 | } 494 | ], 495 | "source": [ 496 | "output = list(pipeline(dataset))\n", 497 | "pred = output[0]" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": 55, 503 | "id": "af080077-8076-4e7c-8ce7-3ab2147000a5", 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [ 507 | "from contact_pred.residue_constants import restype_name_to_atom14_names, restype_1to3\n", 508 | "def write_pdb_no_ref(f, seq, feat):\n", 509 | " k = 0\n", 510 | " resid = 1\n", 511 | " i = 1\n", 512 | " for s in seq:\n", 513 | " res = restype_1to3[s]\n", 514 | " for idx, atom in enumerate(restype_name_to_atom14_names[res]):\n", 515 | " if atom != '':\n", 516 | " xyz = feat[0,k+1,idx]\n", 517 | " write_pdb_line(f,'ATOM', str(i), atom, res, 'A', str(resid), *xyz, 1.0, 1.0, atom[0])\n", 518 | " i+=1\n", 519 | " k+=1\n", 520 | " resid+=1\n", 521 | "\n", 522 | "def write_Calpha_no_ref(f, seq, feat):\n", 523 | " k = 0\n", 524 | " resid = 1\n", 525 | " i = 1\n", 526 | " for s in seq:\n", 527 | " res = restype_1to3[s]\n", 528 | " xyz = feat[0,k+1]\n", 529 | " write_pdb_line(f,'ATOM', str(i), 'CA', res, 'A', str(resid), *xyz, 1.0, 1.0, 'C')\n", 530 | " i+=1\n", 531 | " k+=1\n", 532 | " resid+=1" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 56, 538 | "id": "d597817d-23b6-4866-8332-1587e194cc17", 539 | "metadata": {}, 540 | "outputs": [], 541 | "source": [ 542 | "def write_pdb_line(f,*j):\n", 543 | " j = list(j)\n", 544 | " j[0] = j[0].ljust(6)#atom#6s\n", 545 | " j[1] = j[1].rjust(5)#aomnum#5d\n", 546 | " j[2] = j[2].center(4)#atomname$#4s\n", 547 | " j[3] = j[3].ljust(3)#resname#1s\n", 548 | " j[4] = j[4].rjust(1) #Astring\n", 549 | " j[5] = j[5].rjust(4) #resnum\n", 550 | " j[6] = str('%8.3f' % (float(j[6]))).rjust(8) #x\n", 551 | " j[7] = str('%8.3f' % (float(j[7]))).rjust(8)#y\n", 552 | " j[8] = str('%8.3f' % (float(j[8]))).rjust(8) #z\\\n", 553 | " j[9] =str('%6.2f'%(float(j[9]))).rjust(6)#occ\n", 554 | " j[10]=str('%6.2f'%(float(j[10]))).ljust(6)#temp\n", 555 | " j[11]=j[11].rjust(12)#elname\n", 556 | " f.write(\"%s%s %s %s %s%s %s%s%s%s%s%s\\n\"% (j[0],j[1],j[2],j[3],j[4],j[5],j[6],j[7],j[8],j[9],j[10],j[11]))\n", 557 | " \n", 558 | "#with open(f'TwoFold_DL/examples/pred_Mpro1-199_ligand_7s3s.pdb','w') as f:\n", 559 | "with open(f'TwoFold_DL/examples/pred_Mpro_monomer_ligand_4mds.pdb','w') as f:\n", 560 | " feat = pred['receptor_xyz']\n", 561 | " write_pdb_no_ref(f, df['seq'][0], feat)\n", 562 | " #feat = pred['receptor_frames_xyz']\n", 563 | " #write_Calpha_no_ref(f, df['seq'][0], feat)" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": 67, 569 | "id": "cd7f788b-9d41-475b-9d19-b921ffa52695", 570 | "metadata": {}, 571 | "outputs": [ 572 | { 573 | "name": "stderr", 574 | "output_type": "stream", 575 | "text": [ 576 | "[17:12:24] Molecule does not have explicit Hs. Consider calling AddHs()\n" 577 | ] 578 | } 579 | ], 580 | "source": [ 581 | "# update molecule coordinates using prediction\n", 582 | "from rdkit.Geometry import Point3D\n", 583 | "from rdkit import Chem\n", 584 | "from rdkit.Chem import AllChem\n", 585 | "from utils.token_coords import get_token_coords\n", 586 | "\n", 587 | "smi, ligand_xyz_ref, tokens, atom_map = get_token_coords(mol) \n", 588 | "mol = Chem.MolFromSmiles(smi_4mds)\n", 589 | "AllChem.EmbedMolecule(mol)\n", 590 | "conf = mol.GetConformer()\n", 591 | "for i, xyz in enumerate(pred['ligand_frames_xyz'].squeeze(0)[1:-1]):\n", 592 | " idx = atom_map[i]\n", 593 | "\n", 594 | " if idx is not None:\n", 595 | " conf.SetAtomPosition(idx,Point3D(*xyz.astype(np.double)))\n", 596 | "\n", 597 | "with Chem.SDWriter('TwoFold_DL/examples/ligand_pred_4mds_dimer.sdf') as w:\n", 598 | " w.write(mol)" 599 | ] 600 | } 601 | ], 602 | "metadata": { 603 | "kernelspec": { 604 | "display_name": "OLCF-CUDA11 (ipykernel)", 605 | "language": "python", 606 | "name": "python3" 607 | }, 608 | "language_info": { 609 | "codemirror_mode": { 610 | "name": "ipython", 611 | "version": 3 612 | }, 613 | "file_extension": ".py", 614 | "mimetype": "text/x-python", 615 | "name": "python", 616 | "nbconvert_exporter": "python", 617 | "pygments_lexer": "ipython3", 618 | "version": "3.11.4" 619 | } 620 | }, 621 | "nbformat": 4, 622 | "nbformat_minor": 5 623 | } 624 | -------------------------------------------------------------------------------- /examples/4mds_23H_ligand.sdf: -------------------------------------------------------------------------------- 1 | 2 | RDKit 3D 3 | 4 | 38 41 0 0 0 0 0 0 0 0999 V2000 5 | -13.4120 16.2900 -23.5140 O 0 0 0 0 0 0 0 0 0 0 0 0 6 | -14.0390 15.2460 -23.7520 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -15.3490 15.2130 -23.4560 N 0 0 0 0 0 0 0 0 0 0 0 0 8 | -16.0720 16.4110 -22.9370 C 0 0 1 0 0 0 0 0 0 0 0 0 9 | -16.1430 17.4800 -24.0070 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -16.3870 17.1740 -25.1720 O 0 0 0 0 0 0 0 0 0 0 0 0 11 | -15.9120 18.7320 -23.6000 N 0 0 0 0 0 0 0 0 0 0 0 0 12 | -15.9860 19.9760 -24.3950 C 0 0 2 0 0 0 0 0 0 0 0 0 13 | -17.4610 20.2090 -24.7780 C 0 0 0 0 0 0 0 0 0 0 0 0 14 | -15.1350 19.9790 -25.6550 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | -13.7230 19.4530 -25.4880 C 0 0 0 0 0 0 0 0 0 0 0 0 16 | -15.5390 21.0950 -23.4390 C 0 0 0 0 0 0 0 0 0 0 0 0 17 | -17.4300 16.1180 -22.3810 C 0 0 0 0 0 0 0 0 0 0 0 0 18 | -18.5740 16.2110 -23.1740 C 0 0 0 0 0 0 0 0 0 0 0 0 19 | -19.6060 15.8440 -22.3420 C 0 0 0 0 0 0 0 0 0 0 0 0 20 | -19.0890 15.5500 -21.0960 C 0 0 0 0 0 0 0 0 0 0 0 0 21 | -17.7550 15.7230 -21.1520 N 0 0 0 0 0 0 0 0 0 0 0 0 22 | -16.7980 15.5150 -20.0350 C 0 0 0 0 0 0 0 0 0 0 0 0 23 | -16.1030 14.0340 -23.7420 C 0 0 0 0 0 0 0 0 0 0 0 0 24 | -17.0070 14.0780 -24.7930 C 0 0 0 0 0 0 0 0 0 0 0 0 25 | -17.8040 12.9730 -25.1160 C 0 0 0 0 0 0 0 0 0 0 0 0 26 | -15.9970 12.8560 -23.0010 C 0 0 0 0 0 0 0 0 0 0 0 0 27 | -16.7770 11.7370 -23.3370 C 0 0 0 0 0 0 0 0 0 0 0 0 28 | -17.6880 11.7860 -24.3940 C 0 0 0 0 0 0 0 0 0 0 0 0 29 | -18.4070 10.6270 -24.6370 N 0 0 0 0 0 0 0 0 0 0 0 0 30 | -19.3800 10.3850 -25.5220 C 0 0 0 0 0 0 0 0 0 0 0 0 31 | -20.0530 9.0580 -25.6450 C 0 0 0 0 0 0 0 0 0 0 0 0 32 | -19.7660 11.2790 -26.2610 O 0 0 0 0 0 0 0 0 0 0 0 0 33 | -13.2970 14.0520 -24.3380 C 0 0 0 0 0 0 0 0 0 0 0 0 34 | -12.0180 14.5470 -24.7740 N 0 0 0 0 0 0 0 0 0 0 0 0 35 | -10.9070 14.3230 -24.0180 N 0 0 0 0 0 0 0 0 0 0 0 0 36 | -9.8140 14.9430 -24.5630 N 0 0 0 0 0 0 0 0 0 0 0 0 37 | -10.2810 15.5850 -25.6660 C 0 0 0 0 0 0 0 0 0 0 0 0 38 | -11.6590 15.3490 -25.7880 C 0 0 0 0 0 0 0 0 0 0 0 0 39 | -9.5830 16.3760 -26.5720 C 0 0 0 0 0 0 0 0 0 0 0 0 40 | -10.3020 16.9400 -27.6400 C 0 0 0 0 0 0 0 0 0 0 0 0 41 | -11.6800 16.7030 -27.7630 C 0 0 0 0 0 0 0 0 0 0 0 0 42 | -12.3870 15.9140 -26.8480 C 0 0 0 0 0 0 0 0 0 0 0 0 43 | 2 1 2 0 44 | 3 2 1 0 45 | 4 3 1 0 46 | 4 5 1 6 47 | 6 5 2 0 48 | 7 5 1 0 49 | 8 7 1 0 50 | 8 9 1 6 51 | 10 8 1 0 52 | 11 10 1 0 53 | 12 8 1 0 54 | 13 4 1 0 55 | 14 13 2 0 56 | 15 14 1 0 57 | 16 15 2 0 58 | 17 13 1 0 59 | 17 16 1 0 60 | 18 17 1 0 61 | 19 3 1 0 62 | 20 19 2 0 63 | 21 20 1 0 64 | 22 19 1 0 65 | 23 22 2 0 66 | 24 23 1 0 67 | 24 21 2 0 68 | 25 24 1 0 69 | 26 25 1 0 70 | 27 26 1 0 71 | 28 26 2 0 72 | 29 2 1 0 73 | 30 29 1 0 74 | 31 30 1 0 75 | 32 31 2 0 76 | 33 32 1 0 77 | 34 30 1 0 78 | 34 33 2 0 79 | 35 33 1 0 80 | 36 35 2 0 81 | 37 36 1 0 82 | 38 34 1 0 83 | 38 37 2 0 84 | M END 85 | $$$$ 86 | -------------------------------------------------------------------------------- /examples/7s3s_860_ligand.sdf: -------------------------------------------------------------------------------- 1 | 2 | RDKit 3D 3 | 4 | 21 23 0 0 0 0 0 0 0 0999 V2000 5 | 11.0640 -1.0840 22.8290 C 0 0 0 0 0 0 0 0 0 0 0 0 6 | 5.3440 0.9600 19.2190 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | 5.7680 0.7790 17.8880 C 0 0 0 0 0 0 0 0 0 0 0 0 8 | 7.4860 -0.5950 18.5200 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | 5.6470 0.4810 21.5970 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | 4.5650 1.2720 21.8790 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | 3.8540 1.9170 20.8480 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | 4.2350 1.7690 19.5420 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | 12.7040 0.2900 23.9760 C 0 0 0 0 0 0 0 0 0 0 0 0 14 | 9.6340 -1.4610 22.5700 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | 11.3780 -0.0650 23.7480 C 0 0 0 0 0 0 0 0 0 0 0 0 16 | 8.9780 -0.6040 21.4900 C 0 0 0 0 0 0 0 0 0 0 0 0 17 | 12.0870 -1.7220 22.1480 C 0 0 0 0 0 0 0 0 0 0 0 0 18 | 13.4020 -1.3600 22.3830 C 0 0 0 0 0 0 0 0 0 0 0 0 19 | 13.7260 -0.3570 23.2830 C 0 0 0 0 0 0 0 0 0 0 0 0 20 | 7.1720 -0.5010 19.8550 C 0 0 0 0 0 0 0 0 0 0 0 0 21 | 6.0670 0.3030 20.2550 C 0 0 0 0 0 0 0 0 0 0 0 0 22 | 6.7910 0.0370 17.5470 N 0 0 0 0 0 0 0 0 0 0 0 0 23 | 7.9040 -1.1450 20.8580 N 0 0 0 0 0 0 0 0 0 0 0 0 24 | 9.4200 0.5080 21.2300 O 0 0 0 0 0 0 0 0 0 0 0 0 25 | 14.6660 -2.2000 21.5000 Cl 0 0 0 0 0 0 0 0 0 0 0 0 26 | 3 2 2 0 27 | 6 5 2 0 28 | 7 6 1 0 29 | 8 7 2 0 30 | 8 2 1 0 31 | 10 1 1 0 32 | 11 9 2 0 33 | 11 1 1 0 34 | 12 10 1 0 35 | 13 1 2 0 36 | 14 13 1 0 37 | 15 14 2 0 38 | 15 9 1 0 39 | 16 4 1 0 40 | 17 2 1 0 41 | 17 16 2 0 42 | 17 5 1 0 43 | 18 4 2 0 44 | 18 3 1 0 45 | 19 16 1 0 46 | 19 12 1 0 47 | 20 12 2 0 48 | 21 14 1 0 49 | M END 50 | $$$$ 51 | -------------------------------------------------------------------------------- /examples/ligand_pred_7s3s.sdf: -------------------------------------------------------------------------------- 1 | 2 | RDKit 3D 3 | 4 | 21 23 0 0 0 0 0 0 0 0999 V2000 5 | 10.4148 -6.9854 23.0121 C 0 0 0 0 0 0 0 0 0 0 0 0 6 | 8.7561 -0.5440 23.0525 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | 10.1725 0.1217 23.0572 C 0 0 0 0 0 0 0 0 0 0 0 0 8 | 11.1595 -1.6544 23.0034 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | 7.8817 -2.1460 23.0366 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | 6.7178 -1.6829 23.0643 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | 6.3008 -0.5434 23.0794 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | 7.2401 0.0993 23.0779 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | 9.0543 -8.4141 22.9812 C 0 0 0 0 0 0 0 0 0 0 0 0 14 | 11.1442 -5.8892 23.0245 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | 9.9158 -7.3864 22.9935 C 0 0 0 0 0 0 0 0 0 0 0 0 16 | 10.9186 -4.6593 22.9888 C 0 0 0 0 0 0 0 0 0 0 0 0 17 | 9.8806 -7.4330 23.0059 C 0 0 0 0 0 0 0 0 0 0 0 0 18 | 9.2832 -8.5275 22.9838 C 0 0 0 0 0 0 0 0 0 0 0 0 19 | 8.8322 -8.6806 22.9853 C 0 0 0 0 0 0 0 0 0 0 0 0 20 | 9.9939 -2.2815 22.9972 C 0 0 0 0 0 0 0 0 0 0 0 0 21 | 8.8485 -1.5255 23.0266 C 0 0 0 0 0 0 0 0 0 0 0 0 22 | 11.5732 -0.6031 23.0328 N 0 0 0 0 0 0 0 0 0 0 0 0 23 | 10.4631 -3.4821 22.9818 N 0 0 0 0 0 0 0 0 0 0 0 0 24 | 10.8859 -4.4627 22.9659 O 0 0 0 0 0 0 0 0 0 0 0 0 25 | 8.9101 -8.8559 22.9682 Cl 0 0 0 0 0 0 0 0 0 0 0 0 26 | 3 2 2 0 27 | 6 5 2 0 28 | 7 6 1 0 29 | 8 7 2 0 30 | 8 2 1 0 31 | 10 1 1 0 32 | 11 9 2 0 33 | 11 1 1 0 34 | 12 10 1 0 35 | 13 1 2 0 36 | 14 13 1 0 37 | 15 14 2 0 38 | 15 9 1 0 39 | 16 4 1 0 40 | 17 2 1 0 41 | 17 16 2 0 42 | 17 5 1 0 43 | 18 4 2 0 44 | 18 3 1 0 45 | 19 16 1 0 46 | 19 12 1 0 47 | 20 12 2 0 48 | 21 14 1 0 49 | M END 50 | $$$$ 51 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | import transformers 5 | from transformers import BertModel, BertTokenizer, AutoTokenizer 6 | from transformers import PreTrainedModel, BertConfig 7 | from transformers import Trainer, TrainingArguments 8 | from transformers import BatchEncoding 9 | from transformers import EvalPrediction 10 | 11 | from transformers import get_scheduler 12 | from transformers.trainer_pt_utils import get_parameter_names 13 | 14 | from tokenizers.pre_tokenizers import BertPreTokenizer 15 | from tokenizers.pre_tokenizers import Digits 16 | from tokenizers.pre_tokenizers import Sequence 17 | from tokenizers.pre_tokenizers import WhitespaceSplit 18 | from tokenizers.pre_tokenizers import Split 19 | 20 | from tokenizers import Regex 21 | from tokenizers import pre_tokenizers 22 | from tokenizers import normalizers 23 | from tokenizers.normalizers import Replace 24 | 25 | from dataclasses import dataclass, field 26 | from enum import Enum 27 | 28 | from transformers import HfArgumentParser 29 | from transformers.trainer_utils import get_last_checkpoint 30 | from transformers.modeling_utils import unwrap_model 31 | from transformers.trainer_utils import set_seed 32 | 33 | import datasets 34 | from torch.utils.data import IterableDataset 35 | 36 | from typing import List 37 | 38 | import os 39 | import json 40 | from tqdm.auto import tqdm 41 | 42 | import torch.distributed as dist 43 | 44 | import numpy as np 45 | import random 46 | 47 | from contact_pred.data_utils import EnsembleDataCollatorWithPadding 48 | from contact_pred.models import StructurePrediction 49 | from contact_pred.models import ProteinLigandConfigStructure 50 | from contact_pred.structure import IPAConfig 51 | 52 | import webdataset as wd 53 | 54 | logger = logging.getLogger(__name__) 55 | 56 | def save_json(content, path, indent=4, **json_dump_kwargs): 57 | with open(path, "w") as f: 58 | json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs) 59 | 60 | def handle_metrics(split, metrics, output_dir): 61 | """ 62 | Log and save metrics 63 | Args: 64 | - split: one of train, val, test 65 | - metrics: metrics dict 66 | - output_dir: where to save the metrics 67 | """ 68 | 69 | logger.info(f"***** {split} metrics *****") 70 | for key in sorted(metrics.keys()): 71 | logger.info(f" {key} = {metrics[key]}") 72 | save_json(metrics, os.path.join(output_dir, f"{split}_results.json")) 73 | 74 | @dataclass 75 | class ModelArguments: 76 | model_type: str = field( 77 | default='bert', 78 | metadata = {'choices': ['bert','regex']}, 79 | ) 80 | 81 | seq_model_name: str = field( 82 | default=None 83 | ) 84 | 85 | smiles_model_dir: str = field( 86 | default=None 87 | ) 88 | 89 | smiles_tokenizer_dir: str = field( 90 | default=None 91 | ) 92 | 93 | linear_mem_attn: bool = field( 94 | default=True 95 | ) 96 | 97 | max_seq_length: int = field( 98 | default=2048 99 | ) 100 | 101 | max_smiles_length: int = field( 102 | default=512 103 | ) 104 | 105 | n_cross_attn: int = field( 106 | default=3 107 | ) 108 | 109 | n_ipa: int = field( 110 | default=8 111 | ) 112 | 113 | @dataclass 114 | class DataArguments: 115 | train_dataset: str = field( 116 | default=None 117 | ) 118 | 119 | train_size: int = field( 120 | default=None 121 | ) 122 | 123 | test_dataset: str = field( 124 | default=None 125 | ) 126 | 127 | pretrained_model: str = field( 128 | default=None 129 | ) 130 | 131 | freeze_protein: bool = field( 132 | default=False 133 | ) 134 | 135 | freeze_ligand: bool = field( 136 | default=False 137 | ) 138 | 139 | enable_cross: bool = field( 140 | default=True 141 | ) 142 | 143 | def main(): 144 | from contact_pred.training_utils import mpi_discovery 145 | required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 146 | 147 | auto_mpi_discovery = False 148 | try: 149 | import mpi4py 150 | auto_mpi_discovery = True 151 | except: 152 | logger.info("mpi4py not found, skipping MPI discovery") 153 | pass 154 | 155 | if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)): 156 | logger.info("Not using torchrun, attempting to detect MPI environment...") 157 | mpi_discovery() 158 | 159 | parser = HfArgumentParser([TrainingArguments,ModelArguments, DataArguments]) 160 | 161 | training_args, model_args, data_args = parser.parse_args_into_dataclasses() 162 | 163 | if 'LOCAL_RANK' in os.environ: 164 | training_args.local_rank = int(os.environ["LOCAL_RANK"]) 165 | 166 | # error out when there are unused parameters 167 | training_args.ddp_find_unused_parameters=False 168 | 169 | smiles_tokenizer_directory = model_args.smiles_tokenizer_dir 170 | smiles_model_directory = model_args.smiles_model_dir 171 | tokenizer_config = json.load(open(smiles_tokenizer_directory+'/config.json','r')) 172 | 173 | smiles_tokenizer = AutoTokenizer.from_pretrained(smiles_tokenizer_directory, **tokenizer_config) 174 | 175 | if model_args.model_type == 'regex': 176 | smiles_tokenizer.backend_tokenizer.pre_tokenizer = Sequence([WhitespaceSplit(),Split(Regex(r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""), behavior='isolated')]) 177 | 178 | normalizer = normalizers.Sequence([Replace(Regex('[UZOB]'),'X'),Replace(Regex('\s'),'')]) 179 | pre_tokenizer = pre_tokenizers.Split(Regex(''),behavior='isolated') 180 | seq_tokenizer = AutoTokenizer.from_pretrained(model_args.seq_model_name, do_lower_case=False) 181 | seq_tokenizer.backend_tokenizer.normalizer = normalizer 182 | seq_tokenizer.backend_tokenizer.pre_tokenizer = pre_tokenizer 183 | 184 | 185 | 186 | max_seq_length = model_args.max_seq_length 187 | max_smiles_length = min(smiles_tokenizer.model_max_length, model_args.max_smiles_length) 188 | 189 | train_size = data_args.train_size 190 | import glob 191 | train = wd.WebDataset(glob.glob(data_args.train_dataset + '/part-*.tar'), resampled=True).shuffle(1000).decode('torch').with_epoch(train_size) 192 | test = wd.WebDataset(glob.glob(data_args.test_dataset + '/part-*.tar'), nodesplitter=lambda src: (s for s in src)).shuffle(1000).decode('torch') 193 | 194 | # on-the-fly tokenization 195 | def encode(item): 196 | if data_args.freeze_protein: 197 | item['seq.txt'] = '' 198 | item['receptor_frames_xyz'] = np.empty((0,3)) 199 | item['receptor_frames_rot'] = np.empty((0,9)) 200 | item['receptor_xyz'] = np.empty((0,model.config.num_atoms,3)) 201 | 202 | seq_encodings = seq_tokenizer(item['seq.txt'], 203 | return_offsets_mapping=False, 204 | truncation=True, 205 | max_length=max_seq_length) 206 | 207 | item['input_ids_1'] = torch.tensor(seq_encodings['input_ids']) 208 | item['attention_mask_1'] = torch.tensor(seq_encodings['attention_mask']) 209 | 210 | if data_args.freeze_ligand: 211 | item['smiles.txt'] = '' 212 | item['ligand_xyz'] = np.empty((0,3)) 213 | 214 | smiles_encodings = smiles_tokenizer(item['smiles.txt'], 215 | max_length=max_smiles_length, 216 | truncation=True) 217 | 218 | item['input_ids_2'] = torch.tensor(smiles_encodings['input_ids']) 219 | item['attention_mask_2'] = torch.tensor(smiles_encodings['attention_mask']) 220 | 221 | return item 222 | 223 | ensemble_collator = EnsembleDataCollatorWithPadding(smiles_tokenizer, seq_tokenizer) 224 | 225 | def collator_with_label_padding(features): 226 | if not isinstance(features[0], (dict, BatchEncoding)): 227 | features = [vars(f) for f in features] 228 | 229 | # Special handling for labels. 230 | first = list(features[0].keys()) 231 | 232 | if 'ligand_xyz_2d' in first: 233 | for f in features: 234 | f.pop('ligand_xyz_2d') 235 | 236 | for k in first: 237 | if k.startswith('labels_'): 238 | continue 239 | if 'xyz' in k or 'rot' in k: 240 | if k.startswith('receptor'): 241 | # max len in batch 242 | max_len = max([len(f['attention_mask_1']) for f in features]) 243 | else: 244 | max_len = max([len(f['attention_mask_2']) for f in features]) 245 | 246 | # pad with nan, also account for [CLS] and [SEP] 247 | for f in features: 248 | try: 249 | if k == 'receptor_xyz': 250 | label = torch.tensor(np.pad(f[k][:max_len-2], ((1,max_len-1-len(f[k][:max_len-2])), (0,0), (0,0)), 251 | constant_values=None).astype(np.float64)).type(torch.get_default_dtype()) 252 | else: 253 | label = torch.tensor(np.pad(f[k][:max_len-2], ((1,max_len-1-len(f[k][:max_len-2])), (0,0)), 254 | constant_values=None).astype(np.float64)).type(torch.get_default_dtype()) 255 | except: 256 | print('Error padding inputs', k, f[k]) 257 | raise 258 | 259 | # keep nan positions for loss calculation 260 | # NOTE: a bug in pytorch requires this to be an int64 tensor, otherwise process will hang 261 | non_nans = (~torch.any(torch.isnan(label),dim=-1)).type(torch.int64) 262 | 263 | if k.endswith('_rot'): 264 | label = label.reshape(label.shape[:-1] + (3,3)) 265 | label[non_nans==0,:,:] = torch.eye(3) 266 | elif k != 'receptor_xyz': 267 | label = torch.nan_to_num(label) 268 | 269 | if k == 'receptor_frames_xyz': 270 | f['labels_receptor_token_mask'] = non_nans 271 | elif k == 'ligand_xyz': 272 | f['labels_ligand_token_mask'] = non_nans 273 | 274 | num_feat = model.config.num_atoms 275 | feat = label.unsqueeze(1) 276 | feat = torch.cat([feat, torch.ones(*(feat.shape[:1] + (num_feat-1,) + feat.shape[2:]), 277 | device=feat.device, dtype=feat.dtype)], 1) 278 | feat[:,1:,:] = float('nan') 279 | 280 | f['labels_ligand_frames_xyz'] = label 281 | f['labels_ligand_frames_rot'] = torch.eye(3, device=label.device, dtype=label.dtype).repeat(label.size()[0], 1, 1) 282 | label = feat 283 | 284 | elif k == 'receptor_frames_rot': 285 | pass 286 | 287 | f['labels_'+k] = label 288 | f.pop(k) 289 | 290 | # process the remaining fields 291 | batch = ensemble_collator(features) 292 | 293 | return batch 294 | 295 | class MyDataset(IterableDataset): 296 | def __init__(self, dataset): 297 | self.dataset = dataset 298 | 299 | def __iter__(self): 300 | class Transform: 301 | def __init__(self, dataset): 302 | self.dataset = dataset 303 | 304 | def __iter__(self): 305 | self.iterator = iter(self.dataset) 306 | return self 307 | 308 | def __next__(self): 309 | item = next(self.iterator) 310 | 311 | if 'lig_xyz.pyd' in item: 312 | item['ligand_xyz'] = item.pop('lig_xyz.pyd') 313 | 314 | if 'rec_xyz.pyd' in item: 315 | item['receptor_frames_xyz'] = item.pop('rec_xyz.pyd') 316 | item['receptor_frames_rot'] = item.pop('rec_r.pyd') 317 | item['receptor_xyz'] = item.pop('rec_feat.pyd')[..., :-1, :] 318 | 319 | item = encode(item) 320 | 321 | if data_args.freeze_ligand and 'xyz.pyd' in item: 322 | item['receptor_frames_xyz'] = item.pop('xyz.pyd') 323 | item['receptor_frames_rot'] = item.pop('r.pyd') 324 | item['receptor_xyz'] = item.pop('feat.pyd')[..., :-1, :] 325 | 326 | if data_args.freeze_protein and 'xyz.pyd' in item: 327 | item['ligand_xyz'] = item.pop('xyz.pyd') 328 | 329 | return item 330 | 331 | return iter(Transform(self.dataset)) 332 | 333 | class FromIterableDataset: 334 | def __init__(self, iterable_dataset): 335 | self.dataset = list(iterable_dataset) 336 | 337 | def __getitem__(self, i): 338 | return self.dataset[i] 339 | 340 | def __len__(self): 341 | return len(self.dataset) 342 | 343 | last_checkpoint = None 344 | if os.path.isdir(training_args.output_dir): 345 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 346 | 347 | seq_config = BertConfig.from_pretrained(model_args.seq_model_name) 348 | 349 | smiles_config = BertConfig.from_pretrained(smiles_model_directory) 350 | 351 | seq_ipa_config = IPAConfig(seq_config, 352 | num_ipa_heads=seq_config.num_attention_heads, 353 | ) 354 | smiles_ipa_config = IPAConfig(smiles_config, 355 | num_ipa_heads=smiles_config.num_attention_heads, 356 | ) 357 | 358 | config = ProteinLigandConfigStructure( 359 | seq_config=seq_config, 360 | smiles_config=smiles_config, 361 | n_cross_attention=model_args.n_cross_attn, 362 | seq_ipa_config=seq_ipa_config, 363 | smiles_ipa_config=smiles_ipa_config, 364 | num_ipa_layers=model_args.n_ipa, 365 | linear_mem_attn=model_args.linear_mem_attn, 366 | enable_cross=data_args.enable_cross, 367 | seq_vocab=seq_tokenizer.get_vocab() 368 | ) 369 | 370 | # uniform seed for model weight initialization 371 | set_seed(training_args.seed) 372 | 373 | # instantiate model 374 | model = StructurePrediction(config) 375 | 376 | if not data_args.pretrained_model: 377 | # only load pretrained sequence embeddings 378 | model.pair_representation.embedding.load_pretrained(model_args.seq_model_name, 379 | model_args.smiles_model_dir) 380 | else: 381 | if torch.distributed.get_rank() == 0: 382 | print('Loading pre-trained checkpoint {}'.format(data_args.pretrained_model)) 383 | pretrained_checkpoint = torch.load(data_args.pretrained_model, 384 | torch.device('cuda:{}'.format(training_args.local_rank))) 385 | model.load_state_dict(pretrained_checkpoint, strict=False) 386 | 387 | if data_args.freeze_protein: 388 | model.freeze_protein() 389 | if data_args.freeze_ligand: 390 | model.freeze_ligand() 391 | 392 | training_args.label_names = ['labels_receptor_frames_xyz', 393 | 'labels_receptor_frames_rot', 394 | 'labels_receptor_xyz', 395 | 'labels_ligand_xyz', 396 | 'labels_ligand_frames_xyz', 397 | 'labels_ligand_frames_rot', 398 | 'labels_ligand_token_mask', 399 | 'labels_receptor_token_mask'] 400 | training_args.remove_unused_columns = False 401 | 402 | # create optimizer, only for parameters which require gradients 403 | forbidden_parameter_names = ['bias'] 404 | 405 | # exclude the linear scaling layers producing physical units 406 | forbidden_parameter_names += ['frame_translation.linear.'] 407 | 408 | decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm]) 409 | decay_parameters = [name for name in decay_parameters if not any([s in name for s in forbidden_parameter_names])] 410 | optimizer_grouped_parameters = [ 411 | { 412 | "params": [p for n, p in model.named_parameters() if n in decay_parameters and p.requires_grad], 413 | "weight_decay": training_args.weight_decay, 414 | }, 415 | { 416 | "params": [p for n, p in model.named_parameters() if n not in decay_parameters and p.requires_grad], 417 | "weight_decay": 0.0, 418 | }, 419 | ] 420 | 421 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) 422 | optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 423 | 424 | lr_scheduler = get_scheduler( 425 | training_args.lr_scheduler_type, 426 | optimizer=optimizer, 427 | num_warmup_steps=training_args.get_warmup_steps(training_args.max_steps), 428 | num_training_steps=training_args.max_steps, 429 | ) 430 | 431 | train_dataset = MyDataset(train) 432 | val_dataset = FromIterableDataset(MyDataset(test)) 433 | 434 | trainer = Trainer( 435 | model=model, 436 | args=training_args, # training arguments, defined above 437 | train_dataset=train_dataset, # training dataset 438 | eval_dataset=val_dataset, # evaluation dataset 439 | data_collator=collator_with_label_padding, 440 | optimizers=(optimizer, lr_scheduler), 441 | ) 442 | 443 | # save model configuration 444 | if trainer.is_world_process_zero(): 445 | with open(os.path.join(training_args.output_dir,'config.json'),'w') as f: 446 | json.dump(config.to_dict(), f) 447 | 448 | all_metrics = {} 449 | logger.info("*** Train ***") 450 | train_result = trainer.train(resume_from_checkpoint=last_checkpoint) 451 | 452 | trainer.save_model(training_args.output_dir) 453 | metrics = train_result.metrics 454 | 455 | if trainer.is_world_process_zero(): 456 | handle_metrics("train", metrics, training_args.output_dir) 457 | all_metrics.update(metrics) 458 | 459 | trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) 460 | save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json")) 461 | 462 | if __name__ == "__main__": 463 | main() 464 | -------------------------------------------------------------------------------- /train/.gitignore: -------------------------------------------------------------------------------- 1 | slurm-*.* 2 | finetune.* 3 | logs_* 4 | results_* 5 | *cache* 6 | -------------------------------------------------------------------------------- /train/setup_summit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export NENSEMBLE=1 4 | 5 | module load gcc/9.3.0 6 | module load open-ce/1.5.0-py39-0 7 | conda activate /autofs/nccs-svm1_proj/bif136/summit-env 8 | 9 | export HF_DATASETS_CACHE=/tmp 10 | export TRANSFORMERS_CACHE=/tmp 11 | export TRANSFORMERS_OFFLINE=1 12 | export HF_DATASETS_OFFLINE=1 13 | 14 | export OMP_NUM_THREADS=1 15 | export PYTHONUNBUFFERED=1 16 | export TOKENIZERS_PARALLELISM=false 17 | 18 | export NNODES=`echo $LSB_MCPU_HOSTS | awk '{for (j=3; j <= NF; j+=2) { print $j }}' | wc -l` 19 | export NNODES=$((${NNODES}/${NENSEMBLE})) 20 | export NWORKERS=$((${NNODES}*6)) 21 | 22 | export NCROSS=3 23 | export LR=3e-5 24 | export PER_DEVICE_BATCH_SIZE=1 25 | export CLUSTER=summit 26 | export BATCH_SIZE=$((${PER_DEVICE_BATCH_SIZE}*${NWORKERS})) 27 | 28 | export TORCH_DISTRIBUTED_DEBUG=INFO 29 | 30 | export ENSEMBLE_ID=1 31 | 32 | export GLOBAL_ID=$(((${LSB_JOBINDEX}-1)*${NENSEMBLE}+${ENSEMBLE_ID})) 33 | export ID_STR=${CLUSTER}_bs${BATCH_SIZE}_lr${LR}_ncross_${NCROSS}_${GLOBAL_ID} 34 | 35 | export LD_PRELOAD="${OLCF_GCC_ROOT}/lib64/libstdc++.so.6" 36 | 37 | -------------------------------------------------------------------------------- /train/structure_summit.lsf: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | #BSUB -P ABC123 3 | #BSUB -W 2:00 4 | #BSUB -nnodes 16 5 | #BSUB -q batch 6 | #BSUB -J "structure[1]" 7 | #BSUB -o structure.o%J 8 | #BSUB -e structure.e%J 9 | 10 | export NENSEMBLE=1 11 | 12 | module load gcc/9.3.0 13 | module load open-ce/1.5.0-py39-0 14 | conda activate /autofs/nccs-svm1_proj/bif136/summit-env 15 | 16 | export HF_DATASETS_CACHE=/tmp 17 | export TRANSFORMERS_CACHE=/tmp 18 | export TRANSFORMERS_OFFLINE=1 19 | export HF_DATASETS_OFFLINE=1 20 | 21 | export OMP_NUM_THREADS=1 22 | export PYTHONUNBUFFERED=1 23 | export TOKENIZERS_PARALLELISM=false 24 | 25 | export NNODES=`echo $LSB_MCPU_HOSTS | awk '{for (j=3; j <= NF; j+=2) { print $j }}' | wc -l` 26 | export NNODES=$((${NNODES}/${NENSEMBLE})) 27 | export NWORKERS=$((${NNODES}*6)) 28 | 29 | export TRAIN_DATASET='/path/to/train/set' 30 | export TEST_DATASET='/path/to/test/set' 31 | 32 | export NCROSS=3 33 | export LR=1e-5 34 | export PER_DEVICE_BATCH_SIZE=1 35 | export CLUSTER=summit 36 | export BATCH_SIZE=$((${PER_DEVICE_BATCH_SIZE}*${NWORKERS})) 37 | 38 | export TORCH_DISTRIBUTED_DEBUG=INFO 39 | 40 | for ENSEMBLE_ID in `seq 1 ${NENSEMBLE}`; do 41 | export GLOBAL_ID=$(((${LSB_JOBINDEX}-1)*${NENSEMBLE}+${ENSEMBLE_ID})) 42 | export ID_STR=${CLUSTER}_bs${BATCH_SIZE}_lr${LR}_ncross_${NCROSS}_${GLOBAL_ID} 43 | 44 | final_checkpoint=`ls -t ./results_${ID_STR} | grep "checkpoint" | head -1 | awk 'BEGIN {FS="-"} {printf $2}'` 45 | echo "Removing ./results_${ID_STR}/checkpoint-$final_checkpoint to ensure we progress in training" 46 | rm -r ./results_${ID_STR}/checkpoint-$final_checkpoint 47 | 48 | export LD_PRELOAD="${OLCF_GCC_ROOT}/lib64/libstdc++.so.6" 49 | jsrun -n ${NNODES} -g 6 -a 6 -c 42 python ../train.py \ 50 | --smiles_tokenizer_dir='/gpfs/alpine/world-shared/med106/blnchrd/models/bert_large_plus_clean_regex/tokenizer'\ 51 | --smiles_model_dir='/gpfs/alpine/world-shared/med106/blnchrd/automatedmutations/pretraining/run/job_86neeM/output'\ 52 | --model_type='regex' \ 53 | --seq_model_name='Rostlab/prot_bert_bfd'\ 54 | --train_dataset=${TRAIN_DATASET}\ 55 | --test_dataset=${TEST_DATASET}\ 56 | --train_size=13753\ 57 | --n_cross_attn=${NCROSS}\ 58 | --output_dir=./results_${ID_STR}\ 59 | --max_steps=150000\ 60 | --per_device_train_batch_size=${PER_DEVICE_BATCH_SIZE}\ 61 | --per_device_eval_batch_size=${PER_DEVICE_BATCH_SIZE}\ 62 | --learning_rate=${LR}\ 63 | --weight_decay=0.01\ 64 | --logging_dir=./logs_${ID_STR}\ 65 | --logging_steps=1\ 66 | --lr_scheduler_type=constant_with_warmup\ 67 | --evaluation_strategy="steps"\ 68 | --eval_steps=100\ 69 | --gradient_accumulation_steps=1\ 70 | --fp16=False\ 71 | --save_strategy="steps"\ 72 | --save_steps=100\ 73 | --warmup_steps=10\ 74 | --optim=adafactor\ 75 | --ignore_data_skip\ 76 | --gradient_checkpointing\ 77 | --seed=$((42+${GLOBAL_ID})) & 78 | done 79 | wait 80 | -------------------------------------------------------------------------------- /unit_test/ATTNvsLMATTN.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import sys, os 5 | sys.path.insert(1, os.path.join(sys.path[0], '../contact_pred/')) 6 | from transformers import BertModel, BertConfig 7 | from transformers.models.bert.modeling_bert import BertAttention, BertSelfAttention 8 | from transformers import PreTrainedModel, PretrainedConfig 9 | from linear_mem_attn import attention 10 | from structure import IPAConfig 11 | from modules import LinearMemAttentionWithScoreOutput 12 | 13 | torch.manual_seed(42) 14 | 15 | if __name__ == "__main__": 16 | 17 | # The attention_probs_dropout_prob defaults to 0.1 in BertConfig 18 | # but in linear memory attention there is no dropout layer 19 | # (i.e. attention_probs_dropout_prob is always 0) 20 | # So we set that to 0 to compare the results 21 | seq_config = BertConfig(attention_probs_dropout_prob=0) 22 | 23 | npts=256 24 | 25 | query = torch.rand(1, npts, seq_config.hidden_size) 26 | key = value = query 27 | print(query.shape) 28 | 29 | # Regular memory attention 30 | ATTN = BertSelfAttention(seq_config) 31 | out_ATTN = ATTN(hidden_states = query, encoder_hidden_states = key) 32 | print('Output from original attention') 33 | print(out_ATTN, out_ATTN[0].shape) 34 | print() 35 | 36 | # Linear memory attention 37 | LMATTN = LinearMemAttentionWithScoreOutput(seq_config) 38 | LMATTN.query = ATTN.query 39 | LMATTN.key = ATTN.key 40 | LMATTN.value = ATTN.value 41 | 42 | # Choose a smaller chunk size to test if chunking actually works 43 | out_LMATTN = LMATTN(hidden_states = query, encoder_hidden_states = key, query_chunk_size=16, key_chunk_size=64) 44 | print('Output from linear memory attention') 45 | print(out_LMATTN, out_LMATTN[0].shape) 46 | -------------------------------------------------------------------------------- /unit_test/IPCAvsLMIPCA.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import sys, os 5 | sys.path.insert(1, os.path.join(sys.path[0], '../contact_pred/')) 6 | from IPCA import InvariantPointCrossAttention, LinearMemInvariantPointCrossAttention 7 | from structure import IPAConfig 8 | 9 | torch.manual_seed(42) 10 | 11 | if __name__ == "__main__": 12 | seq_config = IPAConfig() 13 | smiles_config = IPAConfig() 14 | 15 | npts = 256 16 | 17 | query = torch.rand(1, npts, seq_config.bert_config['hidden_size']) 18 | key = value = query 19 | 20 | pair = torch.rand(1, 21 | seq_config.bert_config['num_attention_heads'] + smiles_config.bert_config['num_attention_heads'], 22 | npts, 23 | npts) 24 | 25 | initial_receptor_xyz = torch.rand(query.size()[:2]+(3, ))#, device=hidden_seq.device, dtype=hidden_seq.dtype) 26 | initial_receptor_rot = torch.rand(3, 3).unsqueeze(0).unsqueeze(0).repeat(query.size()[0], query.size()[1], 1, 1) 27 | 28 | translations_receptor = initial_receptor_xyz 29 | rotations_receptor = initial_receptor_rot 30 | 31 | translations_ligand = torch.rand(query.size()[:2]+(3, )) 32 | rotations_ligand = torch.rand(3, 3).unsqueeze(0).unsqueeze(0).repeat(query.size()[0], query.size()[1], 1, 1) 33 | 34 | IPCA = InvariantPointCrossAttention(seq_config, smiles_config) 35 | LMIPCA = LinearMemInvariantPointCrossAttention(seq_config, smiles_config) 36 | 37 | # use the same weights for both layers 38 | 39 | LMIPCA.query = IPCA.query 40 | LMIPCA.key = IPCA.key 41 | LMIPCA.value = IPCA.value 42 | LMIPCA.query_point = IPCA.query_point 43 | LMIPCA.key_point = IPCA.key_point 44 | LMIPCA.value_point = IPCA.value_point 45 | LMIPCA.head_weight = IPCA.head_weight 46 | LMIPCA.pair_attention = IPCA.pair_attention 47 | LMIPCA.output_layer = IPCA.output_layer 48 | 49 | out_IPCA = IPCA(hidden_states = query, encoder_hidden_states = key, pair_representation = pair, 50 | rigid_rotations = rotations_receptor, rigid_translations = translations_receptor, 51 | encoder_rigid_rotations = rotations_ligand, encoder_rigid_translations = translations_receptor) 52 | print('Output from original IPCA') 53 | print(out_IPCA, out_IPCA.shape) 54 | print() 55 | out_LMIPCA = LMIPCA(hidden_states = query, encoder_hidden_states = key, pair_representation = pair, 56 | rigid_rotations = rotations_receptor, rigid_translations = translations_receptor, 57 | encoder_rigid_rotations = rotations_ligand, encoder_rigid_translations = translations_receptor, 58 | query_chunk_size = 16, key_chunk_size = 64) # Set chunk size to < npts to make sure chunking actually happened 59 | print('Output from linear memory IPCA') 60 | print(out_LMIPCA, out_LMIPCA.shape) 61 | -------------------------------------------------------------------------------- /unit_test/README.md: -------------------------------------------------------------------------------- 1 | 2 | IPCAvsLMIPCA.py: Test the two implementations of the same layer with same input. If the test passes 3 | they should output the same. 4 | 5 | ```python 6 | python IPCAvsLMIPCA.py 7 | 8 | Output from original IPCA 9 | tensor([[[-0.1077, -0.0912, -0.0562, ..., -0.1033, 0.1245, 0.2865], 10 | [-0.1094, -0.0763, -0.0503, ..., -0.1108, 0.1063, 0.2816], 11 | [-0.0981, -0.0843, -0.0585, ..., -0.1037, 0.1161, 0.2811], 12 | ..., 13 | [-0.1076, -0.0802, -0.0562, ..., -0.0972, 0.1244, 0.2913], 14 | [-0.1035, -0.0892, -0.0643, ..., -0.1153, 0.1109, 0.2809], 15 | [-0.1043, -0.0743, -0.0568, ..., -0.1010, 0.1181, 0.2886]]], 16 | grad_fn=) torch.Size([1, 256, 768]) 17 | 18 | Output from linear memory IPCA 19 | tensor([[[-0.1077, -0.0912, -0.0562, ..., -0.1033, 0.1245, 0.2865], 20 | [-0.1094, -0.0763, -0.0503, ..., -0.1108, 0.1063, 0.2816], 21 | [-0.0981, -0.0843, -0.0585, ..., -0.1037, 0.1161, 0.2811], 22 | ..., 23 | [-0.1076, -0.0802, -0.0562, ..., -0.0972, 0.1244, 0.2913], 24 | [-0.1035, -0.0892, -0.0643, ..., -0.1153, 0.1109, 0.2809], 25 | [-0.1043, -0.0743, -0.0568, ..., -0.1010, 0.1181, 0.2886]]], 26 | grad_fn=) torch.Size([1, 256, 768]) 27 | ``` 28 | 29 | 30 | ATTNvsLMATTN.py: Test the two implementationso of the BertSelfAttention layer with same input, 31 | setting no dropout prob for consistency. They should be the same if the test passes. 32 | 33 | ``` 34 | python ATTNvsLMATTN.py 35 | 36 | Output from original attention 37 | (tensor([[[-0.0304, 0.0459, -0.2800, ..., 0.0776, -0.0980, -0.6149], 38 | [-0.0310, 0.0453, -0.2807, ..., 0.0777, -0.0980, -0.6151], 39 | [-0.0305, 0.0448, -0.2808, ..., 0.0781, -0.0976, -0.6154], 40 | ..., 41 | [-0.0309, 0.0459, -0.2802, ..., 0.0773, -0.0980, -0.6152], 42 | [-0.0310, 0.0451, -0.2805, ..., 0.0781, -0.0981, -0.6150], 43 | [-0.0310, 0.0447, -0.2805, ..., 0.0781, -0.0978, -0.6150]]], 44 | grad_fn=),) torch.Size([1, 256, 768]) 45 | 46 | Output from linear memory attention 47 | (tensor([[[-0.0304, 0.0459, -0.2800, ..., 0.0776, -0.0980, -0.6149], 48 | [-0.0310, 0.0453, -0.2807, ..., 0.0777, -0.0980, -0.6151], 49 | [-0.0305, 0.0448, -0.2808, ..., 0.0781, -0.0976, -0.6154], 50 | ..., 51 | [-0.0309, 0.0459, -0.2802, ..., 0.0773, -0.0980, -0.6152], 52 | [-0.0310, 0.0451, -0.2805, ..., 0.0781, -0.0981, -0.6150], 53 | [-0.0310, 0.0447, -0.2805, ..., 0.0781, -0.0978, -0.6150]]], 54 | grad_fn=),) torch.Size([1, 256, 768]) 55 | ``` 56 | -------------------------------------------------------------------------------- /unit_test/test_kabsch_rmsd.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import sys, os 5 | sys.path.insert(1, os.path.join(sys.path[0], '../contact_pred/')) 6 | from structure import compute_kabsch_RMSD 7 | 8 | a = torch.Tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]], [[0, 0, 0], [1, 0, 0], [2, 0, 0]]]) 9 | b = torch.Tensor([[[0, 0, 0], [999, 0, 0], [2, 0, 0]], [[0, 0, 0], [0, 3, 0], [0, 6, 0]]]) 10 | w = torch.Tensor([[1, 0, 1], [1, 1, 1]]) 11 | #print(a, b, a.shape, b.shape) 12 | #print(w, w.shape) 13 | 14 | print(compute_kabsch_RMSD(a, b, weight=w)) 15 | -------------------------------------------------------------------------------- /utils/token_coords.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | import numpy as np 3 | 4 | import re 5 | 6 | # all punctuation 7 | punctuation_regex = r"""(\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])""" 8 | 9 | # tokenization regex (Schwaller) 10 | molecule_regex = r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])""" 11 | 12 | def get_token_coords(mol): 13 | smi = Chem.MolToSmiles(mol) 14 | 15 | # position of atoms in SMILES (not counting punctuation) 16 | atom_order = [int(s) for s in list(filter(None,re.sub(r'[\[\]]','',mol.GetProp("_smilesAtomOutputOrder")).split(',')))] 17 | 18 | # tokenize the SMILES 19 | tokens = list(filter(None, re.split(molecule_regex, smi))) 20 | 21 | # remove punctuation 22 | masked_tokens = [re.sub(punctuation_regex,'',s) for s in tokens] 23 | 24 | k = 0 25 | token_pos = [] 26 | atom_idx = [] 27 | for i,token in enumerate(masked_tokens): 28 | if token != '': 29 | token_pos.append(tuple(mol.GetConformer().GetAtomPosition(atom_order[k]))) 30 | atom_idx.append(atom_order[k]) 31 | k += 1 32 | else: 33 | token_pos.append([np.nan, np.nan, np.nan]) 34 | atom_idx.append(None) 35 | 36 | return smi, token_pos, tokens, atom_idx 37 | --------------------------------------------------------------------------------