├── .gitignore ├── requirements.txt ├── e1 ├── get_e1_weights.py └── tokenizer.json ├── licenses ├── fastesm_license.txt ├── e1_license.txt └── ankh_license.txt ├── esm2 └── get_esm2_weights.py ├── exp.ipynb ├── wip └── t5 │ ├── transformers_flex_attn.py │ ├── t5_attention.py │ ├── t5_flex_attention.py │ └── test_t5_flex_attention.py ├── test_scripts ├── test_precision_difference.py ├── test_difference_esmfast.py ├── test_attentions.py ├── test_embedding.py ├── test_compliance_esmc.py ├── test_compliance_esm2.py ├── test_throughput_esmfast.py ├── test_contact_maps.py └── test_throughput.py ├── esm_plusplus └── get_esmc_weights.py ├── README.md ├── update_HF.py ├── readmes ├── fastesm2_readme.md ├── e1_readme.md ├── fastesm_650_readme.md ├── esm_plusplus_large_readme.md └── esm_plusplus_small_readme.md ├── pooler.py └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | *.png 4 | *.pth 5 | *.pt 6 | *.safetensors 7 | *.json 8 | !e1/*.json 9 | *.bin 10 | /results_classification_lora 11 | /results_regression_lora 12 | .qodo 13 | *.db 14 | /esm 15 | *.nbc 16 | *.nbi 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.6.0 2 | matplotlib>=3.5.0 3 | transformers>=4.47.0 4 | numpy>=1.26.2 5 | einops 6 | esm 7 | datasets>=2.14.0 8 | scikit-learn>=1.0.0 9 | scipy>=1.7.0 10 | seaborn>=0.12.0 11 | peft>=0.5.0 12 | accelerate>=1.1.0 13 | networkx -------------------------------------------------------------------------------- /e1/get_e1_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from huggingface_hub import login 3 | 4 | from e1.modeling_e1 import E1ForMaskedLM, E1Config 5 | 6 | 7 | model_dict = { 8 | 'Profluent-E1-150M': 'Profluent-Bio/E1-150m', 9 | 'Profluent-E1-300M': 'Profluent-Bio/E1-300m', 10 | 'Profluent-E1-600M': 'Profluent-Bio/E1-600m', 11 | } 12 | 13 | 14 | if __name__ == "__main__": 15 | # py -m e1.get_e1_weights 16 | login() 17 | for model_name in model_dict: 18 | config = E1Config.from_pretrained(model_dict[model_name]) 19 | config.auto_map = { 20 | "AutoConfig": "modeling_e1.E1Config", 21 | "AutoModel": "modeling_e1.E1Model", 22 | "AutoModelForMaskedLM": "modeling_e1.E1ForMaskedLM", 23 | "AutoModelForSequenceClassification": "modeling_e1.E1ForSequenceClassification", 24 | "AutoModelForTokenClassification": "modeling_e1.E1ForTokenClassification" 25 | } 26 | model = E1ForMaskedLM.from_pretrained(model_dict[model_name], dtype=torch.bfloat16) 27 | model.push_to_hub('Synthyra/' + model_name) 28 | -------------------------------------------------------------------------------- /licenses/fastesm_license.txt: -------------------------------------------------------------------------------- 1 | License for FastESM models, from the ESM2 repo https://github.com/facebookresearch/esm 2 | 3 | MIT License 4 | 5 | Copyright (c) Meta Platforms, Inc. and affiliates. 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /esm2/get_esm2_weights.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from huggingface_hub import login 4 | from transformers import EsmForMaskedLM 5 | from modeling_fastesm import FastEsmForMaskedLM, FastEsmConfig 6 | 7 | 8 | model_dict = { 9 | # Synthyra/ESM2-8M 10 | 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D', 11 | # Synthyra/ESM2-35M 12 | 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D', 13 | # Synthyra/ESM2-150M 14 | 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D', 15 | # Synthyra/ESM2-650M 16 | 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D', 17 | # Synthyra/ESM2-3B 18 | 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D', 19 | } 20 | 21 | 22 | if __name__ == "__main__": 23 | #login() 24 | for model_name in model_dict: 25 | config = FastEsmConfig.from_pretrained(model_dict[model_name]) 26 | config.auto_map = { 27 | "AutoConfig": "modeling_fastesm.FastEsmConfig", 28 | "AutoModel": "modeling_fastesm.FastEsmModel", 29 | "AutoModelForMaskedLM": "modeling_fastesm.FastEsmForMaskedLM", 30 | "AutoModelForSequenceClassification": "modeling_fastesm.FastEsmForSequenceClassification", 31 | "AutoModelForTokenClassification": "modeling_fastesm.FastEsmForTokenClassification" 32 | } 33 | config.tie_word_embeddings = False 34 | original_model = EsmForMaskedLM.from_pretrained(model_dict[model_name]) 35 | model = FastEsmForMaskedLM(config=config).from_pretrained(model_dict[model_name], config=config) 36 | model.lm_head.load_state_dict(original_model.lm_head.state_dict()) 37 | model.lm_head = copy.deepcopy(model.lm_head) 38 | model.push_to_hub('Synthyra/' + model_name) 39 | 40 | for name1, param1 in model.named_parameters(): 41 | for name2, param2 in original_model.named_parameters(): 42 | if name1 == name2: 43 | assert param1.shape == param2.shape, f'{name1} {param1.shape} != {name2} {param2.shape}' 44 | assert torch.equal(param1.data.clone(), param2.data.clone()), f'{name1} {name2}' 45 | 46 | -------------------------------------------------------------------------------- /exp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from modeling_esm_plusplus import ESMplusplusModel" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 3, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "torch.Size([2, 11, 1152])\n" 22 | ] 23 | }, 24 | { 25 | "data": { 26 | "text/plain": [ 27 | "36" 28 | ] 29 | }, 30 | "execution_count": 3, 31 | "metadata": {}, 32 | "output_type": "execute_result" 33 | } 34 | ], 35 | "source": [ 36 | "model = ESMplusplusModel.from_pretrained('Synthyra/ESMplusplus_large')\n", 37 | "tokenizer = model.tokenizer\n", 38 | "\n", 39 | "sequences = ['MPRTEIN', 'MSEQWENCE']\n", 40 | "tokenized = tokenizer(sequences, padding=True, return_tensors='pt')\n", 41 | "\n", 42 | "# tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training\n", 43 | "\n", 44 | "output = model(**tokenized) # get all hidden states with output_hidden_states=True\n", 45 | "print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 960)\n", 46 | "#print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)\n", 47 | "\n", 48 | "output = model(**tokenized, output_attentions=True)\n", 49 | "att = output.attentions\n", 50 | "len(att) # 30, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each" 51 | ] 52 | } 53 | ], 54 | "metadata": { 55 | "kernelspec": { 56 | "display_name": "Python 3", 57 | "language": "python", 58 | "name": "python3" 59 | }, 60 | "language_info": { 61 | "codemirror_mode": { 62 | "name": "ipython", 63 | "version": 3 64 | }, 65 | "file_extension": ".py", 66 | "mimetype": "text/x-python", 67 | "name": "python", 68 | "nbconvert_exporter": "python", 69 | "pygments_lexer": "ipython3", 70 | "version": "3.11.8" 71 | } 72 | }, 73 | "nbformat": 4, 74 | "nbformat_minor": 2 75 | } 76 | -------------------------------------------------------------------------------- /wip/t5/transformers_flex_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Tuple 3 | from torch.nn.attention.flex_attention import flex_attention 4 | 5 | 6 | def flex_attention_forward( 7 | query: torch.Tensor, 8 | key: torch.Tensor, 9 | value: torch.Tensor, 10 | attention_mask: Optional[torch.Tensor], 11 | scaling: Optional[float] = None, 12 | softcap: Optional[float] = None, 13 | head_mask: Optional[torch.Tensor] = None, 14 | **kwargs, 15 | ) -> Tuple[torch.Tensor, torch.Tensor]: 16 | causal_mask = attention_mask 17 | if causal_mask is not None: 18 | causal_mask = causal_mask[:, :, :, : key.shape[-2]] 19 | 20 | def causal_mod(score, b, h, q_idx, kv_idx): 21 | if softcap is not None: 22 | score = softcap * torch.tanh(score / softcap) 23 | if causal_mask is not None: 24 | score = score + causal_mask[b][0][q_idx][kv_idx] 25 | if head_mask is not None: 26 | score = score + head_mask[b][h][0][0] 27 | return score 28 | 29 | attn_output, attention_weights = flex_attention( 30 | query, 31 | key, 32 | value, 33 | score_mod=causal_mod, 34 | enable_gqa=True, 35 | scale=scaling, 36 | # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. 37 | # For simplification, we thus always return it as no additional computations are introduced. 38 | return_lse=True, 39 | ) 40 | # lse is returned in float32 41 | attention_weights = attention_weights.to(value.dtype) 42 | attn_output = attn_output.transpose(1, 2).contiguous() 43 | 44 | return attn_output, attention_weights 45 | 46 | 47 | if __name__ == "__main__": 48 | # py -m wip.t5.transformers_flex_attn 49 | batch_size, seq_len, n_heads, head_dim = 1, 10, 12, 64 50 | 51 | query = torch.randn(batch_size, n_heads, seq_len, head_dim) 52 | key = torch.randn(batch_size, n_heads, seq_len, head_dim) 53 | value = torch.randn(batch_size, n_heads, seq_len, head_dim) 54 | attention_mask = torch.randint(0, 2, (batch_size, seq_len)) 55 | scaling = 1.0 56 | softcap = 2.0 57 | head_mask = None 58 | 59 | attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool() 60 | 61 | attn_output, attention_weights = flex_attention_forward( 62 | query, 63 | key, 64 | value, 65 | attention_mask, 66 | scaling, 67 | softcap, 68 | head_mask, 69 | ) 70 | print(attn_output.shape) 71 | print(attention_weights.shape) 72 | -------------------------------------------------------------------------------- /test_scripts/test_precision_difference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import random 4 | import argparse 5 | from huggingface_hub import login 6 | from transformers import AutoModelForMaskedLM 7 | from tqdm.auto import tqdm 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--model_path', type=str, default='Synthyra/ESMplusplus_small') 12 | parser.add_argument('--token', type=str, default=None) 13 | args = parser.parse_args() 14 | 15 | if args.token: 16 | login(args.token) 17 | 18 | model_path = args.model_path 19 | canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY" 20 | length = 128 21 | seq_count = 1000 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | 24 | 25 | def generate_random_sequence(length: int) -> str: 26 | return 'M' + "".join(random.choices(canonical_amino_acids, k=length-3)) 27 | 28 | 29 | # Generate sequences first 30 | sequences = [generate_random_sequence(length) for _ in range(seq_count)] 31 | 32 | 33 | # Get base model outputs 34 | base_outputs = [] 35 | model = AutoModelForMaskedLM.from_pretrained(model_path, trust_remote_code=True).to(device) 36 | tokenizer = model.tokenizer 37 | with torch.no_grad(): 38 | for seq in tqdm(sequences): 39 | input = tokenizer(seq, return_tensors="pt").to(device) 40 | embeddings = model(**input).last_hidden_state.cpu() 41 | base_outputs.append(embeddings) 42 | model.cpu() 43 | del model 44 | torch.cuda.empty_cache() 45 | 46 | 47 | # Get fp16 outputs 48 | fp16_mse = 0 49 | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).to(device) 50 | with torch.no_grad(): 51 | for i, seq in tqdm(enumerate(sequences), total=len(sequences)): 52 | input = tokenizer(seq, return_tensors="pt").to(device) 53 | fp16_output = model(**input).last_hidden_state.float().cpu() 54 | fp16_mse += F.mse_loss(base_outputs[i], fp16_output).item() 55 | model.cpu() 56 | del model 57 | torch.cuda.empty_cache() 58 | 59 | 60 | # Get bfloat16 outputs 61 | bf16_mse = 0 62 | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.bfloat16, trust_remote_code=True).to(device) 63 | with torch.no_grad(): 64 | for i, seq in tqdm(enumerate(sequences), total=len(sequences)): 65 | input = tokenizer(seq, return_tensors="pt").to(device) 66 | bf16_output = model(**input).last_hidden_state.float().cpu() 67 | bf16_mse += F.mse_loss(base_outputs[i], bf16_output).item() 68 | model.cpu() 69 | del model 70 | torch.cuda.empty_cache() 71 | 72 | fp16_mse /= seq_count 73 | bf16_mse /= seq_count 74 | 75 | print(f"Average MSE for FP16: {fp16_mse:.8f}") 76 | print(f"Average MSE for BF16: {bf16_mse:.8f}") 77 | print(f"{'FP16' if fp16_mse < bf16_mse else 'BF16'} has lower MSE") 78 | -------------------------------------------------------------------------------- /test_scripts/test_difference_esmfast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import random 4 | import argparse 5 | from huggingface_hub import login 6 | from transformers import AutoModelForMaskedLM, EsmTokenizer 7 | from tqdm.auto import tqdm 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--model_path', type=str, default='lhallee/synthyra_esm2_650_mlm') 12 | parser.add_argument('--token', type=str, default=None) 13 | args = parser.parse_args() 14 | 15 | if args.token: 16 | login(args.token) 17 | 18 | model_path = args.model_path 19 | canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY" 20 | length = 128 21 | seq_count = 1000 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | 24 | 25 | def generate_random_sequence(length: int) -> str: 26 | return 'M' + "".join(random.choices(canonical_amino_acids, k=length)) 27 | 28 | 29 | tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D') 30 | # Generate sequences first 31 | sequences = [generate_random_sequence(length) for _ in range(seq_count)] 32 | inputs = [tokenizer(seq, return_tensors="pt").to(device) for seq in sequences] 33 | 34 | 35 | # Get base model outputs 36 | base_outputs = [] 37 | model = AutoModelForMaskedLM.from_pretrained(model_path, trust_remote_code=True).to(device) 38 | with torch.no_grad(): 39 | for input_batch in tqdm(inputs): 40 | base_outputs.append(model(**input_batch, output_hidden_states=True).hidden_states[-1].cpu()) 41 | model.cpu() 42 | del model 43 | torch.cuda.empty_cache() 44 | 45 | 46 | # Get fp16 outputs 47 | fp16_mse = 0 48 | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).to(device) 49 | with torch.no_grad(): 50 | for i, input_batch in tqdm(enumerate(inputs), total=len(inputs)): 51 | fp16_output = model(**input_batch, output_hidden_states=True).hidden_states[-1].float().cpu() 52 | fp16_mse += F.mse_loss(base_outputs[i], fp16_output).item() 53 | model.cpu() 54 | del model 55 | torch.cuda.empty_cache() 56 | 57 | 58 | # Get bfloat16 outputs 59 | bf16_mse = 0 60 | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.bfloat16, trust_remote_code=True).to(device) 61 | with torch.no_grad(): 62 | for i, input_batch in tqdm(enumerate(inputs), total=len(inputs)): 63 | bf16_output = model(**input_batch, output_hidden_states=True).hidden_states[-1].float().cpu() 64 | bf16_mse += F.mse_loss(base_outputs[i], bf16_output).item() 65 | model.cpu() 66 | del model 67 | torch.cuda.empty_cache() 68 | 69 | fp16_mse /= seq_count 70 | bf16_mse /= seq_count 71 | 72 | print(f"Average MSE for FP16: {fp16_mse:.8f}") 73 | print(f"Average MSE for BF16: {bf16_mse:.8f}") 74 | print(f"{'FP16' if fp16_mse < bf16_mse else 'BF16'} has lower MSE") 75 | -------------------------------------------------------------------------------- /esm_plusplus/get_esmc_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import copy 4 | from functools import cache 5 | from pathlib import Path 6 | from huggingface_hub import snapshot_download, login 7 | 8 | from esm_plusplus.modeling_esm_plusplus import ESMplusplusForMaskedLM, ESMplusplusConfig 9 | 10 | 11 | @staticmethod 12 | @cache 13 | def data_root(model: str): 14 | if "INFRA_PROVIDER" in os.environ: 15 | return Path("") 16 | # Try to download from hugginface if it doesn't exist 17 | if model.startswith("esmc-300"): 18 | path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12")) 19 | elif model.startswith("esmc-600"): 20 | path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12")) 21 | else: 22 | raise ValueError(f"{model=} is an invalid model name.") 23 | return path 24 | 25 | 26 | def ESMplusplus_300M(device: torch.device | str = "cpu"): 27 | with torch.device(device): 28 | config = ESMplusplusConfig( 29 | hidden_size=960, 30 | num_attention_heads=15, 31 | num_hidden_layers=30, 32 | ) 33 | model = ESMplusplusForMaskedLM(config) 34 | state_dict = torch.load( 35 | data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth", 36 | map_location=device, 37 | ) 38 | model.load_state_dict(state_dict) 39 | return model 40 | 41 | 42 | def ESMplusplus_600M(device: torch.device | str = "cpu"): 43 | with torch.device(device): 44 | config = ESMplusplusConfig( 45 | hidden_size=1152, 46 | num_attention_heads=18, 47 | num_hidden_layers=36, 48 | ) 49 | model = ESMplusplusForMaskedLM(config) 50 | state_dict = torch.load( 51 | data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth", 52 | map_location=device, 53 | ) 54 | model.load_state_dict(state_dict) 55 | return model 56 | 57 | 58 | if __name__ == "__main__": 59 | #login() 60 | 61 | model_dict = { 62 | # Synthyra/ESM++small 63 | 'Synthyra/ESMplusplus_small': ESMplusplus_300M, 64 | # Synthyra/ESM++large 65 | 'Synthyra/ESMplusplus_large': ESMplusplus_600M, 66 | } 67 | 68 | 69 | for model_path, model_fn in model_dict.items(): 70 | model = model_fn() 71 | model.sequence_head = copy.deepcopy(model.sequence_head) 72 | model.config.auto_map = { 73 | "AutoConfig": "modeling_esm_plusplus.ESMplusplusConfig", 74 | "AutoModel": "modeling_esm_plusplus.ESMplusplusModel", 75 | "AutoModelForMaskedLM": "modeling_esm_plusplus.ESMplusplusForMaskedLM", 76 | "AutoModelForSequenceClassification": "modeling_esm_plusplus.ESMplusplusForSequenceClassification", 77 | "AutoModelForTokenClassification": "modeling_esm_plusplus.ESMplusplusForTokenClassification" 78 | } 79 | model.push_to_hub(model_path) 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastPLMs 2 | 3 | Gemini_Generated_Image_5bmmdc5bmmdc5bmm 4 | 5 | FastPLMs is an open-source effort to increase the efficiency of pretrained protein language models, switching out native attention implementations for Flash or Flex attention. 6 | 7 | All models can be loaded from Huggingface 🤗 transformers via `AutoModel`, this repository does not need to be cloned for most use cases. 8 | 9 | ## Supported models 10 | The currently supported models can be found [here](https://huggingface.co/collections/Synthyra/pretrained-plms-675351ecc050f63baedd77de). 11 | 12 | ## Suggestions 13 | Have suggestions, comments, or requests? Found a bug? Open a GitHub issue and we'll respond soon. 14 | 15 | ## Embed entire datasets with no new code 16 | To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take. 17 | 18 | Example: 19 | ```python 20 | embedding_dict = model.embed_dataset( 21 | sequences=[ 22 | 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences 23 | ], 24 | batch_size=2, # adjust for your GPU memory 25 | max_len=512, # adjust for your needs 26 | full_embeddings=False, # if True, no pooling is performed 27 | embed_dtype=torch.float32, # cast to what dtype you want 28 | pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together 29 | num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets 30 | sql=False, # if True, embeddings will be stored in SQLite database 31 | sql_db_path='embeddings.db', 32 | save=True, # if True, embeddings will be saved as a .pth file 33 | save_path='embeddings.pth', 34 | ) 35 | # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql 36 | ``` 37 | 38 | ``` 39 | model.embed_dataset() 40 | Args: 41 | sequences: List of protein sequences 42 | batch_size: Batch size for processing 43 | max_len: Maximum sequence length 44 | full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False) 45 | pooling_type: Type of pooling ('mean' or 'cls') 46 | num_workers: Number of workers for data loading, 0 for the main process 47 | sql: Whether to store embeddings in SQLite database - will be stored in float32 48 | sql_db_path: Path to SQLite database 49 | 50 | Returns: 51 | Dictionary mapping sequences to embeddings, or None if sql=True 52 | 53 | Note: 54 | - If sql=True, embeddings can only be stored in float32 55 | - sql is ideal if you need to stream a very large dataset for training in real-time 56 | - save=True is ideal if you can store the entire embedding dictionary in RAM 57 | - sql will be used if it is True and save is True or False 58 | - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences 59 | - Sequences will be truncated to max_len and sorted by length in descending order for faster processing 60 | ``` 61 | 62 | ## Upcoming releases 63 | A Fast version of ANKH is in progress. It is functional but is still currently native attention, we are waiting for bias gradient support in [FlexAttention](https://pytorch.org/blog/flexattention/). 64 | -------------------------------------------------------------------------------- /e1/tokenizer.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "1.0", 3 | "truncation": null, 4 | "padding": { 5 | "strategy": "BatchLongest", 6 | "direction": "Right", 7 | "pad_to_multiple_of": null, 8 | "pad_id": 0, 9 | "pad_type_id": 0, 10 | "pad_token": "" 11 | }, 12 | "added_tokens": [ 13 | { 14 | "id": 0, 15 | "content": "", 16 | "single_word": false, 17 | "lstrip": false, 18 | "rstrip": false, 19 | "normalized": false, 20 | "special": true 21 | }, 22 | { 23 | "id": 1, 24 | "content": "", 25 | "single_word": false, 26 | "lstrip": false, 27 | "rstrip": false, 28 | "normalized": false, 29 | "special": true 30 | }, 31 | { 32 | "id": 2, 33 | "content": "", 34 | "single_word": false, 35 | "lstrip": false, 36 | "rstrip": false, 37 | "normalized": false, 38 | "special": true 39 | }, 40 | { 41 | "id": 3, 42 | "content": "", 43 | "single_word": false, 44 | "lstrip": false, 45 | "rstrip": false, 46 | "normalized": false, 47 | "special": true 48 | }, 49 | { 50 | "id": 4, 51 | "content": "", 52 | "single_word": false, 53 | "lstrip": false, 54 | "rstrip": false, 55 | "normalized": false, 56 | "special": true 57 | }, 58 | { 59 | "id": 5, 60 | "content": "?", 61 | "single_word": false, 62 | "lstrip": false, 63 | "rstrip": false, 64 | "normalized": false, 65 | "special": true 66 | } 67 | ], 68 | "normalizer": null, 69 | "pre_tokenizer": { 70 | "type": "ByteLevel", 71 | "add_prefix_space": false, 72 | "trim_offsets": true, 73 | "use_regex": true 74 | }, 75 | "post_processor": { 76 | "type": "ByteLevel", 77 | "add_prefix_space": true, 78 | "trim_offsets": true, 79 | "use_regex": true 80 | }, 81 | "decoder": { 82 | "type": "ByteLevel", 83 | "add_prefix_space": true, 84 | "trim_offsets": true, 85 | "use_regex": true 86 | }, 87 | "model": { 88 | "type": "BPE", 89 | "dropout": null, 90 | "unk_token": "X", 91 | "continuing_subword_prefix": null, 92 | "end_of_word_suffix": null, 93 | "fuse_unk": false, 94 | "byte_fallback": false, 95 | "vocab": { 96 | "": 0, 97 | "": 1, 98 | "": 2, 99 | "": 3, 100 | "": 4, 101 | "?": 5, 102 | "1": 6, 103 | "2": 7, 104 | "A": 8, 105 | "B": 9, 106 | "C": 10, 107 | "D": 11, 108 | "E": 12, 109 | "F": 13, 110 | "G": 14, 111 | "H": 15, 112 | "I": 16, 113 | "J": 17, 114 | "K": 18, 115 | "L": 19, 116 | "M": 20, 117 | "N": 21, 118 | "O": 22, 119 | "P": 23, 120 | "Q": 24, 121 | "R": 25, 122 | "S": 26, 123 | "T": 27, 124 | "U": 28, 125 | "V": 29, 126 | "W": 30, 127 | "X": 31, 128 | "Y": 32, 129 | "Z": 33 130 | }, 131 | "merges": [] 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /test_scripts/test_attentions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from datasets import load_dataset 4 | from esm_plusplus.modeling_esm_plusplus import ESMplusplusModel 5 | from tqdm import tqdm 6 | 7 | 8 | def test_attention_outputs(model, tokenizer, seqs, batch_size=4, tolerances=[1e-3, 1e-5, 1e-7, 1e-9]): 9 | """ 10 | Test if hidden states are the same with and without attention output at different tolerance levels. 11 | 12 | Args: 13 | model: The model to test 14 | tokenizer: The tokenizer to use 15 | seqs: List of sequences to process 16 | batch_size: Batch size for processing 17 | tolerances: List of tolerance values to test with torch.allclose 18 | 19 | Returns: 20 | dict: Results for each tolerance level 21 | """ 22 | results = {tol: True for tol in tolerances} 23 | max_diff = 0.0 24 | 25 | with torch.no_grad(): 26 | for i in tqdm(range(0, len(seqs), batch_size), desc='Testing attention outputs'): 27 | batch_seqs = seqs[i:i+batch_size] 28 | 29 | # Tokenize the batch 30 | tokenized = tokenizer(batch_seqs, padding=True, return_tensors='pt') 31 | tokenized = {k: v.to(model.device) for k, v in tokenized.items()} 32 | 33 | # Get output without attention 34 | output_no_att = model(**tokenized).last_hidden_state.detach().cpu() 35 | 36 | # Get output with attention 37 | output_with_att = model(**tokenized, output_attentions=True).last_hidden_state.detach().cpu() 38 | 39 | # Calculate maximum difference 40 | diff = (output_no_att - output_with_att).abs().max().item() 41 | max_diff = max(max_diff, diff) 42 | print(max_diff) 43 | 44 | # Check for NaN or infinite values 45 | has_nan_or_inf_no_att = torch.isnan(output_no_att).any() or torch.isinf(output_no_att).any() 46 | has_nan_or_inf_with_att = torch.isnan(output_with_att).any() or torch.isinf(output_with_att).any() 47 | if has_nan_or_inf_no_att or has_nan_or_inf_with_att: 48 | print(f"WARNING: Found NaN or infinite values in the outputs! No att: {has_nan_or_inf_no_att}, With att: {has_nan_or_inf_with_att}") 49 | 50 | # Test different tolerance levels 51 | for tol in tolerances: 52 | if not torch.allclose(output_no_att, output_with_att, atol=tol): 53 | results[tol] = False 54 | 55 | return results, max_diff 56 | 57 | 58 | if __name__ == '__main__': 59 | # py -m test_scripts.test_attentions 60 | parser = argparse.ArgumentParser(description='Test attention outputs in ESM++ models') 61 | parser.add_argument('--model', type=str, default='Synthyra/ESMplusplus_small', help='Model to test') 62 | parser.add_argument('--num_samples', type=int, default=100, help='Number of samples to test') 63 | parser.add_argument('--batch_size', type=int, default=4, help='Batch size for processing') 64 | args = parser.parse_args() 65 | 66 | # Load data 67 | seqs = load_dataset('Synthyra/NEGATOME', split='manual_stringent').filter(lambda x: len(x['SeqA']) <= 256).select(range(args.num_samples))['SeqA'] 68 | seqs = list(set(seqs)) 69 | # Load model 70 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 71 | model = ESMplusplusModel.from_pretrained(args.model).to(device) 72 | tokenizer = model.tokenizer 73 | 74 | # Define tolerance levels to test 75 | tolerances = [1e-2, 1e-4, 1e-6, 1e-8] 76 | 77 | # Run tests 78 | print(f"Testing model: {args.model}") 79 | print(f"Device: {device}") 80 | print(f"Testing {len(seqs)} sequences with batch size {args.batch_size}") 81 | 82 | results, max_diff = test_attention_outputs( 83 | model, 84 | tokenizer, 85 | seqs, 86 | batch_size=args.batch_size, 87 | tolerances=tolerances 88 | ) 89 | 90 | # Report results 91 | print("\nTest Results:") 92 | print(f"Maximum absolute difference: {max_diff:.10e}") 93 | print("\nTolerance tests:") 94 | for tol in sorted(tolerances): 95 | status = "PASSED" if results[tol] else "FAILED" 96 | print(f" Tolerance {tol:.0e}: {status}") 97 | 98 | # Overall result 99 | if all(results.values()): 100 | print("\nAll tests PASSED! Hidden states are identical with and without attention output.") 101 | else: 102 | min_passing_tol = min([tol for tol, passed in results.items() if passed], default=None) 103 | if min_passing_tol: 104 | print(f"\nTest PASSED at tolerance {min_passing_tol:.0e} and above.") 105 | else: 106 | print("\nAll tests FAILED. Hidden states differ significantly when output_attentions is True vs False.") -------------------------------------------------------------------------------- /update_HF.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import HfApi, login 2 | 3 | 4 | FAST_ESM_MODELS = [ 5 | 'Synthyra/ESM2-8M', 6 | 'Synthyra/ESM2-35M', 7 | 'Synthyra/ESM2-150M', 8 | 'Synthyra/ESM2-650M', 9 | 'Synthyra/ESM2-3B', 10 | 'Synthyra/FastESM2_650' 11 | ] 12 | 13 | ESM_PLUSPLUS_MODELS = [ 14 | 'Synthyra/ESMplusplus_small', 15 | 'Synthyra/ESMplusplus_large', 16 | ] 17 | 18 | E1_MODELS = [ 19 | 'Synthyra/Profluent-E1-150M', 20 | 'Synthyra/Profluent-E1-300M', 21 | 'Synthyra/Profluent-E1-600M', 22 | ] 23 | 24 | ANKH_MODELS = [ 25 | 'Synthyra/ANKH_base', 26 | 'Synthyra/ANKH_large', 27 | 'Synthyra/ANKH2_large' 28 | ] 29 | 30 | 31 | if __name__ == "__main__": 32 | # py -m update_HF 33 | import argparse 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--token', type=str, default=None) 36 | args = parser.parse_args() 37 | 38 | if args.token: 39 | login(token=args.token) 40 | 41 | api = HfApi() 42 | 43 | for path in FAST_ESM_MODELS: 44 | print(path.lower()) 45 | api.upload_file( 46 | path_or_fileobj="esm2/modeling_fastesm.py", 47 | path_in_repo="modeling_fastesm.py", 48 | repo_id=path, 49 | repo_type="model", 50 | ) 51 | # Upload license file for FastESM models 52 | api.upload_file( 53 | path_or_fileobj="LICENSE", 54 | path_in_repo="LICENSE", 55 | repo_id=path, 56 | repo_type="model", 57 | ) 58 | if 'esm2' in path.lower(): 59 | api.upload_file( 60 | path_or_fileobj="readmes/fastesm2_readme.md", 61 | path_in_repo="README.md", 62 | repo_id=path, 63 | repo_type="model", 64 | ) 65 | 66 | if 'fastesm' in path.lower(): 67 | api.upload_file( 68 | path_or_fileobj="readmes/fastesm_650_readme.md", 69 | path_in_repo="README.md", 70 | repo_id=path, 71 | repo_type="model", 72 | ) 73 | 74 | 75 | for path in ESM_PLUSPLUS_MODELS: 76 | print(path) 77 | api.upload_file( 78 | path_or_fileobj="esm_plusplus/modeling_esm_plusplus.py", 79 | path_in_repo="modeling_esm_plusplus.py", 80 | repo_id=path, 81 | repo_type="model", 82 | ) 83 | if path == 'Synthyra/ESMplusplus_small': 84 | api.upload_file( 85 | path_or_fileobj="readmes/esm_plusplus_small_readme.md", 86 | path_in_repo="README.md", 87 | repo_id=path, 88 | repo_type="model", 89 | ) 90 | # Upload license file for ESM++ small model 91 | api.upload_file( 92 | path_or_fileobj="LICENSE", 93 | path_in_repo="LICENSE", 94 | repo_id=path, 95 | repo_type="model", 96 | ) 97 | 98 | if path == 'Synthyra/ESMplusplus_large': 99 | api.upload_file( 100 | path_or_fileobj="readmes/esm_plusplus_large_readme.md", 101 | path_in_repo="README.md", 102 | repo_id=path, 103 | repo_type="model", 104 | ) 105 | # Upload license file for ESM++ large model 106 | api.upload_file( 107 | path_or_fileobj="LICENSE", 108 | path_in_repo="LICENSE", 109 | repo_id=path, 110 | repo_type="model", 111 | ) 112 | 113 | 114 | for path in E1_MODELS: 115 | print(path) 116 | api.upload_file( 117 | path_or_fileobj="e1/modeling_e1.py", 118 | path_in_repo="modeling_e1.py", 119 | repo_id=path, 120 | repo_type="model", 121 | ) 122 | # Upload license file for FastESM models 123 | api.upload_file( 124 | path_or_fileobj="LICENSE", 125 | path_in_repo="LICENSE", 126 | repo_id=path, 127 | repo_type="model", 128 | ) 129 | api.upload_file( 130 | path_or_fileobj="readmes/e1_readme.md", 131 | path_in_repo="README.md", 132 | repo_id=path, 133 | repo_type="model", 134 | ) 135 | api.upload_file( 136 | path_or_fileobj="e1/tokenizer.json", 137 | path_in_repo="tokenizer.json", 138 | repo_id=path, 139 | repo_type="model", 140 | ) 141 | 142 | 143 | # Add code to upload files for ANKH models 144 | for path in ANKH_MODELS: 145 | print(path) 146 | # Upload license file for ANKH models 147 | api.upload_file( 148 | path_or_fileobj="LICENSE", 149 | path_in_repo="LICENSE", 150 | repo_id=path, 151 | repo_type="model", 152 | ) 153 | -------------------------------------------------------------------------------- /test_scripts/test_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import sqlite3 5 | 6 | from typing import List 7 | from transformers import AutoModel 8 | from esm2.modeling_fastesm import FastEsmModel 9 | from esm_plusplus.modeling_esm_plusplus import ESMplusplusModel 10 | from e1.modeling_e1 import E1Model 11 | 12 | 13 | CANONICAL_AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY" 14 | 15 | 16 | def generate_random_sequence(length: int) -> str: 17 | return 'M' + "".join(random.choices(CANONICAL_AMINO_ACIDS, k=length)) 18 | 19 | 20 | if __name__ == "__main__": 21 | # py -m test_scripts.test_embedding 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | def test_embedding(model, sequences: List[str], name: str): 25 | embeddings = model.embed_dataset( 26 | sequences=sequences, 27 | tokenizer=model.tokenizer if hasattr(model, 'tokenizer') else None, 28 | sql=False, # return dictionary of sequences and embeddings 29 | save=False, 30 | ) 31 | 32 | count = 0 33 | for k, v in embeddings.items(): 34 | print(k) 35 | print(v.dtype, v.shape) 36 | count += 1 37 | if count > 5: 38 | break 39 | 40 | embeddings = model.embed_dataset( 41 | sequences=sequences, 42 | tokenizer=model.tokenizer if hasattr(model, 'tokenizer') else None, 43 | full_embeddings=True, 44 | sql=False, # return dictionary of sequences and embeddings 45 | save=False, 46 | ) 47 | 48 | count = 0 49 | for k, v in embeddings.items(): 50 | print(k) 51 | print(v.dtype, v.shape) 52 | count += 1 53 | if count > 5: 54 | break 55 | 56 | db_path = f'embeddings_{name}.db' 57 | _ = model.embed_dataset( 58 | sequences=sequences, 59 | tokenizer=model.tokenizer if hasattr(model, 'tokenizer') else None, 60 | pooling_types=['cls', 'mean'], 61 | sql=True, 62 | sql_db_path=db_path, 63 | save=True, 64 | ) 65 | 66 | # Verify database contents 67 | conn = sqlite3.connect(db_path) 68 | c = conn.cursor() 69 | 70 | # Check number of sequences 71 | c.execute('SELECT COUNT(*) FROM embeddings') 72 | db_count = c.fetchone()[0] 73 | print(f"\nNumber of sequences in database: {db_count}") 74 | 75 | count = 0 76 | for seq in sequences: 77 | c.execute('SELECT embedding FROM embeddings WHERE sequence = ?', (seq,)) 78 | result = c.fetchone() 79 | assert result is not None, f"Sequence {seq} not found in database" 80 | if count < 10: 81 | embedding = np.frombuffer(result[0], dtype=np.float32) 82 | print(seq) 83 | print(f"Embedding shape: {embedding.shape}") 84 | count += 1 85 | 86 | # Make sure to close the connection before attempting to delete the file 87 | c.close() 88 | conn.close() 89 | 90 | print("Testing E1 model...") 91 | sequences = [generate_random_sequence(random.randint(4, 16)) for _ in range(100)] 92 | model = E1Model.from_pretrained("Synthyra/Profluent-E1-150M", dtype=torch.bfloat16).to(device) 93 | print(model) 94 | test_embedding(model, sequences, 'e1') 95 | 96 | print("Testing FastESM model...") 97 | sequences = [generate_random_sequence(random.randint(4, 16)) for _ in range(100)] 98 | model = FastEsmModel.from_pretrained("Synthyra/ESM2-8M", dtype=torch.float16).to(device) 99 | print(model) 100 | test_embedding(model, sequences, 'fastesm') 101 | 102 | print("Testing ESM++ model...") 103 | sequences = [generate_random_sequence(random.randint(4, 16)) for _ in range(100)] 104 | model = ESMplusplusModel.from_pretrained("Synthyra/ESMplusplus_small", dtype=torch.float16).to(device) 105 | print(model) 106 | test_embedding(model, sequences, 'esmplusplus') 107 | 108 | print("Testing E1 model with AutoModel...") 109 | sequences = [generate_random_sequence(random.randint(4, 16)) for _ in range(100)] 110 | model = AutoModel.from_pretrained("Synthyra/Profluent-E1-150M", dtype=torch.bfloat16, trust_remote_code=True).to(device) 111 | print(model) 112 | test_embedding(model, sequences, 'e1_auto') 113 | 114 | print("Testing FastESM model with AutoModel...") 115 | sequences = [generate_random_sequence(random.randint(4, 16)) for _ in range(100)] 116 | model = AutoModel.from_pretrained("Synthyra/ESM2-8M", dtype=torch.float16, trust_remote_code=True).to(device) 117 | print(model) 118 | test_embedding(model, sequences, 'fastesm_auto') 119 | 120 | print("Testing ESM++ model with AutoModel...") 121 | sequences = [generate_random_sequence(random.randint(4, 16)) for _ in range(100)] 122 | model = AutoModel.from_pretrained("Synthyra/ESMplusplus_small", dtype=torch.float16, trust_remote_code=True).to(device) 123 | print(model) 124 | test_embedding(model, sequences, 'esmplusplus_auto') -------------------------------------------------------------------------------- /test_scripts/test_compliance_esmc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import random 4 | import argparse 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import contextlib 8 | from huggingface_hub import login 9 | from tqdm.auto import tqdm 10 | 11 | from esm.pretrained import ESMC_600M_202412, ESMC_300M_202412 12 | from esm.sdk.api import ESMProtein, LogitsConfig 13 | 14 | from esm_plusplus.modeling_esm_plusplus import ESMplusplus_300M 15 | 16 | 17 | """ 18 | Testing if ESM++ outputs are compliant with ESMC outputs 19 | """ 20 | 21 | def set_seed(seed: int): 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | torch.backends.cudnn.deterministic = True 26 | random.seed(seed) 27 | np.random.seed(seed) 28 | 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--model_path', type=str, default='Synthyra/ESMplusplus_small') 32 | parser.add_argument('--token', type=str, default=None) 33 | args = parser.parse_args() 34 | 35 | 36 | if __name__ == "__main__": 37 | # py -m test_scripts.test_compliance_esmc 38 | if args.token: 39 | login(args.token) 40 | 41 | model_path = args.model_path 42 | canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY" 43 | length = 128 44 | seq_count = 10 45 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 46 | 47 | set_seed(42) 48 | 49 | def generate_random_sequence(length: int) -> str: 50 | return 'M' + "".join(random.choices(canonical_amino_acids, k=length-3)) 51 | 52 | 53 | sequences = [generate_random_sequence(length) for _ in range(seq_count)] 54 | 55 | 56 | if 'small' in model_path: 57 | esmc = ESMC_300M_202412(device=device, use_flash_attn=False) 58 | else: 59 | esmc = ESMC_600M_202412(device=device, use_flash_attn=False) 60 | 61 | 62 | # Get esmc model outputs 63 | base_outputs = [] 64 | base_logits = [] 65 | with ( 66 | torch.no_grad(), 67 | torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) 68 | if device.type == "cuda" else contextlib.nullcontext() 69 | ): 70 | for seq in tqdm(sequences): 71 | protein = ESMProtein(sequence=seq) 72 | protein_tensor = esmc.encode(protein) 73 | logits_result = esmc.logits( 74 | protein_tensor, LogitsConfig(sequence=True, return_embeddings=True) 75 | ) 76 | embeddings = logits_result.embeddings 77 | logits = logits_result.logits.sequence 78 | manual_logits = esmc.sequence_head(embeddings) 79 | 80 | assert (manual_logits == logits).all(), "Logits are not equal" 81 | base_outputs.append(embeddings.float().cpu()) 82 | base_logits.append(logits.float().cpu()) 83 | esmc.cpu() 84 | del esmc 85 | torch.cuda.empty_cache() 86 | 87 | 88 | # Get plusplus outputs 89 | total_mse_embeddings = 0 90 | total_mse_logits = 0 91 | model = ESMplusplus_300M(device=device) 92 | tokenizer = model.tokenizer 93 | with ( 94 | torch.no_grad(), 95 | torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) 96 | if device.type == "cuda" else contextlib.nullcontext() 97 | ): 98 | for i, seq in tqdm(enumerate(sequences), total=len(sequences)): 99 | input = tokenizer(seq, return_tensors="pt").to(device) 100 | outputs = model(**input) 101 | embeddings = outputs.last_hidden_state.float().cpu() 102 | logits = outputs.logits.float().cpu() 103 | 104 | # Compare embeddings 105 | mse_embeddings = F.mse_loss(base_outputs[i], embeddings).item() 106 | # Compare logits 107 | mse_logits = F.mse_loss(base_logits[i], logits).item() 108 | 109 | if mse_embeddings > 0.01 or mse_logits > 0.01: 110 | print(f"Sequence {i}:") 111 | print(f" Embeddings MSE: {mse_embeddings:.8f}") 112 | print(f" Logits MSE: {mse_logits:.8f}") 113 | 114 | # Find positions where tensors differ 115 | diff_embeddings = torch.abs(base_outputs[i] - embeddings) 116 | diff_logits = torch.abs(base_logits[i] - logits) 117 | 118 | # plot diffs 119 | plt.figure(figsize=(12, 5)) 120 | plt.subplot(1, 2, 1) 121 | plt.imshow(diff_embeddings[0].detach().numpy()) 122 | plt.title("Embeddings Difference") 123 | 124 | plt.subplot(1, 2, 2) 125 | plt.imshow(diff_logits[0].detach().numpy()) 126 | plt.title("Logits Difference") 127 | plt.show() 128 | 129 | total_mse_embeddings += mse_embeddings 130 | total_mse_logits += mse_logits 131 | model.cpu() 132 | del model 133 | torch.cuda.empty_cache() 134 | 135 | print(f"Average Embeddings MSE: {total_mse_embeddings / seq_count}") 136 | print(f"Average Logits MSE: {total_mse_logits / seq_count}") 137 | -------------------------------------------------------------------------------- /test_scripts/test_compliance_esm2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import random 4 | import argparse 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from datasets import load_dataset 8 | from huggingface_hub import login 9 | from tqdm.auto import tqdm 10 | from transformers import EsmForMaskedLM 11 | 12 | from esm2.modeling_fastesm import FastEsmForMaskedLM 13 | 14 | 15 | """ 16 | Testing if FastESM outputs are compliant with ESM2 outputs 17 | """ 18 | 19 | def set_seed(seed: int): 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | torch.backends.cudnn.deterministic = True 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--token', type=str, default=None) 30 | args = parser.parse_args() 31 | 32 | 33 | if __name__ == "__main__": 34 | # py -m test_scripts.test_compliance_esm2 35 | if args.token: 36 | login(args.token) 37 | 38 | canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY" 39 | length = 1024 40 | seq_count = 10 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | 43 | set_seed(42) 44 | 45 | sequences = load_dataset('lhallee/ccds_human_512', split='train')['AA'][:seq_count] 46 | sequences = [seq.replace('L', '') for seq in sequences] 47 | 48 | esm2 = EsmForMaskedLM.from_pretrained('facebook/esm2_t33_650M_UR50D').to(device).eval() 49 | #fastesm = FastEsmForMaskedLM.from_pretrained('facebook/esm2_t33_650M_UR50D').to(device) 50 | #fastesm.lm_head.load_state_dict(esm2.lm_head.state_dict()) 51 | fastesm = FastEsmForMaskedLM.from_pretrained('Synthyra/ESM2-650M').to(device).eval() 52 | tokenizer = fastesm.tokenizer 53 | 54 | # Get esmc model outputs 55 | base_outputs = [] 56 | base_logits = [] 57 | with torch.no_grad(): 58 | for seq in tqdm(sequences): 59 | input = tokenizer(seq, return_tensors="pt").to(device) 60 | outputs = esm2(**input, output_hidden_states=True) 61 | embeddings = outputs.hidden_states[-1].float().cpu() 62 | logits = outputs.logits.float().cpu() 63 | base_outputs.append(embeddings) 64 | base_logits.append(logits) 65 | esm2.cpu() 66 | del esm2 67 | torch.cuda.empty_cache() 68 | 69 | # Get plusplus outputs 70 | total_mse_embeddings = 0 71 | total_mse_logits = 0 72 | total_max_diff_embeddings = 0 73 | total_max_diff_logits = 0 74 | total_accuracy = 0 75 | 76 | with torch.no_grad(): 77 | for i, seq in tqdm(enumerate(sequences), total=len(sequences)): 78 | input = tokenizer(seq, return_tensors="pt").to(device) 79 | outputs = fastesm(**input, output_attentions=False) 80 | embeddings = outputs.last_hidden_state.float().cpu() 81 | logits = outputs.logits.float().cpu() 82 | 83 | # Compare embeddings 84 | mse_embeddings = F.mse_loss(base_outputs[i], embeddings).item() 85 | # Compare logits 86 | mse_logits = F.mse_loss(base_logits[i], logits).item() 87 | max_diff_embeddings = torch.max(torch.abs(base_outputs[i] - embeddings)).item() 88 | max_diff_logits = torch.max(torch.abs(base_logits[i] - logits)).item() 89 | 90 | # Calculate accuracy of argmaxed logits 91 | base_argmax = torch.argmax(base_logits[i], dim=-1) 92 | fastesm_argmax = torch.argmax(logits, dim=-1) 93 | accuracy = (base_argmax == fastesm_argmax).float().mean().item() 94 | if mse_embeddings > 0.01 or mse_logits > 0.1: 95 | print(f"Sequence {i}:") 96 | print(f" Embeddings MSE: {mse_embeddings:.8f}") 97 | print(f" Logits MSE: {mse_logits:.8f}") 98 | print(f" Argmax Accuracy: {accuracy:.6f}") 99 | 100 | # Find positions where tensors differ 101 | diff_embeddings = torch.abs(base_outputs[i] - embeddings) 102 | diff_logits = torch.abs(base_logits[i] - logits) 103 | 104 | # plot diffs 105 | plt.figure(figsize=(12, 5)) 106 | plt.subplot(1, 2, 1) 107 | plt.imshow(diff_embeddings[0].detach().numpy()) 108 | plt.title("Embeddings Difference") 109 | 110 | plt.subplot(1, 2, 2) 111 | plt.imshow(diff_logits[0].detach().numpy()) 112 | plt.title("Logits Difference") 113 | plt.show() 114 | 115 | total_mse_embeddings += mse_embeddings 116 | total_mse_logits += mse_logits 117 | total_max_diff_embeddings += max_diff_embeddings 118 | total_max_diff_logits += max_diff_logits 119 | total_accuracy += accuracy 120 | fastesm.cpu() 121 | del fastesm 122 | torch.cuda.empty_cache() 123 | 124 | print(f"Average Embeddings MSE: {total_mse_embeddings / seq_count}") 125 | print(f"Average Logits MSE: {total_mse_logits / seq_count}") 126 | print(f"Average Max Diff Embeddings: {total_max_diff_embeddings / seq_count}") 127 | print(f"Average Max Diff Logits: {total_max_diff_logits / seq_count}") 128 | print(f"Average Argmax Accuracy: {total_accuracy / seq_count:.6f}") -------------------------------------------------------------------------------- /readmes/fastesm2_readme.md: -------------------------------------------------------------------------------- 1 | --- 2 | library_name: transformers 3 | tags: [] 4 | --- 5 | 6 | # NOTE 7 | The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git) 8 | 9 | # FastESM 10 | FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation. 11 | 12 | Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance. 13 | 14 | Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned. 15 | Various other optimizations also make the base implementation slightly different than the one in transformers. 16 | 17 | ## Use with 🤗 transformers 18 | 19 | ### Supported models 20 | ```python 21 | model_dict = { 22 | # Synthyra/ESM2-8M 23 | 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D', 24 | # Synthyra/ESM2-35M 25 | 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D', 26 | # Synthyra/ESM2-150M 27 | 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D', 28 | # Synthyra/ESM2-650M 29 | 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D', 30 | # Synthyra/ESM2-3B 31 | 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D', 32 | } 33 | ``` 34 | 35 | ### For working with embeddings 36 | ```python 37 | import torch 38 | from transformers import AutoModel, AutoTokenizer 39 | 40 | model_path = 'Synthyra/ESM2-8M' 41 | model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval() 42 | tokenizer = model.tokenizer 43 | 44 | sequences = ['MPRTEIN', 'MSEQWENCE'] 45 | tokenized = tokenizer(sequences, padding=True, return_tensors='pt') 46 | with torch.no_grad(): 47 | embeddings = model(**tokenized).last_hidden_state 48 | 49 | print(embeddings.shape) # (2, 11, 1280) 50 | ``` 51 | 52 | ### For working with sequence logits 53 | ```python 54 | import torch 55 | from transformers import AutoModelForMaskedLM, AutoTokenizer 56 | 57 | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval() 58 | with torch.no_grad(): 59 | logits = model(**tokenized).logits 60 | 61 | print(logits.shape) # (2, 11, 33) 62 | ``` 63 | 64 | ### For working with attention maps 65 | ```python 66 | import torch 67 | from transformers import AutoModel, AutoTokenizer 68 | 69 | model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval() 70 | with torch.no_grad(): 71 | attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len) 72 | 73 | print(attentions[-1].shape) # (2, 20, 11, 11) 74 | ``` 75 | 76 | ### Contact prediction 77 | Because we can output attentions using the naive attention implementation, the contact prediction is also supported 78 | ```python 79 | with torch.no_grad(): 80 | contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len) 81 | ``` 82 | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/9707OSXZ3Wdgn0Ni-55T-.png) 83 | 84 | ## Embed entire datasets with no new code 85 | To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take. 86 | 87 | Example: 88 | ```python 89 | embedding_dict = model.embed_dataset( 90 | sequences=[ 91 | 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences 92 | ], 93 | tokenizer=model.tokenizer, 94 | batch_size=2, # adjust for your GPU memory 95 | max_len=512, # adjust for your needs 96 | full_embeddings=False, # if True, no pooling is performed 97 | embed_dtype=torch.float32, # cast to what dtype you want 98 | pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together 99 | num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets 100 | sql=False, # if True, embeddings will be stored in SQLite database 101 | sql_db_path='embeddings.db', 102 | save=True, # if True, embeddings will be saved as a .pth file 103 | save_path='embeddings.pth', 104 | ) 105 | # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql 106 | ``` 107 | 108 | ``` 109 | model.embed_dataset() 110 | Args: 111 | sequences: List of protein sequences 112 | batch_size: Batch size for processing 113 | max_len: Maximum sequence length 114 | full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False) 115 | pooling_type: Type of pooling ('mean' or 'cls') 116 | num_workers: Number of workers for data loading, 0 for the main process 117 | sql: Whether to store embeddings in SQLite database - will be stored in float32 118 | sql_db_path: Path to SQLite database 119 | 120 | Returns: 121 | Dictionary mapping sequences to embeddings, or None if sql=True 122 | 123 | Note: 124 | - If sql=True, embeddings can only be stored in float32 125 | - sql is ideal if you need to stream a very large dataset for training in real-time 126 | - save=True is ideal if you can store the entire embedding dictionary in RAM 127 | - sql will be used if it is True and save is True or False 128 | - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences 129 | - Sequences will be truncated to max_len and sorted by length in descending order for faster processing 130 | ``` 131 | 132 | 133 | ### Citation 134 | If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper). 135 | ``` 136 | @misc {FastPLMs, 137 | author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.}, 138 | title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.}, 139 | year = {2024}, 140 | url = { https://huggingface.co/Synthyra/ESMplusplus_small }, 141 | DOI = { 10.57967/hf/3726 }, 142 | publisher = { Hugging Face } 143 | } 144 | ``` -------------------------------------------------------------------------------- /test_scripts/test_throughput_esmfast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import random 4 | import argparse 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | from huggingface_hub import login 8 | from transformers import EsmForMaskedLM, AutoModelForMaskedLM, EsmTokenizer 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--model_path', type=str, default='Synthyra/FastESM2_650') 13 | parser.add_argument('--token', type=str, default=None) 14 | args = parser.parse_args() 15 | 16 | if args.token: 17 | login(args.token) 18 | 19 | model_path = args.model_path 20 | canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY" 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | 24 | def generate_random_sequence(length: int) -> str: 25 | return 'M' + "".join(random.choices(canonical_amino_acids, k=length-1)) 26 | 27 | 28 | def time_model(model, inputs, warmup=10): 29 | model.eval() 30 | with torch.no_grad(): 31 | # Warmup 32 | for _ in range(warmup): 33 | _ = model(**inputs[0]) 34 | 35 | start_time = time.time() 36 | for input_batch in inputs: 37 | _ = model(**input_batch) 38 | return time.time() - start_time 39 | 40 | 41 | def get_gpu_memory(): 42 | torch.cuda.synchronize() 43 | return torch.cuda.max_memory_allocated() / 1024**2 # Convert to MB 44 | 45 | 46 | tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D') 47 | 48 | # Test different sequence lengths and batch sizes 49 | #lengths = [128, 256, 512, 1024, 2048] 50 | #batch_sizes = [1, 4, 16, 32] 51 | lengths = [8, 16] 52 | batch_sizes = [1, 2] 53 | 54 | 55 | results = [] 56 | 57 | # Generate all test sequences first 58 | all_test_inputs = {} 59 | for length in lengths: 60 | for batch_size in batch_sizes: 61 | print(f"\nGenerating sequences for length={length}, batch_size={batch_size}") 62 | all_sequences = [] 63 | for _ in range(100): 64 | batch_sequences = [generate_random_sequence(length) for _ in range(batch_size)] 65 | all_sequences.append(batch_sequences) 66 | inputs = [tokenizer(sequences, padding=True, return_tensors="pt").to(device) for sequences in all_sequences] 67 | all_test_inputs[(length, batch_size)] = inputs 68 | 69 | # Test ESM model in fp32 70 | print("\nTesting ESM model in FP32...") 71 | model = EsmForMaskedLM.from_pretrained('facebook/esm2_t33_650M_UR50D').to(device) 72 | for length in lengths: 73 | for batch_size in batch_sizes: 74 | print(f"\nTesting length={length}, batch_size={batch_size}") 75 | inputs = all_test_inputs[(length, batch_size)] 76 | 77 | torch.cuda.reset_peak_memory_stats() 78 | esm_time = time_model(model, inputs) 79 | esm_memory = get_gpu_memory() 80 | 81 | results.append({ 82 | 'Length': length, 83 | 'Batch Size': batch_size, 84 | 'ESM Time': esm_time, 85 | 'ESM Memory': esm_memory 86 | }) 87 | print(f"ESM FP32 time: {esm_time:.2f}s, memory: {esm_memory:.0f}MB") 88 | torch.cuda.empty_cache() 89 | 90 | model.cpu() 91 | del model 92 | torch.cuda.empty_cache() 93 | 94 | # Test FastESM in fp16 95 | print("\nTesting FastESM model in FP16...") 96 | model = AutoModelForMaskedLM.from_pretrained(model_path, trust_remote_code=True, dtype=torch.float16).to(device) 97 | for i, (length, batch_size) in enumerate([(l,b) for l in lengths for b in batch_sizes]): 98 | print(f"\nTesting length={length}, batch_size={batch_size}") 99 | inputs = all_test_inputs[(length, batch_size)] 100 | 101 | torch.cuda.reset_peak_memory_stats() 102 | fast_time = time_model(model, inputs) 103 | fast_memory = get_gpu_memory() 104 | 105 | results[i].update({ 106 | 'FastESM Time': fast_time, 107 | 'FastESM Memory': fast_memory, 108 | 'Speedup': results[i]['ESM Time']/fast_time 109 | }) 110 | print(f"FastESM FP16 time: {fast_time:.2f}s, memory: {fast_memory:.0f}MB") 111 | print(f"Speedup: {results[i]['Speedup']:.2f}x") 112 | torch.cuda.empty_cache() 113 | 114 | model.cpu() 115 | del model 116 | torch.cuda.empty_cache() 117 | 118 | # Create plots 119 | plt.figure(figsize=(15, 10)) 120 | 121 | # Speedup heatmap 122 | plt.subplot(221) 123 | speedup_data = [[r['Speedup'] for r in results if r['Length']==l] for l in lengths] 124 | sns.heatmap(speedup_data, 125 | xticklabels=batch_sizes, 126 | yticklabels=lengths, 127 | annot=True, 128 | fmt='.2f', 129 | cmap='viridis') 130 | plt.title('Speedup (ESM/FastESM)') 131 | plt.xlabel('Batch Size') 132 | plt.ylabel('Sequence Length') 133 | 134 | # Absolute times line plot 135 | plt.subplot(222) 136 | for length in lengths: 137 | length_results = [r for r in results if r['Length']==length] 138 | plt.plot([r['Batch Size'] for r in length_results], 139 | [r['FastESM Time'] for r in length_results], 140 | label=f'Length {length}', 141 | marker='o') 142 | 143 | plt.xlabel('Batch Size') 144 | plt.ylabel('Time (s)') 145 | plt.title('FastESM Processing Time') 146 | plt.legend() 147 | plt.xscale('log') 148 | plt.yscale('log') 149 | 150 | # ESM Memory heatmap 151 | plt.subplot(223) 152 | memory_data = [[r['ESM Memory'] for r in results if r['Length']==l] for l in lengths] 153 | sns.heatmap(memory_data, 154 | xticklabels=batch_sizes, 155 | yticklabels=lengths, 156 | annot=True, 157 | fmt='.0f', 158 | cmap='viridis') 159 | plt.title('ESM FP32 Memory Usage (MB)') 160 | plt.xlabel('Batch Size') 161 | plt.ylabel('Sequence Length') 162 | 163 | # FastESM Memory heatmap 164 | plt.subplot(224) 165 | memory_data = [[r['FastESM Memory'] for r in results if r['Length']==l] for l in lengths] 166 | sns.heatmap(memory_data, 167 | xticklabels=batch_sizes, 168 | yticklabels=lengths, 169 | annot=True, 170 | fmt='.0f', 171 | cmap='viridis') 172 | plt.title('FastESM FP16 Memory Usage (MB)') 173 | plt.xlabel('Batch Size') 174 | plt.ylabel('Sequence Length') 175 | 176 | plt.tight_layout() 177 | plt.savefig('throughput_results.png') 178 | plt.close() 179 | 180 | print("\nPlot saved as throughput_results.png") 181 | -------------------------------------------------------------------------------- /readmes/e1_readme.md: -------------------------------------------------------------------------------- 1 | --- 2 | library_name: transformers 3 | tags: [] 4 | --- 5 | 6 | # NOTE 7 | The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git) 8 | 9 | # Profluent-E1 10 | [Synthyra's version of Profluent-E1](https://github.com/Synthyra/Profluent-E1-300M) is a faithful implementation of Profluent's [E1](https://www.profluent.bio/showcase/e1) models ([license](https://github.com/Profluent-AI/E1/tree/main?tab=License-1-ov-file)) that integrates Huggingface AutoModel compatability and nice embedding functionality. 11 | 12 | 13 | ## Use with 🤗 transformers 14 | ### Supported models 15 | ```python 16 | model_dict = { 17 | # Synthyra/Profluent-E1-150M 18 | 'Profluent-E1-150M': 'Profluent-Bio/E1-150m', 19 | # Synthyra/Profluent-E1-150M 20 | 'Profluent-E1-300M': 'Profluent-Bio/E1-300m', 21 | # Synthyra/Profluent-E1-150M 22 | 'Profluent-E1-600M': 'Profluent-Bio/E1-600m', 23 | } 24 | ``` 25 | 26 | ```python 27 | import torch 28 | from transformers import AutoModelForMaskedLM 29 | 30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 31 | model = AutoModelForMaskedLM.from_pretrained('Synthyra/Profluent-E1-150M', trust_remote_code=True, dtype=torch.bfloat16).eval().to(device) 32 | 33 | sequences = ['MPRTEIN', 'MSEQWENCE'] 34 | batch = model.prep_tokens.get_batch_kwargs(sequences, device=device) 35 | 36 | output = model(**batch) # get all hidden states with output_hidden_states=True 37 | print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 34) 38 | print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 768) 39 | print(output.loss) # language modeling loss if you passed labels 40 | #print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple) 41 | #print(outout.attentions) # all attention matrices if you passed output_attentions=True (in tuple) 42 | ``` 43 | 44 | Our E1 implementation also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization. 45 | 46 | ```python 47 | from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification 48 | 49 | model = AutoModelForSequenceClassification.from_pretrained('Synthyra/Profluent-E1-150M', num_labels=2, trust_remote_code=True) 50 | logits = model(**batch, labels=labels).logits 51 | print(logits.shape) # (batch_size, num_labels), (2, 2) 52 | ``` 53 | 54 | E1 weights were trained in bf16 and are in bf16 by default. You can load them in the precision of your choosing by leveraging the dtype parameter: 55 | ```python 56 | import torch 57 | model = AutoModelForMaskedLM.from_pretrained('Synthyra/Profluent-E1-150M', trust_remote_code=True, dtype=torch.float) # fp32 58 | ``` 59 | 60 | ## Embed entire datasets with no new code 61 | To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take. 62 | 63 | Example: 64 | ```python 65 | embedding_dict = model.embed_dataset( 66 | sequences=[ 67 | 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences 68 | ], 69 | batch_size=2, # adjust for your GPU memory 70 | max_len=512, # adjust for your needs 71 | full_embeddings=False, # if True, no pooling is performed 72 | embed_dtype=torch.float32, # cast to what dtype you want 73 | pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together 74 | sql=False, # if True, embeddings will be stored in SQLite database 75 | sql_db_path='embeddings.db', 76 | save=True, # if True, embeddings will be saved as a .pth file 77 | save_path='embeddings.pth', 78 | ) 79 | # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql 80 | ``` 81 | 82 | ``` 83 | model.embed_dataset() 84 | Args: 85 | sequences: List of protein sequences 86 | batch_size: Batch size for processing 87 | max_len: Maximum sequence length 88 | full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False) 89 | pooling_type: Type of pooling ('mean' or 'cls') 90 | sql: Whether to store embeddings in SQLite database - will be stored in float32 91 | sql_db_path: Path to SQLite database 92 | 93 | Returns: 94 | Dictionary mapping sequences to embeddings, or None if sql=True 95 | 96 | Note: 97 | - If sql=True, embeddings can only be stored in float32 98 | - sql is ideal if you need to stream a very large dataset for training in real-time 99 | - save=True is ideal if you can store the entire embedding dictionary in RAM 100 | - sql will be used if it is True and save is True or False 101 | - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences 102 | - Sequences will be truncated to max_len and sorted by length in descending order for faster processing 103 | ``` 104 | 105 | ## Fine-tuning with 🤗 peft 106 | ```python 107 | model = AutoModelForSequenceClassification.from_pretrained('Synthyra/Profluent-E1-150M', num_labels=2, trust_remote_code=True) 108 | # these modules handle E1 attention layers 109 | target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] 110 | 111 | lora_config = LoraConfig( 112 | r=8, # choose lora parameters to your liking 113 | lora_alpha=16, 114 | lora_dropout=0.01, 115 | bias="none", 116 | target_modules=target_modules, 117 | ) 118 | 119 | # Apply LoRA to the model 120 | model = get_peft_model(model, lora_config) 121 | 122 | # Unfreeze the classifier head 123 | for param in model.classifier.parameters(): 124 | param.requires_grad = True 125 | ``` 126 | 127 | For a more thourough example of fine-tuning, check out our example script [here](https://github.com/Synthyra/FastPLMs/blob/main/fine_tuning_example.py). 128 | 129 | 130 | ### Citation 131 | If you use any of this implementation or work please cite the following DOI and Profluent's paper. 132 | 133 | ``` 134 | @misc {FastPLMs, 135 | author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.}, 136 | title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.}, 137 | year = {2024}, 138 | url = { https://huggingface.co/Synthyra/ESMplusplus_small }, 139 | DOI = { 10.57967/hf/3726 }, 140 | publisher = { Hugging Face } 141 | } 142 | ``` 143 | 144 | ``` 145 | @article{Jain_Beazer_Ruffolo_Bhatnagar_Madani_2025, 146 | title={E1: Retrieval-Augmented Protein Encoder Models}, 147 | url={https://www.biorxiv.org/content/early/2025/11/13/2025.11.12.688125}, 148 | DOI={10.1101/2025.11.12.688125}, 149 | journal={bioRxiv}, 150 | publisher={Cold Spring Harbor Laboratory}, 151 | author={Jain, Sarthak and Beazer, Joel and Ruffolo, Jeffrey A and Bhatnagar, Aadyot and Madani, Ali}, 152 | year={2025} 153 | } 154 | ``` -------------------------------------------------------------------------------- /readmes/fastesm_650_readme.md: -------------------------------------------------------------------------------- 1 | --- 2 | library_name: transformers 3 | tags: [] 4 | --- 5 | 6 | # NOTE 7 | The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git) 8 | 9 | # FastESM 10 | FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation. 11 | 12 | Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance. 13 | 14 | Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned. 15 | Various other optimizations also make the base implementation slightly different than the one in transformers. 16 | 17 | # FastESM2-650 18 | 19 | ## A faster half-precision version of ESM2-650 with FlashAttention2 and longer context 20 | To enhance the weights with longer context and better fp16 support, we trained ESM2-650 50000 additional steps with a traditional MLM objective (20% masking) in fp16 mixed precision on [OMGprot50](https://huggingface.co/datasets/tattabio/OMG_prot50) up to sequence length of **2048**. 21 | 22 | ## Use with 🤗 transformers 23 | 24 | ### For working with embeddings 25 | ```python 26 | import torch 27 | from transformers import AutoModel, AutoTokenizer 28 | 29 | model_path = 'Synthyra/FastESM2_650' 30 | model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval() 31 | tokenizer = model.tokenizer 32 | 33 | sequences = ['MPRTEIN', 'MSEQWENCE'] 34 | tokenized = tokenizer(sequences, padding=True, return_tensors='pt') 35 | with torch.no_grad(): 36 | embeddings = model(**tokenized).last_hidden_state 37 | 38 | print(embeddings.shape) # (2, 11, 1280) 39 | ``` 40 | 41 | ### For working with sequence logits 42 | ```python 43 | import torch 44 | from transformers import AutoModelForMaskedLM, AutoTokenizer 45 | 46 | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval() 47 | with torch.no_grad(): 48 | logits = model(**tokenized).logits 49 | 50 | print(logits.shape) # (2, 11, 33) 51 | ``` 52 | 53 | ### For working with attention maps 54 | ```python 55 | import torch 56 | from transformers import AutoModel, AutoTokenizer 57 | 58 | model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval() 59 | with torch.no_grad(): 60 | attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len) 61 | 62 | print(attentions[-1].shape) # (2, 20, 11, 11) 63 | ``` 64 | 65 | ## Embed entire datasets with no new code 66 | To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take. 67 | 68 | Example: 69 | ```python 70 | embedding_dict = model.embed_dataset( 71 | sequences=[ 72 | 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences 73 | ], 74 | tokenizer=model.tokenizer, 75 | batch_size=2, # adjust for your GPU memory 76 | max_len=512, # adjust for your needs 77 | full_embeddings=False, # if True, no pooling is performed 78 | embed_dtype=torch.float32, # cast to what dtype you want 79 | pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together 80 | num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets 81 | sql=False, # if True, embeddings will be stored in SQLite database 82 | sql_db_path='embeddings.db', 83 | save=True, # if True, embeddings will be saved as a .pth file 84 | save_path='embeddings.pth', 85 | ) 86 | # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql 87 | ``` 88 | 89 | ``` 90 | model.embed_dataset() 91 | Args: 92 | sequences: List of protein sequences 93 | batch_size: Batch size for processing 94 | max_len: Maximum sequence length 95 | full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False) 96 | pooling_type: Type of pooling ('mean' or 'cls') 97 | num_workers: Number of workers for data loading, 0 for the main process 98 | sql: Whether to store embeddings in SQLite database - will be stored in float32 99 | sql_db_path: Path to SQLite database 100 | 101 | Returns: 102 | Dictionary mapping sequences to embeddings, or None if sql=True 103 | 104 | Note: 105 | - If sql=True, embeddings can only be stored in float32 106 | - sql is ideal if you need to stream a very large dataset for training in real-time 107 | - save=True is ideal if you can store the entire embedding dictionary in RAM 108 | - sql will be used if it is True and save is True or False 109 | - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences 110 | - Sequences will be truncated to max_len and sorted by length in descending order for faster processing 111 | ``` 112 | 113 | ## Model probes 114 | We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. FastESM performs very well. 115 | 116 | The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2. 117 | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/d1Xi6k1Q4-9By_MtzTvdV.png) 118 | 119 | ## Comparison of half precisions 120 | Presumabely because we trained in mixed-precision fp16, fp16 has closer outputs to the fp32 weights then bf16. Therefore, we recommend loading in fp16. 121 | 122 | When summing the MSE of 1000 sequences vs. the fp32 weights: 123 | 124 | Average MSE for FP16: 0.00000140 125 | 126 | Average MSE for BF16: 0.00004125 127 | 128 | ### Inference speed 129 | We look at various ESM models and their throughput on an H100. FastESM is over twice as fast as ESM2-650 with longer sequences. Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). 130 | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/PvaBGfuJXEW2v_WLkt63y.png) 131 | 132 | ### Citation 133 | If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper). 134 | ``` 135 | @misc {FastPLMs, 136 | author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.}, 137 | title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.}, 138 | year = {2024}, 139 | url = { https://huggingface.co/Synthyra/ESMplusplus_small }, 140 | DOI = { 10.57967/hf/3726 }, 141 | publisher = { Hugging Face } 142 | } 143 | ``` 144 | -------------------------------------------------------------------------------- /pooler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import networkx as nx 4 | from typing import Optional, List 5 | 6 | 7 | class Pooler: 8 | def __init__(self, pooling_types: List[str]): 9 | self.pooling_types = pooling_types 10 | self.pooling_options = { 11 | 'mean': self.mean_pooling, 12 | 'max': self.max_pooling, 13 | 'norm': self.norm_pooling, 14 | 'median': self.median_pooling, 15 | 'std': self.std_pooling, 16 | 'var': self.var_pooling, 17 | 'cls': self.cls_pooling, 18 | 'parti': self._pool_parti, 19 | } 20 | 21 | def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: 22 | maxed_attentions = torch.max(attentions, dim=1)[0] 23 | return maxed_attentions 24 | 25 | def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"): 26 | # Run PageRank on the attention matrix converted to a graph. 27 | # Raises exceptions if the graph doesn't match the token sequence or has no edges. 28 | # Returns the PageRank scores for each token node. 29 | G = self._convert_to_graph(attention_matrix) 30 | if G.number_of_nodes() != attention_matrix.shape[0]: 31 | raise Exception( 32 | f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") 33 | if G.number_of_edges() == 0: 34 | raise Exception(f"You don't seem to have any attention edges left in the graph.") 35 | 36 | return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) 37 | 38 | def _convert_to_graph(self, matrix): 39 | # Convert a matrix (e.g., attention scores) to a directed graph using networkx. 40 | # Each element in the matrix represents a directed edge with a weight. 41 | G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) 42 | return G 43 | 44 | def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None): 45 | # Remove keys where attention_mask is 0 46 | if attention_mask is not None: 47 | for k in list(dict_importance.keys()): 48 | if attention_mask[k] == 0: 49 | del dict_importance[k] 50 | 51 | #dict_importance[0] # remove cls 52 | #dict_importance[-1] # remove eos 53 | total = sum(dict_importance.values()) 54 | return np.array([v / total for _, v in dict_importance.items()]) 55 | 56 | def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) 57 | maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() 58 | # emb is (b, L, d), maxed_attentions is (b, L, L) 59 | emb_pooled = [] 60 | for e, a, mask in zip(emb, maxed_attentions, attention_mask): 61 | dict_importance = self._page_rank(a) 62 | importance_weights = self._calculate_importance_weights(dict_importance, mask) 63 | num_tokens = int(mask.sum().item()) 64 | emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) 65 | pooled = torch.tensor(np.array(emb_pooled)) 66 | return pooled 67 | 68 | def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) 69 | if attention_mask is None: 70 | return emb.mean(dim=1) 71 | else: 72 | attention_mask = attention_mask.unsqueeze(-1) 73 | return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) 74 | 75 | def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) 76 | if attention_mask is None: 77 | return emb.max(dim=1).values 78 | else: 79 | attention_mask = attention_mask.unsqueeze(-1) 80 | return (emb * attention_mask).max(dim=1).values 81 | 82 | def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) 83 | if attention_mask is None: 84 | return emb.norm(dim=1, p=2) 85 | else: 86 | attention_mask = attention_mask.unsqueeze(-1) 87 | return (emb * attention_mask).norm(dim=1, p=2) 88 | 89 | def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) 90 | if attention_mask is None: 91 | return emb.median(dim=1).values 92 | else: 93 | attention_mask = attention_mask.unsqueeze(-1) 94 | return (emb * attention_mask).median(dim=1).values 95 | 96 | def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) 97 | if attention_mask is None: 98 | return emb.std(dim=1) 99 | else: 100 | # Compute variance correctly over non-masked positions, then take sqrt 101 | var = self.var_pooling(emb, attention_mask, **kwargs) 102 | return torch.sqrt(var) 103 | 104 | def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) 105 | if attention_mask is None: 106 | return emb.var(dim=1) 107 | else: 108 | # Correctly compute variance over only non-masked positions 109 | attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1) 110 | # Compute mean over non-masked positions 111 | mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) 112 | mean = mean.unsqueeze(1) # (b, 1, d) 113 | # Compute squared differences from mean, only over non-masked positions 114 | squared_diff = (emb - mean) ** 2 # (b, L, d) 115 | # Sum squared differences over non-masked positions and divide by count 116 | var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) 117 | return var 118 | 119 | def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) 120 | return emb[:, 0, :] 121 | 122 | def __call__( 123 | self, 124 | emb: torch.Tensor, 125 | attention_mask: Optional[torch.Tensor] = None, 126 | attentions: Optional[torch.Tensor] = None 127 | ): # [mean, max] 128 | final_emb = [] 129 | for pooling_type in self.pooling_types: 130 | final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d) 131 | return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) 132 | 133 | 134 | if __name__ == "__main__": 135 | # py -m pooler 136 | pooler = Pooler(pooling_types=['max', 'parti']) 137 | 138 | batch_size = 8 139 | seq_len = 64 140 | hidden_size = 128 141 | num_layers = 12 142 | emb = torch.randn(batch_size, seq_len, hidden_size) 143 | attentions = torch.randn(batch_size, num_layers, seq_len, seq_len) 144 | attention_mask = torch.ones(batch_size, seq_len) 145 | 146 | y = pooler(emb=emb, attention_mask=attention_mask, attentions=attentions) 147 | print(y.shape) 148 | -------------------------------------------------------------------------------- /test_scripts/test_contact_maps.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import torch 5 | import os 6 | import random 7 | import requests 8 | import tempfile 9 | import sys 10 | from Bio.PDB import PDBParser, PPBuilder 11 | 12 | from esm2.modeling_fastesm import FastEsmModel 13 | 14 | 15 | def download_random_pdb(): 16 | """ 17 | Download a random protein chain PDB file. 18 | 19 | Returns: 20 | str: Path to the downloaded PDB file. 21 | """ 22 | example_pdbs = ["1AKE"] 23 | 24 | # Select a random PDB ID 25 | pdb_id = random.choice(example_pdbs) 26 | print(f"Selected random PDB ID: {pdb_id}") 27 | 28 | # Create a temporary file to store the PDB 29 | temp_file = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) 30 | temp_file_path = temp_file.name 31 | temp_file.close() 32 | 33 | # Download the PDB file 34 | url = f"https://files.rcsb.org/download/{pdb_id}.pdb" 35 | response = requests.get(url) 36 | 37 | if response.status_code == 200: 38 | with open(temp_file_path, 'wb') as f: 39 | f.write(response.content) 40 | print(f"Downloaded PDB file to: {temp_file_path}") 41 | return temp_file_path 42 | else: 43 | raise Exception(f"Failed to download PDB file: {response.status_code}") 44 | 45 | 46 | def parse_pdb(pdb_file): 47 | """ 48 | Parse a PDB file and extract the protein sequence and CA atom coordinates. 49 | 50 | Parameters: 51 | pdb_file (str): Path to the PDB file. 52 | 53 | Returns: 54 | tuple: (sequence (str), coords (np.ndarray of shape (L, 3))) 55 | """ 56 | parser = PDBParser(QUIET=True) 57 | structure = parser.get_structure("protein", pdb_file) 58 | ppb = PPBuilder() 59 | 60 | # Assume a single protein chain; take the first polypeptide found. 61 | for pp in ppb.build_peptides(structure): 62 | sequence = str(pp.get_sequence()) 63 | coords = [] 64 | for residue in pp: 65 | # Only add the CA atom if available. 66 | if 'CA' in residue: 67 | coords.append(residue['CA'].get_coord()) 68 | if len(coords) == 0: 69 | raise ValueError("No CA atoms found in the polypeptide.") 70 | return sequence, np.array(coords) 71 | 72 | raise ValueError("No polypeptide chains were found in the PDB file.") 73 | 74 | 75 | def compute_distance_matrix(coords): 76 | """ 77 | Compute the pairwise Euclidean distance matrix from a set of coordinates. 78 | 79 | Parameters: 80 | coords (np.ndarray): Array of shape (L, 3) where L is the number of residues. 81 | 82 | Returns: 83 | np.ndarray: A matrix of shape (L, L) containing distances. 84 | """ 85 | diff = coords[:, None, :] - coords[None, :, :] 86 | dist_matrix = np.sqrt(np.sum(diff**2, axis=-1)) 87 | 88 | return dist_matrix 89 | 90 | 91 | def get_esm_contact_map(sequence): 92 | """ 93 | Use the ESM model to predict a contact map for the given protein sequence. 94 | 95 | Parameters: 96 | sequence (str): Amino acid sequence. 97 | 98 | Returns: 99 | np.ndarray: A 2D array (L x L) with contact probabilities. 100 | """ 101 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 102 | model_path = "Synthyra/ESM2-650M" 103 | model = FastEsmModel.from_pretrained(model_path).eval().to(device) 104 | tokenizer = model.tokenizer 105 | 106 | inputs = tokenizer(sequence, return_tensors="pt") 107 | inputs = {key: value.to(device) for key, value in inputs.items()} 108 | with torch.no_grad(): 109 | contact_map = model.predict_contacts(inputs["input_ids"], inputs["attention_mask"]) 110 | print(contact_map.shape) 111 | contact_map = contact_map.squeeze().cpu().numpy() 112 | print(contact_map.shape) 113 | return contact_map 114 | 115 | 116 | def plot_maps(true_contact_map, predicted_contact_map, pdb_file): 117 | """ 118 | Generate two subplots: 119 | 1. ESM predicted contact map. 120 | 2. True contact map from the PDB (binary, thresholded). 121 | 122 | Parameters: 123 | true_contact_map (np.ndarray): Binary (0/1) contact map from PDB. 124 | predicted_contact_map (np.ndarray): Predicted contact probabilities from ESM. 125 | pdb_file (str): Path to the PDB file, used to generate output filename. 126 | """ 127 | fig, axs = plt.subplots(1, 2, figsize=(12, 6)) 128 | 129 | # Plot the ESM-predicted contact map. 130 | im0 = axs[0].imshow(predicted_contact_map, cmap='RdYlBu_r', aspect='equal') 131 | axs[0].set_title("Predicted contact probabilities") 132 | axs[0].set_xlabel("Residue index") 133 | axs[0].set_ylabel("Residue index") 134 | fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04) 135 | 136 | # Plot the true contact map (binary contacts). 137 | im1 = axs[1].imshow(true_contact_map, cmap='RdYlBu_r', aspect='equal') 138 | axs[1].set_title("True contacts (PDB, threshold = 8 Å)") 139 | axs[1].set_xlabel("Residue index") 140 | axs[1].set_ylabel("Residue index") 141 | fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04) 142 | 143 | plt.tight_layout() 144 | 145 | # Generate output filename from PDB filename 146 | pdb_name = os.path.splitext(os.path.basename(pdb_file))[0] 147 | output_file = f"contact_maps_{pdb_name}.png" 148 | plt.savefig(output_file, dpi=300, bbox_inches='tight') 149 | plt.close() 150 | 151 | 152 | def main(): 153 | # py tests/test_contact_maps.py 154 | parser = argparse.ArgumentParser( 155 | description="Extract protein sequence and compute contact maps from a PDB file using ESM predictions." 156 | ) 157 | parser.add_argument("--pdb_file", type=str, help="Path to the PDB file of the protein. If not provided, a random PDB will be downloaded.", default=None) 158 | parser.add_argument( 159 | "--threshold", 160 | type=float, 161 | default=8.0, 162 | help="Distance threshold (in Å) for defining true contacts (default: 8.0 Å)." 163 | ) 164 | args = parser.parse_args() 165 | 166 | # If no PDB file is provided, download a random one 167 | if args.pdb_file is None: 168 | pdb_file = download_random_pdb() 169 | else: 170 | pdb_file = args.pdb_file 171 | 172 | try: 173 | # Parse the PDB file. 174 | sequence, coords = parse_pdb(pdb_file) 175 | print("Extracted Protein Sequence:") 176 | print(sequence) 177 | 178 | # Compute the pairwise distance matrix. 179 | dist_matrix = compute_distance_matrix(coords) 180 | 181 | # Create a binary contact map from the distance matrix using the threshold. 182 | true_contact_map = (dist_matrix < args.threshold).astype(float) 183 | 184 | # Get the predicted contact map from the ESM model. 185 | predicted_contact_map = get_esm_contact_map(sequence) 186 | 187 | # Check that the dimensions agree. 188 | if predicted_contact_map.shape[0] != true_contact_map.shape[0]: 189 | print("Warning: The predicted contact map and true contact map have different dimensions.") 190 | 191 | # Plot the maps. 192 | plot_maps(true_contact_map, predicted_contact_map, pdb_file) 193 | 194 | print(f"Contact maps saved to: contact_maps_{os.path.splitext(os.path.basename(pdb_file))[0]}.png") 195 | 196 | finally: 197 | # Clean up the temporary file if we downloaded a random PDB 198 | if args.pdb_file is None and os.path.exists(pdb_file): 199 | os.remove(pdb_file) 200 | print(f"Removed temporary PDB file: {pdb_file}") 201 | 202 | 203 | if __name__ == '__main__': 204 | main() 205 | 206 | 207 | 208 | -------------------------------------------------------------------------------- /test_scripts/test_throughput.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import random 4 | import argparse 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | from huggingface_hub import login 8 | from transformers import AutoModelForMaskedLM, EsmTokenizer 9 | from esm.models.esmc import ESMC 10 | from esm.sdk.api import ESMProtein, LogitsConfig 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--model_paths', nargs='+', type=str, default=[ 15 | #'facebook/esm2_t6_8M_UR50D', 16 | 'Synthyra/FastESM2_650', 17 | 'facebook/esm2_t12_35M_UR50D', 18 | 'facebook/esm2_t30_150M_UR50D', 19 | 'facebook/esm2_t33_650M_UR50D', 20 | 'esmc_300m', # esmc model 21 | 'esmc_600m', # esmc model 22 | 'Synthyra/ESMplusplus_small', 23 | 'Synthyra/ESMplusplus_large' 24 | ]) 25 | parser.add_argument('--token', type=str, default=None) 26 | parser.add_argument('--test', action='store_true', help='Generate random results for testing') 27 | args = parser.parse_args() 28 | 29 | if args.token: 30 | login(args.token) 31 | 32 | model_paths = args.model_paths 33 | canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY" 34 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 35 | 36 | 37 | class ESMCForEmbedding(torch.nn.Module): 38 | def __init__(self, esm): 39 | super().__init__() 40 | self.esm = esm 41 | 42 | def forward(self, seq): 43 | protein = ESMProtein(sequence=seq) 44 | protein_tensor = self.esm.encode(protein) 45 | embeddings = self.esm.logits( 46 | protein_tensor, LogitsConfig(sequence=True, return_embeddings=True) 47 | ).embeddings.cpu() 48 | return embeddings 49 | 50 | 51 | def generate_random_sequence(length: int) -> str: 52 | return 'M' + "".join(random.choices(canonical_amino_acids, k=length-3)) 53 | 54 | 55 | def generate_batch_sequences(length: int, batch_size: int, num_batches: int = 100) -> list: 56 | all_sequences = [] 57 | for _ in range(num_batches): 58 | batch_sequences = [generate_random_sequence(length) for _ in range(batch_size)] 59 | all_sequences.append(batch_sequences) 60 | return all_sequences 61 | 62 | 63 | def time_model(model, inputs, warmup=4): 64 | model.eval() 65 | with torch.no_grad(): 66 | # Warmup 67 | for _ in range(warmup): 68 | _ = model(**inputs[0]) 69 | 70 | start_time = time.time() 71 | for input_batch in inputs: 72 | _ = model(**input_batch) 73 | return time.time() - start_time 74 | 75 | 76 | def time_model_esmc(model, sequences, warmup=10): 77 | model.eval() 78 | with torch.no_grad(): 79 | # Warmup 80 | for _ in range(warmup): 81 | for seq in sequences[0]: 82 | _ = model(seq) 83 | 84 | start_time = time.time() 85 | for batch in sequences: 86 | for seq in batch: 87 | _ = model(seq) 88 | return time.time() - start_time 89 | 90 | 91 | def get_gpu_memory(): 92 | torch.cuda.synchronize() 93 | return torch.cuda.max_memory_allocated() / 1024**2 # Convert to MB 94 | 95 | 96 | # Test different sequence lengths and batch sizes 97 | lengths = [32, 64, 128, 256, 512, 1024, 2048] 98 | batch_sizes = [1, 2, 4, 8, 16, 32] 99 | num_batches = 16 100 | results = [] 101 | 102 | if not args.test: 103 | # Generate all test sequences first 104 | all_sequences = {} 105 | for length in lengths: 106 | for batch_size in batch_sizes: 107 | print(f"\nGenerating sequences for length={length}, batch_size={batch_size}") 108 | all_sequences[(length, batch_size)] = generate_batch_sequences(length, batch_size, num_batches) 109 | 110 | # Test each model 111 | for model_path in model_paths: 112 | print(f"\nTesting {model_path}...") 113 | if 'esmc' in model_path.lower(): 114 | esm = ESMC.from_pretrained(model_path, device=device).to(device) 115 | model = ESMCForEmbedding(esm).to(device) 116 | tokenizer = None 117 | elif 'synthyra' in model_path.lower(): 118 | model = AutoModelForMaskedLM.from_pretrained(model_path, trust_remote_code=True, dtype=torch.float16).to(device) 119 | tokenizer = model.tokenizer 120 | else: 121 | model = AutoModelForMaskedLM.from_pretrained(model_path).to(device) 122 | tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D') 123 | 124 | for length in lengths: 125 | for batch_size in batch_sizes: 126 | print(f"\nTesting length={length}, batch_size={batch_size}") 127 | sequences = all_sequences[(length, batch_size)] 128 | 129 | torch.cuda.reset_peak_memory_stats() 130 | if isinstance(model, ESMCForEmbedding): 131 | model_time = time_model_esmc(model, sequences) 132 | else: 133 | inputs = [tokenizer(batch_seq, padding=True, return_tensors="pt").to(device) for batch_seq in sequences] 134 | model_time = time_model(model, inputs) 135 | model_memory = get_gpu_memory() 136 | 137 | results.append({ 138 | 'Model': model_path, 139 | 'Length': length, 140 | 'Batch Size': batch_size, 141 | 'Time': model_time, 142 | 'Memory': model_memory 143 | }) 144 | print(f"Time: {model_time:.2f}s, memory: {model_memory:.0f}MB") 145 | torch.cuda.empty_cache() 146 | 147 | model.cpu() 148 | del model 149 | torch.cuda.empty_cache() 150 | else: 151 | # Generate random test results 152 | for model_path in model_paths: 153 | for length in lengths: 154 | for batch_size in batch_sizes: 155 | # Generate random time between 0.1 and 10 seconds, scaling with length and batch size 156 | model_time = random.uniform(0.1, 10) * (length/2) * (batch_size/1) 157 | # Generate random memory between 100 and 5000 MB, scaling with length and batch size 158 | model_memory = random.uniform(100, 5000) * (length/2) * (batch_size/1) 159 | 160 | results.append({ 161 | 'Model': model_path, 162 | 'Length': length, 163 | 'Batch Size': batch_size, 164 | 'Time': model_time, 165 | 'Memory': model_memory 166 | }) 167 | print(f"Generated random - Time: {model_time:.2f}s, memory: {model_memory:.0f}MB") 168 | 169 | # Save results to CSV 170 | df = pd.DataFrame(results) 171 | df.to_csv('model_benchmarks.csv', index=False) 172 | 173 | # Create visualization for throughput 174 | num_batch_sizes = len(batch_sizes) 175 | plt.figure(figsize=(15, 5 * num_batch_sizes)) 176 | 177 | for i, batch_size in enumerate(batch_sizes): 178 | plt.subplot(num_batch_sizes, 1, i + 1) 179 | for model_path in model_paths: 180 | model_results = [(r['Length'], r['Time']) for r in results 181 | if r['Model'] == model_path and r['Batch Size'] == batch_size] 182 | if model_results: 183 | lengths, times = zip(*model_results) 184 | throughput = [batch_size * len * num_batches / time for len, time in zip(lengths, times)] 185 | plt.plot(lengths, throughput, marker='o', label=model_path) 186 | 187 | plt.title(f'Model Throughput vs Sequence Length (Batch Size = {batch_size})') 188 | plt.xlabel('Sequence Length') 189 | plt.ylabel('Throughput (tokens/second)') 190 | plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') 191 | plt.grid(True) 192 | 193 | plt.tight_layout() 194 | plt.savefig('model_throughput.png', bbox_inches='tight', dpi=300) 195 | plt.close() 196 | 197 | # Create visualization for memory usage 198 | plt.figure(figsize=(15, 5 * num_batch_sizes)) 199 | 200 | for i, batch_size in enumerate(batch_sizes): 201 | plt.subplot(num_batch_sizes, 1, i + 1) 202 | for model_path in model_paths: 203 | model_results = [(r['Length'], r['Memory']) for r in results 204 | if r['Model'] == model_path and r['Batch Size'] == batch_size] 205 | if model_results: 206 | lengths, memory = zip(*model_results) 207 | plt.plot(lengths, memory, marker='o', label=model_path) 208 | 209 | plt.title(f'GPU Memory Usage vs Sequence Length (Batch Size = {batch_size})') 210 | plt.xlabel('Sequence Length') 211 | plt.ylabel('Memory Usage (MB)') 212 | plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') 213 | plt.grid(True) 214 | 215 | plt.tight_layout() 216 | plt.savefig('model_memory.png', bbox_inches='tight', dpi=300) 217 | plt.close() 218 | -------------------------------------------------------------------------------- /readmes/esm_plusplus_large_readme.md: -------------------------------------------------------------------------------- 1 | --- 2 | library_name: transformers 3 | tags: [] 4 | --- 5 | 6 | # NOTE 7 | The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git) 8 | 9 | # ESM++ 10 | [ESM++](https://github.com/Synthyra/ESMplusplus) is a faithful implementation of [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) ([license](https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement)) that allows for batching and standard Huggingface compatibility without requiring the ESM Python package. 11 | The large version corresponds to the 600 million parameter version of ESMC. 12 | 13 | 14 | ## Use with 🤗 transformers 15 | ```python 16 | from transformers import AutoModelForMaskedLM 17 | model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True) 18 | tokenizer = model.tokenizer 19 | 20 | sequences = ['MPRTEIN', 'MSEQWENCE'] 21 | tokenized = tokenizer(sequences, padding=True, return_tensors='pt') 22 | 23 | # tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training 24 | 25 | output = model(**tokenized) # get all hidden states with output_hidden_states=True 26 | print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64) 27 | print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 1152) 28 | print(output.loss) # language modeling loss if you passed labels 29 | #print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple) 30 | ``` 31 | 32 | ESM++ also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization. 33 | 34 | ```python 35 | from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification 36 | 37 | model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True) 38 | logits = model(**tokenized).logits 39 | print(logits.shape) # (batch_size, num_labels), (2, 2) 40 | ``` 41 | 42 | ESM++ weights are fp32 by default. You can load them in fp16 or bf16 like this: 43 | ```python 44 | import torch 45 | model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True, dtype=torch.float16) # or torch.bfloat16 46 | ``` 47 | 48 | ## Embed entire datasets with no new code 49 | To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take. 50 | 51 | Example: 52 | ```python 53 | embedding_dict = model.embed_dataset( 54 | sequences=[ 55 | 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences 56 | ], 57 | tokenizer=model.tokenizer, 58 | batch_size=2, # adjust for your GPU memory 59 | max_len=512, # adjust for your needs 60 | full_embeddings=False, # if True, no pooling is performed 61 | embed_dtype=torch.float32, # cast to what dtype you want 62 | pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together 63 | num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets 64 | sql=False, # if True, embeddings will be stored in SQLite database 65 | sql_db_path='embeddings.db', 66 | save=True, # if True, embeddings will be saved as a .pth file 67 | save_path='embeddings.pth', 68 | ) 69 | # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql 70 | ``` 71 | 72 | ``` 73 | model.embed_dataset() 74 | Args: 75 | sequences: List of protein sequences 76 | batch_size: Batch size for processing 77 | max_len: Maximum sequence length 78 | full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False) 79 | pooling_type: Type of pooling ('mean' or 'cls') 80 | num_workers: Number of workers for data loading, 0 for the main process 81 | sql: Whether to store embeddings in SQLite database - will be stored in float32 82 | sql_db_path: Path to SQLite database 83 | 84 | Returns: 85 | Dictionary mapping sequences to embeddings, or None if sql=True 86 | 87 | Note: 88 | - If sql=True, embeddings can only be stored in float32 89 | - sql is ideal if you need to stream a very large dataset for training in real-time 90 | - save=True is ideal if you can store the entire embedding dictionary in RAM 91 | - sql will be used if it is True and save is True or False 92 | - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences 93 | - Sequences will be truncated to max_len and sorted by length in descending order for faster processing 94 | ``` 95 | 96 | ## Fine-tuning with 🤗 peft 97 | ```python 98 | model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True) 99 | # these modules handle ESM++ and ESM2 attention layers 100 | target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"] 101 | 102 | lora_config = LoraConfig( 103 | r=8, # choose lora parameters to your liking 104 | lora_alpha=16, 105 | lora_dropout=0.01, 106 | bias="none", 107 | target_modules=target_modules, 108 | ) 109 | 110 | # Apply LoRA to the model 111 | model = get_peft_model(model, lora_config) 112 | 113 | # Unfreeze the classifier head 114 | for param in model.classifier.parameters(): 115 | param.requires_grad = True 116 | ``` 117 | 118 | For a more thourough example of fine-tuning, check out our example script [here](https://github.com/Synthyra/FastPLMs/blob/main/fine_tuning_example.py). 119 | 120 | 121 | ## Returning attention maps 122 | Usually F.scaled_dot_product_attention is used for the attention calculations, which is much faster than native PyTorch. However, it cannot return attention maps. 123 | ESM++ has the option to ```output_attentions```, which will calculate attention manually. This is much slower, so do not use unless you need the attention maps. 124 | 125 | ```python 126 | output = model(**tokenized, output_attentions=True) 127 | att = output.attentions 128 | len(att) # 33, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each 129 | ``` 130 | 131 | ## Comparison across floating-point precision and implementations 132 | We measured the difference of the last hidden states of the fp32 weights vs. fp16 or bf16. We find that the fp16 is closer to the fp32 outputs, so we recommend loading in fp16. 133 | Please note that the ESM package also loads ESMC in fp32 but casts to bf16 by default, which has its share of advantages and disadvantages in inference / training - so load whichever you like for half precision. 134 | 135 | Average MSE for FP16: 0.00000003 136 | 137 | Average MSE for BF16: 0.00000122 138 | 139 | We also measured the difference between the outputs of ESM++ vs. ESMC (both in bfloat16) on 1000 random sequences to ensure compliance with the ESM package. 140 | 141 | Average MSE of last hidden state: 2.46e-09 142 | 143 | You can load the weights from the ESM package instead of transformers by replacing .from_pretrained(...) to .from_pretrained_esm('esmc_600m') 144 | 145 | ## Model probes 146 | We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. ESMC (and thus ESM++) perform very well. 147 | 148 | The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2. 149 | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/uRAHYQcwkbgajylTIFbUb.png) 150 | 151 | ## Inference speeds 152 | We look at various ESM models and their throughput on an H100. Adding efficient batching between ESMC and ESM++ significantly improves the throughput, although ESM++ is also faster than ESMC for batch size one. ESM++ small is even faster than ESM2-35M with long sequences! The most gains will be seen with PyTorch > 2.5 on linux machines. 153 | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/Lu6nWB9Fc-7YTql3Z1hVB.png) 154 | 155 | ### Citation 156 | If you use any of this implementation or work please cite it (as well as the ESMC preprint). 157 | 158 | ``` 159 | @misc {FastPLMs, 160 | author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.}, 161 | title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.}, 162 | year = {2024}, 163 | url = { https://huggingface.co/Synthyra/ESMplusplus_small }, 164 | DOI = { 10.57967/hf/3726 }, 165 | publisher = { Hugging Face } 166 | } 167 | ``` -------------------------------------------------------------------------------- /readmes/esm_plusplus_small_readme.md: -------------------------------------------------------------------------------- 1 | --- 2 | library_name: transformers 3 | tags: [] 4 | --- 5 | 6 | # NOTE 7 | The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git) 8 | 9 | # ESM++ 10 | [ESM++](https://github.com/Synthyra/ESMplusplus) is a faithful implementation of [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) ([license](https://www.evolutionaryscale.ai/policies/cambrian-open-license-agreement)) that allows for batching and standard Huggingface compatibility without requiring the ESM Python package. 11 | The small version corresponds to the 300 million parameter version of ESMC. 12 | 13 | 14 | ## Use with 🤗 transformers 15 | ```python 16 | from transformers import AutoModelForMaskedLM 17 | model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True) 18 | tokenizer = model.tokenizer 19 | 20 | sequences = ['MPRTEIN', 'MSEQWENCE'] 21 | tokenized = tokenizer(sequences, padding=True, return_tensors='pt') 22 | 23 | # tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training 24 | 25 | output = model(**tokenized) # get all hidden states with output_hidden_states=True 26 | print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64) 27 | print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 960) 28 | print(output.loss) # language modeling loss if you passed labels 29 | #print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple) 30 | ``` 31 | 32 | ESM++ also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization. 33 | 34 | ```python 35 | from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification 36 | 37 | model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True) 38 | logits = model(**tokenized).logits 39 | print(logits.shape) # (batch_size, num_labels), (2, 2) 40 | ``` 41 | 42 | ESM++ weights are fp32 by default. You can load them in fp16 or bf16 like this: 43 | ```python 44 | import torch 45 | model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True, dtype=torch.float16) # or torch.bfloat16 46 | ``` 47 | 48 | ## Embed entire datasets with no new code 49 | To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take. 50 | 51 | Example: 52 | ```python 53 | embedding_dict = model.embed_dataset( 54 | sequences=[ 55 | 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences 56 | ], 57 | tokenizer=model.tokenizer, 58 | batch_size=2, # adjust for your GPU memory 59 | max_len=512, # adjust for your needs 60 | full_embeddings=False, # if True, no pooling is performed 61 | embed_dtype=torch.float32, # cast to what dtype you want 62 | pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together 63 | num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets 64 | sql=False, # if True, embeddings will be stored in SQLite database 65 | sql_db_path='embeddings.db', 66 | save=True, # if True, embeddings will be saved as a .pth file 67 | save_path='embeddings.pth', 68 | ) 69 | # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql 70 | ``` 71 | 72 | ``` 73 | model.embed_dataset() 74 | Args: 75 | sequences: List of protein sequences 76 | batch_size: Batch size for processing 77 | max_len: Maximum sequence length 78 | full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False) 79 | pooling_type: Type of pooling ('mean' or 'cls') 80 | num_workers: Number of workers for data loading, 0 for the main process 81 | sql: Whether to store embeddings in SQLite database - will be stored in float32 82 | sql_db_path: Path to SQLite database 83 | 84 | Returns: 85 | Dictionary mapping sequences to embeddings, or None if sql=True 86 | 87 | Note: 88 | - If sql=True, embeddings can only be stored in float32 89 | - sql is ideal if you need to stream a very large dataset for training in real-time 90 | - save=True is ideal if you can store the entire embedding dictionary in RAM 91 | - sql will be used if it is True and save is True or False 92 | - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences 93 | - Sequences will be truncated to max_len and sorted by length in descending order for faster processing 94 | ``` 95 | 96 | ## Fine-tuning with 🤗 peft 97 | ```python 98 | model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True) 99 | # these modules handle ESM++ and ESM2 attention layers 100 | target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"] 101 | 102 | lora_config = LoraConfig( 103 | r=8, # choose lora parameters to your liking 104 | lora_alpha=16, 105 | lora_dropout=0.01, 106 | bias="none", 107 | target_modules=target_modules, 108 | ) 109 | 110 | # Apply LoRA to the model 111 | model = get_peft_model(model, lora_config) 112 | 113 | # Unfreeze the classifier head 114 | for param in model.classifier.parameters(): 115 | param.requires_grad = True 116 | ``` 117 | 118 | For a more thourough example of fine-tuning, check out our example script [here](https://github.com/Synthyra/FastPLMs/blob/main/fine_tuning_example.py). 119 | 120 | 121 | ## Returning attention maps 122 | Usually F.scaled_dot_product_attention is used for the attention calculations, which is much faster than native PyTorch. However, it cannot return attention maps. 123 | ESM++ has the option to ```output_attentions```, which will calculate attention manually. This is much slower, so do not use unless you need the attention maps. 124 | 125 | ```python 126 | output = model(**tokenized, output_attentions=True) 127 | att = output.attentions 128 | len(att) # 30, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each 129 | ``` 130 | 131 | ## Comparison across floating-point precision and implementations 132 | We measured the difference of the last hidden states of the fp32 weights vs. fp16 or bf16. We find that the fp16 is closer to the fp32 outputs, so we recommend loading in fp16. 133 | Please note that the ESM package also loads ESMC in fp32 but casts to bf16 by default, which has its share of advantages and disadvantages in inference / training - so load whichever you like for half precision. 134 | 135 | Average MSE FP32 vs. FP16: 0.00000003 136 | 137 | Average MSE FP32 vs. BF16: 0.00000140 138 | 139 | We also measured the difference between the outputs of ESM++ vs. ESMC (both in bfloat16) on 1000 random sequences to ensure compliance with the ESM package. 140 | 141 | Average MSE of last hidden state: 7.74e-10 142 | 143 | You can load the weights from the ESM package instead of transformers by replacing .from_pretrained(...) to .from_pretrained_esm('esmc_300m') 144 | 145 | ## Model probes 146 | We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. ESMC (and thus ESM++) perform very well. 147 | 148 | The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2. 149 | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/2zyUZeHyOgCR_twvPF2Wy.png) 150 | 151 | ## Inference speeds 152 | We look at various ESM models and their throughput on an H100. Adding efficient batching between ESMC and ESM++ significantly improves the throughput, although ESM++ is also faster than ESMC for batch size one. ESM++ small is even faster than ESM2-35M with long sequences! 153 | The most gains will be seen with PyTorch > 2.5 on linux machines. 154 | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/RfLRSchFivdsqJrWMh4bo.png) 155 | 156 | ### Citation 157 | If you use any of this implementation or work please cite it (as well as the ESMC preprint). 158 | 159 | ``` 160 | @misc {FastPLMs, 161 | author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.}, 162 | title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.}, 163 | year = {2024}, 164 | url = { https://huggingface.co/Synthyra/ESMplusplus_small }, 165 | DOI = { 10.57967/hf/3726 }, 166 | publisher = { Hugging Face } 167 | } 168 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | PLEASE NOTE THE APACHE LICENSE ONLY APPLIES TO THE CODE IN THE FastPLMs GITHUB AND ASSOCIATED HUGGINGFACE REPOSITORIES, NOT NECESSARILY THE MODEL WEIGHTS. THOSE LICENSES CAN BE FOUND HERE https://github.com/Synthyra/FastPLMs/tree/main/licenses 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /wip/t5/t5_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import math 4 | from typing import Optional 5 | from torch import nn 6 | from transformers import T5Config 7 | from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class T5AttentionTransformers(nn.Module): 14 | def __init__( 15 | self, 16 | config: T5Config, 17 | has_relative_attention_bias=False, 18 | layer_idx: Optional[int] = None, 19 | ): 20 | super().__init__() 21 | self.is_decoder = config.is_decoder 22 | self.has_relative_attention_bias = has_relative_attention_bias 23 | self.relative_attention_num_buckets = config.relative_attention_num_buckets 24 | self.relative_attention_max_distance = config.relative_attention_max_distance 25 | self.d_model = config.d_model 26 | self.key_value_proj_dim = config.d_kv 27 | self.n_heads = config.num_heads 28 | self.dropout = config.dropout_rate 29 | self.inner_dim = self.n_heads * self.key_value_proj_dim 30 | self.layer_idx = layer_idx 31 | if layer_idx is None and self.is_decoder: 32 | logger.warning_once( 33 | f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " 34 | "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " 35 | "when creating this class." 36 | ) 37 | 38 | # Mesh TensorFlow initialization to avoid scaling before softmax 39 | self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) 40 | self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) 41 | self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) 42 | self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) 43 | 44 | if self.has_relative_attention_bias: 45 | self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) 46 | self.pruned_heads = set() 47 | self.gradient_checkpointing = False 48 | 49 | def prune_heads(self, heads): 50 | if len(heads) == 0: 51 | return 52 | heads, index = find_pruneable_heads_and_indices( 53 | heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads 54 | ) 55 | # Prune linear layers 56 | self.q = prune_linear_layer(self.q, index) 57 | self.k = prune_linear_layer(self.k, index) 58 | self.v = prune_linear_layer(self.v, index) 59 | self.o = prune_linear_layer(self.o, index, dim=1) 60 | # Update hyper params 61 | self.n_heads = self.n_heads - len(heads) 62 | self.inner_dim = self.key_value_proj_dim * self.n_heads 63 | self.pruned_heads = self.pruned_heads.union(heads) 64 | 65 | @staticmethod 66 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): 67 | """ 68 | Adapted from Mesh Tensorflow: 69 | https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 70 | 71 | Translate relative position to a bucket number for relative attention. The relative position is defined as 72 | memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to 73 | position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for 74 | small absolute relative_position and larger buckets for larger absolute relative_positions. All relative 75 | positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. 76 | This should allow for more graceful generalization to longer sequences than the model has been trained on 77 | 78 | Args: 79 | relative_position: an int32 Tensor 80 | bidirectional: a boolean - whether the attention is bidirectional 81 | num_buckets: an integer 82 | max_distance: an integer 83 | 84 | Returns: 85 | a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) 86 | """ 87 | relative_buckets = 0 88 | if bidirectional: 89 | num_buckets //= 2 90 | relative_buckets += (relative_position > 0).to(torch.long) * num_buckets 91 | relative_position = torch.abs(relative_position) 92 | else: 93 | relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) 94 | # now relative_position is in the range [0, inf) 95 | 96 | # half of the buckets are for exact increments in positions 97 | max_exact = num_buckets // 2 98 | is_small = relative_position < max_exact 99 | 100 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 101 | relative_position_if_large = max_exact + ( 102 | torch.log(relative_position.float() / max_exact) 103 | / math.log(max_distance / max_exact) 104 | * (num_buckets - max_exact) 105 | ).to(torch.long) 106 | relative_position_if_large = torch.min( 107 | relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) 108 | ) 109 | 110 | relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) 111 | return relative_buckets 112 | 113 | def compute_bias(self, query_length, key_length, device=None, cache_position=None): 114 | """Compute binned relative position bias""" 115 | if device is None: 116 | device = self.relative_attention_bias.weight.device 117 | if cache_position is None: 118 | context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] 119 | else: 120 | context_position = cache_position[:, None].to(device) 121 | memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] 122 | relative_position = memory_position - context_position # shape (query_length, key_length) 123 | relative_position_bucket = self._relative_position_bucket( 124 | relative_position, # shape (query_length, key_length) 125 | bidirectional=(not self.is_decoder), 126 | num_buckets=self.relative_attention_num_buckets, 127 | max_distance=self.relative_attention_max_distance, 128 | ) 129 | values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) 130 | values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) 131 | return values 132 | 133 | def forward( 134 | self, 135 | hidden_states, 136 | mask=None, 137 | key_value_states=None, 138 | position_bias=None, 139 | past_key_value=None, 140 | layer_head_mask=None, 141 | query_length=None, 142 | use_cache=False, 143 | output_attentions=False, 144 | cache_position=None, 145 | ): 146 | """ 147 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). 148 | """ 149 | # Input is (batch_size, seq_length, dim) 150 | # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) 151 | batch_size, seq_length = hidden_states.shape[:2] 152 | 153 | # if key_value_states are provided this layer is used as a cross-attention layer for the decoder 154 | is_cross_attention = key_value_states is not None 155 | 156 | query_states = self.q(hidden_states) 157 | query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) 158 | 159 | if cache_position is None: 160 | if past_key_value is not None: 161 | logger.warning_once( 162 | f"{self.__class__.__name__} forward called without cache_position when using cache, which might result in errors. " 163 | "Please provide a cache_position when calling this function. " 164 | "See 'Best Practices for Generation with Cache' in the docs for more information. " 165 | "Assuming cache position starts at 0." 166 | ) 167 | cache_position = torch.arange(seq_length) 168 | 169 | if past_key_value is not None: 170 | is_updated = past_key_value.is_updated.get(self.layer_idx) 171 | if is_cross_attention: 172 | # after the first generated id, we can subsequently re-use all key/value_states from cache 173 | curr_past_key_value = past_key_value.cross_attention_cache 174 | else: 175 | curr_past_key_value = past_key_value.self_attention_cache 176 | 177 | current_states = key_value_states if is_cross_attention else hidden_states 178 | if is_cross_attention and past_key_value is not None and is_updated: 179 | # reuse k,v, cross_attentions 180 | key_states = curr_past_key_value.key_cache[self.layer_idx] 181 | value_states = curr_past_key_value.value_cache[self.layer_idx] 182 | else: 183 | key_states = self.k(current_states) 184 | value_states = self.v(current_states) 185 | key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) 186 | value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) 187 | 188 | if past_key_value is not None: 189 | # save all key/value_states to cache to be re-used for fast auto-regressive generation 190 | cache_position = cache_position if not is_cross_attention else None 191 | key_states, value_states = curr_past_key_value.update( 192 | key_states, value_states, self.layer_idx, {"cache_position": cache_position} 193 | ) 194 | # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls 195 | if is_cross_attention: 196 | past_key_value.is_updated[self.layer_idx] = True 197 | 198 | # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 199 | scores = torch.matmul(query_states, key_states.transpose(3, 2)) 200 | 201 | if position_bias is None: 202 | key_length = key_states.shape[-2] 203 | # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) 204 | real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 205 | if not self.has_relative_attention_bias: 206 | position_bias = torch.zeros( 207 | (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype 208 | ) 209 | if self.gradient_checkpointing and self.training: 210 | position_bias.requires_grad = True 211 | else: 212 | position_bias = self.compute_bias( 213 | real_seq_length, key_length, device=scores.device, cache_position=cache_position 214 | ) 215 | position_bias = position_bias[:, :, -seq_length:, :] 216 | 217 | if mask is not None: 218 | causal_mask = mask[:, :, :, : key_states.shape[-2]] 219 | position_bias = position_bias + causal_mask 220 | 221 | if self.pruned_heads: 222 | mask = torch.ones(position_bias.shape[1]) 223 | mask[list(self.pruned_heads)] = 0 224 | position_bias_masked = position_bias[:, mask.bool()] 225 | else: 226 | position_bias_masked = position_bias 227 | 228 | scores += position_bias_masked 229 | 230 | # (batch_size, n_heads, seq_length, key_length) 231 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) 232 | attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 233 | 234 | # Mask heads if we want to 235 | if layer_head_mask is not None: 236 | attn_weights = attn_weights * layer_head_mask 237 | 238 | attn_output = torch.matmul(attn_weights, value_states) 239 | 240 | attn_output = attn_output.transpose(1, 2).contiguous() 241 | attn_output = attn_output.view(batch_size, -1, self.inner_dim) 242 | attn_output = self.o(attn_output) 243 | 244 | outputs = (attn_output, past_key_value, position_bias) 245 | 246 | if output_attentions: 247 | outputs = outputs + (attn_weights,) 248 | return outputs -------------------------------------------------------------------------------- /wip/t5/t5_flex_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO 3 | Future addition to current T5 PLM models to use much more efficient flex attention 4 | Waiting for official support of learned relative position embeddings in the score_mode 5 | https://github.com/pytorch-labs/attention-gym/issues/20 6 | https://pytorch.org/blog/flexattention/ 7 | https://github.com/pytorch-labs/attention-gym/pull/84/commits/4d045e172474e2b964e3961d794177ceca1549f8 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | from typing import Optional, Tuple, Union 12 | from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer 13 | from torch.nn.attention.flex_attention import flex_attention 14 | from .t5_attention import T5AttentionTransformers 15 | 16 | 17 | class T5FlexAttentionMixin: 18 | """ 19 | A mixin class to enable flex attention in T5 models. 20 | 21 | This mixin replaces the standard attention mechanism in T5Attention with flex attention, 22 | preserving the position bias mechanism that T5 uses. 23 | """ 24 | 25 | def __init__(self, compile_flex=None, softcap=None): 26 | """ 27 | Initialize the T5FlexAttentionMixin. 28 | 29 | Args: 30 | compile_flex: Whether to compile the flex attention function. 31 | kernel_options: Optional kernel options for flex attention. 32 | """ 33 | self.compile_flex = compile_flex and (torch.cuda.is_available() or torch.backends.mps.is_available()) 34 | self.flex_attention_fn = torch.compile(flex_attention) if self.compile_flex else flex_attention 35 | self.softcap = softcap 36 | 37 | def _apply_flex_attention(self, query, key, value, position_bias=None, attention_mask=None, head_mask=None): 38 | # adapted from https://github.com/huggingface/transformers/blob/752ef3fd4e70869626ec70657a770a85c0ad9219/src/transformers/integrations/flex_attention.py 39 | def score_mod(score, b, h, q_idx, kv_idx): 40 | if self.softcap is not None: 41 | score = self.softcap * torch.tanh(score / self.softcap) 42 | if position_bias is not None: 43 | score = score + position_bias[0][h][q_idx][kv_idx] 44 | if attention_mask is not None: 45 | score = score + attention_mask[b][0][q_idx][kv_idx] 46 | if head_mask is not None: 47 | score = score + head_mask[b][h][q_idx][kv_idx] 48 | return score 49 | 50 | # Apply flex attention 51 | attn_output, attention_weights = self.flex_attention_fn( 52 | query=query, 53 | key=key, 54 | value=value, 55 | score_mod=score_mod, 56 | enable_gqa=True, 57 | scale=1.0, 58 | return_lse=True, 59 | ) 60 | return attn_output, attention_weights 61 | 62 | 63 | class T5FlexAttention(nn.Module, T5FlexAttentionMixin): 64 | """ 65 | Drop-in replacement for T5Attention that uses flex attention. 66 | 67 | This class preserves the interface of T5Attention but uses flex attention 68 | for the core attention computation. 69 | """ 70 | 71 | def __init__(self, config, has_relative_attention_bias=False, layer_idx=None, compile_flex=False): 72 | nn.Module.__init__(self) 73 | T5FlexAttentionMixin.__init__(self, compile_flex=compile_flex) 74 | 75 | self.is_decoder = config.is_decoder 76 | self.has_relative_attention_bias = has_relative_attention_bias 77 | self.layer_idx = layer_idx 78 | 79 | self.relative_attention_num_buckets = config.relative_attention_num_buckets 80 | self.relative_attention_max_distance = config.relative_attention_max_distance 81 | self.d_model = config.d_model 82 | self.key_value_proj_dim = config.d_kv 83 | self.n_heads = config.num_heads 84 | self.dropout = config.dropout_rate 85 | self.inner_dim = self.n_heads * self.key_value_proj_dim 86 | 87 | # Regular T5 projection layers 88 | self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) 89 | self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) 90 | self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) 91 | self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) 92 | 93 | if self.has_relative_attention_bias: 94 | self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) 95 | 96 | self.pruned_heads = set() 97 | 98 | def prune_heads(self, heads): 99 | # Implementation similar to T5Attention.prune_heads 100 | if len(heads) == 0: 101 | return 102 | heads, index = find_pruneable_heads_and_indices( 103 | heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads 104 | ) 105 | # Prune linear layers 106 | self.q = prune_linear_layer(self.q, index) 107 | self.k = prune_linear_layer(self.k, index) 108 | self.v = prune_linear_layer(self.v, index) 109 | self.o = prune_linear_layer(self.o, index, dim=1) 110 | # Update hyper params 111 | self.n_heads = self.n_heads - len(heads) 112 | self.inner_dim = self.key_value_proj_dim * self.n_heads 113 | self.pruned_heads = self.pruned_heads.union(heads) 114 | 115 | @staticmethod 116 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): 117 | """ 118 | Adapted from T5Attention._relative_position_bucket 119 | """ 120 | relative_buckets = 0 121 | if bidirectional: 122 | num_buckets //= 2 123 | relative_buckets += (relative_position > 0).to(torch.long) * num_buckets 124 | relative_position = torch.abs(relative_position) 125 | else: 126 | relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) 127 | 128 | # half of the buckets are for exact increments in positions 129 | max_exact = num_buckets // 2 130 | is_small = relative_position < max_exact 131 | 132 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 133 | relative_position_if_large = max_exact + ( 134 | torch.log(relative_position.float() / max_exact) 135 | / torch.log(torch.tensor(max_distance / max_exact, device=relative_position.device)) 136 | * (num_buckets - max_exact) 137 | ).to(torch.long) 138 | relative_position_if_large = torch.min( 139 | relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) 140 | ) 141 | 142 | relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) 143 | return relative_buckets 144 | 145 | def compute_bias(self, query_length, key_length, device=None, cache_position=None): 146 | """ 147 | Compute binned relative position bias, same as in T5Attention 148 | """ 149 | if device is None: 150 | device = self.relative_attention_bias.weight.device 151 | 152 | if cache_position is None: 153 | context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] 154 | else: 155 | context_position = cache_position[:, None].to(device) 156 | 157 | memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] 158 | relative_position = memory_position - context_position # shape (query_length, key_length) 159 | 160 | relative_position_bucket = self._relative_position_bucket( 161 | relative_position, # shape (query_length, key_length) 162 | bidirectional=(not self.is_decoder), 163 | num_buckets=self.relative_attention_num_buckets, 164 | max_distance=self.relative_attention_max_distance, 165 | ) 166 | 167 | values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) 168 | values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) 169 | return values 170 | 171 | def forward( 172 | self, 173 | hidden_states, 174 | mask=None, 175 | key_value_states=None, 176 | position_bias=None, 177 | past_key_value=None, 178 | layer_head_mask=None, 179 | query_length=None, 180 | use_cache=False, 181 | output_attentions=False, 182 | cache_position=None, 183 | ): 184 | """ 185 | Forward pass, similar to T5Attention but using flex attention 186 | """ 187 | batch_size, seq_length = hidden_states.shape[:2] 188 | 189 | # if key_value_states are provided this layer is used as a cross-attention layer for the decoder 190 | is_cross_attention = key_value_states is not None 191 | 192 | # Project hidden states to query vectors 193 | query_states = self.q(hidden_states) 194 | query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) 195 | 196 | # Handle past key values for caching 197 | if past_key_value is not None: 198 | is_updated = past_key_value.is_updated.get(self.layer_idx) 199 | if is_cross_attention: 200 | # after the first generated id, we can subsequently re-use all key/value_states from cache 201 | curr_past_key_value = past_key_value.cross_attention_cache 202 | else: 203 | curr_past_key_value = past_key_value.self_attention_cache 204 | 205 | # Get current states for key and value 206 | current_states = key_value_states if is_cross_attention else hidden_states 207 | 208 | # Reuse cached key/value states if available and updated 209 | if is_cross_attention and past_key_value is not None and is_updated: 210 | # reuse k,v, cross_attentions 211 | key_states = curr_past_key_value.key_cache[self.layer_idx] 212 | value_states = curr_past_key_value.value_cache[self.layer_idx] 213 | else: 214 | # Compute new key and value states 215 | key_states = self.k(current_states) 216 | value_states = self.v(current_states) 217 | key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) 218 | value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) 219 | 220 | # Update cache if needed 221 | if past_key_value is not None: 222 | # save all key/value_states to cache to be re-used for fast auto-regressive generation 223 | cache_position = cache_position if not is_cross_attention else None 224 | key_states, value_states = curr_past_key_value.update( 225 | key_states, value_states, self.layer_idx, {"cache_position": cache_position} 226 | ) 227 | # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls 228 | if is_cross_attention: 229 | past_key_value.is_updated[self.layer_idx] = True 230 | 231 | # Compute position bias if not provided 232 | if position_bias is None: 233 | key_length = key_states.shape[-2] 234 | # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) 235 | real_seq_length = query_length if query_length is not None else (cache_position[-1] + 1 if cache_position is not None else seq_length) 236 | 237 | if not self.has_relative_attention_bias: 238 | position_bias = torch.zeros( 239 | (1, self.n_heads, seq_length, key_length), device=query_states.device, dtype=query_states.dtype 240 | ) 241 | if getattr(self, 'gradient_checkpointing', False) and self.training: 242 | position_bias.requires_grad = True 243 | else: 244 | position_bias = self.compute_bias( 245 | real_seq_length, key_length, device=query_states.device, cache_position=cache_position 246 | ) 247 | position_bias = position_bias[:, :, -seq_length:, :] 248 | 249 | # Add mask to position bias if provided 250 | if mask is not None: 251 | causal_mask = mask[:, :, :, : key_states.shape[-2]] 252 | position_bias = position_bias + causal_mask 253 | 254 | # Apply pruned heads if any 255 | if self.pruned_heads: 256 | mask = torch.ones(position_bias.shape[1]) 257 | mask[list(self.pruned_heads)] = 0 258 | position_bias_masked = position_bias[:, mask.bool()] 259 | else: 260 | position_bias_masked = position_bias 261 | 262 | # Apply flex attention 263 | attn_output, attention_weights = self._apply_flex_attention( 264 | query=query_states, 265 | key=key_states, 266 | value=value_states, 267 | position_bias=position_bias_masked, 268 | attention_mask=mask, 269 | head_mask=layer_head_mask 270 | ) 271 | 272 | # Reshape output and apply output projection 273 | attn_output = attn_output.transpose(1, 2).contiguous() 274 | attn_output = attn_output.view(batch_size, -1, self.inner_dim) 275 | attn_output = self.o(attn_output) 276 | 277 | # Prepare outputs 278 | outputs = (attn_output, past_key_value, position_bias) 279 | if output_attentions: 280 | outputs = outputs + (attention_weights,) 281 | return outputs 282 | 283 | 284 | def replace_t5_attention_with_flex(model, compile_flex=False): 285 | """ 286 | Replace all T5Attention modules in a T5 model with T5FlexAttention. 287 | 288 | Args: 289 | model: A T5 model instance 290 | 291 | Returns: 292 | The modified model with flex attention 293 | """ 294 | # Recursively replace all T5Attention modules 295 | for name, module in model.named_children(): 296 | if isinstance(module, T5AttentionTransformers): 297 | # Create a new T5FlexAttention with the same parameters 298 | flex_attn = T5FlexAttention( 299 | config=model.config, 300 | has_relative_attention_bias=module.has_relative_attention_bias, 301 | layer_idx=module.layer_idx, 302 | compile_flex=compile_flex 303 | ) 304 | 305 | # Copy weights 306 | flex_attn.q.weight.data = module.q.weight.data.clone() 307 | flex_attn.k.weight.data = module.k.weight.data.clone() 308 | flex_attn.v.weight.data = module.v.weight.data.clone() 309 | flex_attn.o.weight.data = module.o.weight.data.clone() 310 | 311 | if module.has_relative_attention_bias: 312 | flex_attn.relative_attention_bias.weight.data = module.relative_attention_bias.weight.data.clone() 313 | 314 | # Replace the module 315 | setattr(model, name, flex_attn) 316 | else: 317 | # Recursively process child modules 318 | replace_t5_attention_with_flex(module) 319 | 320 | return model 321 | 322 | 323 | 324 | if __name__ == "__main__": 325 | # py -m wip.t5.t5_flex_attention 326 | from transformers import T5Config 327 | 328 | config = T5Config() 329 | 330 | attention_layer = T5FlexAttention(config=config, has_relative_attention_bias=True, layer_idx=0, compile_flex=False) 331 | hidden_states = torch.randn(1, 10, config.d_model) 332 | 333 | output = attention_layer(hidden_states) 334 | 335 | print(output) 336 | -------------------------------------------------------------------------------- /licenses/e1_license.txt: -------------------------------------------------------------------------------- 1 | Your use of the Profluent-E1 model code is governed by the Apache License 2.0, while your use of the Profluent-E1 model weights and the full release of the Profluent-E1 model is governed by a similarly permissive license with additional attribution requirements - see the NOTICE file for details. You can use, share, and modify Profluent-E1 for free, but you must follow our ATTRIBUTION guidelines to give credit, include the license when you share and follow some other basic rules. Profluent is not responsible for what you build, and may terminate your rights to use Profluent-E1 if you breach the license. 2 | 3 | Code in src/E1/model/flash_attention_utils.py is adapted from flash-attention project under BSD-3-Clause license. 4 | 5 | Profluent-E1 Notice File 6 | ------------------------ 7 | 8 | Copyright 2025 Profluent Bio Inc. 9 | 10 | Licensed under the Profluent-E1 Clickthrough License Agreement (the “Agreement”); you may not use 11 | this file except in compliance with the Agreement. Unless required by applicable law or agreed to 12 | in writing, software distributed under the Agreement is distributed on an "AS IS" BASIS, WITHOUT 13 | WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the Agreement for the specific 14 | language governing permissions and limitations under the Agreement and the Attribution Guidelines 15 | for more information about attribution requirements for use of Profluent-E1. You may obtain a copy 16 | of the Agreement at https://github.com/Profluent-AI/E1/blob/main/LICENSE and a copy of the Attribution 17 | Guidelines at https://github.com/Profluent-AI/E1/blob/main/ATTRIBUTION, each as may be updated or 18 | amended from time to time. 19 | 20 | Profluent-E1 Clickthrough License Agreement 21 | ------------------------------------------- 22 | 23 | Please read this Profluent-E1 Clickthrough License Agreement (the “Agreement”) carefully before using Profluent-E1 (as defined below), which is offered by Profluent Bio Inc. (“Profluent”). 24 | 25 | By downloading Profluent-E1, or otherwise using Profluent-E1 in any manner, You agree that You have read and agree to be bound by the terms of this Agreement. If You are accessing Profluent-E1 on behalf of an organization or entity, You represent and warrant that You are authorized to enter into this Agreement on that organization's or entity's behalf and bind them to the terms of this Agreement (in which case, the references to “You” and “Your” in this Agreement, except for in this sentence, refer to that organization or entity). Use of Profluent-E1 and all other Profluent-E1 IP is expressly conditioned upon Your assent to all terms of this Agreement, including the Attribution Guidelines incorporated herein, to the exclusion of all other terms. 26 | 27 | 1. Definitions. 28 | 29 | 1.1 “AAA Rules” shall mean the Commercial Arbitration Rules and Mediation Procedures of the American Arbitration Association (“AAA”). 30 | 31 | 1.2 “Claim” shall mean any claim (including any tort claim), cause of action, and/or dispute under, arising out of, or relating to this Agreement. 32 | 33 | 1.3 “Contribution” shall mean any work of authorship, including the original version of Profluent-E1 and any modifications or additions to that Profluent-E1 or Derivative Works thereof, that is intentionally submitted to Profluent for inclusion in Profluent-E1 by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to Profluent or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, Profluent for the purpose of discussing and improving Profluent-E1, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 34 | 35 | 1.4 “Contributor” shall mean Profluent and any individual or Legal Entity on behalf of whom a Contribution has been received by Profluent and subsequently incorporated within Profluent-E1. 36 | 37 | 1.5 “Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) Profluent-E1 and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this Agreement, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, Profluent-E1 and Derivative Works thereof. 38 | 39 | 1.6 “GitHub Page” shall mean the page made available at https://github.com/Profluent-AI/E1, as may be updated and amended from time to time. 40 | 41 | 1.7 “Hugging Face Pages” shall mean the pages made available at https://huggingface.co/Profluent-Bio/E1-600m, https://huggingface.co/Profluent-Bio/E1-300m and https://huggingface.co/Profluent-Bio/E1-150m, each as may be updated and amended from time to time. 42 | 43 | 1.8 “Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 44 | 45 | 1.9 “Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 46 | 47 | 1.10 “Profluent-E1” shall mean the Profluent-E1 Model Code, Profluent-E1 Model Weights and all software, algorithms, machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed or made available to You by Profluent on the GitHub Page, each as may be updated and amended from time to time, whether in Source or Object form. 48 | 49 | 1.11 “Profluent-E1 Model Code” shall mean the code and data for Profluent-E1 made available to You at https://github.com/Profluent-AI/E1, as may be updated and amended from time to time. 50 | 51 | 1.12 “Profluent-E1 Model Weights” shall mean the trained model weights for Profluent-E1 made available to You on one or more of the Hugging Face Pages, as may be updated and amended from time to time, including all model weights that are directly or indirectly accessed or copied from the Hugging Face Pages by You or any third party. 52 | 53 | 1.13 “Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 54 | 55 | 1.14 “You” or “Your” shall mean an individual entering into this Agreement or the organization or Legal Entity on whose behalf such individual is entering into this Agreement. 56 | 57 | 2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute Profluent-E1 and such Derivative Works in Source or Object form. 58 | 59 | 3. Grant of Patent License. Subject to the terms and conditions of this Agreement, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer Profluent-E1, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with Profluent-E1 to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that Profluent-E1 or a Contribution incorporated within Profluent-E1 constitutes direct or contributory patent infringement, then any patent licenses granted to You under this Agreement for Profluent-E1 shall terminate as of the date such litigation is filed. 60 | 61 | 4. Use of Profluent-E1 Model Code. Your access to and use of the Profluent-E1 Model Code separate and apart from Profluent-E1 and the Profluent-E1 Model Weights is subject to the Apache License, Version 2.0 made available at https://www.apache.org/licenses/LICENSE-2.0, as may be updated and amended from time to time. 62 | 63 | 5. Redistribution. You may reproduce and distribute copies of Profluent-E1 or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 64 | 65 | (a) You must give any other recipients of Profluent-E1 or Derivative Works a copy of this Agreement; and 66 | 67 | (b) You must cause any modified files to carry prominent notices stating that You changed the files; and 68 | 69 | (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of Profluent-E1, excluding those notices that do not pertain to any part of the Derivative Works; and 70 | 71 | (d) If Profluent-E1 includes a "NOTICE" text file as part of its distribution or there is otherwise a “NOTICE” text file available on the GitHub Page, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify this Agreement. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from Profluent-E1, provided that such additional attribution notices cannot be construed as modifying this Agreement; and 72 | 73 | (e) You must comply with the Attribution Guidelines and provide attribution to Profluent-E1 in accordance therewith. 74 | 75 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of Profluent-E1 otherwise complies with the conditions stated in this License. 76 | 77 | 6. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in Profluent-E1 by You to Profluent shall be under the terms and conditions of this Agreement, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Profluent regarding such Contributions. 78 | 79 | 7. Trademarks. This Agreement does not grant permission to use any Profluent trade names, trademarks, service marks, or product names, except as required for reasonable and customary use in describing the origin of Profluent-E1, reproducing the content of the NOTICE file and complying with the Attribution Guidelines. 80 | 81 | 8. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Profluent provides Profluent-E1 (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing Profluent-E1 and assume any risks associated with Your exercise of permissions under this Agreement. FOR THE AVOIDANCE OF DOUBT AND NOTWITHSTANDING ANYTHING TO THE CONTRARY, YOU ACKNOWLEDGE AND AGREE THAT PROFLUENT IS NOT RESPONSIBLE OR LIABLE FOR ANYTHING YOU BUILD, CREATE, DEVELOP OR DERIVE FROM YOUR USE OF PROFLUENT-E1, INCLUDING ANY DERIVATIVE WORKS. 82 | 83 | 9. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use Profluent-E1 (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 84 | 85 | 10. Accepting Warranty or Additional Liability. While redistributing Profluent-E1 or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this Agreement. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 86 | 87 | 11. Term and Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to Profluent-E1 and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Profluent may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of Profluent-E1. All provisions of this Agreement which by their nature should survive termination shall survive termination, including, without limitation, ownership provisions, warranty disclaimers, indemnity obligations, limitations of liability and provisions regarding dispute resolution. 88 | 89 | 12. General. 90 | 91 | 12.1 This Agreement constitutes the entire agreement between You and Profluent relating to the subject matter hereof and supersedes all proposals, understandings, or discussions, whether written or oral, relating to the subject matter of this Agreement and all past dealing or industry custom. The failure of either party to enforce its rights under this Agreement at any time for any period shall not be construed as a waiver of such rights. 92 | 93 | 12.2 Profluent may amend or modify this Agreement from time to time and will use reasonable efforts to provide You with notice of any material changes that may negatively impact Your use of Profluent-E1 through the GitHub Page or through another means made available to You. No other changes, modifications or waivers to this Agreement will be effective unless in writing and signed by both parties. 94 | 95 | 12.3 You and Profluent are independent contractors, and nothing herein shall be deemed to constitute either party as the agent or representative of the other or both parties as joint venturers or partners for any purpose. 96 | 97 | 12.4 You shall comply with the U.S. Foreign Corrupt Practices Act and all applicable export laws, restrictions and regulations of the U.S. Department of Commerce, and any other applicable U.S. and foreign authority. 98 | 99 | 12.5 This Agreement and the rights and obligations herein may not be assigned or transferred, in whole or in part, by You without the prior written consent of Profluent. Any assignment in violation of this provision is void. Profluent may freely assign or transfer this Agreement, in whole or in part. This Agreement shall be binding upon, and inure to the benefit of, the successors and permitted assigns of the parties. 100 | 101 | 12.6 This Agreement shall be governed by and construed under the Federal Arbitration Act, applicable federal law, and the laws of the State of California and the United States without regard to conflicts of laws provisions thereof, and without regard to the Uniform Computer Information Transactions Act. 102 | 103 | 12.7 If You have a Claim, You agree to provide Profluent with at least sixty (60) days' written notice so that the parties may attempt to resolve the Claim internally before requesting arbitration. In case of any Claims or disputes between You and Profluent that cannot be resolved through informal internal discussions, You agree that, subject to the provisions of this section, any and all Claims will be settled by final and binding arbitration in accordance with the AAA Rules by an arbitrator selected pursuant to the AAA Rules; provided, however, that Claims arising out of or relating to Your violations of Profluent's intellectual property rights, including copyright infringement, patent infringement, trademark infringement, or efforts to use Profluent-E1 in unauthorized ways will not be subject to such obligation for settlement by final and binding arbitration, and such claims will instead be brought in the state and federal courts of Alameda County, California. The arbitrator's decision and award will be non-appealable and may be entered in, and will be enforceable in, any court of competent jurisdiction. The arbitration will take place in Alameda County, California. Each party will bear its own costs of arbitration unless the arbitrator directs otherwise. The arbitrator may not award relief or damages in excess of or contrary to what this Agreement provides, or order consolidation or arbitration on a class-wide or representative basis. 104 | 105 | 12.8 If any provision of this Agreement is held to be invalid, illegal or unenforceable in any respect, that provision shall be limited or eliminated to the minimum extent necessary so that this Agreement otherwise remains in full force and effect and enforceable. 106 | -------------------------------------------------------------------------------- /licenses/ankh_license.txt: -------------------------------------------------------------------------------- 1 | License for ANKH models, from their repo https://github.com/agemagician/Ankh?tab=License-1-ov-file 2 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 3 | 4 | Creative Commons Corporation ("Creative Commons") is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an "as-is" basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 5 | 6 | Using Creative Commons Public Licenses 7 | 8 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 9 | 10 | Considerations for licensors: Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. More considerations for licensors : wiki.creativecommons.org/Considerations_for_licensors 11 | 12 | Considerations for the public: By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor's permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More considerations for the public : wiki.creativecommons.org/Considerations_for_licensees 13 | 14 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License 15 | 16 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 17 | 18 | Section 1 – Definitions. 19 | 20 | a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 21 | b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 22 | c. BY-NC-SA Compatible License means a license listed at creativecommons.org/compatiblelicenses, approved by Creative Commons as essentially the equivalent of this Public License. 23 | d. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 24 | e. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 25 | f. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 26 | g. License Elements means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike. 27 | h. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 28 | i. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 29 | j. Licensor means the individual(s) or entity(ies) granting rights under this Public License. 30 | k. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 31 | l. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 32 | m. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 33 | n. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 34 | Section 2 – Scope. 35 | 36 | a. License grant. 37 | Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 38 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 39 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 40 | Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 41 | Term. The term of this Public License is specified in Section 6(a). 42 | Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 43 | Downstream recipients. 44 | A. Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 45 | B. Additional offer from the Licensor – Adapted Material. Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter's License You apply. 46 | C. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 47 | No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 48 | b. Other rights. 49 | Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 50 | Patent and trademark rights are not licensed under this Public License. 51 | To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 52 | Section 3 – License Conditions. 53 | 54 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 55 | 56 | a. Attribution. 57 | If You Share the Licensed Material (including in modified form), You must: 58 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 59 | 60 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 61 | ii. a copyright notice; 62 | iii. a notice that refers to this Public License; 63 | iv. a notice that refers to the disclaimer of warranties; 64 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 65 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 66 | 67 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 68 | 69 | You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 70 | If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 71 | b. ShareAlike.In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 72 | The Adapter's License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License. 73 | You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 74 | You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. 75 | Section 4 – Sui Generis Database Rights. 76 | 77 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 78 | 79 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 80 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and 81 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 82 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 83 | Section 5 – Disclaimer of Warranties and Limitation of Liability. 84 | 85 | a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You. 86 | b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You. 87 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 88 | Section 6 – Term and Termination. 89 | 90 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 91 | 92 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 93 | 94 | automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 95 | upon express reinstatement by the Licensor. 96 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 97 | 98 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 99 | 100 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 101 | 102 | Section 7 – Other Terms and Conditions. 103 | 104 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 105 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 106 | Section 8 – Interpretation. 107 | 108 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 109 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 110 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 111 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 112 | Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the "Licensor." The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark "Creative Commons" or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 113 | 114 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /wip/t5/test_t5_flex_attention.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import numpy as np 4 | import time 5 | import matplotlib.pyplot as plt 6 | from transformers import T5Config, T5EncoderModel 7 | from .t5_flex_attention import T5FlexAttention, replace_t5_attention_with_flex 8 | from .t5_attention import T5AttentionTransformers 9 | 10 | 11 | def test_attention_layer_equivalence( 12 | batch_size=2, 13 | seq_length=16, 14 | d_model=32, 15 | num_heads=4, 16 | d_kv=8, 17 | seed=42, 18 | device=None, 19 | measure_time=False, 20 | num_warmup=10, 21 | num_timing_runs=10, 22 | compile_flex=False, 23 | ): 24 | """ 25 | Test the equivalence between T5Attention and T5FlexAttention layers. 26 | 27 | Args: 28 | batch_size: Batch size for test inputs 29 | seq_length: Sequence length for test inputs 30 | d_model: Model dimension 31 | num_heads: Number of attention heads 32 | d_kv: Dimension of key and value vectors 33 | seed: Random seed for reproducibility 34 | device: Device to run the test on 35 | measure_time: Whether to measure execution time 36 | num_warmup: Number of warmup runs before timing 37 | num_timing_runs: Number of runs to average timing over 38 | compile_flex: Whether to compile the flex attention module 39 | 40 | Returns: 41 | max_diff: Maximum absolute difference between outputs 42 | mean_diff: Mean absolute difference between outputs 43 | std_time: Average execution time for standard attention (if measure_time=True) 44 | flex_time: Average execution time for flex attention (if measure_time=True) 45 | """ 46 | print(f"\n{'='*20} Testing Attention Layer Equivalence {'='*20}") 47 | print(f"Compile flex: {compile_flex}") 48 | 49 | # Set random seed for reproducibility 50 | torch.manual_seed(seed) 51 | np.random.seed(seed) 52 | 53 | # Create a simple T5 config 54 | config = T5Config( 55 | d_model=d_model, 56 | d_kv=d_kv, 57 | num_heads=num_heads, 58 | is_decoder=False, 59 | use_cache=False, 60 | ) 61 | 62 | # Create standard T5Attention and T5FlexAttention layers 63 | std_attention = T5AttentionTransformers(config, has_relative_attention_bias=True).to(device) 64 | flex_attention = T5FlexAttention(config, has_relative_attention_bias=True, compile_flex=compile_flex).to(device) 65 | 66 | # Copy weights from standard attention to flex attention 67 | flex_attention.q.weight.data = std_attention.q.weight.data.clone() 68 | flex_attention.k.weight.data = std_attention.k.weight.data.clone() 69 | flex_attention.v.weight.data = std_attention.v.weight.data.clone() 70 | flex_attention.o.weight.data = std_attention.o.weight.data.clone() 71 | flex_attention.relative_attention_bias.weight.data = std_attention.relative_attention_bias.weight.data.clone() 72 | 73 | # Create random input 74 | hidden_states = torch.randn(batch_size, seq_length, d_model).to(device) 75 | 76 | # Set both models to eval mode 77 | std_attention.eval() 78 | flex_attention.eval() 79 | 80 | # Timing measurements 81 | std_time = 0 82 | flex_time = 0 83 | 84 | if measure_time: 85 | # Warmup runs 86 | print(f"Performing {num_warmup} warmup runs...") 87 | for _ in range(num_warmup): 88 | with torch.no_grad(): 89 | _ = std_attention(hidden_states)[0] 90 | _ = flex_attention(hidden_states)[0] 91 | 92 | # Timing runs for standard attention 93 | print(f"Measuring standard attention over {num_timing_runs} runs...") 94 | torch.cuda.synchronize() if device.type == 'cuda' else None 95 | start_time = time.time() 96 | for _ in range(num_timing_runs): 97 | with torch.no_grad(): 98 | _ = std_attention(hidden_states)[0] 99 | torch.cuda.synchronize() if device.type == 'cuda' else None 100 | std_time = (time.time() - start_time) / num_timing_runs 101 | 102 | # Timing runs for flex attention 103 | print(f"Measuring flex attention over {num_timing_runs} runs...") 104 | torch.cuda.synchronize() if device.type == 'cuda' else None 105 | start_time = time.time() 106 | for _ in range(num_timing_runs): 107 | with torch.no_grad(): 108 | _ = flex_attention(hidden_states)[0] 109 | torch.cuda.synchronize() if device.type == 'cuda' else None 110 | flex_time = (time.time() - start_time) / num_timing_runs 111 | 112 | print(f"Standard attention average time: {std_time*1000:.4f} ms") 113 | print(f"Flex attention average time: {flex_time*1000:.4f} ms") 114 | print(f"Speedup: {std_time/flex_time:.2f}x") 115 | 116 | # Forward pass for correctness check 117 | with torch.no_grad(): 118 | std_output = std_attention(hidden_states)[0] 119 | flex_output = flex_attention(hidden_states)[0] 120 | 121 | # Calculate differences 122 | abs_diff = torch.abs(std_output - flex_output) 123 | max_diff = torch.max(abs_diff).item() 124 | mean_diff = torch.mean(abs_diff).item() 125 | 126 | print(f"Max absolute difference: {max_diff:.8f}") 127 | print(f"Mean absolute difference: {mean_diff:.8f}") 128 | 129 | if measure_time: 130 | return max_diff, mean_diff, std_time, flex_time 131 | else: 132 | return max_diff, mean_diff 133 | 134 | 135 | def test_model_equivalence( 136 | model_name, 137 | batch_size=2, 138 | seq_length=16, 139 | seed=42, 140 | device=None, 141 | measure_time=False, 142 | num_warmup=10, 143 | num_timing_runs=10, 144 | compile_flex=False, 145 | ): 146 | """ 147 | Test the equivalence between a standard T5EncoderModel and one with flex attention. 148 | 149 | Args: 150 | model_name: Name or path of the pretrained model 151 | batch_size: Batch size for test inputs 152 | seq_length: Sequence length for test inputs 153 | seed: Random seed for reproducibility 154 | device: Device to run the test on 155 | measure_time: Whether to measure execution time 156 | num_warmup: Number of warmup runs before timing 157 | num_timing_runs: Number of runs to average timing over 158 | compile_flex: Whether to compile the flex attention module 159 | Returns: 160 | max_diff: Maximum absolute difference between outputs 161 | mean_diff: Mean absolute difference between outputs 162 | std_time: Average execution time for standard model (if measure_time=True) 163 | flex_time: Average execution time for flex model (if measure_time=True) 164 | """ 165 | print(f"\n{'='*20} Testing Model Equivalence {'='*20}") 166 | print(f"Using pretrained model: {model_name}") 167 | print(f"Compile flex: {compile_flex}") 168 | 169 | # Set random seed for reproducibility 170 | torch.manual_seed(seed) 171 | np.random.seed(seed) 172 | 173 | # Load the pretrained model 174 | std_model = T5EncoderModel.from_pretrained(model_name).to(device) 175 | 176 | # Create a copy with flex attention 177 | flex_model = T5EncoderModel.from_pretrained(model_name).to(device) 178 | flex_model = replace_t5_attention_with_flex(flex_model, compile_flex=compile_flex) 179 | 180 | # Create random input IDs 181 | input_ids = torch.randint(0, std_model.config.vocab_size, (batch_size, seq_length)).to(device) 182 | attention_mask = torch.ones_like(input_ids).to(device) 183 | 184 | # Set both models to eval mode 185 | std_model.eval() 186 | flex_model.eval() 187 | 188 | # Timing measurements 189 | std_time = 0 190 | flex_time = 0 191 | 192 | if measure_time: 193 | # Warmup runs 194 | print(f"Performing {num_warmup} warmup runs...") 195 | for _ in range(num_warmup): 196 | with torch.no_grad(): 197 | _ = std_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state 198 | _ = flex_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state 199 | 200 | # Timing runs for standard model 201 | print(f"Measuring standard model over {num_timing_runs} runs...") 202 | torch.cuda.synchronize() if device.type == 'cuda' else None 203 | start_time = time.time() 204 | for _ in range(num_timing_runs): 205 | with torch.no_grad(): 206 | _ = std_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state 207 | torch.cuda.synchronize() if device.type == 'cuda' else None 208 | std_time = (time.time() - start_time) / num_timing_runs 209 | 210 | # Timing runs for flex model 211 | print(f"Measuring flex model over {num_timing_runs} runs...") 212 | torch.cuda.synchronize() if device.type == 'cuda' else None 213 | start_time = time.time() 214 | for _ in range(num_timing_runs): 215 | with torch.no_grad(): 216 | _ = flex_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state 217 | torch.cuda.synchronize() if device.type == 'cuda' else None 218 | flex_time = (time.time() - start_time) / num_timing_runs 219 | 220 | print(f"Standard model average time: {std_time*1000:.4f} ms") 221 | print(f"Flex model average time: {flex_time*1000:.4f} ms") 222 | print(f"Speedup: {std_time/flex_time:.2f}x") 223 | 224 | # Forward pass for correctness check 225 | with torch.no_grad(): 226 | std_output = std_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state 227 | flex_output = flex_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state 228 | 229 | # Calculate differences 230 | abs_diff = torch.abs(std_output - flex_output) 231 | max_diff = torch.max(abs_diff).item() 232 | mean_diff = torch.mean(abs_diff).item() 233 | 234 | print(f"Max absolute difference in last_hidden_state: {max_diff:.8f}") 235 | print(f"Mean absolute difference in last_hidden_state: {mean_diff:.8f}") 236 | 237 | if measure_time: 238 | return max_diff, mean_diff, std_time, flex_time 239 | else: 240 | return max_diff, mean_diff 241 | 242 | 243 | if __name__ == "__main__": 244 | # py -m wip.t5.test_t5_flex_attention 245 | # py -m wip.t5.test_t5_flex_attention --measure_time --seq_length_range --num_timing_runs 100 246 | parser = argparse.ArgumentParser(description="Test T5 Flex Attention equivalence") 247 | parser.add_argument("--model_name", type=str, default="Synthyra/ANKH_base", 248 | help="Pretrained model name or path") 249 | parser.add_argument("--batch_size", type=int, default=2, 250 | help="Batch size for test inputs (default: 2)") 251 | parser.add_argument("--seq_length", type=int, default=16, 252 | help="Sequence length for test inputs (default: 16)") 253 | parser.add_argument("--seq_length_range", action="store_true", 254 | help="Test a range of sequence lengths from 8 to 2048") 255 | parser.add_argument("--seed", type=int, default=42, 256 | help="Random seed for reproducibility (default: 42)") 257 | parser.add_argument("--tolerance", type=float, default=1e-5, 258 | help="Tolerance for differences (default: 1e-5)") 259 | parser.add_argument("--measure_time", action="store_true", 260 | help="Measure execution time") 261 | parser.add_argument("--num_warmup", type=int, default=10, 262 | help="Number of warmup runs before timing (default: 10)") 263 | parser.add_argument("--num_timing_runs", type=int, default=10, 264 | help="Number of runs to average timing over (default: 10)") 265 | parser.add_argument("--compile_flex", action="store_true", 266 | help="Compile flex attention") 267 | args = parser.parse_args() 268 | 269 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 270 | print(f"Running tests on device: {device}") 271 | 272 | if args.seq_length_range and args.measure_time: 273 | # Test a range of sequence lengths 274 | seq_lengths = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] 275 | std_times = [] 276 | flex_times = [] 277 | max_diffs = [] 278 | 279 | print(f"\n{'='*20} Testing Sequence Length Range {'='*20}") 280 | 281 | for seq_length in seq_lengths: 282 | print(f"\nTesting sequence length: {seq_length}") 283 | 284 | # Test attention layer equivalence 285 | layer_max_diff, layer_mean_diff, layer_std_time, layer_flex_time = test_attention_layer_equivalence( 286 | batch_size=args.batch_size, 287 | seq_length=seq_length, 288 | seed=args.seed, 289 | device=device, 290 | measure_time=True, 291 | num_warmup=args.num_warmup, 292 | num_timing_runs=args.num_timing_runs, 293 | compile_flex=args.compile_flex 294 | ) 295 | 296 | std_times.append(layer_std_time * 1000) # Convert to ms 297 | flex_times.append(layer_flex_time * 1000) # Convert to ms 298 | max_diffs.append(layer_max_diff) 299 | 300 | # Create plot for timing results 301 | plt.figure(figsize=(12, 8)) 302 | 303 | # Plot timing results 304 | plt.subplot(2, 1, 1) 305 | plt.plot(seq_lengths, std_times, 'o-', label='Standard Attention') 306 | plt.plot(seq_lengths, flex_times, 'o-', label='Flex Attention') 307 | plt.xscale('log', base=2) 308 | plt.yscale('log') 309 | plt.xlabel('Sequence Length') 310 | plt.ylabel('Time (ms)') 311 | plt.title('Attention Layer Execution Time vs Sequence Length') 312 | plt.grid(True, which="both", ls="--") 313 | plt.legend() 314 | 315 | # Plot speedup 316 | plt.subplot(2, 1, 2) 317 | speedups = [std/flex for std, flex in zip(std_times, flex_times)] 318 | plt.plot(seq_lengths, speedups, 'o-', color='green') 319 | plt.xscale('log', base=2) 320 | plt.xlabel('Sequence Length') 321 | plt.ylabel('Speedup (x)') 322 | plt.title('Flex Attention Speedup vs Sequence Length') 323 | plt.grid(True, which="both", ls="--") 324 | 325 | plt.tight_layout() 326 | plt.savefig('sequence_length_timing.png') 327 | print(f"\nPlot saved to sequence_length_timing.png") 328 | 329 | # Print results in table format 330 | print(f"\n{'='*60}") 331 | print(f"{'Sequence Length':^15} | {'Standard (ms)':^15} | {'Flex (ms)':^15} | {'Speedup':^10}") 332 | print(f"{'-'*60}") 333 | for i, seq_len in enumerate(seq_lengths): 334 | print(f"{seq_len:^15} | {std_times[i]:^15.4f} | {flex_times[i]:^15.4f} | {speedups[i]:^10.2f}") 335 | print(f"{'='*60}") 336 | 337 | else: 338 | # Test attention layer equivalence 339 | if args.measure_time: 340 | layer_max_diff, layer_mean_diff, layer_std_time, layer_flex_time = test_attention_layer_equivalence( 341 | batch_size=args.batch_size, 342 | seq_length=args.seq_length, 343 | seed=args.seed, 344 | device=device, 345 | measure_time=True, 346 | num_warmup=args.num_warmup, 347 | num_timing_runs=args.num_timing_runs, 348 | compile_flex=args.compile_flex 349 | ) 350 | else: 351 | layer_max_diff, layer_mean_diff = test_attention_layer_equivalence( 352 | batch_size=args.batch_size, 353 | seq_length=args.seq_length, 354 | seed=args.seed, 355 | device=device, 356 | compile_flex=args.compile_flex 357 | ) 358 | 359 | # Test model equivalence 360 | if args.measure_time: 361 | model_max_diff, model_mean_diff, model_std_time, model_flex_time = test_model_equivalence( 362 | model_name=args.model_name, 363 | batch_size=args.batch_size, 364 | seq_length=args.seq_length, 365 | seed=args.seed, 366 | device=device, 367 | measure_time=True, 368 | num_warmup=args.num_warmup, 369 | num_timing_runs=args.num_timing_runs, 370 | compile_flex=args.compile_flex 371 | ) 372 | else: 373 | model_max_diff, model_mean_diff = test_model_equivalence( 374 | model_name=args.model_name, 375 | batch_size=args.batch_size, 376 | seq_length=args.seq_length, 377 | seed=args.seed, 378 | device=device, 379 | compile_flex=args.compile_flex 380 | ) 381 | 382 | # Check if differences are within tolerance 383 | print(f"\n{'='*20} Results {'='*20}") 384 | print(f"Tolerance threshold: {args.tolerance}") 385 | 386 | if layer_max_diff <= args.tolerance: 387 | print(f"✅ Attention layer test PASSED: Max diff {layer_max_diff:.8f} <= {args.tolerance}") 388 | else: 389 | print(f"❌ Attention layer test FAILED: Max diff {layer_max_diff:.8f} > {args.tolerance}") 390 | 391 | if model_max_diff <= args.tolerance: 392 | print(f"✅ Model test PASSED: Max diff {model_max_diff:.8f} <= {args.tolerance}") 393 | else: 394 | print(f"❌ Model test FAILED: Max diff {model_max_diff:.8f} > {args.tolerance}") 395 | 396 | # Print timing summary if measured 397 | if args.measure_time: 398 | print(f"\n{'='*20} Timing Summary {'='*20}") 399 | print(f"Attention Layer (compile_flex={args.compile_flex}):") 400 | print(f" Standard: {layer_std_time*1000:.4f} ms") 401 | print(f" Flex: {layer_flex_time*1000:.4f} ms") 402 | print(f" Speedup: {layer_std_time/layer_flex_time:.2f}x") 403 | 404 | print(f"\nFull Model (compile_flex={args.compile_flex}):") 405 | print(f" Standard: {model_std_time*1000:.4f} ms") 406 | print(f" Flex: {model_flex_time*1000:.4f} ms") 407 | print(f" Speedup: {model_std_time/model_flex_time:.2f}x") --------------------------------------------------------------------------------