├── .gitignore
├── requirements.txt
├── e1
├── get_e1_weights.py
└── tokenizer.json
├── licenses
├── fastesm_license.txt
├── e1_license.txt
└── ankh_license.txt
├── esm2
└── get_esm2_weights.py
├── exp.ipynb
├── wip
└── t5
│ ├── transformers_flex_attn.py
│ ├── t5_attention.py
│ ├── t5_flex_attention.py
│ └── test_t5_flex_attention.py
├── test_scripts
├── test_precision_difference.py
├── test_difference_esmfast.py
├── test_attentions.py
├── test_embedding.py
├── test_compliance_esmc.py
├── test_compliance_esm2.py
├── test_throughput_esmfast.py
├── test_contact_maps.py
└── test_throughput.py
├── esm_plusplus
└── get_esmc_weights.py
├── README.md
├── update_HF.py
├── readmes
├── fastesm2_readme.md
├── e1_readme.md
├── fastesm_650_readme.md
├── esm_plusplus_large_readme.md
└── esm_plusplus_small_readme.md
├── pooler.py
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | *.pyc
3 | *.png
4 | *.pth
5 | *.pt
6 | *.safetensors
7 | *.json
8 | !e1/*.json
9 | *.bin
10 | /results_classification_lora
11 | /results_regression_lora
12 | .qodo
13 | *.db
14 | /esm
15 | *.nbc
16 | *.nbi
17 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.6.0
2 | matplotlib>=3.5.0
3 | transformers>=4.47.0
4 | numpy>=1.26.2
5 | einops
6 | esm
7 | datasets>=2.14.0
8 | scikit-learn>=1.0.0
9 | scipy>=1.7.0
10 | seaborn>=0.12.0
11 | peft>=0.5.0
12 | accelerate>=1.1.0
13 | networkx
--------------------------------------------------------------------------------
/e1/get_e1_weights.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from huggingface_hub import login
3 |
4 | from e1.modeling_e1 import E1ForMaskedLM, E1Config
5 |
6 |
7 | model_dict = {
8 | 'Profluent-E1-150M': 'Profluent-Bio/E1-150m',
9 | 'Profluent-E1-300M': 'Profluent-Bio/E1-300m',
10 | 'Profluent-E1-600M': 'Profluent-Bio/E1-600m',
11 | }
12 |
13 |
14 | if __name__ == "__main__":
15 | # py -m e1.get_e1_weights
16 | login()
17 | for model_name in model_dict:
18 | config = E1Config.from_pretrained(model_dict[model_name])
19 | config.auto_map = {
20 | "AutoConfig": "modeling_e1.E1Config",
21 | "AutoModel": "modeling_e1.E1Model",
22 | "AutoModelForMaskedLM": "modeling_e1.E1ForMaskedLM",
23 | "AutoModelForSequenceClassification": "modeling_e1.E1ForSequenceClassification",
24 | "AutoModelForTokenClassification": "modeling_e1.E1ForTokenClassification"
25 | }
26 | model = E1ForMaskedLM.from_pretrained(model_dict[model_name], dtype=torch.bfloat16)
27 | model.push_to_hub('Synthyra/' + model_name)
28 |
--------------------------------------------------------------------------------
/licenses/fastesm_license.txt:
--------------------------------------------------------------------------------
1 | License for FastESM models, from the ESM2 repo https://github.com/facebookresearch/esm
2 |
3 | MIT License
4 |
5 | Copyright (c) Meta Platforms, Inc. and affiliates.
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 |
--------------------------------------------------------------------------------
/esm2/get_esm2_weights.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from huggingface_hub import login
4 | from transformers import EsmForMaskedLM
5 | from modeling_fastesm import FastEsmForMaskedLM, FastEsmConfig
6 |
7 |
8 | model_dict = {
9 | # Synthyra/ESM2-8M
10 | 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
11 | # Synthyra/ESM2-35M
12 | 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
13 | # Synthyra/ESM2-150M
14 | 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
15 | # Synthyra/ESM2-650M
16 | 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
17 | # Synthyra/ESM2-3B
18 | 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
19 | }
20 |
21 |
22 | if __name__ == "__main__":
23 | #login()
24 | for model_name in model_dict:
25 | config = FastEsmConfig.from_pretrained(model_dict[model_name])
26 | config.auto_map = {
27 | "AutoConfig": "modeling_fastesm.FastEsmConfig",
28 | "AutoModel": "modeling_fastesm.FastEsmModel",
29 | "AutoModelForMaskedLM": "modeling_fastesm.FastEsmForMaskedLM",
30 | "AutoModelForSequenceClassification": "modeling_fastesm.FastEsmForSequenceClassification",
31 | "AutoModelForTokenClassification": "modeling_fastesm.FastEsmForTokenClassification"
32 | }
33 | config.tie_word_embeddings = False
34 | original_model = EsmForMaskedLM.from_pretrained(model_dict[model_name])
35 | model = FastEsmForMaskedLM(config=config).from_pretrained(model_dict[model_name], config=config)
36 | model.lm_head.load_state_dict(original_model.lm_head.state_dict())
37 | model.lm_head = copy.deepcopy(model.lm_head)
38 | model.push_to_hub('Synthyra/' + model_name)
39 |
40 | for name1, param1 in model.named_parameters():
41 | for name2, param2 in original_model.named_parameters():
42 | if name1 == name2:
43 | assert param1.shape == param2.shape, f'{name1} {param1.shape} != {name2} {param2.shape}'
44 | assert torch.equal(param1.data.clone(), param2.data.clone()), f'{name1} {name2}'
45 |
46 |
--------------------------------------------------------------------------------
/exp.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from modeling_esm_plusplus import ESMplusplusModel"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 3,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "torch.Size([2, 11, 1152])\n"
22 | ]
23 | },
24 | {
25 | "data": {
26 | "text/plain": [
27 | "36"
28 | ]
29 | },
30 | "execution_count": 3,
31 | "metadata": {},
32 | "output_type": "execute_result"
33 | }
34 | ],
35 | "source": [
36 | "model = ESMplusplusModel.from_pretrained('Synthyra/ESMplusplus_large')\n",
37 | "tokenizer = model.tokenizer\n",
38 | "\n",
39 | "sequences = ['MPRTEIN', 'MSEQWENCE']\n",
40 | "tokenized = tokenizer(sequences, padding=True, return_tensors='pt')\n",
41 | "\n",
42 | "# tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training\n",
43 | "\n",
44 | "output = model(**tokenized) # get all hidden states with output_hidden_states=True\n",
45 | "print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 960)\n",
46 | "#print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)\n",
47 | "\n",
48 | "output = model(**tokenized, output_attentions=True)\n",
49 | "att = output.attentions\n",
50 | "len(att) # 30, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each"
51 | ]
52 | }
53 | ],
54 | "metadata": {
55 | "kernelspec": {
56 | "display_name": "Python 3",
57 | "language": "python",
58 | "name": "python3"
59 | },
60 | "language_info": {
61 | "codemirror_mode": {
62 | "name": "ipython",
63 | "version": 3
64 | },
65 | "file_extension": ".py",
66 | "mimetype": "text/x-python",
67 | "name": "python",
68 | "nbconvert_exporter": "python",
69 | "pygments_lexer": "ipython3",
70 | "version": "3.11.8"
71 | }
72 | },
73 | "nbformat": 4,
74 | "nbformat_minor": 2
75 | }
76 |
--------------------------------------------------------------------------------
/wip/t5/transformers_flex_attn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional, Tuple
3 | from torch.nn.attention.flex_attention import flex_attention
4 |
5 |
6 | def flex_attention_forward(
7 | query: torch.Tensor,
8 | key: torch.Tensor,
9 | value: torch.Tensor,
10 | attention_mask: Optional[torch.Tensor],
11 | scaling: Optional[float] = None,
12 | softcap: Optional[float] = None,
13 | head_mask: Optional[torch.Tensor] = None,
14 | **kwargs,
15 | ) -> Tuple[torch.Tensor, torch.Tensor]:
16 | causal_mask = attention_mask
17 | if causal_mask is not None:
18 | causal_mask = causal_mask[:, :, :, : key.shape[-2]]
19 |
20 | def causal_mod(score, b, h, q_idx, kv_idx):
21 | if softcap is not None:
22 | score = softcap * torch.tanh(score / softcap)
23 | if causal_mask is not None:
24 | score = score + causal_mask[b][0][q_idx][kv_idx]
25 | if head_mask is not None:
26 | score = score + head_mask[b][h][0][0]
27 | return score
28 |
29 | attn_output, attention_weights = flex_attention(
30 | query,
31 | key,
32 | value,
33 | score_mod=causal_mod,
34 | enable_gqa=True,
35 | scale=scaling,
36 | # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
37 | # For simplification, we thus always return it as no additional computations are introduced.
38 | return_lse=True,
39 | )
40 | # lse is returned in float32
41 | attention_weights = attention_weights.to(value.dtype)
42 | attn_output = attn_output.transpose(1, 2).contiguous()
43 |
44 | return attn_output, attention_weights
45 |
46 |
47 | if __name__ == "__main__":
48 | # py -m wip.t5.transformers_flex_attn
49 | batch_size, seq_len, n_heads, head_dim = 1, 10, 12, 64
50 |
51 | query = torch.randn(batch_size, n_heads, seq_len, head_dim)
52 | key = torch.randn(batch_size, n_heads, seq_len, head_dim)
53 | value = torch.randn(batch_size, n_heads, seq_len, head_dim)
54 | attention_mask = torch.randint(0, 2, (batch_size, seq_len))
55 | scaling = 1.0
56 | softcap = 2.0
57 | head_mask = None
58 |
59 | attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
60 |
61 | attn_output, attention_weights = flex_attention_forward(
62 | query,
63 | key,
64 | value,
65 | attention_mask,
66 | scaling,
67 | softcap,
68 | head_mask,
69 | )
70 | print(attn_output.shape)
71 | print(attention_weights.shape)
72 |
--------------------------------------------------------------------------------
/test_scripts/test_precision_difference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import random
4 | import argparse
5 | from huggingface_hub import login
6 | from transformers import AutoModelForMaskedLM
7 | from tqdm.auto import tqdm
8 |
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--model_path', type=str, default='Synthyra/ESMplusplus_small')
12 | parser.add_argument('--token', type=str, default=None)
13 | args = parser.parse_args()
14 |
15 | if args.token:
16 | login(args.token)
17 |
18 | model_path = args.model_path
19 | canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
20 | length = 128
21 | seq_count = 1000
22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23 |
24 |
25 | def generate_random_sequence(length: int) -> str:
26 | return 'M' + "".join(random.choices(canonical_amino_acids, k=length-3))
27 |
28 |
29 | # Generate sequences first
30 | sequences = [generate_random_sequence(length) for _ in range(seq_count)]
31 |
32 |
33 | # Get base model outputs
34 | base_outputs = []
35 | model = AutoModelForMaskedLM.from_pretrained(model_path, trust_remote_code=True).to(device)
36 | tokenizer = model.tokenizer
37 | with torch.no_grad():
38 | for seq in tqdm(sequences):
39 | input = tokenizer(seq, return_tensors="pt").to(device)
40 | embeddings = model(**input).last_hidden_state.cpu()
41 | base_outputs.append(embeddings)
42 | model.cpu()
43 | del model
44 | torch.cuda.empty_cache()
45 |
46 |
47 | # Get fp16 outputs
48 | fp16_mse = 0
49 | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).to(device)
50 | with torch.no_grad():
51 | for i, seq in tqdm(enumerate(sequences), total=len(sequences)):
52 | input = tokenizer(seq, return_tensors="pt").to(device)
53 | fp16_output = model(**input).last_hidden_state.float().cpu()
54 | fp16_mse += F.mse_loss(base_outputs[i], fp16_output).item()
55 | model.cpu()
56 | del model
57 | torch.cuda.empty_cache()
58 |
59 |
60 | # Get bfloat16 outputs
61 | bf16_mse = 0
62 | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.bfloat16, trust_remote_code=True).to(device)
63 | with torch.no_grad():
64 | for i, seq in tqdm(enumerate(sequences), total=len(sequences)):
65 | input = tokenizer(seq, return_tensors="pt").to(device)
66 | bf16_output = model(**input).last_hidden_state.float().cpu()
67 | bf16_mse += F.mse_loss(base_outputs[i], bf16_output).item()
68 | model.cpu()
69 | del model
70 | torch.cuda.empty_cache()
71 |
72 | fp16_mse /= seq_count
73 | bf16_mse /= seq_count
74 |
75 | print(f"Average MSE for FP16: {fp16_mse:.8f}")
76 | print(f"Average MSE for BF16: {bf16_mse:.8f}")
77 | print(f"{'FP16' if fp16_mse < bf16_mse else 'BF16'} has lower MSE")
78 |
--------------------------------------------------------------------------------
/test_scripts/test_difference_esmfast.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import random
4 | import argparse
5 | from huggingface_hub import login
6 | from transformers import AutoModelForMaskedLM, EsmTokenizer
7 | from tqdm.auto import tqdm
8 |
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--model_path', type=str, default='lhallee/synthyra_esm2_650_mlm')
12 | parser.add_argument('--token', type=str, default=None)
13 | args = parser.parse_args()
14 |
15 | if args.token:
16 | login(args.token)
17 |
18 | model_path = args.model_path
19 | canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
20 | length = 128
21 | seq_count = 1000
22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23 |
24 |
25 | def generate_random_sequence(length: int) -> str:
26 | return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
27 |
28 |
29 | tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
30 | # Generate sequences first
31 | sequences = [generate_random_sequence(length) for _ in range(seq_count)]
32 | inputs = [tokenizer(seq, return_tensors="pt").to(device) for seq in sequences]
33 |
34 |
35 | # Get base model outputs
36 | base_outputs = []
37 | model = AutoModelForMaskedLM.from_pretrained(model_path, trust_remote_code=True).to(device)
38 | with torch.no_grad():
39 | for input_batch in tqdm(inputs):
40 | base_outputs.append(model(**input_batch, output_hidden_states=True).hidden_states[-1].cpu())
41 | model.cpu()
42 | del model
43 | torch.cuda.empty_cache()
44 |
45 |
46 | # Get fp16 outputs
47 | fp16_mse = 0
48 | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).to(device)
49 | with torch.no_grad():
50 | for i, input_batch in tqdm(enumerate(inputs), total=len(inputs)):
51 | fp16_output = model(**input_batch, output_hidden_states=True).hidden_states[-1].float().cpu()
52 | fp16_mse += F.mse_loss(base_outputs[i], fp16_output).item()
53 | model.cpu()
54 | del model
55 | torch.cuda.empty_cache()
56 |
57 |
58 | # Get bfloat16 outputs
59 | bf16_mse = 0
60 | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.bfloat16, trust_remote_code=True).to(device)
61 | with torch.no_grad():
62 | for i, input_batch in tqdm(enumerate(inputs), total=len(inputs)):
63 | bf16_output = model(**input_batch, output_hidden_states=True).hidden_states[-1].float().cpu()
64 | bf16_mse += F.mse_loss(base_outputs[i], bf16_output).item()
65 | model.cpu()
66 | del model
67 | torch.cuda.empty_cache()
68 |
69 | fp16_mse /= seq_count
70 | bf16_mse /= seq_count
71 |
72 | print(f"Average MSE for FP16: {fp16_mse:.8f}")
73 | print(f"Average MSE for BF16: {bf16_mse:.8f}")
74 | print(f"{'FP16' if fp16_mse < bf16_mse else 'BF16'} has lower MSE")
75 |
--------------------------------------------------------------------------------
/esm_plusplus/get_esmc_weights.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import copy
4 | from functools import cache
5 | from pathlib import Path
6 | from huggingface_hub import snapshot_download, login
7 |
8 | from esm_plusplus.modeling_esm_plusplus import ESMplusplusForMaskedLM, ESMplusplusConfig
9 |
10 |
11 | @staticmethod
12 | @cache
13 | def data_root(model: str):
14 | if "INFRA_PROVIDER" in os.environ:
15 | return Path("")
16 | # Try to download from hugginface if it doesn't exist
17 | if model.startswith("esmc-300"):
18 | path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
19 | elif model.startswith("esmc-600"):
20 | path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
21 | else:
22 | raise ValueError(f"{model=} is an invalid model name.")
23 | return path
24 |
25 |
26 | def ESMplusplus_300M(device: torch.device | str = "cpu"):
27 | with torch.device(device):
28 | config = ESMplusplusConfig(
29 | hidden_size=960,
30 | num_attention_heads=15,
31 | num_hidden_layers=30,
32 | )
33 | model = ESMplusplusForMaskedLM(config)
34 | state_dict = torch.load(
35 | data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
36 | map_location=device,
37 | )
38 | model.load_state_dict(state_dict)
39 | return model
40 |
41 |
42 | def ESMplusplus_600M(device: torch.device | str = "cpu"):
43 | with torch.device(device):
44 | config = ESMplusplusConfig(
45 | hidden_size=1152,
46 | num_attention_heads=18,
47 | num_hidden_layers=36,
48 | )
49 | model = ESMplusplusForMaskedLM(config)
50 | state_dict = torch.load(
51 | data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
52 | map_location=device,
53 | )
54 | model.load_state_dict(state_dict)
55 | return model
56 |
57 |
58 | if __name__ == "__main__":
59 | #login()
60 |
61 | model_dict = {
62 | # Synthyra/ESM++small
63 | 'Synthyra/ESMplusplus_small': ESMplusplus_300M,
64 | # Synthyra/ESM++large
65 | 'Synthyra/ESMplusplus_large': ESMplusplus_600M,
66 | }
67 |
68 |
69 | for model_path, model_fn in model_dict.items():
70 | model = model_fn()
71 | model.sequence_head = copy.deepcopy(model.sequence_head)
72 | model.config.auto_map = {
73 | "AutoConfig": "modeling_esm_plusplus.ESMplusplusConfig",
74 | "AutoModel": "modeling_esm_plusplus.ESMplusplusModel",
75 | "AutoModelForMaskedLM": "modeling_esm_plusplus.ESMplusplusForMaskedLM",
76 | "AutoModelForSequenceClassification": "modeling_esm_plusplus.ESMplusplusForSequenceClassification",
77 | "AutoModelForTokenClassification": "modeling_esm_plusplus.ESMplusplusForTokenClassification"
78 | }
79 | model.push_to_hub(model_path)
80 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FastPLMs
2 |
3 |
4 |
5 | FastPLMs is an open-source effort to increase the efficiency of pretrained protein language models, switching out native attention implementations for Flash or Flex attention.
6 |
7 | All models can be loaded from Huggingface 🤗 transformers via `AutoModel`, this repository does not need to be cloned for most use cases.
8 |
9 | ## Supported models
10 | The currently supported models can be found [here](https://huggingface.co/collections/Synthyra/pretrained-plms-675351ecc050f63baedd77de).
11 |
12 | ## Suggestions
13 | Have suggestions, comments, or requests? Found a bug? Open a GitHub issue and we'll respond soon.
14 |
15 | ## Embed entire datasets with no new code
16 | To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
17 |
18 | Example:
19 | ```python
20 | embedding_dict = model.embed_dataset(
21 | sequences=[
22 | 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
23 | ],
24 | batch_size=2, # adjust for your GPU memory
25 | max_len=512, # adjust for your needs
26 | full_embeddings=False, # if True, no pooling is performed
27 | embed_dtype=torch.float32, # cast to what dtype you want
28 | pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together
29 | num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
30 | sql=False, # if True, embeddings will be stored in SQLite database
31 | sql_db_path='embeddings.db',
32 | save=True, # if True, embeddings will be saved as a .pth file
33 | save_path='embeddings.pth',
34 | )
35 | # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
36 | ```
37 |
38 | ```
39 | model.embed_dataset()
40 | Args:
41 | sequences: List of protein sequences
42 | batch_size: Batch size for processing
43 | max_len: Maximum sequence length
44 | full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
45 | pooling_type: Type of pooling ('mean' or 'cls')
46 | num_workers: Number of workers for data loading, 0 for the main process
47 | sql: Whether to store embeddings in SQLite database - will be stored in float32
48 | sql_db_path: Path to SQLite database
49 |
50 | Returns:
51 | Dictionary mapping sequences to embeddings, or None if sql=True
52 |
53 | Note:
54 | - If sql=True, embeddings can only be stored in float32
55 | - sql is ideal if you need to stream a very large dataset for training in real-time
56 | - save=True is ideal if you can store the entire embedding dictionary in RAM
57 | - sql will be used if it is True and save is True or False
58 | - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
59 | - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
60 | ```
61 |
62 | ## Upcoming releases
63 | A Fast version of ANKH is in progress. It is functional but is still currently native attention, we are waiting for bias gradient support in [FlexAttention](https://pytorch.org/blog/flexattention/).
64 |
--------------------------------------------------------------------------------
/e1/tokenizer.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "1.0",
3 | "truncation": null,
4 | "padding": {
5 | "strategy": "BatchLongest",
6 | "direction": "Right",
7 | "pad_to_multiple_of": null,
8 | "pad_id": 0,
9 | "pad_type_id": 0,
10 | "pad_token": ""
11 | },
12 | "added_tokens": [
13 | {
14 | "id": 0,
15 | "content": "",
16 | "single_word": false,
17 | "lstrip": false,
18 | "rstrip": false,
19 | "normalized": false,
20 | "special": true
21 | },
22 | {
23 | "id": 1,
24 | "content": "",
25 | "single_word": false,
26 | "lstrip": false,
27 | "rstrip": false,
28 | "normalized": false,
29 | "special": true
30 | },
31 | {
32 | "id": 2,
33 | "content": "",
34 | "single_word": false,
35 | "lstrip": false,
36 | "rstrip": false,
37 | "normalized": false,
38 | "special": true
39 | },
40 | {
41 | "id": 3,
42 | "content": "",
43 | "single_word": false,
44 | "lstrip": false,
45 | "rstrip": false,
46 | "normalized": false,
47 | "special": true
48 | },
49 | {
50 | "id": 4,
51 | "content": "",
52 | "single_word": false,
53 | "lstrip": false,
54 | "rstrip": false,
55 | "normalized": false,
56 | "special": true
57 | },
58 | {
59 | "id": 5,
60 | "content": "?",
61 | "single_word": false,
62 | "lstrip": false,
63 | "rstrip": false,
64 | "normalized": false,
65 | "special": true
66 | }
67 | ],
68 | "normalizer": null,
69 | "pre_tokenizer": {
70 | "type": "ByteLevel",
71 | "add_prefix_space": false,
72 | "trim_offsets": true,
73 | "use_regex": true
74 | },
75 | "post_processor": {
76 | "type": "ByteLevel",
77 | "add_prefix_space": true,
78 | "trim_offsets": true,
79 | "use_regex": true
80 | },
81 | "decoder": {
82 | "type": "ByteLevel",
83 | "add_prefix_space": true,
84 | "trim_offsets": true,
85 | "use_regex": true
86 | },
87 | "model": {
88 | "type": "BPE",
89 | "dropout": null,
90 | "unk_token": "X",
91 | "continuing_subword_prefix": null,
92 | "end_of_word_suffix": null,
93 | "fuse_unk": false,
94 | "byte_fallback": false,
95 | "vocab": {
96 | "": 0,
97 | "": 1,
98 | "": 2,
99 | "": 3,
100 | "": 4,
101 | "?": 5,
102 | "1": 6,
103 | "2": 7,
104 | "A": 8,
105 | "B": 9,
106 | "C": 10,
107 | "D": 11,
108 | "E": 12,
109 | "F": 13,
110 | "G": 14,
111 | "H": 15,
112 | "I": 16,
113 | "J": 17,
114 | "K": 18,
115 | "L": 19,
116 | "M": 20,
117 | "N": 21,
118 | "O": 22,
119 | "P": 23,
120 | "Q": 24,
121 | "R": 25,
122 | "S": 26,
123 | "T": 27,
124 | "U": 28,
125 | "V": 29,
126 | "W": 30,
127 | "X": 31,
128 | "Y": 32,
129 | "Z": 33
130 | },
131 | "merges": []
132 | }
133 | }
134 |
--------------------------------------------------------------------------------
/test_scripts/test_attentions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | from datasets import load_dataset
4 | from esm_plusplus.modeling_esm_plusplus import ESMplusplusModel
5 | from tqdm import tqdm
6 |
7 |
8 | def test_attention_outputs(model, tokenizer, seqs, batch_size=4, tolerances=[1e-3, 1e-5, 1e-7, 1e-9]):
9 | """
10 | Test if hidden states are the same with and without attention output at different tolerance levels.
11 |
12 | Args:
13 | model: The model to test
14 | tokenizer: The tokenizer to use
15 | seqs: List of sequences to process
16 | batch_size: Batch size for processing
17 | tolerances: List of tolerance values to test with torch.allclose
18 |
19 | Returns:
20 | dict: Results for each tolerance level
21 | """
22 | results = {tol: True for tol in tolerances}
23 | max_diff = 0.0
24 |
25 | with torch.no_grad():
26 | for i in tqdm(range(0, len(seqs), batch_size), desc='Testing attention outputs'):
27 | batch_seqs = seqs[i:i+batch_size]
28 |
29 | # Tokenize the batch
30 | tokenized = tokenizer(batch_seqs, padding=True, return_tensors='pt')
31 | tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
32 |
33 | # Get output without attention
34 | output_no_att = model(**tokenized).last_hidden_state.detach().cpu()
35 |
36 | # Get output with attention
37 | output_with_att = model(**tokenized, output_attentions=True).last_hidden_state.detach().cpu()
38 |
39 | # Calculate maximum difference
40 | diff = (output_no_att - output_with_att).abs().max().item()
41 | max_diff = max(max_diff, diff)
42 | print(max_diff)
43 |
44 | # Check for NaN or infinite values
45 | has_nan_or_inf_no_att = torch.isnan(output_no_att).any() or torch.isinf(output_no_att).any()
46 | has_nan_or_inf_with_att = torch.isnan(output_with_att).any() or torch.isinf(output_with_att).any()
47 | if has_nan_or_inf_no_att or has_nan_or_inf_with_att:
48 | print(f"WARNING: Found NaN or infinite values in the outputs! No att: {has_nan_or_inf_no_att}, With att: {has_nan_or_inf_with_att}")
49 |
50 | # Test different tolerance levels
51 | for tol in tolerances:
52 | if not torch.allclose(output_no_att, output_with_att, atol=tol):
53 | results[tol] = False
54 |
55 | return results, max_diff
56 |
57 |
58 | if __name__ == '__main__':
59 | # py -m test_scripts.test_attentions
60 | parser = argparse.ArgumentParser(description='Test attention outputs in ESM++ models')
61 | parser.add_argument('--model', type=str, default='Synthyra/ESMplusplus_small', help='Model to test')
62 | parser.add_argument('--num_samples', type=int, default=100, help='Number of samples to test')
63 | parser.add_argument('--batch_size', type=int, default=4, help='Batch size for processing')
64 | args = parser.parse_args()
65 |
66 | # Load data
67 | seqs = load_dataset('Synthyra/NEGATOME', split='manual_stringent').filter(lambda x: len(x['SeqA']) <= 256).select(range(args.num_samples))['SeqA']
68 | seqs = list(set(seqs))
69 | # Load model
70 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
71 | model = ESMplusplusModel.from_pretrained(args.model).to(device)
72 | tokenizer = model.tokenizer
73 |
74 | # Define tolerance levels to test
75 | tolerances = [1e-2, 1e-4, 1e-6, 1e-8]
76 |
77 | # Run tests
78 | print(f"Testing model: {args.model}")
79 | print(f"Device: {device}")
80 | print(f"Testing {len(seqs)} sequences with batch size {args.batch_size}")
81 |
82 | results, max_diff = test_attention_outputs(
83 | model,
84 | tokenizer,
85 | seqs,
86 | batch_size=args.batch_size,
87 | tolerances=tolerances
88 | )
89 |
90 | # Report results
91 | print("\nTest Results:")
92 | print(f"Maximum absolute difference: {max_diff:.10e}")
93 | print("\nTolerance tests:")
94 | for tol in sorted(tolerances):
95 | status = "PASSED" if results[tol] else "FAILED"
96 | print(f" Tolerance {tol:.0e}: {status}")
97 |
98 | # Overall result
99 | if all(results.values()):
100 | print("\nAll tests PASSED! Hidden states are identical with and without attention output.")
101 | else:
102 | min_passing_tol = min([tol for tol, passed in results.items() if passed], default=None)
103 | if min_passing_tol:
104 | print(f"\nTest PASSED at tolerance {min_passing_tol:.0e} and above.")
105 | else:
106 | print("\nAll tests FAILED. Hidden states differ significantly when output_attentions is True vs False.")
--------------------------------------------------------------------------------
/update_HF.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import HfApi, login
2 |
3 |
4 | FAST_ESM_MODELS = [
5 | 'Synthyra/ESM2-8M',
6 | 'Synthyra/ESM2-35M',
7 | 'Synthyra/ESM2-150M',
8 | 'Synthyra/ESM2-650M',
9 | 'Synthyra/ESM2-3B',
10 | 'Synthyra/FastESM2_650'
11 | ]
12 |
13 | ESM_PLUSPLUS_MODELS = [
14 | 'Synthyra/ESMplusplus_small',
15 | 'Synthyra/ESMplusplus_large',
16 | ]
17 |
18 | E1_MODELS = [
19 | 'Synthyra/Profluent-E1-150M',
20 | 'Synthyra/Profluent-E1-300M',
21 | 'Synthyra/Profluent-E1-600M',
22 | ]
23 |
24 | ANKH_MODELS = [
25 | 'Synthyra/ANKH_base',
26 | 'Synthyra/ANKH_large',
27 | 'Synthyra/ANKH2_large'
28 | ]
29 |
30 |
31 | if __name__ == "__main__":
32 | # py -m update_HF
33 | import argparse
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument('--token', type=str, default=None)
36 | args = parser.parse_args()
37 |
38 | if args.token:
39 | login(token=args.token)
40 |
41 | api = HfApi()
42 |
43 | for path in FAST_ESM_MODELS:
44 | print(path.lower())
45 | api.upload_file(
46 | path_or_fileobj="esm2/modeling_fastesm.py",
47 | path_in_repo="modeling_fastesm.py",
48 | repo_id=path,
49 | repo_type="model",
50 | )
51 | # Upload license file for FastESM models
52 | api.upload_file(
53 | path_or_fileobj="LICENSE",
54 | path_in_repo="LICENSE",
55 | repo_id=path,
56 | repo_type="model",
57 | )
58 | if 'esm2' in path.lower():
59 | api.upload_file(
60 | path_or_fileobj="readmes/fastesm2_readme.md",
61 | path_in_repo="README.md",
62 | repo_id=path,
63 | repo_type="model",
64 | )
65 |
66 | if 'fastesm' in path.lower():
67 | api.upload_file(
68 | path_or_fileobj="readmes/fastesm_650_readme.md",
69 | path_in_repo="README.md",
70 | repo_id=path,
71 | repo_type="model",
72 | )
73 |
74 |
75 | for path in ESM_PLUSPLUS_MODELS:
76 | print(path)
77 | api.upload_file(
78 | path_or_fileobj="esm_plusplus/modeling_esm_plusplus.py",
79 | path_in_repo="modeling_esm_plusplus.py",
80 | repo_id=path,
81 | repo_type="model",
82 | )
83 | if path == 'Synthyra/ESMplusplus_small':
84 | api.upload_file(
85 | path_or_fileobj="readmes/esm_plusplus_small_readme.md",
86 | path_in_repo="README.md",
87 | repo_id=path,
88 | repo_type="model",
89 | )
90 | # Upload license file for ESM++ small model
91 | api.upload_file(
92 | path_or_fileobj="LICENSE",
93 | path_in_repo="LICENSE",
94 | repo_id=path,
95 | repo_type="model",
96 | )
97 |
98 | if path == 'Synthyra/ESMplusplus_large':
99 | api.upload_file(
100 | path_or_fileobj="readmes/esm_plusplus_large_readme.md",
101 | path_in_repo="README.md",
102 | repo_id=path,
103 | repo_type="model",
104 | )
105 | # Upload license file for ESM++ large model
106 | api.upload_file(
107 | path_or_fileobj="LICENSE",
108 | path_in_repo="LICENSE",
109 | repo_id=path,
110 | repo_type="model",
111 | )
112 |
113 |
114 | for path in E1_MODELS:
115 | print(path)
116 | api.upload_file(
117 | path_or_fileobj="e1/modeling_e1.py",
118 | path_in_repo="modeling_e1.py",
119 | repo_id=path,
120 | repo_type="model",
121 | )
122 | # Upload license file for FastESM models
123 | api.upload_file(
124 | path_or_fileobj="LICENSE",
125 | path_in_repo="LICENSE",
126 | repo_id=path,
127 | repo_type="model",
128 | )
129 | api.upload_file(
130 | path_or_fileobj="readmes/e1_readme.md",
131 | path_in_repo="README.md",
132 | repo_id=path,
133 | repo_type="model",
134 | )
135 | api.upload_file(
136 | path_or_fileobj="e1/tokenizer.json",
137 | path_in_repo="tokenizer.json",
138 | repo_id=path,
139 | repo_type="model",
140 | )
141 |
142 |
143 | # Add code to upload files for ANKH models
144 | for path in ANKH_MODELS:
145 | print(path)
146 | # Upload license file for ANKH models
147 | api.upload_file(
148 | path_or_fileobj="LICENSE",
149 | path_in_repo="LICENSE",
150 | repo_id=path,
151 | repo_type="model",
152 | )
153 |
--------------------------------------------------------------------------------
/test_scripts/test_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | import sqlite3
5 |
6 | from typing import List
7 | from transformers import AutoModel
8 | from esm2.modeling_fastesm import FastEsmModel
9 | from esm_plusplus.modeling_esm_plusplus import ESMplusplusModel
10 | from e1.modeling_e1 import E1Model
11 |
12 |
13 | CANONICAL_AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
14 |
15 |
16 | def generate_random_sequence(length: int) -> str:
17 | return 'M' + "".join(random.choices(CANONICAL_AMINO_ACIDS, k=length))
18 |
19 |
20 | if __name__ == "__main__":
21 | # py -m test_scripts.test_embedding
22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23 |
24 | def test_embedding(model, sequences: List[str], name: str):
25 | embeddings = model.embed_dataset(
26 | sequences=sequences,
27 | tokenizer=model.tokenizer if hasattr(model, 'tokenizer') else None,
28 | sql=False, # return dictionary of sequences and embeddings
29 | save=False,
30 | )
31 |
32 | count = 0
33 | for k, v in embeddings.items():
34 | print(k)
35 | print(v.dtype, v.shape)
36 | count += 1
37 | if count > 5:
38 | break
39 |
40 | embeddings = model.embed_dataset(
41 | sequences=sequences,
42 | tokenizer=model.tokenizer if hasattr(model, 'tokenizer') else None,
43 | full_embeddings=True,
44 | sql=False, # return dictionary of sequences and embeddings
45 | save=False,
46 | )
47 |
48 | count = 0
49 | for k, v in embeddings.items():
50 | print(k)
51 | print(v.dtype, v.shape)
52 | count += 1
53 | if count > 5:
54 | break
55 |
56 | db_path = f'embeddings_{name}.db'
57 | _ = model.embed_dataset(
58 | sequences=sequences,
59 | tokenizer=model.tokenizer if hasattr(model, 'tokenizer') else None,
60 | pooling_types=['cls', 'mean'],
61 | sql=True,
62 | sql_db_path=db_path,
63 | save=True,
64 | )
65 |
66 | # Verify database contents
67 | conn = sqlite3.connect(db_path)
68 | c = conn.cursor()
69 |
70 | # Check number of sequences
71 | c.execute('SELECT COUNT(*) FROM embeddings')
72 | db_count = c.fetchone()[0]
73 | print(f"\nNumber of sequences in database: {db_count}")
74 |
75 | count = 0
76 | for seq in sequences:
77 | c.execute('SELECT embedding FROM embeddings WHERE sequence = ?', (seq,))
78 | result = c.fetchone()
79 | assert result is not None, f"Sequence {seq} not found in database"
80 | if count < 10:
81 | embedding = np.frombuffer(result[0], dtype=np.float32)
82 | print(seq)
83 | print(f"Embedding shape: {embedding.shape}")
84 | count += 1
85 |
86 | # Make sure to close the connection before attempting to delete the file
87 | c.close()
88 | conn.close()
89 |
90 | print("Testing E1 model...")
91 | sequences = [generate_random_sequence(random.randint(4, 16)) for _ in range(100)]
92 | model = E1Model.from_pretrained("Synthyra/Profluent-E1-150M", dtype=torch.bfloat16).to(device)
93 | print(model)
94 | test_embedding(model, sequences, 'e1')
95 |
96 | print("Testing FastESM model...")
97 | sequences = [generate_random_sequence(random.randint(4, 16)) for _ in range(100)]
98 | model = FastEsmModel.from_pretrained("Synthyra/ESM2-8M", dtype=torch.float16).to(device)
99 | print(model)
100 | test_embedding(model, sequences, 'fastesm')
101 |
102 | print("Testing ESM++ model...")
103 | sequences = [generate_random_sequence(random.randint(4, 16)) for _ in range(100)]
104 | model = ESMplusplusModel.from_pretrained("Synthyra/ESMplusplus_small", dtype=torch.float16).to(device)
105 | print(model)
106 | test_embedding(model, sequences, 'esmplusplus')
107 |
108 | print("Testing E1 model with AutoModel...")
109 | sequences = [generate_random_sequence(random.randint(4, 16)) for _ in range(100)]
110 | model = AutoModel.from_pretrained("Synthyra/Profluent-E1-150M", dtype=torch.bfloat16, trust_remote_code=True).to(device)
111 | print(model)
112 | test_embedding(model, sequences, 'e1_auto')
113 |
114 | print("Testing FastESM model with AutoModel...")
115 | sequences = [generate_random_sequence(random.randint(4, 16)) for _ in range(100)]
116 | model = AutoModel.from_pretrained("Synthyra/ESM2-8M", dtype=torch.float16, trust_remote_code=True).to(device)
117 | print(model)
118 | test_embedding(model, sequences, 'fastesm_auto')
119 |
120 | print("Testing ESM++ model with AutoModel...")
121 | sequences = [generate_random_sequence(random.randint(4, 16)) for _ in range(100)]
122 | model = AutoModel.from_pretrained("Synthyra/ESMplusplus_small", dtype=torch.float16, trust_remote_code=True).to(device)
123 | print(model)
124 | test_embedding(model, sequences, 'esmplusplus_auto')
--------------------------------------------------------------------------------
/test_scripts/test_compliance_esmc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import random
4 | import argparse
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import contextlib
8 | from huggingface_hub import login
9 | from tqdm.auto import tqdm
10 |
11 | from esm.pretrained import ESMC_600M_202412, ESMC_300M_202412
12 | from esm.sdk.api import ESMProtein, LogitsConfig
13 |
14 | from esm_plusplus.modeling_esm_plusplus import ESMplusplus_300M
15 |
16 |
17 | """
18 | Testing if ESM++ outputs are compliant with ESMC outputs
19 | """
20 |
21 | def set_seed(seed: int):
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | torch.cuda.manual_seed_all(seed)
25 | torch.backends.cudnn.deterministic = True
26 | random.seed(seed)
27 | np.random.seed(seed)
28 |
29 |
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--model_path', type=str, default='Synthyra/ESMplusplus_small')
32 | parser.add_argument('--token', type=str, default=None)
33 | args = parser.parse_args()
34 |
35 |
36 | if __name__ == "__main__":
37 | # py -m test_scripts.test_compliance_esmc
38 | if args.token:
39 | login(args.token)
40 |
41 | model_path = args.model_path
42 | canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
43 | length = 128
44 | seq_count = 10
45 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46 |
47 | set_seed(42)
48 |
49 | def generate_random_sequence(length: int) -> str:
50 | return 'M' + "".join(random.choices(canonical_amino_acids, k=length-3))
51 |
52 |
53 | sequences = [generate_random_sequence(length) for _ in range(seq_count)]
54 |
55 |
56 | if 'small' in model_path:
57 | esmc = ESMC_300M_202412(device=device, use_flash_attn=False)
58 | else:
59 | esmc = ESMC_600M_202412(device=device, use_flash_attn=False)
60 |
61 |
62 | # Get esmc model outputs
63 | base_outputs = []
64 | base_logits = []
65 | with (
66 | torch.no_grad(),
67 | torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16)
68 | if device.type == "cuda" else contextlib.nullcontext()
69 | ):
70 | for seq in tqdm(sequences):
71 | protein = ESMProtein(sequence=seq)
72 | protein_tensor = esmc.encode(protein)
73 | logits_result = esmc.logits(
74 | protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
75 | )
76 | embeddings = logits_result.embeddings
77 | logits = logits_result.logits.sequence
78 | manual_logits = esmc.sequence_head(embeddings)
79 |
80 | assert (manual_logits == logits).all(), "Logits are not equal"
81 | base_outputs.append(embeddings.float().cpu())
82 | base_logits.append(logits.float().cpu())
83 | esmc.cpu()
84 | del esmc
85 | torch.cuda.empty_cache()
86 |
87 |
88 | # Get plusplus outputs
89 | total_mse_embeddings = 0
90 | total_mse_logits = 0
91 | model = ESMplusplus_300M(device=device)
92 | tokenizer = model.tokenizer
93 | with (
94 | torch.no_grad(),
95 | torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16)
96 | if device.type == "cuda" else contextlib.nullcontext()
97 | ):
98 | for i, seq in tqdm(enumerate(sequences), total=len(sequences)):
99 | input = tokenizer(seq, return_tensors="pt").to(device)
100 | outputs = model(**input)
101 | embeddings = outputs.last_hidden_state.float().cpu()
102 | logits = outputs.logits.float().cpu()
103 |
104 | # Compare embeddings
105 | mse_embeddings = F.mse_loss(base_outputs[i], embeddings).item()
106 | # Compare logits
107 | mse_logits = F.mse_loss(base_logits[i], logits).item()
108 |
109 | if mse_embeddings > 0.01 or mse_logits > 0.01:
110 | print(f"Sequence {i}:")
111 | print(f" Embeddings MSE: {mse_embeddings:.8f}")
112 | print(f" Logits MSE: {mse_logits:.8f}")
113 |
114 | # Find positions where tensors differ
115 | diff_embeddings = torch.abs(base_outputs[i] - embeddings)
116 | diff_logits = torch.abs(base_logits[i] - logits)
117 |
118 | # plot diffs
119 | plt.figure(figsize=(12, 5))
120 | plt.subplot(1, 2, 1)
121 | plt.imshow(diff_embeddings[0].detach().numpy())
122 | plt.title("Embeddings Difference")
123 |
124 | plt.subplot(1, 2, 2)
125 | plt.imshow(diff_logits[0].detach().numpy())
126 | plt.title("Logits Difference")
127 | plt.show()
128 |
129 | total_mse_embeddings += mse_embeddings
130 | total_mse_logits += mse_logits
131 | model.cpu()
132 | del model
133 | torch.cuda.empty_cache()
134 |
135 | print(f"Average Embeddings MSE: {total_mse_embeddings / seq_count}")
136 | print(f"Average Logits MSE: {total_mse_logits / seq_count}")
137 |
--------------------------------------------------------------------------------
/test_scripts/test_compliance_esm2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import random
4 | import argparse
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | from datasets import load_dataset
8 | from huggingface_hub import login
9 | from tqdm.auto import tqdm
10 | from transformers import EsmForMaskedLM
11 |
12 | from esm2.modeling_fastesm import FastEsmForMaskedLM
13 |
14 |
15 | """
16 | Testing if FastESM outputs are compliant with ESM2 outputs
17 | """
18 |
19 | def set_seed(seed: int):
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | torch.backends.cudnn.deterministic = True
24 | random.seed(seed)
25 | np.random.seed(seed)
26 |
27 |
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument('--token', type=str, default=None)
30 | args = parser.parse_args()
31 |
32 |
33 | if __name__ == "__main__":
34 | # py -m test_scripts.test_compliance_esm2
35 | if args.token:
36 | login(args.token)
37 |
38 | canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
39 | length = 1024
40 | seq_count = 10
41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42 |
43 | set_seed(42)
44 |
45 | sequences = load_dataset('lhallee/ccds_human_512', split='train')['AA'][:seq_count]
46 | sequences = [seq.replace('L', '') for seq in sequences]
47 |
48 | esm2 = EsmForMaskedLM.from_pretrained('facebook/esm2_t33_650M_UR50D').to(device).eval()
49 | #fastesm = FastEsmForMaskedLM.from_pretrained('facebook/esm2_t33_650M_UR50D').to(device)
50 | #fastesm.lm_head.load_state_dict(esm2.lm_head.state_dict())
51 | fastesm = FastEsmForMaskedLM.from_pretrained('Synthyra/ESM2-650M').to(device).eval()
52 | tokenizer = fastesm.tokenizer
53 |
54 | # Get esmc model outputs
55 | base_outputs = []
56 | base_logits = []
57 | with torch.no_grad():
58 | for seq in tqdm(sequences):
59 | input = tokenizer(seq, return_tensors="pt").to(device)
60 | outputs = esm2(**input, output_hidden_states=True)
61 | embeddings = outputs.hidden_states[-1].float().cpu()
62 | logits = outputs.logits.float().cpu()
63 | base_outputs.append(embeddings)
64 | base_logits.append(logits)
65 | esm2.cpu()
66 | del esm2
67 | torch.cuda.empty_cache()
68 |
69 | # Get plusplus outputs
70 | total_mse_embeddings = 0
71 | total_mse_logits = 0
72 | total_max_diff_embeddings = 0
73 | total_max_diff_logits = 0
74 | total_accuracy = 0
75 |
76 | with torch.no_grad():
77 | for i, seq in tqdm(enumerate(sequences), total=len(sequences)):
78 | input = tokenizer(seq, return_tensors="pt").to(device)
79 | outputs = fastesm(**input, output_attentions=False)
80 | embeddings = outputs.last_hidden_state.float().cpu()
81 | logits = outputs.logits.float().cpu()
82 |
83 | # Compare embeddings
84 | mse_embeddings = F.mse_loss(base_outputs[i], embeddings).item()
85 | # Compare logits
86 | mse_logits = F.mse_loss(base_logits[i], logits).item()
87 | max_diff_embeddings = torch.max(torch.abs(base_outputs[i] - embeddings)).item()
88 | max_diff_logits = torch.max(torch.abs(base_logits[i] - logits)).item()
89 |
90 | # Calculate accuracy of argmaxed logits
91 | base_argmax = torch.argmax(base_logits[i], dim=-1)
92 | fastesm_argmax = torch.argmax(logits, dim=-1)
93 | accuracy = (base_argmax == fastesm_argmax).float().mean().item()
94 | if mse_embeddings > 0.01 or mse_logits > 0.1:
95 | print(f"Sequence {i}:")
96 | print(f" Embeddings MSE: {mse_embeddings:.8f}")
97 | print(f" Logits MSE: {mse_logits:.8f}")
98 | print(f" Argmax Accuracy: {accuracy:.6f}")
99 |
100 | # Find positions where tensors differ
101 | diff_embeddings = torch.abs(base_outputs[i] - embeddings)
102 | diff_logits = torch.abs(base_logits[i] - logits)
103 |
104 | # plot diffs
105 | plt.figure(figsize=(12, 5))
106 | plt.subplot(1, 2, 1)
107 | plt.imshow(diff_embeddings[0].detach().numpy())
108 | plt.title("Embeddings Difference")
109 |
110 | plt.subplot(1, 2, 2)
111 | plt.imshow(diff_logits[0].detach().numpy())
112 | plt.title("Logits Difference")
113 | plt.show()
114 |
115 | total_mse_embeddings += mse_embeddings
116 | total_mse_logits += mse_logits
117 | total_max_diff_embeddings += max_diff_embeddings
118 | total_max_diff_logits += max_diff_logits
119 | total_accuracy += accuracy
120 | fastesm.cpu()
121 | del fastesm
122 | torch.cuda.empty_cache()
123 |
124 | print(f"Average Embeddings MSE: {total_mse_embeddings / seq_count}")
125 | print(f"Average Logits MSE: {total_mse_logits / seq_count}")
126 | print(f"Average Max Diff Embeddings: {total_max_diff_embeddings / seq_count}")
127 | print(f"Average Max Diff Logits: {total_max_diff_logits / seq_count}")
128 | print(f"Average Argmax Accuracy: {total_accuracy / seq_count:.6f}")
--------------------------------------------------------------------------------
/readmes/fastesm2_readme.md:
--------------------------------------------------------------------------------
1 | ---
2 | library_name: transformers
3 | tags: []
4 | ---
5 |
6 | # NOTE
7 | The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8 |
9 | # FastESM
10 | FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
11 |
12 | Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
13 |
14 | Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
15 | Various other optimizations also make the base implementation slightly different than the one in transformers.
16 |
17 | ## Use with 🤗 transformers
18 |
19 | ### Supported models
20 | ```python
21 | model_dict = {
22 | # Synthyra/ESM2-8M
23 | 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
24 | # Synthyra/ESM2-35M
25 | 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
26 | # Synthyra/ESM2-150M
27 | 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
28 | # Synthyra/ESM2-650M
29 | 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
30 | # Synthyra/ESM2-3B
31 | 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
32 | }
33 | ```
34 |
35 | ### For working with embeddings
36 | ```python
37 | import torch
38 | from transformers import AutoModel, AutoTokenizer
39 |
40 | model_path = 'Synthyra/ESM2-8M'
41 | model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
42 | tokenizer = model.tokenizer
43 |
44 | sequences = ['MPRTEIN', 'MSEQWENCE']
45 | tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
46 | with torch.no_grad():
47 | embeddings = model(**tokenized).last_hidden_state
48 |
49 | print(embeddings.shape) # (2, 11, 1280)
50 | ```
51 |
52 | ### For working with sequence logits
53 | ```python
54 | import torch
55 | from transformers import AutoModelForMaskedLM, AutoTokenizer
56 |
57 | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
58 | with torch.no_grad():
59 | logits = model(**tokenized).logits
60 |
61 | print(logits.shape) # (2, 11, 33)
62 | ```
63 |
64 | ### For working with attention maps
65 | ```python
66 | import torch
67 | from transformers import AutoModel, AutoTokenizer
68 |
69 | model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
70 | with torch.no_grad():
71 | attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
72 |
73 | print(attentions[-1].shape) # (2, 20, 11, 11)
74 | ```
75 |
76 | ### Contact prediction
77 | Because we can output attentions using the naive attention implementation, the contact prediction is also supported
78 | ```python
79 | with torch.no_grad():
80 | contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len)
81 | ```
82 | 
83 |
84 | ## Embed entire datasets with no new code
85 | To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
86 |
87 | Example:
88 | ```python
89 | embedding_dict = model.embed_dataset(
90 | sequences=[
91 | 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
92 | ],
93 | tokenizer=model.tokenizer,
94 | batch_size=2, # adjust for your GPU memory
95 | max_len=512, # adjust for your needs
96 | full_embeddings=False, # if True, no pooling is performed
97 | embed_dtype=torch.float32, # cast to what dtype you want
98 | pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
99 | num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
100 | sql=False, # if True, embeddings will be stored in SQLite database
101 | sql_db_path='embeddings.db',
102 | save=True, # if True, embeddings will be saved as a .pth file
103 | save_path='embeddings.pth',
104 | )
105 | # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
106 | ```
107 |
108 | ```
109 | model.embed_dataset()
110 | Args:
111 | sequences: List of protein sequences
112 | batch_size: Batch size for processing
113 | max_len: Maximum sequence length
114 | full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
115 | pooling_type: Type of pooling ('mean' or 'cls')
116 | num_workers: Number of workers for data loading, 0 for the main process
117 | sql: Whether to store embeddings in SQLite database - will be stored in float32
118 | sql_db_path: Path to SQLite database
119 |
120 | Returns:
121 | Dictionary mapping sequences to embeddings, or None if sql=True
122 |
123 | Note:
124 | - If sql=True, embeddings can only be stored in float32
125 | - sql is ideal if you need to stream a very large dataset for training in real-time
126 | - save=True is ideal if you can store the entire embedding dictionary in RAM
127 | - sql will be used if it is True and save is True or False
128 | - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
129 | - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
130 | ```
131 |
132 |
133 | ### Citation
134 | If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
135 | ```
136 | @misc {FastPLMs,
137 | author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
138 | title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.},
139 | year = {2024},
140 | url = { https://huggingface.co/Synthyra/ESMplusplus_small },
141 | DOI = { 10.57967/hf/3726 },
142 | publisher = { Hugging Face }
143 | }
144 | ```
--------------------------------------------------------------------------------
/test_scripts/test_throughput_esmfast.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import random
4 | import argparse
5 | import matplotlib.pyplot as plt
6 | import seaborn as sns
7 | from huggingface_hub import login
8 | from transformers import EsmForMaskedLM, AutoModelForMaskedLM, EsmTokenizer
9 |
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--model_path', type=str, default='Synthyra/FastESM2_650')
13 | parser.add_argument('--token', type=str, default=None)
14 | args = parser.parse_args()
15 |
16 | if args.token:
17 | login(args.token)
18 |
19 | model_path = args.model_path
20 | canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22 |
23 |
24 | def generate_random_sequence(length: int) -> str:
25 | return 'M' + "".join(random.choices(canonical_amino_acids, k=length-1))
26 |
27 |
28 | def time_model(model, inputs, warmup=10):
29 | model.eval()
30 | with torch.no_grad():
31 | # Warmup
32 | for _ in range(warmup):
33 | _ = model(**inputs[0])
34 |
35 | start_time = time.time()
36 | for input_batch in inputs:
37 | _ = model(**input_batch)
38 | return time.time() - start_time
39 |
40 |
41 | def get_gpu_memory():
42 | torch.cuda.synchronize()
43 | return torch.cuda.max_memory_allocated() / 1024**2 # Convert to MB
44 |
45 |
46 | tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
47 |
48 | # Test different sequence lengths and batch sizes
49 | #lengths = [128, 256, 512, 1024, 2048]
50 | #batch_sizes = [1, 4, 16, 32]
51 | lengths = [8, 16]
52 | batch_sizes = [1, 2]
53 |
54 |
55 | results = []
56 |
57 | # Generate all test sequences first
58 | all_test_inputs = {}
59 | for length in lengths:
60 | for batch_size in batch_sizes:
61 | print(f"\nGenerating sequences for length={length}, batch_size={batch_size}")
62 | all_sequences = []
63 | for _ in range(100):
64 | batch_sequences = [generate_random_sequence(length) for _ in range(batch_size)]
65 | all_sequences.append(batch_sequences)
66 | inputs = [tokenizer(sequences, padding=True, return_tensors="pt").to(device) for sequences in all_sequences]
67 | all_test_inputs[(length, batch_size)] = inputs
68 |
69 | # Test ESM model in fp32
70 | print("\nTesting ESM model in FP32...")
71 | model = EsmForMaskedLM.from_pretrained('facebook/esm2_t33_650M_UR50D').to(device)
72 | for length in lengths:
73 | for batch_size in batch_sizes:
74 | print(f"\nTesting length={length}, batch_size={batch_size}")
75 | inputs = all_test_inputs[(length, batch_size)]
76 |
77 | torch.cuda.reset_peak_memory_stats()
78 | esm_time = time_model(model, inputs)
79 | esm_memory = get_gpu_memory()
80 |
81 | results.append({
82 | 'Length': length,
83 | 'Batch Size': batch_size,
84 | 'ESM Time': esm_time,
85 | 'ESM Memory': esm_memory
86 | })
87 | print(f"ESM FP32 time: {esm_time:.2f}s, memory: {esm_memory:.0f}MB")
88 | torch.cuda.empty_cache()
89 |
90 | model.cpu()
91 | del model
92 | torch.cuda.empty_cache()
93 |
94 | # Test FastESM in fp16
95 | print("\nTesting FastESM model in FP16...")
96 | model = AutoModelForMaskedLM.from_pretrained(model_path, trust_remote_code=True, dtype=torch.float16).to(device)
97 | for i, (length, batch_size) in enumerate([(l,b) for l in lengths for b in batch_sizes]):
98 | print(f"\nTesting length={length}, batch_size={batch_size}")
99 | inputs = all_test_inputs[(length, batch_size)]
100 |
101 | torch.cuda.reset_peak_memory_stats()
102 | fast_time = time_model(model, inputs)
103 | fast_memory = get_gpu_memory()
104 |
105 | results[i].update({
106 | 'FastESM Time': fast_time,
107 | 'FastESM Memory': fast_memory,
108 | 'Speedup': results[i]['ESM Time']/fast_time
109 | })
110 | print(f"FastESM FP16 time: {fast_time:.2f}s, memory: {fast_memory:.0f}MB")
111 | print(f"Speedup: {results[i]['Speedup']:.2f}x")
112 | torch.cuda.empty_cache()
113 |
114 | model.cpu()
115 | del model
116 | torch.cuda.empty_cache()
117 |
118 | # Create plots
119 | plt.figure(figsize=(15, 10))
120 |
121 | # Speedup heatmap
122 | plt.subplot(221)
123 | speedup_data = [[r['Speedup'] for r in results if r['Length']==l] for l in lengths]
124 | sns.heatmap(speedup_data,
125 | xticklabels=batch_sizes,
126 | yticklabels=lengths,
127 | annot=True,
128 | fmt='.2f',
129 | cmap='viridis')
130 | plt.title('Speedup (ESM/FastESM)')
131 | plt.xlabel('Batch Size')
132 | plt.ylabel('Sequence Length')
133 |
134 | # Absolute times line plot
135 | plt.subplot(222)
136 | for length in lengths:
137 | length_results = [r for r in results if r['Length']==length]
138 | plt.plot([r['Batch Size'] for r in length_results],
139 | [r['FastESM Time'] for r in length_results],
140 | label=f'Length {length}',
141 | marker='o')
142 |
143 | plt.xlabel('Batch Size')
144 | plt.ylabel('Time (s)')
145 | plt.title('FastESM Processing Time')
146 | plt.legend()
147 | plt.xscale('log')
148 | plt.yscale('log')
149 |
150 | # ESM Memory heatmap
151 | plt.subplot(223)
152 | memory_data = [[r['ESM Memory'] for r in results if r['Length']==l] for l in lengths]
153 | sns.heatmap(memory_data,
154 | xticklabels=batch_sizes,
155 | yticklabels=lengths,
156 | annot=True,
157 | fmt='.0f',
158 | cmap='viridis')
159 | plt.title('ESM FP32 Memory Usage (MB)')
160 | plt.xlabel('Batch Size')
161 | plt.ylabel('Sequence Length')
162 |
163 | # FastESM Memory heatmap
164 | plt.subplot(224)
165 | memory_data = [[r['FastESM Memory'] for r in results if r['Length']==l] for l in lengths]
166 | sns.heatmap(memory_data,
167 | xticklabels=batch_sizes,
168 | yticklabels=lengths,
169 | annot=True,
170 | fmt='.0f',
171 | cmap='viridis')
172 | plt.title('FastESM FP16 Memory Usage (MB)')
173 | plt.xlabel('Batch Size')
174 | plt.ylabel('Sequence Length')
175 |
176 | plt.tight_layout()
177 | plt.savefig('throughput_results.png')
178 | plt.close()
179 |
180 | print("\nPlot saved as throughput_results.png")
181 |
--------------------------------------------------------------------------------
/readmes/e1_readme.md:
--------------------------------------------------------------------------------
1 | ---
2 | library_name: transformers
3 | tags: []
4 | ---
5 |
6 | # NOTE
7 | The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8 |
9 | # Profluent-E1
10 | [Synthyra's version of Profluent-E1](https://github.com/Synthyra/Profluent-E1-300M) is a faithful implementation of Profluent's [E1](https://www.profluent.bio/showcase/e1) models ([license](https://github.com/Profluent-AI/E1/tree/main?tab=License-1-ov-file)) that integrates Huggingface AutoModel compatability and nice embedding functionality.
11 |
12 |
13 | ## Use with 🤗 transformers
14 | ### Supported models
15 | ```python
16 | model_dict = {
17 | # Synthyra/Profluent-E1-150M
18 | 'Profluent-E1-150M': 'Profluent-Bio/E1-150m',
19 | # Synthyra/Profluent-E1-150M
20 | 'Profluent-E1-300M': 'Profluent-Bio/E1-300m',
21 | # Synthyra/Profluent-E1-150M
22 | 'Profluent-E1-600M': 'Profluent-Bio/E1-600m',
23 | }
24 | ```
25 |
26 | ```python
27 | import torch
28 | from transformers import AutoModelForMaskedLM
29 |
30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31 | model = AutoModelForMaskedLM.from_pretrained('Synthyra/Profluent-E1-150M', trust_remote_code=True, dtype=torch.bfloat16).eval().to(device)
32 |
33 | sequences = ['MPRTEIN', 'MSEQWENCE']
34 | batch = model.prep_tokens.get_batch_kwargs(sequences, device=device)
35 |
36 | output = model(**batch) # get all hidden states with output_hidden_states=True
37 | print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 34)
38 | print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 768)
39 | print(output.loss) # language modeling loss if you passed labels
40 | #print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
41 | #print(outout.attentions) # all attention matrices if you passed output_attentions=True (in tuple)
42 | ```
43 |
44 | Our E1 implementation also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization.
45 |
46 | ```python
47 | from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
48 |
49 | model = AutoModelForSequenceClassification.from_pretrained('Synthyra/Profluent-E1-150M', num_labels=2, trust_remote_code=True)
50 | logits = model(**batch, labels=labels).logits
51 | print(logits.shape) # (batch_size, num_labels), (2, 2)
52 | ```
53 |
54 | E1 weights were trained in bf16 and are in bf16 by default. You can load them in the precision of your choosing by leveraging the dtype parameter:
55 | ```python
56 | import torch
57 | model = AutoModelForMaskedLM.from_pretrained('Synthyra/Profluent-E1-150M', trust_remote_code=True, dtype=torch.float) # fp32
58 | ```
59 |
60 | ## Embed entire datasets with no new code
61 | To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
62 |
63 | Example:
64 | ```python
65 | embedding_dict = model.embed_dataset(
66 | sequences=[
67 | 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
68 | ],
69 | batch_size=2, # adjust for your GPU memory
70 | max_len=512, # adjust for your needs
71 | full_embeddings=False, # if True, no pooling is performed
72 | embed_dtype=torch.float32, # cast to what dtype you want
73 | pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
74 | sql=False, # if True, embeddings will be stored in SQLite database
75 | sql_db_path='embeddings.db',
76 | save=True, # if True, embeddings will be saved as a .pth file
77 | save_path='embeddings.pth',
78 | )
79 | # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
80 | ```
81 |
82 | ```
83 | model.embed_dataset()
84 | Args:
85 | sequences: List of protein sequences
86 | batch_size: Batch size for processing
87 | max_len: Maximum sequence length
88 | full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
89 | pooling_type: Type of pooling ('mean' or 'cls')
90 | sql: Whether to store embeddings in SQLite database - will be stored in float32
91 | sql_db_path: Path to SQLite database
92 |
93 | Returns:
94 | Dictionary mapping sequences to embeddings, or None if sql=True
95 |
96 | Note:
97 | - If sql=True, embeddings can only be stored in float32
98 | - sql is ideal if you need to stream a very large dataset for training in real-time
99 | - save=True is ideal if you can store the entire embedding dictionary in RAM
100 | - sql will be used if it is True and save is True or False
101 | - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
102 | - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
103 | ```
104 |
105 | ## Fine-tuning with 🤗 peft
106 | ```python
107 | model = AutoModelForSequenceClassification.from_pretrained('Synthyra/Profluent-E1-150M', num_labels=2, trust_remote_code=True)
108 | # these modules handle E1 attention layers
109 | target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
110 |
111 | lora_config = LoraConfig(
112 | r=8, # choose lora parameters to your liking
113 | lora_alpha=16,
114 | lora_dropout=0.01,
115 | bias="none",
116 | target_modules=target_modules,
117 | )
118 |
119 | # Apply LoRA to the model
120 | model = get_peft_model(model, lora_config)
121 |
122 | # Unfreeze the classifier head
123 | for param in model.classifier.parameters():
124 | param.requires_grad = True
125 | ```
126 |
127 | For a more thourough example of fine-tuning, check out our example script [here](https://github.com/Synthyra/FastPLMs/blob/main/fine_tuning_example.py).
128 |
129 |
130 | ### Citation
131 | If you use any of this implementation or work please cite the following DOI and Profluent's paper.
132 |
133 | ```
134 | @misc {FastPLMs,
135 | author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
136 | title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.},
137 | year = {2024},
138 | url = { https://huggingface.co/Synthyra/ESMplusplus_small },
139 | DOI = { 10.57967/hf/3726 },
140 | publisher = { Hugging Face }
141 | }
142 | ```
143 |
144 | ```
145 | @article{Jain_Beazer_Ruffolo_Bhatnagar_Madani_2025,
146 | title={E1: Retrieval-Augmented Protein Encoder Models},
147 | url={https://www.biorxiv.org/content/early/2025/11/13/2025.11.12.688125},
148 | DOI={10.1101/2025.11.12.688125},
149 | journal={bioRxiv},
150 | publisher={Cold Spring Harbor Laboratory},
151 | author={Jain, Sarthak and Beazer, Joel and Ruffolo, Jeffrey A and Bhatnagar, Aadyot and Madani, Ali},
152 | year={2025}
153 | }
154 | ```
--------------------------------------------------------------------------------
/readmes/fastesm_650_readme.md:
--------------------------------------------------------------------------------
1 | ---
2 | library_name: transformers
3 | tags: []
4 | ---
5 |
6 | # NOTE
7 | The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8 |
9 | # FastESM
10 | FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
11 |
12 | Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
13 |
14 | Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
15 | Various other optimizations also make the base implementation slightly different than the one in transformers.
16 |
17 | # FastESM2-650
18 |
19 | ## A faster half-precision version of ESM2-650 with FlashAttention2 and longer context
20 | To enhance the weights with longer context and better fp16 support, we trained ESM2-650 50000 additional steps with a traditional MLM objective (20% masking) in fp16 mixed precision on [OMGprot50](https://huggingface.co/datasets/tattabio/OMG_prot50) up to sequence length of **2048**.
21 |
22 | ## Use with 🤗 transformers
23 |
24 | ### For working with embeddings
25 | ```python
26 | import torch
27 | from transformers import AutoModel, AutoTokenizer
28 |
29 | model_path = 'Synthyra/FastESM2_650'
30 | model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
31 | tokenizer = model.tokenizer
32 |
33 | sequences = ['MPRTEIN', 'MSEQWENCE']
34 | tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
35 | with torch.no_grad():
36 | embeddings = model(**tokenized).last_hidden_state
37 |
38 | print(embeddings.shape) # (2, 11, 1280)
39 | ```
40 |
41 | ### For working with sequence logits
42 | ```python
43 | import torch
44 | from transformers import AutoModelForMaskedLM, AutoTokenizer
45 |
46 | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
47 | with torch.no_grad():
48 | logits = model(**tokenized).logits
49 |
50 | print(logits.shape) # (2, 11, 33)
51 | ```
52 |
53 | ### For working with attention maps
54 | ```python
55 | import torch
56 | from transformers import AutoModel, AutoTokenizer
57 |
58 | model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
59 | with torch.no_grad():
60 | attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
61 |
62 | print(attentions[-1].shape) # (2, 20, 11, 11)
63 | ```
64 |
65 | ## Embed entire datasets with no new code
66 | To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
67 |
68 | Example:
69 | ```python
70 | embedding_dict = model.embed_dataset(
71 | sequences=[
72 | 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
73 | ],
74 | tokenizer=model.tokenizer,
75 | batch_size=2, # adjust for your GPU memory
76 | max_len=512, # adjust for your needs
77 | full_embeddings=False, # if True, no pooling is performed
78 | embed_dtype=torch.float32, # cast to what dtype you want
79 | pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
80 | num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
81 | sql=False, # if True, embeddings will be stored in SQLite database
82 | sql_db_path='embeddings.db',
83 | save=True, # if True, embeddings will be saved as a .pth file
84 | save_path='embeddings.pth',
85 | )
86 | # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
87 | ```
88 |
89 | ```
90 | model.embed_dataset()
91 | Args:
92 | sequences: List of protein sequences
93 | batch_size: Batch size for processing
94 | max_len: Maximum sequence length
95 | full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
96 | pooling_type: Type of pooling ('mean' or 'cls')
97 | num_workers: Number of workers for data loading, 0 for the main process
98 | sql: Whether to store embeddings in SQLite database - will be stored in float32
99 | sql_db_path: Path to SQLite database
100 |
101 | Returns:
102 | Dictionary mapping sequences to embeddings, or None if sql=True
103 |
104 | Note:
105 | - If sql=True, embeddings can only be stored in float32
106 | - sql is ideal if you need to stream a very large dataset for training in real-time
107 | - save=True is ideal if you can store the entire embedding dictionary in RAM
108 | - sql will be used if it is True and save is True or False
109 | - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
110 | - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
111 | ```
112 |
113 | ## Model probes
114 | We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. FastESM performs very well.
115 |
116 | The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
117 | 
118 |
119 | ## Comparison of half precisions
120 | Presumabely because we trained in mixed-precision fp16, fp16 has closer outputs to the fp32 weights then bf16. Therefore, we recommend loading in fp16.
121 |
122 | When summing the MSE of 1000 sequences vs. the fp32 weights:
123 |
124 | Average MSE for FP16: 0.00000140
125 |
126 | Average MSE for BF16: 0.00004125
127 |
128 | ### Inference speed
129 | We look at various ESM models and their throughput on an H100. FastESM is over twice as fast as ESM2-650 with longer sequences. Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
130 | 
131 |
132 | ### Citation
133 | If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
134 | ```
135 | @misc {FastPLMs,
136 | author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
137 | title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.},
138 | year = {2024},
139 | url = { https://huggingface.co/Synthyra/ESMplusplus_small },
140 | DOI = { 10.57967/hf/3726 },
141 | publisher = { Hugging Face }
142 | }
143 | ```
144 |
--------------------------------------------------------------------------------
/pooler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import networkx as nx
4 | from typing import Optional, List
5 |
6 |
7 | class Pooler:
8 | def __init__(self, pooling_types: List[str]):
9 | self.pooling_types = pooling_types
10 | self.pooling_options = {
11 | 'mean': self.mean_pooling,
12 | 'max': self.max_pooling,
13 | 'norm': self.norm_pooling,
14 | 'median': self.median_pooling,
15 | 'std': self.std_pooling,
16 | 'var': self.var_pooling,
17 | 'cls': self.cls_pooling,
18 | 'parti': self._pool_parti,
19 | }
20 |
21 | def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
22 | maxed_attentions = torch.max(attentions, dim=1)[0]
23 | return maxed_attentions
24 |
25 | def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"):
26 | # Run PageRank on the attention matrix converted to a graph.
27 | # Raises exceptions if the graph doesn't match the token sequence or has no edges.
28 | # Returns the PageRank scores for each token node.
29 | G = self._convert_to_graph(attention_matrix)
30 | if G.number_of_nodes() != attention_matrix.shape[0]:
31 | raise Exception(
32 | f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
33 | if G.number_of_edges() == 0:
34 | raise Exception(f"You don't seem to have any attention edges left in the graph.")
35 |
36 | return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
37 |
38 | def _convert_to_graph(self, matrix):
39 | # Convert a matrix (e.g., attention scores) to a directed graph using networkx.
40 | # Each element in the matrix represents a directed edge with a weight.
41 | G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
42 | return G
43 |
44 | def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None):
45 | # Remove keys where attention_mask is 0
46 | if attention_mask is not None:
47 | for k in list(dict_importance.keys()):
48 | if attention_mask[k] == 0:
49 | del dict_importance[k]
50 |
51 | #dict_importance[0] # remove cls
52 | #dict_importance[-1] # remove eos
53 | total = sum(dict_importance.values())
54 | return np.array([v / total for _, v in dict_importance.items()])
55 |
56 | def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
57 | maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
58 | # emb is (b, L, d), maxed_attentions is (b, L, L)
59 | emb_pooled = []
60 | for e, a, mask in zip(emb, maxed_attentions, attention_mask):
61 | dict_importance = self._page_rank(a)
62 | importance_weights = self._calculate_importance_weights(dict_importance, mask)
63 | num_tokens = int(mask.sum().item())
64 | emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
65 | pooled = torch.tensor(np.array(emb_pooled))
66 | return pooled
67 |
68 | def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
69 | if attention_mask is None:
70 | return emb.mean(dim=1)
71 | else:
72 | attention_mask = attention_mask.unsqueeze(-1)
73 | return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
74 |
75 | def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
76 | if attention_mask is None:
77 | return emb.max(dim=1).values
78 | else:
79 | attention_mask = attention_mask.unsqueeze(-1)
80 | return (emb * attention_mask).max(dim=1).values
81 |
82 | def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
83 | if attention_mask is None:
84 | return emb.norm(dim=1, p=2)
85 | else:
86 | attention_mask = attention_mask.unsqueeze(-1)
87 | return (emb * attention_mask).norm(dim=1, p=2)
88 |
89 | def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
90 | if attention_mask is None:
91 | return emb.median(dim=1).values
92 | else:
93 | attention_mask = attention_mask.unsqueeze(-1)
94 | return (emb * attention_mask).median(dim=1).values
95 |
96 | def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
97 | if attention_mask is None:
98 | return emb.std(dim=1)
99 | else:
100 | # Compute variance correctly over non-masked positions, then take sqrt
101 | var = self.var_pooling(emb, attention_mask, **kwargs)
102 | return torch.sqrt(var)
103 |
104 | def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
105 | if attention_mask is None:
106 | return emb.var(dim=1)
107 | else:
108 | # Correctly compute variance over only non-masked positions
109 | attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
110 | # Compute mean over non-masked positions
111 | mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
112 | mean = mean.unsqueeze(1) # (b, 1, d)
113 | # Compute squared differences from mean, only over non-masked positions
114 | squared_diff = (emb - mean) ** 2 # (b, L, d)
115 | # Sum squared differences over non-masked positions and divide by count
116 | var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
117 | return var
118 |
119 | def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
120 | return emb[:, 0, :]
121 |
122 | def __call__(
123 | self,
124 | emb: torch.Tensor,
125 | attention_mask: Optional[torch.Tensor] = None,
126 | attentions: Optional[torch.Tensor] = None
127 | ): # [mean, max]
128 | final_emb = []
129 | for pooling_type in self.pooling_types:
130 | final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
131 | return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
132 |
133 |
134 | if __name__ == "__main__":
135 | # py -m pooler
136 | pooler = Pooler(pooling_types=['max', 'parti'])
137 |
138 | batch_size = 8
139 | seq_len = 64
140 | hidden_size = 128
141 | num_layers = 12
142 | emb = torch.randn(batch_size, seq_len, hidden_size)
143 | attentions = torch.randn(batch_size, num_layers, seq_len, seq_len)
144 | attention_mask = torch.ones(batch_size, seq_len)
145 |
146 | y = pooler(emb=emb, attention_mask=attention_mask, attentions=attentions)
147 | print(y.shape)
148 |
--------------------------------------------------------------------------------
/test_scripts/test_contact_maps.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import torch
5 | import os
6 | import random
7 | import requests
8 | import tempfile
9 | import sys
10 | from Bio.PDB import PDBParser, PPBuilder
11 |
12 | from esm2.modeling_fastesm import FastEsmModel
13 |
14 |
15 | def download_random_pdb():
16 | """
17 | Download a random protein chain PDB file.
18 |
19 | Returns:
20 | str: Path to the downloaded PDB file.
21 | """
22 | example_pdbs = ["1AKE"]
23 |
24 | # Select a random PDB ID
25 | pdb_id = random.choice(example_pdbs)
26 | print(f"Selected random PDB ID: {pdb_id}")
27 |
28 | # Create a temporary file to store the PDB
29 | temp_file = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False)
30 | temp_file_path = temp_file.name
31 | temp_file.close()
32 |
33 | # Download the PDB file
34 | url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
35 | response = requests.get(url)
36 |
37 | if response.status_code == 200:
38 | with open(temp_file_path, 'wb') as f:
39 | f.write(response.content)
40 | print(f"Downloaded PDB file to: {temp_file_path}")
41 | return temp_file_path
42 | else:
43 | raise Exception(f"Failed to download PDB file: {response.status_code}")
44 |
45 |
46 | def parse_pdb(pdb_file):
47 | """
48 | Parse a PDB file and extract the protein sequence and CA atom coordinates.
49 |
50 | Parameters:
51 | pdb_file (str): Path to the PDB file.
52 |
53 | Returns:
54 | tuple: (sequence (str), coords (np.ndarray of shape (L, 3)))
55 | """
56 | parser = PDBParser(QUIET=True)
57 | structure = parser.get_structure("protein", pdb_file)
58 | ppb = PPBuilder()
59 |
60 | # Assume a single protein chain; take the first polypeptide found.
61 | for pp in ppb.build_peptides(structure):
62 | sequence = str(pp.get_sequence())
63 | coords = []
64 | for residue in pp:
65 | # Only add the CA atom if available.
66 | if 'CA' in residue:
67 | coords.append(residue['CA'].get_coord())
68 | if len(coords) == 0:
69 | raise ValueError("No CA atoms found in the polypeptide.")
70 | return sequence, np.array(coords)
71 |
72 | raise ValueError("No polypeptide chains were found in the PDB file.")
73 |
74 |
75 | def compute_distance_matrix(coords):
76 | """
77 | Compute the pairwise Euclidean distance matrix from a set of coordinates.
78 |
79 | Parameters:
80 | coords (np.ndarray): Array of shape (L, 3) where L is the number of residues.
81 |
82 | Returns:
83 | np.ndarray: A matrix of shape (L, L) containing distances.
84 | """
85 | diff = coords[:, None, :] - coords[None, :, :]
86 | dist_matrix = np.sqrt(np.sum(diff**2, axis=-1))
87 |
88 | return dist_matrix
89 |
90 |
91 | def get_esm_contact_map(sequence):
92 | """
93 | Use the ESM model to predict a contact map for the given protein sequence.
94 |
95 | Parameters:
96 | sequence (str): Amino acid sequence.
97 |
98 | Returns:
99 | np.ndarray: A 2D array (L x L) with contact probabilities.
100 | """
101 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102 | model_path = "Synthyra/ESM2-650M"
103 | model = FastEsmModel.from_pretrained(model_path).eval().to(device)
104 | tokenizer = model.tokenizer
105 |
106 | inputs = tokenizer(sequence, return_tensors="pt")
107 | inputs = {key: value.to(device) for key, value in inputs.items()}
108 | with torch.no_grad():
109 | contact_map = model.predict_contacts(inputs["input_ids"], inputs["attention_mask"])
110 | print(contact_map.shape)
111 | contact_map = contact_map.squeeze().cpu().numpy()
112 | print(contact_map.shape)
113 | return contact_map
114 |
115 |
116 | def plot_maps(true_contact_map, predicted_contact_map, pdb_file):
117 | """
118 | Generate two subplots:
119 | 1. ESM predicted contact map.
120 | 2. True contact map from the PDB (binary, thresholded).
121 |
122 | Parameters:
123 | true_contact_map (np.ndarray): Binary (0/1) contact map from PDB.
124 | predicted_contact_map (np.ndarray): Predicted contact probabilities from ESM.
125 | pdb_file (str): Path to the PDB file, used to generate output filename.
126 | """
127 | fig, axs = plt.subplots(1, 2, figsize=(12, 6))
128 |
129 | # Plot the ESM-predicted contact map.
130 | im0 = axs[0].imshow(predicted_contact_map, cmap='RdYlBu_r', aspect='equal')
131 | axs[0].set_title("Predicted contact probabilities")
132 | axs[0].set_xlabel("Residue index")
133 | axs[0].set_ylabel("Residue index")
134 | fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)
135 |
136 | # Plot the true contact map (binary contacts).
137 | im1 = axs[1].imshow(true_contact_map, cmap='RdYlBu_r', aspect='equal')
138 | axs[1].set_title("True contacts (PDB, threshold = 8 Å)")
139 | axs[1].set_xlabel("Residue index")
140 | axs[1].set_ylabel("Residue index")
141 | fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)
142 |
143 | plt.tight_layout()
144 |
145 | # Generate output filename from PDB filename
146 | pdb_name = os.path.splitext(os.path.basename(pdb_file))[0]
147 | output_file = f"contact_maps_{pdb_name}.png"
148 | plt.savefig(output_file, dpi=300, bbox_inches='tight')
149 | plt.close()
150 |
151 |
152 | def main():
153 | # py tests/test_contact_maps.py
154 | parser = argparse.ArgumentParser(
155 | description="Extract protein sequence and compute contact maps from a PDB file using ESM predictions."
156 | )
157 | parser.add_argument("--pdb_file", type=str, help="Path to the PDB file of the protein. If not provided, a random PDB will be downloaded.", default=None)
158 | parser.add_argument(
159 | "--threshold",
160 | type=float,
161 | default=8.0,
162 | help="Distance threshold (in Å) for defining true contacts (default: 8.0 Å)."
163 | )
164 | args = parser.parse_args()
165 |
166 | # If no PDB file is provided, download a random one
167 | if args.pdb_file is None:
168 | pdb_file = download_random_pdb()
169 | else:
170 | pdb_file = args.pdb_file
171 |
172 | try:
173 | # Parse the PDB file.
174 | sequence, coords = parse_pdb(pdb_file)
175 | print("Extracted Protein Sequence:")
176 | print(sequence)
177 |
178 | # Compute the pairwise distance matrix.
179 | dist_matrix = compute_distance_matrix(coords)
180 |
181 | # Create a binary contact map from the distance matrix using the threshold.
182 | true_contact_map = (dist_matrix < args.threshold).astype(float)
183 |
184 | # Get the predicted contact map from the ESM model.
185 | predicted_contact_map = get_esm_contact_map(sequence)
186 |
187 | # Check that the dimensions agree.
188 | if predicted_contact_map.shape[0] != true_contact_map.shape[0]:
189 | print("Warning: The predicted contact map and true contact map have different dimensions.")
190 |
191 | # Plot the maps.
192 | plot_maps(true_contact_map, predicted_contact_map, pdb_file)
193 |
194 | print(f"Contact maps saved to: contact_maps_{os.path.splitext(os.path.basename(pdb_file))[0]}.png")
195 |
196 | finally:
197 | # Clean up the temporary file if we downloaded a random PDB
198 | if args.pdb_file is None and os.path.exists(pdb_file):
199 | os.remove(pdb_file)
200 | print(f"Removed temporary PDB file: {pdb_file}")
201 |
202 |
203 | if __name__ == '__main__':
204 | main()
205 |
206 |
207 |
208 |
--------------------------------------------------------------------------------
/test_scripts/test_throughput.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import random
4 | import argparse
5 | import matplotlib.pyplot as plt
6 | import pandas as pd
7 | from huggingface_hub import login
8 | from transformers import AutoModelForMaskedLM, EsmTokenizer
9 | from esm.models.esmc import ESMC
10 | from esm.sdk.api import ESMProtein, LogitsConfig
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--model_paths', nargs='+', type=str, default=[
15 | #'facebook/esm2_t6_8M_UR50D',
16 | 'Synthyra/FastESM2_650',
17 | 'facebook/esm2_t12_35M_UR50D',
18 | 'facebook/esm2_t30_150M_UR50D',
19 | 'facebook/esm2_t33_650M_UR50D',
20 | 'esmc_300m', # esmc model
21 | 'esmc_600m', # esmc model
22 | 'Synthyra/ESMplusplus_small',
23 | 'Synthyra/ESMplusplus_large'
24 | ])
25 | parser.add_argument('--token', type=str, default=None)
26 | parser.add_argument('--test', action='store_true', help='Generate random results for testing')
27 | args = parser.parse_args()
28 |
29 | if args.token:
30 | login(args.token)
31 |
32 | model_paths = args.model_paths
33 | canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
34 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35 |
36 |
37 | class ESMCForEmbedding(torch.nn.Module):
38 | def __init__(self, esm):
39 | super().__init__()
40 | self.esm = esm
41 |
42 | def forward(self, seq):
43 | protein = ESMProtein(sequence=seq)
44 | protein_tensor = self.esm.encode(protein)
45 | embeddings = self.esm.logits(
46 | protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
47 | ).embeddings.cpu()
48 | return embeddings
49 |
50 |
51 | def generate_random_sequence(length: int) -> str:
52 | return 'M' + "".join(random.choices(canonical_amino_acids, k=length-3))
53 |
54 |
55 | def generate_batch_sequences(length: int, batch_size: int, num_batches: int = 100) -> list:
56 | all_sequences = []
57 | for _ in range(num_batches):
58 | batch_sequences = [generate_random_sequence(length) for _ in range(batch_size)]
59 | all_sequences.append(batch_sequences)
60 | return all_sequences
61 |
62 |
63 | def time_model(model, inputs, warmup=4):
64 | model.eval()
65 | with torch.no_grad():
66 | # Warmup
67 | for _ in range(warmup):
68 | _ = model(**inputs[0])
69 |
70 | start_time = time.time()
71 | for input_batch in inputs:
72 | _ = model(**input_batch)
73 | return time.time() - start_time
74 |
75 |
76 | def time_model_esmc(model, sequences, warmup=10):
77 | model.eval()
78 | with torch.no_grad():
79 | # Warmup
80 | for _ in range(warmup):
81 | for seq in sequences[0]:
82 | _ = model(seq)
83 |
84 | start_time = time.time()
85 | for batch in sequences:
86 | for seq in batch:
87 | _ = model(seq)
88 | return time.time() - start_time
89 |
90 |
91 | def get_gpu_memory():
92 | torch.cuda.synchronize()
93 | return torch.cuda.max_memory_allocated() / 1024**2 # Convert to MB
94 |
95 |
96 | # Test different sequence lengths and batch sizes
97 | lengths = [32, 64, 128, 256, 512, 1024, 2048]
98 | batch_sizes = [1, 2, 4, 8, 16, 32]
99 | num_batches = 16
100 | results = []
101 |
102 | if not args.test:
103 | # Generate all test sequences first
104 | all_sequences = {}
105 | for length in lengths:
106 | for batch_size in batch_sizes:
107 | print(f"\nGenerating sequences for length={length}, batch_size={batch_size}")
108 | all_sequences[(length, batch_size)] = generate_batch_sequences(length, batch_size, num_batches)
109 |
110 | # Test each model
111 | for model_path in model_paths:
112 | print(f"\nTesting {model_path}...")
113 | if 'esmc' in model_path.lower():
114 | esm = ESMC.from_pretrained(model_path, device=device).to(device)
115 | model = ESMCForEmbedding(esm).to(device)
116 | tokenizer = None
117 | elif 'synthyra' in model_path.lower():
118 | model = AutoModelForMaskedLM.from_pretrained(model_path, trust_remote_code=True, dtype=torch.float16).to(device)
119 | tokenizer = model.tokenizer
120 | else:
121 | model = AutoModelForMaskedLM.from_pretrained(model_path).to(device)
122 | tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
123 |
124 | for length in lengths:
125 | for batch_size in batch_sizes:
126 | print(f"\nTesting length={length}, batch_size={batch_size}")
127 | sequences = all_sequences[(length, batch_size)]
128 |
129 | torch.cuda.reset_peak_memory_stats()
130 | if isinstance(model, ESMCForEmbedding):
131 | model_time = time_model_esmc(model, sequences)
132 | else:
133 | inputs = [tokenizer(batch_seq, padding=True, return_tensors="pt").to(device) for batch_seq in sequences]
134 | model_time = time_model(model, inputs)
135 | model_memory = get_gpu_memory()
136 |
137 | results.append({
138 | 'Model': model_path,
139 | 'Length': length,
140 | 'Batch Size': batch_size,
141 | 'Time': model_time,
142 | 'Memory': model_memory
143 | })
144 | print(f"Time: {model_time:.2f}s, memory: {model_memory:.0f}MB")
145 | torch.cuda.empty_cache()
146 |
147 | model.cpu()
148 | del model
149 | torch.cuda.empty_cache()
150 | else:
151 | # Generate random test results
152 | for model_path in model_paths:
153 | for length in lengths:
154 | for batch_size in batch_sizes:
155 | # Generate random time between 0.1 and 10 seconds, scaling with length and batch size
156 | model_time = random.uniform(0.1, 10) * (length/2) * (batch_size/1)
157 | # Generate random memory between 100 and 5000 MB, scaling with length and batch size
158 | model_memory = random.uniform(100, 5000) * (length/2) * (batch_size/1)
159 |
160 | results.append({
161 | 'Model': model_path,
162 | 'Length': length,
163 | 'Batch Size': batch_size,
164 | 'Time': model_time,
165 | 'Memory': model_memory
166 | })
167 | print(f"Generated random - Time: {model_time:.2f}s, memory: {model_memory:.0f}MB")
168 |
169 | # Save results to CSV
170 | df = pd.DataFrame(results)
171 | df.to_csv('model_benchmarks.csv', index=False)
172 |
173 | # Create visualization for throughput
174 | num_batch_sizes = len(batch_sizes)
175 | plt.figure(figsize=(15, 5 * num_batch_sizes))
176 |
177 | for i, batch_size in enumerate(batch_sizes):
178 | plt.subplot(num_batch_sizes, 1, i + 1)
179 | for model_path in model_paths:
180 | model_results = [(r['Length'], r['Time']) for r in results
181 | if r['Model'] == model_path and r['Batch Size'] == batch_size]
182 | if model_results:
183 | lengths, times = zip(*model_results)
184 | throughput = [batch_size * len * num_batches / time for len, time in zip(lengths, times)]
185 | plt.plot(lengths, throughput, marker='o', label=model_path)
186 |
187 | plt.title(f'Model Throughput vs Sequence Length (Batch Size = {batch_size})')
188 | plt.xlabel('Sequence Length')
189 | plt.ylabel('Throughput (tokens/second)')
190 | plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
191 | plt.grid(True)
192 |
193 | plt.tight_layout()
194 | plt.savefig('model_throughput.png', bbox_inches='tight', dpi=300)
195 | plt.close()
196 |
197 | # Create visualization for memory usage
198 | plt.figure(figsize=(15, 5 * num_batch_sizes))
199 |
200 | for i, batch_size in enumerate(batch_sizes):
201 | plt.subplot(num_batch_sizes, 1, i + 1)
202 | for model_path in model_paths:
203 | model_results = [(r['Length'], r['Memory']) for r in results
204 | if r['Model'] == model_path and r['Batch Size'] == batch_size]
205 | if model_results:
206 | lengths, memory = zip(*model_results)
207 | plt.plot(lengths, memory, marker='o', label=model_path)
208 |
209 | plt.title(f'GPU Memory Usage vs Sequence Length (Batch Size = {batch_size})')
210 | plt.xlabel('Sequence Length')
211 | plt.ylabel('Memory Usage (MB)')
212 | plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
213 | plt.grid(True)
214 |
215 | plt.tight_layout()
216 | plt.savefig('model_memory.png', bbox_inches='tight', dpi=300)
217 | plt.close()
218 |
--------------------------------------------------------------------------------
/readmes/esm_plusplus_large_readme.md:
--------------------------------------------------------------------------------
1 | ---
2 | library_name: transformers
3 | tags: []
4 | ---
5 |
6 | # NOTE
7 | The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8 |
9 | # ESM++
10 | [ESM++](https://github.com/Synthyra/ESMplusplus) is a faithful implementation of [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) ([license](https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement)) that allows for batching and standard Huggingface compatibility without requiring the ESM Python package.
11 | The large version corresponds to the 600 million parameter version of ESMC.
12 |
13 |
14 | ## Use with 🤗 transformers
15 | ```python
16 | from transformers import AutoModelForMaskedLM
17 | model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True)
18 | tokenizer = model.tokenizer
19 |
20 | sequences = ['MPRTEIN', 'MSEQWENCE']
21 | tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
22 |
23 | # tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training
24 |
25 | output = model(**tokenized) # get all hidden states with output_hidden_states=True
26 | print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
27 | print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 1152)
28 | print(output.loss) # language modeling loss if you passed labels
29 | #print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
30 | ```
31 |
32 | ESM++ also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization.
33 |
34 | ```python
35 | from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
36 |
37 | model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
38 | logits = model(**tokenized).logits
39 | print(logits.shape) # (batch_size, num_labels), (2, 2)
40 | ```
41 |
42 | ESM++ weights are fp32 by default. You can load them in fp16 or bf16 like this:
43 | ```python
44 | import torch
45 | model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True, dtype=torch.float16) # or torch.bfloat16
46 | ```
47 |
48 | ## Embed entire datasets with no new code
49 | To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
50 |
51 | Example:
52 | ```python
53 | embedding_dict = model.embed_dataset(
54 | sequences=[
55 | 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
56 | ],
57 | tokenizer=model.tokenizer,
58 | batch_size=2, # adjust for your GPU memory
59 | max_len=512, # adjust for your needs
60 | full_embeddings=False, # if True, no pooling is performed
61 | embed_dtype=torch.float32, # cast to what dtype you want
62 | pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
63 | num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
64 | sql=False, # if True, embeddings will be stored in SQLite database
65 | sql_db_path='embeddings.db',
66 | save=True, # if True, embeddings will be saved as a .pth file
67 | save_path='embeddings.pth',
68 | )
69 | # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
70 | ```
71 |
72 | ```
73 | model.embed_dataset()
74 | Args:
75 | sequences: List of protein sequences
76 | batch_size: Batch size for processing
77 | max_len: Maximum sequence length
78 | full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
79 | pooling_type: Type of pooling ('mean' or 'cls')
80 | num_workers: Number of workers for data loading, 0 for the main process
81 | sql: Whether to store embeddings in SQLite database - will be stored in float32
82 | sql_db_path: Path to SQLite database
83 |
84 | Returns:
85 | Dictionary mapping sequences to embeddings, or None if sql=True
86 |
87 | Note:
88 | - If sql=True, embeddings can only be stored in float32
89 | - sql is ideal if you need to stream a very large dataset for training in real-time
90 | - save=True is ideal if you can store the entire embedding dictionary in RAM
91 | - sql will be used if it is True and save is True or False
92 | - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
93 | - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
94 | ```
95 |
96 | ## Fine-tuning with 🤗 peft
97 | ```python
98 | model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
99 | # these modules handle ESM++ and ESM2 attention layers
100 | target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"]
101 |
102 | lora_config = LoraConfig(
103 | r=8, # choose lora parameters to your liking
104 | lora_alpha=16,
105 | lora_dropout=0.01,
106 | bias="none",
107 | target_modules=target_modules,
108 | )
109 |
110 | # Apply LoRA to the model
111 | model = get_peft_model(model, lora_config)
112 |
113 | # Unfreeze the classifier head
114 | for param in model.classifier.parameters():
115 | param.requires_grad = True
116 | ```
117 |
118 | For a more thourough example of fine-tuning, check out our example script [here](https://github.com/Synthyra/FastPLMs/blob/main/fine_tuning_example.py).
119 |
120 |
121 | ## Returning attention maps
122 | Usually F.scaled_dot_product_attention is used for the attention calculations, which is much faster than native PyTorch. However, it cannot return attention maps.
123 | ESM++ has the option to ```output_attentions```, which will calculate attention manually. This is much slower, so do not use unless you need the attention maps.
124 |
125 | ```python
126 | output = model(**tokenized, output_attentions=True)
127 | att = output.attentions
128 | len(att) # 33, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each
129 | ```
130 |
131 | ## Comparison across floating-point precision and implementations
132 | We measured the difference of the last hidden states of the fp32 weights vs. fp16 or bf16. We find that the fp16 is closer to the fp32 outputs, so we recommend loading in fp16.
133 | Please note that the ESM package also loads ESMC in fp32 but casts to bf16 by default, which has its share of advantages and disadvantages in inference / training - so load whichever you like for half precision.
134 |
135 | Average MSE for FP16: 0.00000003
136 |
137 | Average MSE for BF16: 0.00000122
138 |
139 | We also measured the difference between the outputs of ESM++ vs. ESMC (both in bfloat16) on 1000 random sequences to ensure compliance with the ESM package.
140 |
141 | Average MSE of last hidden state: 2.46e-09
142 |
143 | You can load the weights from the ESM package instead of transformers by replacing .from_pretrained(...) to .from_pretrained_esm('esmc_600m')
144 |
145 | ## Model probes
146 | We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. ESMC (and thus ESM++) perform very well.
147 |
148 | The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
149 | 
150 |
151 | ## Inference speeds
152 | We look at various ESM models and their throughput on an H100. Adding efficient batching between ESMC and ESM++ significantly improves the throughput, although ESM++ is also faster than ESMC for batch size one. ESM++ small is even faster than ESM2-35M with long sequences! The most gains will be seen with PyTorch > 2.5 on linux machines.
153 | 
154 |
155 | ### Citation
156 | If you use any of this implementation or work please cite it (as well as the ESMC preprint).
157 |
158 | ```
159 | @misc {FastPLMs,
160 | author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
161 | title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.},
162 | year = {2024},
163 | url = { https://huggingface.co/Synthyra/ESMplusplus_small },
164 | DOI = { 10.57967/hf/3726 },
165 | publisher = { Hugging Face }
166 | }
167 | ```
--------------------------------------------------------------------------------
/readmes/esm_plusplus_small_readme.md:
--------------------------------------------------------------------------------
1 | ---
2 | library_name: transformers
3 | tags: []
4 | ---
5 |
6 | # NOTE
7 | The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8 |
9 | # ESM++
10 | [ESM++](https://github.com/Synthyra/ESMplusplus) is a faithful implementation of [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) ([license](https://www.evolutionaryscale.ai/policies/cambrian-open-license-agreement)) that allows for batching and standard Huggingface compatibility without requiring the ESM Python package.
11 | The small version corresponds to the 300 million parameter version of ESMC.
12 |
13 |
14 | ## Use with 🤗 transformers
15 | ```python
16 | from transformers import AutoModelForMaskedLM
17 | model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
18 | tokenizer = model.tokenizer
19 |
20 | sequences = ['MPRTEIN', 'MSEQWENCE']
21 | tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
22 |
23 | # tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training
24 |
25 | output = model(**tokenized) # get all hidden states with output_hidden_states=True
26 | print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
27 | print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 960)
28 | print(output.loss) # language modeling loss if you passed labels
29 | #print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
30 | ```
31 |
32 | ESM++ also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization.
33 |
34 | ```python
35 | from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
36 |
37 | model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
38 | logits = model(**tokenized).logits
39 | print(logits.shape) # (batch_size, num_labels), (2, 2)
40 | ```
41 |
42 | ESM++ weights are fp32 by default. You can load them in fp16 or bf16 like this:
43 | ```python
44 | import torch
45 | model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True, dtype=torch.float16) # or torch.bfloat16
46 | ```
47 |
48 | ## Embed entire datasets with no new code
49 | To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
50 |
51 | Example:
52 | ```python
53 | embedding_dict = model.embed_dataset(
54 | sequences=[
55 | 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
56 | ],
57 | tokenizer=model.tokenizer,
58 | batch_size=2, # adjust for your GPU memory
59 | max_len=512, # adjust for your needs
60 | full_embeddings=False, # if True, no pooling is performed
61 | embed_dtype=torch.float32, # cast to what dtype you want
62 | pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
63 | num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
64 | sql=False, # if True, embeddings will be stored in SQLite database
65 | sql_db_path='embeddings.db',
66 | save=True, # if True, embeddings will be saved as a .pth file
67 | save_path='embeddings.pth',
68 | )
69 | # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
70 | ```
71 |
72 | ```
73 | model.embed_dataset()
74 | Args:
75 | sequences: List of protein sequences
76 | batch_size: Batch size for processing
77 | max_len: Maximum sequence length
78 | full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
79 | pooling_type: Type of pooling ('mean' or 'cls')
80 | num_workers: Number of workers for data loading, 0 for the main process
81 | sql: Whether to store embeddings in SQLite database - will be stored in float32
82 | sql_db_path: Path to SQLite database
83 |
84 | Returns:
85 | Dictionary mapping sequences to embeddings, or None if sql=True
86 |
87 | Note:
88 | - If sql=True, embeddings can only be stored in float32
89 | - sql is ideal if you need to stream a very large dataset for training in real-time
90 | - save=True is ideal if you can store the entire embedding dictionary in RAM
91 | - sql will be used if it is True and save is True or False
92 | - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
93 | - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
94 | ```
95 |
96 | ## Fine-tuning with 🤗 peft
97 | ```python
98 | model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
99 | # these modules handle ESM++ and ESM2 attention layers
100 | target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"]
101 |
102 | lora_config = LoraConfig(
103 | r=8, # choose lora parameters to your liking
104 | lora_alpha=16,
105 | lora_dropout=0.01,
106 | bias="none",
107 | target_modules=target_modules,
108 | )
109 |
110 | # Apply LoRA to the model
111 | model = get_peft_model(model, lora_config)
112 |
113 | # Unfreeze the classifier head
114 | for param in model.classifier.parameters():
115 | param.requires_grad = True
116 | ```
117 |
118 | For a more thourough example of fine-tuning, check out our example script [here](https://github.com/Synthyra/FastPLMs/blob/main/fine_tuning_example.py).
119 |
120 |
121 | ## Returning attention maps
122 | Usually F.scaled_dot_product_attention is used for the attention calculations, which is much faster than native PyTorch. However, it cannot return attention maps.
123 | ESM++ has the option to ```output_attentions```, which will calculate attention manually. This is much slower, so do not use unless you need the attention maps.
124 |
125 | ```python
126 | output = model(**tokenized, output_attentions=True)
127 | att = output.attentions
128 | len(att) # 30, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each
129 | ```
130 |
131 | ## Comparison across floating-point precision and implementations
132 | We measured the difference of the last hidden states of the fp32 weights vs. fp16 or bf16. We find that the fp16 is closer to the fp32 outputs, so we recommend loading in fp16.
133 | Please note that the ESM package also loads ESMC in fp32 but casts to bf16 by default, which has its share of advantages and disadvantages in inference / training - so load whichever you like for half precision.
134 |
135 | Average MSE FP32 vs. FP16: 0.00000003
136 |
137 | Average MSE FP32 vs. BF16: 0.00000140
138 |
139 | We also measured the difference between the outputs of ESM++ vs. ESMC (both in bfloat16) on 1000 random sequences to ensure compliance with the ESM package.
140 |
141 | Average MSE of last hidden state: 7.74e-10
142 |
143 | You can load the weights from the ESM package instead of transformers by replacing .from_pretrained(...) to .from_pretrained_esm('esmc_300m')
144 |
145 | ## Model probes
146 | We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. ESMC (and thus ESM++) perform very well.
147 |
148 | The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
149 | 
150 |
151 | ## Inference speeds
152 | We look at various ESM models and their throughput on an H100. Adding efficient batching between ESMC and ESM++ significantly improves the throughput, although ESM++ is also faster than ESMC for batch size one. ESM++ small is even faster than ESM2-35M with long sequences!
153 | The most gains will be seen with PyTorch > 2.5 on linux machines.
154 | 
155 |
156 | ### Citation
157 | If you use any of this implementation or work please cite it (as well as the ESMC preprint).
158 |
159 | ```
160 | @misc {FastPLMs,
161 | author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
162 | title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.},
163 | year = {2024},
164 | url = { https://huggingface.co/Synthyra/ESMplusplus_small },
165 | DOI = { 10.57967/hf/3726 },
166 | publisher = { Hugging Face }
167 | }
168 | ```
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | PLEASE NOTE THE APACHE LICENSE ONLY APPLIES TO THE CODE IN THE FastPLMs GITHUB AND ASSOCIATED HUGGINGFACE REPOSITORIES, NOT NECESSARILY THE MODEL WEIGHTS. THOSE LICENSES CAN BE FOUND HERE https://github.com/Synthyra/FastPLMs/tree/main/licenses
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright [yyyy] [name of copyright owner]
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
204 |
--------------------------------------------------------------------------------
/wip/t5/t5_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import math
4 | from typing import Optional
5 | from torch import nn
6 | from transformers import T5Config
7 | from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
8 |
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class T5AttentionTransformers(nn.Module):
14 | def __init__(
15 | self,
16 | config: T5Config,
17 | has_relative_attention_bias=False,
18 | layer_idx: Optional[int] = None,
19 | ):
20 | super().__init__()
21 | self.is_decoder = config.is_decoder
22 | self.has_relative_attention_bias = has_relative_attention_bias
23 | self.relative_attention_num_buckets = config.relative_attention_num_buckets
24 | self.relative_attention_max_distance = config.relative_attention_max_distance
25 | self.d_model = config.d_model
26 | self.key_value_proj_dim = config.d_kv
27 | self.n_heads = config.num_heads
28 | self.dropout = config.dropout_rate
29 | self.inner_dim = self.n_heads * self.key_value_proj_dim
30 | self.layer_idx = layer_idx
31 | if layer_idx is None and self.is_decoder:
32 | logger.warning_once(
33 | f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
34 | "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
35 | "when creating this class."
36 | )
37 |
38 | # Mesh TensorFlow initialization to avoid scaling before softmax
39 | self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
40 | self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
41 | self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
42 | self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
43 |
44 | if self.has_relative_attention_bias:
45 | self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
46 | self.pruned_heads = set()
47 | self.gradient_checkpointing = False
48 |
49 | def prune_heads(self, heads):
50 | if len(heads) == 0:
51 | return
52 | heads, index = find_pruneable_heads_and_indices(
53 | heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
54 | )
55 | # Prune linear layers
56 | self.q = prune_linear_layer(self.q, index)
57 | self.k = prune_linear_layer(self.k, index)
58 | self.v = prune_linear_layer(self.v, index)
59 | self.o = prune_linear_layer(self.o, index, dim=1)
60 | # Update hyper params
61 | self.n_heads = self.n_heads - len(heads)
62 | self.inner_dim = self.key_value_proj_dim * self.n_heads
63 | self.pruned_heads = self.pruned_heads.union(heads)
64 |
65 | @staticmethod
66 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
67 | """
68 | Adapted from Mesh Tensorflow:
69 | https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
70 |
71 | Translate relative position to a bucket number for relative attention. The relative position is defined as
72 | memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
73 | position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
74 | small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
75 | positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
76 | This should allow for more graceful generalization to longer sequences than the model has been trained on
77 |
78 | Args:
79 | relative_position: an int32 Tensor
80 | bidirectional: a boolean - whether the attention is bidirectional
81 | num_buckets: an integer
82 | max_distance: an integer
83 |
84 | Returns:
85 | a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
86 | """
87 | relative_buckets = 0
88 | if bidirectional:
89 | num_buckets //= 2
90 | relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
91 | relative_position = torch.abs(relative_position)
92 | else:
93 | relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
94 | # now relative_position is in the range [0, inf)
95 |
96 | # half of the buckets are for exact increments in positions
97 | max_exact = num_buckets // 2
98 | is_small = relative_position < max_exact
99 |
100 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
101 | relative_position_if_large = max_exact + (
102 | torch.log(relative_position.float() / max_exact)
103 | / math.log(max_distance / max_exact)
104 | * (num_buckets - max_exact)
105 | ).to(torch.long)
106 | relative_position_if_large = torch.min(
107 | relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
108 | )
109 |
110 | relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
111 | return relative_buckets
112 |
113 | def compute_bias(self, query_length, key_length, device=None, cache_position=None):
114 | """Compute binned relative position bias"""
115 | if device is None:
116 | device = self.relative_attention_bias.weight.device
117 | if cache_position is None:
118 | context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
119 | else:
120 | context_position = cache_position[:, None].to(device)
121 | memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
122 | relative_position = memory_position - context_position # shape (query_length, key_length)
123 | relative_position_bucket = self._relative_position_bucket(
124 | relative_position, # shape (query_length, key_length)
125 | bidirectional=(not self.is_decoder),
126 | num_buckets=self.relative_attention_num_buckets,
127 | max_distance=self.relative_attention_max_distance,
128 | )
129 | values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
130 | values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
131 | return values
132 |
133 | def forward(
134 | self,
135 | hidden_states,
136 | mask=None,
137 | key_value_states=None,
138 | position_bias=None,
139 | past_key_value=None,
140 | layer_head_mask=None,
141 | query_length=None,
142 | use_cache=False,
143 | output_attentions=False,
144 | cache_position=None,
145 | ):
146 | """
147 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
148 | """
149 | # Input is (batch_size, seq_length, dim)
150 | # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
151 | batch_size, seq_length = hidden_states.shape[:2]
152 |
153 | # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
154 | is_cross_attention = key_value_states is not None
155 |
156 | query_states = self.q(hidden_states)
157 | query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
158 |
159 | if cache_position is None:
160 | if past_key_value is not None:
161 | logger.warning_once(
162 | f"{self.__class__.__name__} forward called without cache_position when using cache, which might result in errors. "
163 | "Please provide a cache_position when calling this function. "
164 | "See 'Best Practices for Generation with Cache' in the docs for more information. "
165 | "Assuming cache position starts at 0."
166 | )
167 | cache_position = torch.arange(seq_length)
168 |
169 | if past_key_value is not None:
170 | is_updated = past_key_value.is_updated.get(self.layer_idx)
171 | if is_cross_attention:
172 | # after the first generated id, we can subsequently re-use all key/value_states from cache
173 | curr_past_key_value = past_key_value.cross_attention_cache
174 | else:
175 | curr_past_key_value = past_key_value.self_attention_cache
176 |
177 | current_states = key_value_states if is_cross_attention else hidden_states
178 | if is_cross_attention and past_key_value is not None and is_updated:
179 | # reuse k,v, cross_attentions
180 | key_states = curr_past_key_value.key_cache[self.layer_idx]
181 | value_states = curr_past_key_value.value_cache[self.layer_idx]
182 | else:
183 | key_states = self.k(current_states)
184 | value_states = self.v(current_states)
185 | key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
186 | value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
187 |
188 | if past_key_value is not None:
189 | # save all key/value_states to cache to be re-used for fast auto-regressive generation
190 | cache_position = cache_position if not is_cross_attention else None
191 | key_states, value_states = curr_past_key_value.update(
192 | key_states, value_states, self.layer_idx, {"cache_position": cache_position}
193 | )
194 | # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
195 | if is_cross_attention:
196 | past_key_value.is_updated[self.layer_idx] = True
197 |
198 | # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
199 | scores = torch.matmul(query_states, key_states.transpose(3, 2))
200 |
201 | if position_bias is None:
202 | key_length = key_states.shape[-2]
203 | # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
204 | real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
205 | if not self.has_relative_attention_bias:
206 | position_bias = torch.zeros(
207 | (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
208 | )
209 | if self.gradient_checkpointing and self.training:
210 | position_bias.requires_grad = True
211 | else:
212 | position_bias = self.compute_bias(
213 | real_seq_length, key_length, device=scores.device, cache_position=cache_position
214 | )
215 | position_bias = position_bias[:, :, -seq_length:, :]
216 |
217 | if mask is not None:
218 | causal_mask = mask[:, :, :, : key_states.shape[-2]]
219 | position_bias = position_bias + causal_mask
220 |
221 | if self.pruned_heads:
222 | mask = torch.ones(position_bias.shape[1])
223 | mask[list(self.pruned_heads)] = 0
224 | position_bias_masked = position_bias[:, mask.bool()]
225 | else:
226 | position_bias_masked = position_bias
227 |
228 | scores += position_bias_masked
229 |
230 | # (batch_size, n_heads, seq_length, key_length)
231 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
232 | attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
233 |
234 | # Mask heads if we want to
235 | if layer_head_mask is not None:
236 | attn_weights = attn_weights * layer_head_mask
237 |
238 | attn_output = torch.matmul(attn_weights, value_states)
239 |
240 | attn_output = attn_output.transpose(1, 2).contiguous()
241 | attn_output = attn_output.view(batch_size, -1, self.inner_dim)
242 | attn_output = self.o(attn_output)
243 |
244 | outputs = (attn_output, past_key_value, position_bias)
245 |
246 | if output_attentions:
247 | outputs = outputs + (attn_weights,)
248 | return outputs
--------------------------------------------------------------------------------
/wip/t5/t5_flex_attention.py:
--------------------------------------------------------------------------------
1 | """
2 | TODO
3 | Future addition to current T5 PLM models to use much more efficient flex attention
4 | Waiting for official support of learned relative position embeddings in the score_mode
5 | https://github.com/pytorch-labs/attention-gym/issues/20
6 | https://pytorch.org/blog/flexattention/
7 | https://github.com/pytorch-labs/attention-gym/pull/84/commits/4d045e172474e2b964e3961d794177ceca1549f8
8 | """
9 | import torch
10 | import torch.nn as nn
11 | from typing import Optional, Tuple, Union
12 | from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
13 | from torch.nn.attention.flex_attention import flex_attention
14 | from .t5_attention import T5AttentionTransformers
15 |
16 |
17 | class T5FlexAttentionMixin:
18 | """
19 | A mixin class to enable flex attention in T5 models.
20 |
21 | This mixin replaces the standard attention mechanism in T5Attention with flex attention,
22 | preserving the position bias mechanism that T5 uses.
23 | """
24 |
25 | def __init__(self, compile_flex=None, softcap=None):
26 | """
27 | Initialize the T5FlexAttentionMixin.
28 |
29 | Args:
30 | compile_flex: Whether to compile the flex attention function.
31 | kernel_options: Optional kernel options for flex attention.
32 | """
33 | self.compile_flex = compile_flex and (torch.cuda.is_available() or torch.backends.mps.is_available())
34 | self.flex_attention_fn = torch.compile(flex_attention) if self.compile_flex else flex_attention
35 | self.softcap = softcap
36 |
37 | def _apply_flex_attention(self, query, key, value, position_bias=None, attention_mask=None, head_mask=None):
38 | # adapted from https://github.com/huggingface/transformers/blob/752ef3fd4e70869626ec70657a770a85c0ad9219/src/transformers/integrations/flex_attention.py
39 | def score_mod(score, b, h, q_idx, kv_idx):
40 | if self.softcap is not None:
41 | score = self.softcap * torch.tanh(score / self.softcap)
42 | if position_bias is not None:
43 | score = score + position_bias[0][h][q_idx][kv_idx]
44 | if attention_mask is not None:
45 | score = score + attention_mask[b][0][q_idx][kv_idx]
46 | if head_mask is not None:
47 | score = score + head_mask[b][h][q_idx][kv_idx]
48 | return score
49 |
50 | # Apply flex attention
51 | attn_output, attention_weights = self.flex_attention_fn(
52 | query=query,
53 | key=key,
54 | value=value,
55 | score_mod=score_mod,
56 | enable_gqa=True,
57 | scale=1.0,
58 | return_lse=True,
59 | )
60 | return attn_output, attention_weights
61 |
62 |
63 | class T5FlexAttention(nn.Module, T5FlexAttentionMixin):
64 | """
65 | Drop-in replacement for T5Attention that uses flex attention.
66 |
67 | This class preserves the interface of T5Attention but uses flex attention
68 | for the core attention computation.
69 | """
70 |
71 | def __init__(self, config, has_relative_attention_bias=False, layer_idx=None, compile_flex=False):
72 | nn.Module.__init__(self)
73 | T5FlexAttentionMixin.__init__(self, compile_flex=compile_flex)
74 |
75 | self.is_decoder = config.is_decoder
76 | self.has_relative_attention_bias = has_relative_attention_bias
77 | self.layer_idx = layer_idx
78 |
79 | self.relative_attention_num_buckets = config.relative_attention_num_buckets
80 | self.relative_attention_max_distance = config.relative_attention_max_distance
81 | self.d_model = config.d_model
82 | self.key_value_proj_dim = config.d_kv
83 | self.n_heads = config.num_heads
84 | self.dropout = config.dropout_rate
85 | self.inner_dim = self.n_heads * self.key_value_proj_dim
86 |
87 | # Regular T5 projection layers
88 | self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
89 | self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
90 | self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
91 | self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
92 |
93 | if self.has_relative_attention_bias:
94 | self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
95 |
96 | self.pruned_heads = set()
97 |
98 | def prune_heads(self, heads):
99 | # Implementation similar to T5Attention.prune_heads
100 | if len(heads) == 0:
101 | return
102 | heads, index = find_pruneable_heads_and_indices(
103 | heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
104 | )
105 | # Prune linear layers
106 | self.q = prune_linear_layer(self.q, index)
107 | self.k = prune_linear_layer(self.k, index)
108 | self.v = prune_linear_layer(self.v, index)
109 | self.o = prune_linear_layer(self.o, index, dim=1)
110 | # Update hyper params
111 | self.n_heads = self.n_heads - len(heads)
112 | self.inner_dim = self.key_value_proj_dim * self.n_heads
113 | self.pruned_heads = self.pruned_heads.union(heads)
114 |
115 | @staticmethod
116 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
117 | """
118 | Adapted from T5Attention._relative_position_bucket
119 | """
120 | relative_buckets = 0
121 | if bidirectional:
122 | num_buckets //= 2
123 | relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
124 | relative_position = torch.abs(relative_position)
125 | else:
126 | relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
127 |
128 | # half of the buckets are for exact increments in positions
129 | max_exact = num_buckets // 2
130 | is_small = relative_position < max_exact
131 |
132 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
133 | relative_position_if_large = max_exact + (
134 | torch.log(relative_position.float() / max_exact)
135 | / torch.log(torch.tensor(max_distance / max_exact, device=relative_position.device))
136 | * (num_buckets - max_exact)
137 | ).to(torch.long)
138 | relative_position_if_large = torch.min(
139 | relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
140 | )
141 |
142 | relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
143 | return relative_buckets
144 |
145 | def compute_bias(self, query_length, key_length, device=None, cache_position=None):
146 | """
147 | Compute binned relative position bias, same as in T5Attention
148 | """
149 | if device is None:
150 | device = self.relative_attention_bias.weight.device
151 |
152 | if cache_position is None:
153 | context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
154 | else:
155 | context_position = cache_position[:, None].to(device)
156 |
157 | memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
158 | relative_position = memory_position - context_position # shape (query_length, key_length)
159 |
160 | relative_position_bucket = self._relative_position_bucket(
161 | relative_position, # shape (query_length, key_length)
162 | bidirectional=(not self.is_decoder),
163 | num_buckets=self.relative_attention_num_buckets,
164 | max_distance=self.relative_attention_max_distance,
165 | )
166 |
167 | values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
168 | values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
169 | return values
170 |
171 | def forward(
172 | self,
173 | hidden_states,
174 | mask=None,
175 | key_value_states=None,
176 | position_bias=None,
177 | past_key_value=None,
178 | layer_head_mask=None,
179 | query_length=None,
180 | use_cache=False,
181 | output_attentions=False,
182 | cache_position=None,
183 | ):
184 | """
185 | Forward pass, similar to T5Attention but using flex attention
186 | """
187 | batch_size, seq_length = hidden_states.shape[:2]
188 |
189 | # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
190 | is_cross_attention = key_value_states is not None
191 |
192 | # Project hidden states to query vectors
193 | query_states = self.q(hidden_states)
194 | query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
195 |
196 | # Handle past key values for caching
197 | if past_key_value is not None:
198 | is_updated = past_key_value.is_updated.get(self.layer_idx)
199 | if is_cross_attention:
200 | # after the first generated id, we can subsequently re-use all key/value_states from cache
201 | curr_past_key_value = past_key_value.cross_attention_cache
202 | else:
203 | curr_past_key_value = past_key_value.self_attention_cache
204 |
205 | # Get current states for key and value
206 | current_states = key_value_states if is_cross_attention else hidden_states
207 |
208 | # Reuse cached key/value states if available and updated
209 | if is_cross_attention and past_key_value is not None and is_updated:
210 | # reuse k,v, cross_attentions
211 | key_states = curr_past_key_value.key_cache[self.layer_idx]
212 | value_states = curr_past_key_value.value_cache[self.layer_idx]
213 | else:
214 | # Compute new key and value states
215 | key_states = self.k(current_states)
216 | value_states = self.v(current_states)
217 | key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
218 | value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
219 |
220 | # Update cache if needed
221 | if past_key_value is not None:
222 | # save all key/value_states to cache to be re-used for fast auto-regressive generation
223 | cache_position = cache_position if not is_cross_attention else None
224 | key_states, value_states = curr_past_key_value.update(
225 | key_states, value_states, self.layer_idx, {"cache_position": cache_position}
226 | )
227 | # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
228 | if is_cross_attention:
229 | past_key_value.is_updated[self.layer_idx] = True
230 |
231 | # Compute position bias if not provided
232 | if position_bias is None:
233 | key_length = key_states.shape[-2]
234 | # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
235 | real_seq_length = query_length if query_length is not None else (cache_position[-1] + 1 if cache_position is not None else seq_length)
236 |
237 | if not self.has_relative_attention_bias:
238 | position_bias = torch.zeros(
239 | (1, self.n_heads, seq_length, key_length), device=query_states.device, dtype=query_states.dtype
240 | )
241 | if getattr(self, 'gradient_checkpointing', False) and self.training:
242 | position_bias.requires_grad = True
243 | else:
244 | position_bias = self.compute_bias(
245 | real_seq_length, key_length, device=query_states.device, cache_position=cache_position
246 | )
247 | position_bias = position_bias[:, :, -seq_length:, :]
248 |
249 | # Add mask to position bias if provided
250 | if mask is not None:
251 | causal_mask = mask[:, :, :, : key_states.shape[-2]]
252 | position_bias = position_bias + causal_mask
253 |
254 | # Apply pruned heads if any
255 | if self.pruned_heads:
256 | mask = torch.ones(position_bias.shape[1])
257 | mask[list(self.pruned_heads)] = 0
258 | position_bias_masked = position_bias[:, mask.bool()]
259 | else:
260 | position_bias_masked = position_bias
261 |
262 | # Apply flex attention
263 | attn_output, attention_weights = self._apply_flex_attention(
264 | query=query_states,
265 | key=key_states,
266 | value=value_states,
267 | position_bias=position_bias_masked,
268 | attention_mask=mask,
269 | head_mask=layer_head_mask
270 | )
271 |
272 | # Reshape output and apply output projection
273 | attn_output = attn_output.transpose(1, 2).contiguous()
274 | attn_output = attn_output.view(batch_size, -1, self.inner_dim)
275 | attn_output = self.o(attn_output)
276 |
277 | # Prepare outputs
278 | outputs = (attn_output, past_key_value, position_bias)
279 | if output_attentions:
280 | outputs = outputs + (attention_weights,)
281 | return outputs
282 |
283 |
284 | def replace_t5_attention_with_flex(model, compile_flex=False):
285 | """
286 | Replace all T5Attention modules in a T5 model with T5FlexAttention.
287 |
288 | Args:
289 | model: A T5 model instance
290 |
291 | Returns:
292 | The modified model with flex attention
293 | """
294 | # Recursively replace all T5Attention modules
295 | for name, module in model.named_children():
296 | if isinstance(module, T5AttentionTransformers):
297 | # Create a new T5FlexAttention with the same parameters
298 | flex_attn = T5FlexAttention(
299 | config=model.config,
300 | has_relative_attention_bias=module.has_relative_attention_bias,
301 | layer_idx=module.layer_idx,
302 | compile_flex=compile_flex
303 | )
304 |
305 | # Copy weights
306 | flex_attn.q.weight.data = module.q.weight.data.clone()
307 | flex_attn.k.weight.data = module.k.weight.data.clone()
308 | flex_attn.v.weight.data = module.v.weight.data.clone()
309 | flex_attn.o.weight.data = module.o.weight.data.clone()
310 |
311 | if module.has_relative_attention_bias:
312 | flex_attn.relative_attention_bias.weight.data = module.relative_attention_bias.weight.data.clone()
313 |
314 | # Replace the module
315 | setattr(model, name, flex_attn)
316 | else:
317 | # Recursively process child modules
318 | replace_t5_attention_with_flex(module)
319 |
320 | return model
321 |
322 |
323 |
324 | if __name__ == "__main__":
325 | # py -m wip.t5.t5_flex_attention
326 | from transformers import T5Config
327 |
328 | config = T5Config()
329 |
330 | attention_layer = T5FlexAttention(config=config, has_relative_attention_bias=True, layer_idx=0, compile_flex=False)
331 | hidden_states = torch.randn(1, 10, config.d_model)
332 |
333 | output = attention_layer(hidden_states)
334 |
335 | print(output)
336 |
--------------------------------------------------------------------------------
/licenses/e1_license.txt:
--------------------------------------------------------------------------------
1 | Your use of the Profluent-E1 model code is governed by the Apache License 2.0, while your use of the Profluent-E1 model weights and the full release of the Profluent-E1 model is governed by a similarly permissive license with additional attribution requirements - see the NOTICE file for details. You can use, share, and modify Profluent-E1 for free, but you must follow our ATTRIBUTION guidelines to give credit, include the license when you share and follow some other basic rules. Profluent is not responsible for what you build, and may terminate your rights to use Profluent-E1 if you breach the license.
2 |
3 | Code in src/E1/model/flash_attention_utils.py is adapted from flash-attention project under BSD-3-Clause license.
4 |
5 | Profluent-E1 Notice File
6 | ------------------------
7 |
8 | Copyright 2025 Profluent Bio Inc.
9 |
10 | Licensed under the Profluent-E1 Clickthrough License Agreement (the “Agreement”); you may not use
11 | this file except in compliance with the Agreement. Unless required by applicable law or agreed to
12 | in writing, software distributed under the Agreement is distributed on an "AS IS" BASIS, WITHOUT
13 | WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the Agreement for the specific
14 | language governing permissions and limitations under the Agreement and the Attribution Guidelines
15 | for more information about attribution requirements for use of Profluent-E1. You may obtain a copy
16 | of the Agreement at https://github.com/Profluent-AI/E1/blob/main/LICENSE and a copy of the Attribution
17 | Guidelines at https://github.com/Profluent-AI/E1/blob/main/ATTRIBUTION, each as may be updated or
18 | amended from time to time.
19 |
20 | Profluent-E1 Clickthrough License Agreement
21 | -------------------------------------------
22 |
23 | Please read this Profluent-E1 Clickthrough License Agreement (the “Agreement”) carefully before using Profluent-E1 (as defined below), which is offered by Profluent Bio Inc. (“Profluent”).
24 |
25 | By downloading Profluent-E1, or otherwise using Profluent-E1 in any manner, You agree that You have read and agree to be bound by the terms of this Agreement. If You are accessing Profluent-E1 on behalf of an organization or entity, You represent and warrant that You are authorized to enter into this Agreement on that organization's or entity's behalf and bind them to the terms of this Agreement (in which case, the references to “You” and “Your” in this Agreement, except for in this sentence, refer to that organization or entity). Use of Profluent-E1 and all other Profluent-E1 IP is expressly conditioned upon Your assent to all terms of this Agreement, including the Attribution Guidelines incorporated herein, to the exclusion of all other terms.
26 |
27 | 1. Definitions.
28 |
29 | 1.1 “AAA Rules” shall mean the Commercial Arbitration Rules and Mediation Procedures of the American Arbitration Association (“AAA”).
30 |
31 | 1.2 “Claim” shall mean any claim (including any tort claim), cause of action, and/or dispute under, arising out of, or relating to this Agreement.
32 |
33 | 1.3 “Contribution” shall mean any work of authorship, including the original version of Profluent-E1 and any modifications or additions to that Profluent-E1 or Derivative Works thereof, that is intentionally submitted to Profluent for inclusion in Profluent-E1 by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to Profluent or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, Profluent for the purpose of discussing and improving Profluent-E1, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
34 |
35 | 1.4 “Contributor” shall mean Profluent and any individual or Legal Entity on behalf of whom a Contribution has been received by Profluent and subsequently incorporated within Profluent-E1.
36 |
37 | 1.5 “Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) Profluent-E1 and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this Agreement, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, Profluent-E1 and Derivative Works thereof.
38 |
39 | 1.6 “GitHub Page” shall mean the page made available at https://github.com/Profluent-AI/E1, as may be updated and amended from time to time.
40 |
41 | 1.7 “Hugging Face Pages” shall mean the pages made available at https://huggingface.co/Profluent-Bio/E1-600m, https://huggingface.co/Profluent-Bio/E1-300m and https://huggingface.co/Profluent-Bio/E1-150m, each as may be updated and amended from time to time.
42 |
43 | 1.8 “Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
44 |
45 | 1.9 “Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
46 |
47 | 1.10 “Profluent-E1” shall mean the Profluent-E1 Model Code, Profluent-E1 Model Weights and all software, algorithms, machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed or made available to You by Profluent on the GitHub Page, each as may be updated and amended from time to time, whether in Source or Object form.
48 |
49 | 1.11 “Profluent-E1 Model Code” shall mean the code and data for Profluent-E1 made available to You at https://github.com/Profluent-AI/E1, as may be updated and amended from time to time.
50 |
51 | 1.12 “Profluent-E1 Model Weights” shall mean the trained model weights for Profluent-E1 made available to You on one or more of the Hugging Face Pages, as may be updated and amended from time to time, including all model weights that are directly or indirectly accessed or copied from the Hugging Face Pages by You or any third party.
52 |
53 | 1.13 “Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
54 |
55 | 1.14 “You” or “Your” shall mean an individual entering into this Agreement or the organization or Legal Entity on whose behalf such individual is entering into this Agreement.
56 |
57 | 2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute Profluent-E1 and such Derivative Works in Source or Object form.
58 |
59 | 3. Grant of Patent License. Subject to the terms and conditions of this Agreement, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer Profluent-E1, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with Profluent-E1 to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that Profluent-E1 or a Contribution incorporated within Profluent-E1 constitutes direct or contributory patent infringement, then any patent licenses granted to You under this Agreement for Profluent-E1 shall terminate as of the date such litigation is filed.
60 |
61 | 4. Use of Profluent-E1 Model Code. Your access to and use of the Profluent-E1 Model Code separate and apart from Profluent-E1 and the Profluent-E1 Model Weights is subject to the Apache License, Version 2.0 made available at https://www.apache.org/licenses/LICENSE-2.0, as may be updated and amended from time to time.
62 |
63 | 5. Redistribution. You may reproduce and distribute copies of Profluent-E1 or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
64 |
65 | (a) You must give any other recipients of Profluent-E1 or Derivative Works a copy of this Agreement; and
66 |
67 | (b) You must cause any modified files to carry prominent notices stating that You changed the files; and
68 |
69 | (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of Profluent-E1, excluding those notices that do not pertain to any part of the Derivative Works; and
70 |
71 | (d) If Profluent-E1 includes a "NOTICE" text file as part of its distribution or there is otherwise a “NOTICE” text file available on the GitHub Page, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify this Agreement. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from Profluent-E1, provided that such additional attribution notices cannot be construed as modifying this Agreement; and
72 |
73 | (e) You must comply with the Attribution Guidelines and provide attribution to Profluent-E1 in accordance therewith.
74 |
75 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of Profluent-E1 otherwise complies with the conditions stated in this License.
76 |
77 | 6. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in Profluent-E1 by You to Profluent shall be under the terms and conditions of this Agreement, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Profluent regarding such Contributions.
78 |
79 | 7. Trademarks. This Agreement does not grant permission to use any Profluent trade names, trademarks, service marks, or product names, except as required for reasonable and customary use in describing the origin of Profluent-E1, reproducing the content of the NOTICE file and complying with the Attribution Guidelines.
80 |
81 | 8. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Profluent provides Profluent-E1 (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing Profluent-E1 and assume any risks associated with Your exercise of permissions under this Agreement. FOR THE AVOIDANCE OF DOUBT AND NOTWITHSTANDING ANYTHING TO THE CONTRARY, YOU ACKNOWLEDGE AND AGREE THAT PROFLUENT IS NOT RESPONSIBLE OR LIABLE FOR ANYTHING YOU BUILD, CREATE, DEVELOP OR DERIVE FROM YOUR USE OF PROFLUENT-E1, INCLUDING ANY DERIVATIVE WORKS.
82 |
83 | 9. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use Profluent-E1 (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
84 |
85 | 10. Accepting Warranty or Additional Liability. While redistributing Profluent-E1 or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this Agreement. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
86 |
87 | 11. Term and Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to Profluent-E1 and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Profluent may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of Profluent-E1. All provisions of this Agreement which by their nature should survive termination shall survive termination, including, without limitation, ownership provisions, warranty disclaimers, indemnity obligations, limitations of liability and provisions regarding dispute resolution.
88 |
89 | 12. General.
90 |
91 | 12.1 This Agreement constitutes the entire agreement between You and Profluent relating to the subject matter hereof and supersedes all proposals, understandings, or discussions, whether written or oral, relating to the subject matter of this Agreement and all past dealing or industry custom. The failure of either party to enforce its rights under this Agreement at any time for any period shall not be construed as a waiver of such rights.
92 |
93 | 12.2 Profluent may amend or modify this Agreement from time to time and will use reasonable efforts to provide You with notice of any material changes that may negatively impact Your use of Profluent-E1 through the GitHub Page or through another means made available to You. No other changes, modifications or waivers to this Agreement will be effective unless in writing and signed by both parties.
94 |
95 | 12.3 You and Profluent are independent contractors, and nothing herein shall be deemed to constitute either party as the agent or representative of the other or both parties as joint venturers or partners for any purpose.
96 |
97 | 12.4 You shall comply with the U.S. Foreign Corrupt Practices Act and all applicable export laws, restrictions and regulations of the U.S. Department of Commerce, and any other applicable U.S. and foreign authority.
98 |
99 | 12.5 This Agreement and the rights and obligations herein may not be assigned or transferred, in whole or in part, by You without the prior written consent of Profluent. Any assignment in violation of this provision is void. Profluent may freely assign or transfer this Agreement, in whole or in part. This Agreement shall be binding upon, and inure to the benefit of, the successors and permitted assigns of the parties.
100 |
101 | 12.6 This Agreement shall be governed by and construed under the Federal Arbitration Act, applicable federal law, and the laws of the State of California and the United States without regard to conflicts of laws provisions thereof, and without regard to the Uniform Computer Information Transactions Act.
102 |
103 | 12.7 If You have a Claim, You agree to provide Profluent with at least sixty (60) days' written notice so that the parties may attempt to resolve the Claim internally before requesting arbitration. In case of any Claims or disputes between You and Profluent that cannot be resolved through informal internal discussions, You agree that, subject to the provisions of this section, any and all Claims will be settled by final and binding arbitration in accordance with the AAA Rules by an arbitrator selected pursuant to the AAA Rules; provided, however, that Claims arising out of or relating to Your violations of Profluent's intellectual property rights, including copyright infringement, patent infringement, trademark infringement, or efforts to use Profluent-E1 in unauthorized ways will not be subject to such obligation for settlement by final and binding arbitration, and such claims will instead be brought in the state and federal courts of Alameda County, California. The arbitrator's decision and award will be non-appealable and may be entered in, and will be enforceable in, any court of competent jurisdiction. The arbitration will take place in Alameda County, California. Each party will bear its own costs of arbitration unless the arbitrator directs otherwise. The arbitrator may not award relief or damages in excess of or contrary to what this Agreement provides, or order consolidation or arbitration on a class-wide or representative basis.
104 |
105 | 12.8 If any provision of this Agreement is held to be invalid, illegal or unenforceable in any respect, that provision shall be limited or eliminated to the minimum extent necessary so that this Agreement otherwise remains in full force and effect and enforceable.
106 |
--------------------------------------------------------------------------------
/licenses/ankh_license.txt:
--------------------------------------------------------------------------------
1 | License for ANKH models, from their repo https://github.com/agemagician/Ankh?tab=License-1-ov-file
2 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
3 |
4 | Creative Commons Corporation ("Creative Commons") is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an "as-is" basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
5 |
6 | Using Creative Commons Public Licenses
7 |
8 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
9 |
10 | Considerations for licensors: Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. More considerations for licensors : wiki.creativecommons.org/Considerations_for_licensors
11 |
12 | Considerations for the public: By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor's permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More considerations for the public : wiki.creativecommons.org/Considerations_for_licensees
13 |
14 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
15 |
16 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
17 |
18 | Section 1 – Definitions.
19 |
20 | a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
21 | b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
22 | c. BY-NC-SA Compatible License means a license listed at creativecommons.org/compatiblelicenses, approved by Creative Commons as essentially the equivalent of this Public License.
23 | d. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
24 | e. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
25 | f. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
26 | g. License Elements means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
27 | h. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
28 | i. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
29 | j. Licensor means the individual(s) or entity(ies) granting rights under this Public License.
30 | k. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
31 | l. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
32 | m. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
33 | n. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
34 | Section 2 – Scope.
35 |
36 | a. License grant.
37 | Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
38 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
39 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
40 | Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
41 | Term. The term of this Public License is specified in Section 6(a).
42 | Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
43 | Downstream recipients.
44 | A. Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
45 | B. Additional offer from the Licensor – Adapted Material. Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter's License You apply.
46 | C. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
47 | No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
48 | b. Other rights.
49 | Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
50 | Patent and trademark rights are not licensed under this Public License.
51 | To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
52 | Section 3 – License Conditions.
53 |
54 | Your exercise of the Licensed Rights is expressly made subject to the following conditions.
55 |
56 | a. Attribution.
57 | If You Share the Licensed Material (including in modified form), You must:
58 | A. retain the following if it is supplied by the Licensor with the Licensed Material:
59 |
60 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
61 | ii. a copyright notice;
62 | iii. a notice that refers to this Public License;
63 | iv. a notice that refers to the disclaimer of warranties;
64 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
65 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
66 |
67 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
68 |
69 | You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
70 | If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
71 | b. ShareAlike.In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
72 | The Adapter's License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
73 | You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
74 | You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
75 | Section 4 – Sui Generis Database Rights.
76 |
77 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
78 |
79 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
80 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
81 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
82 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
83 | Section 5 – Disclaimer of Warranties and Limitation of Liability.
84 |
85 | a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.
86 | b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.
87 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
88 | Section 6 – Term and Termination.
89 |
90 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
91 |
92 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
93 |
94 | automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
95 | upon express reinstatement by the Licensor.
96 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
97 |
98 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
99 |
100 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
101 |
102 | Section 7 – Other Terms and Conditions.
103 |
104 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
105 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
106 | Section 8 – Interpretation.
107 |
108 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
109 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
110 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
111 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
112 | Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the "Licensor." The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark "Creative Commons" or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
113 |
114 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/wip/t5/test_t5_flex_attention.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import numpy as np
4 | import time
5 | import matplotlib.pyplot as plt
6 | from transformers import T5Config, T5EncoderModel
7 | from .t5_flex_attention import T5FlexAttention, replace_t5_attention_with_flex
8 | from .t5_attention import T5AttentionTransformers
9 |
10 |
11 | def test_attention_layer_equivalence(
12 | batch_size=2,
13 | seq_length=16,
14 | d_model=32,
15 | num_heads=4,
16 | d_kv=8,
17 | seed=42,
18 | device=None,
19 | measure_time=False,
20 | num_warmup=10,
21 | num_timing_runs=10,
22 | compile_flex=False,
23 | ):
24 | """
25 | Test the equivalence between T5Attention and T5FlexAttention layers.
26 |
27 | Args:
28 | batch_size: Batch size for test inputs
29 | seq_length: Sequence length for test inputs
30 | d_model: Model dimension
31 | num_heads: Number of attention heads
32 | d_kv: Dimension of key and value vectors
33 | seed: Random seed for reproducibility
34 | device: Device to run the test on
35 | measure_time: Whether to measure execution time
36 | num_warmup: Number of warmup runs before timing
37 | num_timing_runs: Number of runs to average timing over
38 | compile_flex: Whether to compile the flex attention module
39 |
40 | Returns:
41 | max_diff: Maximum absolute difference between outputs
42 | mean_diff: Mean absolute difference between outputs
43 | std_time: Average execution time for standard attention (if measure_time=True)
44 | flex_time: Average execution time for flex attention (if measure_time=True)
45 | """
46 | print(f"\n{'='*20} Testing Attention Layer Equivalence {'='*20}")
47 | print(f"Compile flex: {compile_flex}")
48 |
49 | # Set random seed for reproducibility
50 | torch.manual_seed(seed)
51 | np.random.seed(seed)
52 |
53 | # Create a simple T5 config
54 | config = T5Config(
55 | d_model=d_model,
56 | d_kv=d_kv,
57 | num_heads=num_heads,
58 | is_decoder=False,
59 | use_cache=False,
60 | )
61 |
62 | # Create standard T5Attention and T5FlexAttention layers
63 | std_attention = T5AttentionTransformers(config, has_relative_attention_bias=True).to(device)
64 | flex_attention = T5FlexAttention(config, has_relative_attention_bias=True, compile_flex=compile_flex).to(device)
65 |
66 | # Copy weights from standard attention to flex attention
67 | flex_attention.q.weight.data = std_attention.q.weight.data.clone()
68 | flex_attention.k.weight.data = std_attention.k.weight.data.clone()
69 | flex_attention.v.weight.data = std_attention.v.weight.data.clone()
70 | flex_attention.o.weight.data = std_attention.o.weight.data.clone()
71 | flex_attention.relative_attention_bias.weight.data = std_attention.relative_attention_bias.weight.data.clone()
72 |
73 | # Create random input
74 | hidden_states = torch.randn(batch_size, seq_length, d_model).to(device)
75 |
76 | # Set both models to eval mode
77 | std_attention.eval()
78 | flex_attention.eval()
79 |
80 | # Timing measurements
81 | std_time = 0
82 | flex_time = 0
83 |
84 | if measure_time:
85 | # Warmup runs
86 | print(f"Performing {num_warmup} warmup runs...")
87 | for _ in range(num_warmup):
88 | with torch.no_grad():
89 | _ = std_attention(hidden_states)[0]
90 | _ = flex_attention(hidden_states)[0]
91 |
92 | # Timing runs for standard attention
93 | print(f"Measuring standard attention over {num_timing_runs} runs...")
94 | torch.cuda.synchronize() if device.type == 'cuda' else None
95 | start_time = time.time()
96 | for _ in range(num_timing_runs):
97 | with torch.no_grad():
98 | _ = std_attention(hidden_states)[0]
99 | torch.cuda.synchronize() if device.type == 'cuda' else None
100 | std_time = (time.time() - start_time) / num_timing_runs
101 |
102 | # Timing runs for flex attention
103 | print(f"Measuring flex attention over {num_timing_runs} runs...")
104 | torch.cuda.synchronize() if device.type == 'cuda' else None
105 | start_time = time.time()
106 | for _ in range(num_timing_runs):
107 | with torch.no_grad():
108 | _ = flex_attention(hidden_states)[0]
109 | torch.cuda.synchronize() if device.type == 'cuda' else None
110 | flex_time = (time.time() - start_time) / num_timing_runs
111 |
112 | print(f"Standard attention average time: {std_time*1000:.4f} ms")
113 | print(f"Flex attention average time: {flex_time*1000:.4f} ms")
114 | print(f"Speedup: {std_time/flex_time:.2f}x")
115 |
116 | # Forward pass for correctness check
117 | with torch.no_grad():
118 | std_output = std_attention(hidden_states)[0]
119 | flex_output = flex_attention(hidden_states)[0]
120 |
121 | # Calculate differences
122 | abs_diff = torch.abs(std_output - flex_output)
123 | max_diff = torch.max(abs_diff).item()
124 | mean_diff = torch.mean(abs_diff).item()
125 |
126 | print(f"Max absolute difference: {max_diff:.8f}")
127 | print(f"Mean absolute difference: {mean_diff:.8f}")
128 |
129 | if measure_time:
130 | return max_diff, mean_diff, std_time, flex_time
131 | else:
132 | return max_diff, mean_diff
133 |
134 |
135 | def test_model_equivalence(
136 | model_name,
137 | batch_size=2,
138 | seq_length=16,
139 | seed=42,
140 | device=None,
141 | measure_time=False,
142 | num_warmup=10,
143 | num_timing_runs=10,
144 | compile_flex=False,
145 | ):
146 | """
147 | Test the equivalence between a standard T5EncoderModel and one with flex attention.
148 |
149 | Args:
150 | model_name: Name or path of the pretrained model
151 | batch_size: Batch size for test inputs
152 | seq_length: Sequence length for test inputs
153 | seed: Random seed for reproducibility
154 | device: Device to run the test on
155 | measure_time: Whether to measure execution time
156 | num_warmup: Number of warmup runs before timing
157 | num_timing_runs: Number of runs to average timing over
158 | compile_flex: Whether to compile the flex attention module
159 | Returns:
160 | max_diff: Maximum absolute difference between outputs
161 | mean_diff: Mean absolute difference between outputs
162 | std_time: Average execution time for standard model (if measure_time=True)
163 | flex_time: Average execution time for flex model (if measure_time=True)
164 | """
165 | print(f"\n{'='*20} Testing Model Equivalence {'='*20}")
166 | print(f"Using pretrained model: {model_name}")
167 | print(f"Compile flex: {compile_flex}")
168 |
169 | # Set random seed for reproducibility
170 | torch.manual_seed(seed)
171 | np.random.seed(seed)
172 |
173 | # Load the pretrained model
174 | std_model = T5EncoderModel.from_pretrained(model_name).to(device)
175 |
176 | # Create a copy with flex attention
177 | flex_model = T5EncoderModel.from_pretrained(model_name).to(device)
178 | flex_model = replace_t5_attention_with_flex(flex_model, compile_flex=compile_flex)
179 |
180 | # Create random input IDs
181 | input_ids = torch.randint(0, std_model.config.vocab_size, (batch_size, seq_length)).to(device)
182 | attention_mask = torch.ones_like(input_ids).to(device)
183 |
184 | # Set both models to eval mode
185 | std_model.eval()
186 | flex_model.eval()
187 |
188 | # Timing measurements
189 | std_time = 0
190 | flex_time = 0
191 |
192 | if measure_time:
193 | # Warmup runs
194 | print(f"Performing {num_warmup} warmup runs...")
195 | for _ in range(num_warmup):
196 | with torch.no_grad():
197 | _ = std_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
198 | _ = flex_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
199 |
200 | # Timing runs for standard model
201 | print(f"Measuring standard model over {num_timing_runs} runs...")
202 | torch.cuda.synchronize() if device.type == 'cuda' else None
203 | start_time = time.time()
204 | for _ in range(num_timing_runs):
205 | with torch.no_grad():
206 | _ = std_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
207 | torch.cuda.synchronize() if device.type == 'cuda' else None
208 | std_time = (time.time() - start_time) / num_timing_runs
209 |
210 | # Timing runs for flex model
211 | print(f"Measuring flex model over {num_timing_runs} runs...")
212 | torch.cuda.synchronize() if device.type == 'cuda' else None
213 | start_time = time.time()
214 | for _ in range(num_timing_runs):
215 | with torch.no_grad():
216 | _ = flex_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
217 | torch.cuda.synchronize() if device.type == 'cuda' else None
218 | flex_time = (time.time() - start_time) / num_timing_runs
219 |
220 | print(f"Standard model average time: {std_time*1000:.4f} ms")
221 | print(f"Flex model average time: {flex_time*1000:.4f} ms")
222 | print(f"Speedup: {std_time/flex_time:.2f}x")
223 |
224 | # Forward pass for correctness check
225 | with torch.no_grad():
226 | std_output = std_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
227 | flex_output = flex_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
228 |
229 | # Calculate differences
230 | abs_diff = torch.abs(std_output - flex_output)
231 | max_diff = torch.max(abs_diff).item()
232 | mean_diff = torch.mean(abs_diff).item()
233 |
234 | print(f"Max absolute difference in last_hidden_state: {max_diff:.8f}")
235 | print(f"Mean absolute difference in last_hidden_state: {mean_diff:.8f}")
236 |
237 | if measure_time:
238 | return max_diff, mean_diff, std_time, flex_time
239 | else:
240 | return max_diff, mean_diff
241 |
242 |
243 | if __name__ == "__main__":
244 | # py -m wip.t5.test_t5_flex_attention
245 | # py -m wip.t5.test_t5_flex_attention --measure_time --seq_length_range --num_timing_runs 100
246 | parser = argparse.ArgumentParser(description="Test T5 Flex Attention equivalence")
247 | parser.add_argument("--model_name", type=str, default="Synthyra/ANKH_base",
248 | help="Pretrained model name or path")
249 | parser.add_argument("--batch_size", type=int, default=2,
250 | help="Batch size for test inputs (default: 2)")
251 | parser.add_argument("--seq_length", type=int, default=16,
252 | help="Sequence length for test inputs (default: 16)")
253 | parser.add_argument("--seq_length_range", action="store_true",
254 | help="Test a range of sequence lengths from 8 to 2048")
255 | parser.add_argument("--seed", type=int, default=42,
256 | help="Random seed for reproducibility (default: 42)")
257 | parser.add_argument("--tolerance", type=float, default=1e-5,
258 | help="Tolerance for differences (default: 1e-5)")
259 | parser.add_argument("--measure_time", action="store_true",
260 | help="Measure execution time")
261 | parser.add_argument("--num_warmup", type=int, default=10,
262 | help="Number of warmup runs before timing (default: 10)")
263 | parser.add_argument("--num_timing_runs", type=int, default=10,
264 | help="Number of runs to average timing over (default: 10)")
265 | parser.add_argument("--compile_flex", action="store_true",
266 | help="Compile flex attention")
267 | args = parser.parse_args()
268 |
269 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
270 | print(f"Running tests on device: {device}")
271 |
272 | if args.seq_length_range and args.measure_time:
273 | # Test a range of sequence lengths
274 | seq_lengths = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
275 | std_times = []
276 | flex_times = []
277 | max_diffs = []
278 |
279 | print(f"\n{'='*20} Testing Sequence Length Range {'='*20}")
280 |
281 | for seq_length in seq_lengths:
282 | print(f"\nTesting sequence length: {seq_length}")
283 |
284 | # Test attention layer equivalence
285 | layer_max_diff, layer_mean_diff, layer_std_time, layer_flex_time = test_attention_layer_equivalence(
286 | batch_size=args.batch_size,
287 | seq_length=seq_length,
288 | seed=args.seed,
289 | device=device,
290 | measure_time=True,
291 | num_warmup=args.num_warmup,
292 | num_timing_runs=args.num_timing_runs,
293 | compile_flex=args.compile_flex
294 | )
295 |
296 | std_times.append(layer_std_time * 1000) # Convert to ms
297 | flex_times.append(layer_flex_time * 1000) # Convert to ms
298 | max_diffs.append(layer_max_diff)
299 |
300 | # Create plot for timing results
301 | plt.figure(figsize=(12, 8))
302 |
303 | # Plot timing results
304 | plt.subplot(2, 1, 1)
305 | plt.plot(seq_lengths, std_times, 'o-', label='Standard Attention')
306 | plt.plot(seq_lengths, flex_times, 'o-', label='Flex Attention')
307 | plt.xscale('log', base=2)
308 | plt.yscale('log')
309 | plt.xlabel('Sequence Length')
310 | plt.ylabel('Time (ms)')
311 | plt.title('Attention Layer Execution Time vs Sequence Length')
312 | plt.grid(True, which="both", ls="--")
313 | plt.legend()
314 |
315 | # Plot speedup
316 | plt.subplot(2, 1, 2)
317 | speedups = [std/flex for std, flex in zip(std_times, flex_times)]
318 | plt.plot(seq_lengths, speedups, 'o-', color='green')
319 | plt.xscale('log', base=2)
320 | plt.xlabel('Sequence Length')
321 | plt.ylabel('Speedup (x)')
322 | plt.title('Flex Attention Speedup vs Sequence Length')
323 | plt.grid(True, which="both", ls="--")
324 |
325 | plt.tight_layout()
326 | plt.savefig('sequence_length_timing.png')
327 | print(f"\nPlot saved to sequence_length_timing.png")
328 |
329 | # Print results in table format
330 | print(f"\n{'='*60}")
331 | print(f"{'Sequence Length':^15} | {'Standard (ms)':^15} | {'Flex (ms)':^15} | {'Speedup':^10}")
332 | print(f"{'-'*60}")
333 | for i, seq_len in enumerate(seq_lengths):
334 | print(f"{seq_len:^15} | {std_times[i]:^15.4f} | {flex_times[i]:^15.4f} | {speedups[i]:^10.2f}")
335 | print(f"{'='*60}")
336 |
337 | else:
338 | # Test attention layer equivalence
339 | if args.measure_time:
340 | layer_max_diff, layer_mean_diff, layer_std_time, layer_flex_time = test_attention_layer_equivalence(
341 | batch_size=args.batch_size,
342 | seq_length=args.seq_length,
343 | seed=args.seed,
344 | device=device,
345 | measure_time=True,
346 | num_warmup=args.num_warmup,
347 | num_timing_runs=args.num_timing_runs,
348 | compile_flex=args.compile_flex
349 | )
350 | else:
351 | layer_max_diff, layer_mean_diff = test_attention_layer_equivalence(
352 | batch_size=args.batch_size,
353 | seq_length=args.seq_length,
354 | seed=args.seed,
355 | device=device,
356 | compile_flex=args.compile_flex
357 | )
358 |
359 | # Test model equivalence
360 | if args.measure_time:
361 | model_max_diff, model_mean_diff, model_std_time, model_flex_time = test_model_equivalence(
362 | model_name=args.model_name,
363 | batch_size=args.batch_size,
364 | seq_length=args.seq_length,
365 | seed=args.seed,
366 | device=device,
367 | measure_time=True,
368 | num_warmup=args.num_warmup,
369 | num_timing_runs=args.num_timing_runs,
370 | compile_flex=args.compile_flex
371 | )
372 | else:
373 | model_max_diff, model_mean_diff = test_model_equivalence(
374 | model_name=args.model_name,
375 | batch_size=args.batch_size,
376 | seq_length=args.seq_length,
377 | seed=args.seed,
378 | device=device,
379 | compile_flex=args.compile_flex
380 | )
381 |
382 | # Check if differences are within tolerance
383 | print(f"\n{'='*20} Results {'='*20}")
384 | print(f"Tolerance threshold: {args.tolerance}")
385 |
386 | if layer_max_diff <= args.tolerance:
387 | print(f"✅ Attention layer test PASSED: Max diff {layer_max_diff:.8f} <= {args.tolerance}")
388 | else:
389 | print(f"❌ Attention layer test FAILED: Max diff {layer_max_diff:.8f} > {args.tolerance}")
390 |
391 | if model_max_diff <= args.tolerance:
392 | print(f"✅ Model test PASSED: Max diff {model_max_diff:.8f} <= {args.tolerance}")
393 | else:
394 | print(f"❌ Model test FAILED: Max diff {model_max_diff:.8f} > {args.tolerance}")
395 |
396 | # Print timing summary if measured
397 | if args.measure_time:
398 | print(f"\n{'='*20} Timing Summary {'='*20}")
399 | print(f"Attention Layer (compile_flex={args.compile_flex}):")
400 | print(f" Standard: {layer_std_time*1000:.4f} ms")
401 | print(f" Flex: {layer_flex_time*1000:.4f} ms")
402 | print(f" Speedup: {layer_std_time/layer_flex_time:.2f}x")
403 |
404 | print(f"\nFull Model (compile_flex={args.compile_flex}):")
405 | print(f" Standard: {model_std_time*1000:.4f} ms")
406 | print(f" Flex: {model_flex_time*1000:.4f} ms")
407 | print(f" Speedup: {model_std_time/model_flex_time:.2f}x")
--------------------------------------------------------------------------------