├── ckpt
├── .placeholder
└── huggingface
│ └── .placeholder
├── .gitignore
├── figures
└── overview.png
├── examples
└── inputs.tsv
├── environment.sh
├── model
├── Evolla
│ ├── llm_interface.py
│ ├── encoder_interface.py
│ ├── sequence_encoder_saprot.py
│ ├── fusion_module.py
│ ├── injection_module.py
│ ├── Evolla_model.py
│ └── llama_llm.py
└── model_interface.py
├── config
└── Evolla_10B.yaml
├── LICENSE
├── scripts
└── inference.py
├── utils
├── easydict.py
└── others.py
└── README.md
/ckpt/.placeholder:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ckpt/huggingface/.placeholder:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | .idea/
3 | __pycache__/
4 | tests
5 | ckpt
--------------------------------------------------------------------------------
/figures/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/westlake-repl/Evolla/HEAD/figures/overview.png
--------------------------------------------------------------------------------
/examples/inputs.tsv:
--------------------------------------------------------------------------------
1 | C9RH78 MLLEETLKSCPIVKRGKYHYFIHPISDGVPLVEPKLLREVATRIIKIGNFEGVNKIVTAEAMGIPLVTTLSLYTDIPYVIMRKREYKLPGEVPVFQSTGYSKGQLYLNGIEKGDKVIIIDDVISTGGTMIAIINALERAGAEIKDIICVIERGDGKKIVEEKTGYKIKTLVKIDVVDGEVVIL dvvvvqqqpfawdddppdtdgcgclapvpdpddpvvlvvllvlcvvpadpvqaqeeeeeddscpsnvvsncvvpvhyydywylddppdppkdwqwf######gitidpdqaaaheyeyeeaeqdqlrvvlsvvvrcvvrnyhhrayeyaeyhycnqvvccvvpvghyhynwywdqdpsgidtd "What is the catalytic activity of this protein?"
2 |
--------------------------------------------------------------------------------
/environment.sh:
--------------------------------------------------------------------------------
1 | # works on 2025/01/09
2 | pip install pyyaml
3 | pip3 install torch torchvision torchaudio
4 | # pip3 install torch torchvision torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple
5 | pip install tqdm
6 | pip install lightning
7 | # pip install lightning -i https://pypi.tuna.tsinghua.edu.cn/simple
8 | pip install transformers
9 | # pip install transformers -i https://pypi.tuna.tsinghua.edu.cn/simple
10 | pip install einops
11 | pip install einops_exts
12 | pip install peft
13 | # pip install peft -i https://pypi.tuna.tsinghua.edu.cn/simple
14 | pip install -U bitsandbytes
15 | # pip install -U bitsandbytes -i https://pypi.tuna.tsinghua.edu.cn/simple
--------------------------------------------------------------------------------
/model/Evolla/llm_interface.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def register_llm(cls):
4 | global now_cls
5 | now_cls = cls
6 | return cls
7 |
8 | class LLMInterface:
9 | @classmethod
10 | def init_llm(cls, model_py_path, **kwargs):
11 | """
12 | Initialize model from python file.
13 | Args:
14 | model_py_path: Path to model python file. e.g. model/transformer.py
15 | **kwargs: Kwargs for model initialization
16 |
17 | Returns:
18 | Initialized model
19 | """
20 | sub_dirs = model_py_path.split(os.sep)
21 | cmd = f"from {'.'.join(sub_dirs[:-1])} import {sub_dirs[-1].split('.')[0]}"
22 | exec(cmd)
23 | return now_cls(**kwargs)
24 |
--------------------------------------------------------------------------------
/model/model_interface.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def register_model(cls):
4 | global now_cls
5 | now_cls = cls
6 | return cls
7 |
8 | class ModelInterface:
9 | @classmethod
10 | def init_model(cls, model_py_path, **kwargs):
11 | """
12 | Initialize model from python file.
13 | Args:
14 | model_py_path: Path to model python file. e.g. model/transformer.py
15 | **kwargs: Kwargs for model initialization
16 |
17 | Returns:
18 | Initialized model
19 | """
20 | sub_dirs = model_py_path.split(os.sep)
21 | cmd = f"from {'.'.join(sub_dirs[:-1])} import {sub_dirs[-1].split('.')[0]}"
22 | exec(cmd)
23 | return now_cls(**kwargs)
24 |
--------------------------------------------------------------------------------
/model/Evolla/encoder_interface.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def register_encoder(cls):
4 | global now_cls
5 | now_cls = cls
6 | return cls
7 |
8 | class EncoderInterface:
9 | @classmethod
10 | def init_encoder(cls, model_py_path, **kwargs):
11 | """
12 | Initialize model from python file.
13 | Args:
14 | model_py_path: Path to model python file. e.g. model/transformer.py
15 | **kwargs: Kwargs for model initialization
16 |
17 | Returns:
18 | Initialized model
19 | """
20 | sub_dirs = model_py_path.split(os.sep)
21 | cmd = f"from {'.'.join(sub_dirs[:-1])} import {sub_dirs[-1].split('.')[0]}"
22 | exec(cmd)
23 | return now_cls(**kwargs)
24 |
--------------------------------------------------------------------------------
/config/Evolla_10B.yaml:
--------------------------------------------------------------------------------
1 | setting:
2 | seed: 42
3 | # from_checkpoint: ckpt/Evolla-10B
4 | from_checkpoint: ckpt/huggingface/Evolla-10B/Evolla-10B
5 |
6 | model:
7 | cls: model/Evolla/Evolla_model.py
8 | generate_config:
9 | max_new_tokens: 512
10 | do_sample: True
11 | temperature: 0.6
12 | top_p: 0.9
13 | config:
14 | text_length: 2048
15 | protein_encoder:
16 | cls: model/Evolla/sequence_encoder_saprot.py
17 | config_path: ckpt/huggingface/SaProt_650M_AF2
18 | fusion_module:
19 | cls: SequenceCompressorResampler
20 | depth: 6
21 | heads: 8
22 | num_latents: 64
23 | ff_mult: 4
24 | llm:
25 | cls: model/Evolla/llama_llm.py
26 | hf_dir: ckpt/huggingface/meta-llama_Meta-Llama-3-8B-Instruct
27 | cross_attention_config:
28 | ffn_mult: 4
29 | enable_bias: true
30 | attention_probs_dropout_prob: 0.1
31 | quantization: 8bit
32 | # quantization: false
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 westlake-repl
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/scripts/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | HOME_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
4 | sys.path.append(HOME_PATH)
5 |
6 | import argparse
7 | import json
8 |
9 | import traceback
10 | from threading import Thread
11 | from utils.others import setup_seed, load_config, load_model_from_config
12 |
13 | from transformers import TextIteratorStreamer
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--config_path", type=str, required=True)
17 | parser.add_argument("--input_path", type=str, required=True, help="Path to the input file, each line is a tab-separated triplet of Uniprot ID (or any other identifier), sequence, foldseek, and question (in JSON format)")
18 | args = parser.parse_args()
19 | CONFIG_PATH = args.config_path
20 | input_path = args.input_path
21 |
22 | config = load_config(CONFIG_PATH)
23 |
24 | if config.setting.seed:
25 | setup_seed(config.setting.seed)
26 |
27 |
28 | model = load_model_from_config(config, local_rank=0, dtype="bf16")
29 |
30 | with open(input_path, "r") as f:
31 | for line in f:
32 | line = line.strip()
33 | uniprot_id, sequence, foldseek, question = line.split("\t")
34 | question = json.loads(question)
35 | streamer = TextIteratorStreamer(
36 | model.llm_tokenizer,
37 | # skip_prompt=True,
38 | skip_prompt=False,
39 | skip_special_tokens=True,
40 | )
41 |
42 | mixed_sequence = "".join([s+f for s, f in zip(sequence, foldseek)])
43 | print(f"{uniprot_id}")
44 | print(f"{question}")
45 | print(f"{mixed_sequence}")
46 | generation_kwargs = {
47 | "seqs": [mixed_sequence],
48 | "foldseeks": [None],
49 | "questions": [question],
50 | "streamer": streamer,
51 | }
52 |
53 | def generate_wrapper():
54 | try:
55 | model.generate(**generation_kwargs, **model.generate_config)
56 | except Exception as e:
57 | # traceback the exception
58 | traceback.print_exc()
59 | print(f"Exception in generate_wrapper: {e}")
60 |
61 | thread = Thread(target=generate_wrapper)
62 | thread.start()
63 | for a in streamer:
64 | print(a, end="", flush=True)
65 | thread.join()
66 | print("=" * 50)
67 |
--------------------------------------------------------------------------------
/model/Evolla/sequence_encoder_saprot.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from transformers import EsmConfig, EsmForMaskedLM, EsmModel, EsmTokenizer
5 |
6 | from .encoder_interface import register_encoder
7 | from .fusion_module import SequenceCompressorResampler
8 |
9 |
10 | @register_encoder
11 | class SaProtSequenceEncoder(nn.Module):
12 | def __init__(
13 | self,
14 | config_path: str,
15 | load_pretrained: bool = True,
16 | fusion_module: dict = None,
17 | **kwargs,
18 | ):
19 | super().__init__()
20 | if load_pretrained:
21 | # self.model = EsmModel.from_pretrained(config_path)
22 | self.model = EsmForMaskedLM.from_pretrained(config_path)
23 | self.config = EsmConfig.from_pretrained(config_path)
24 | else:
25 | self.config = EsmConfig.from_pretrained(config_path)
26 | # self.model = EsmModel(self.config)
27 | self.model = EsmForMaskedLM(self.config)
28 |
29 | self.tokenizer = EsmTokenizer.from_pretrained(config_path)
30 |
31 | fusion_cls = fusion_module.pop("cls", None)
32 | if fusion_cls is None or fusion_cls == "SequenceCompressorResampler":
33 | self.resampler = SequenceCompressorResampler(**fusion_module)
34 | else:
35 | raise ValueError(f"Unknown fusion module class: {fusion_cls}")
36 |
37 | @property
38 | def num_layers(self):
39 | return len(self.model.encoder.layer)
40 |
41 | def sequence_encode(self, seqs):
42 | """
43 | Encode protein sequence into protein representation
44 | """
45 | seqs = [seq if seq is not None else "" for seq in seqs]
46 | protein_tokens = self.tokenizer.batch_encode_plus(
47 | seqs, return_tensors="pt", truncation=True, max_length=1026, padding=True
48 | ).to(self.model.device)
49 |
50 | protein_output = self.model(
51 | protein_tokens["input_ids"],
52 | protein_tokens["attention_mask"],
53 | return_dict=True,
54 | output_hidden_states=True,
55 | )
56 |
57 | protein_embeds = protein_output.hidden_states[-1]
58 |
59 | mask = protein_tokens["attention_mask"]
60 |
61 | return protein_embeds, mask
62 |
63 | def forward(self, seqs):
64 | # create batch mask for seqs
65 | seqs_batch_mask = torch.tensor(
66 | [True if seq is not None else False for seq in seqs]
67 | )
68 | # print("this is structure encoder", flush=True)
69 | sequence_embeds, mask = self.sequence_encode(seqs)
70 |
71 | sequence_repr = self.resampler(sequence_embeds, mask)
72 |
73 | return sequence_repr, sequence_embeds, mask, seqs_batch_mask
74 |
75 |
--------------------------------------------------------------------------------
/utils/easydict.py:
--------------------------------------------------------------------------------
1 | # from easydict package
2 | # https://github.com/makinacorpus/easydict
3 | class MyEasyDict(dict):
4 | """
5 | Get attributes
6 |
7 | >>> d = EasyDict({'foo':3})
8 | >>> d['foo']
9 | 3
10 | >>> d.foo
11 | 3
12 | >>> d.bar
13 | Traceback (most recent call last):
14 | ...
15 | AttributeError: 'EasyDict' object has no attribute 'bar'
16 |
17 | Works recursively
18 |
19 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
20 | >>> isinstance(d.bar, dict)
21 | True
22 | >>> d.bar.x
23 | 1
24 |
25 | Bullet-proof
26 |
27 | >>> EasyDict({})
28 | {}
29 | >>> EasyDict(d={})
30 | {}
31 | >>> EasyDict(None)
32 | {}
33 | >>> d = {'a': 1}
34 | >>> EasyDict(**d)
35 | {'a': 1}
36 |
37 | Set attributes
38 |
39 | >>> d = EasyDict()
40 | >>> d.foo = 3
41 | >>> d.foo
42 | 3
43 | >>> d.bar = {'prop': 'value'}
44 | >>> d.bar.prop
45 | 'value'
46 | >>> d
47 | {'foo': 3, 'bar': {'prop': 'value'}}
48 | >>> d.bar.prop = 'newer'
49 | >>> d.bar.prop
50 | 'newer'
51 |
52 |
53 | Values extraction
54 |
55 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
56 | >>> isinstance(d.bar, list)
57 | True
58 | >>> from operator import attrgetter
59 | >>> map(attrgetter('x'), d.bar)
60 | [1, 3]
61 | >>> map(attrgetter('y'), d.bar)
62 | [2, 4]
63 | >>> d = EasyDict()
64 | >>> d.keys()
65 | []
66 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
67 | >>> d.foo
68 | 3
69 | >>> d.bar.x
70 | 1
71 |
72 | Still like a dict though
73 |
74 | >>> o = EasyDict({'clean':True})
75 | >>> o.items()
76 | [('clean', True)]
77 |
78 | And like a class
79 |
80 | >>> class Flower(EasyDict):
81 | ... power = 1
82 | ...
83 | >>> f = Flower()
84 | >>> f.power
85 | 1
86 | >>> f = Flower({'height': 12})
87 | >>> f.height
88 | 12
89 | >>> f['power']
90 | 1
91 | >>> sorted(f.keys())
92 | ['height', 'power']
93 |
94 | update and pop items
95 | >>> d = EasyDict(a=1, b='2')
96 | >>> e = EasyDict(c=3.0, a=9.0)
97 | >>> d.update(e)
98 | >>> d.c
99 | 3.0
100 | >>> d['c']
101 | 3.0
102 | >>> d.get('c')
103 | 3.0
104 | >>> d.update(a=4, b=4)
105 | >>> d.b
106 | 4
107 | >>> d.pop('a')
108 | 4
109 | >>> d.a
110 | Traceback (most recent call last):
111 | ...
112 | AttributeError: 'EasyDict' object has no attribute 'a'
113 | """
114 |
115 | def __init__(self, d=None, **kwargs):
116 | if d is None:
117 | d = {}
118 | if kwargs:
119 | d.update(**kwargs)
120 | for k, v in d.items():
121 | setattr(self, k, v)
122 | # Class attributes
123 | for k in self.__class__.__dict__.keys():
124 | if not (k.startswith("__") and k.endswith("__")) and not k in (
125 | "update",
126 | "pop",
127 | ):
128 | setattr(self, k, getattr(self, k))
129 |
130 | def __setattr__(self, name, value):
131 | if isinstance(value, (list, tuple)):
132 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
133 | elif isinstance(value, dict) and not isinstance(value, self.__class__):
134 | value = self.__class__(value)
135 | super(MyEasyDict, self).__setattr__(name, value)
136 | super(MyEasyDict, self).__setitem__(name, value)
137 |
138 | __setitem__ = __setattr__
139 |
140 | def update(self, e=None, **f):
141 | d = e or dict()
142 | d.update(f)
143 | for k in d:
144 | setattr(self, k, d[k])
145 |
146 | def pop(self, k, d=None):
147 | if k not in self:
148 | return d
149 | delattr(self, k)
150 | return super(MyEasyDict, self).pop(k, d)
151 |
152 | def __getattr__(self, name):
153 | return self.__class__.__dict__.get(name, None)
154 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Evolla
2 |
3 |
4 |
5 |
6 |
7 | A frontier protein-language generative model designed to decode the molecular language of proteins.
8 |
9 | *Quickly try our online server (Evolla-10B) [here](http://www.chat-protein.com/).*
10 |
11 | Table of contents
12 |
13 | - [News](#News)
14 | - [Overview](#Overview)
15 | - [Enviroment installation](#Enviroment-installation)
16 | - [Prepare the Evolla model](#Prepare-the-Evolla-model)
17 | - [Prepare input data](#Prepare-input-data)
18 | - [Run Evolla](#Run-Evolla)
19 | - [Citation](#Citation)
20 |
21 |
22 | > We have 2 PhD positions for international students at Westlake University, China! see [here](https://x.com/duguyuan/status/1897101692665258245).
23 | >
24 | ## News
25 | - **2025/01/06** We released our paper [Decoding the Molecular Language of Proteins with Evolla](https://doi.org/10.1101/2025.01.05.630192).
26 | - **2024/12/06** We uploaded the [Evolla-10B model](https://huggingface.co/westlake-repl/Evolla-10B) to `huggingface hub`.
27 | ## Overview
28 |
29 | 
30 |
31 | ## Enviroment installation
32 |
33 | ### Create a virtual environment
34 | ```
35 | conda create -n Evolla python=3.10
36 | conda activate Evolla
37 | ```
38 |
39 | ### Install packages
40 | ```
41 | bash environment.sh
42 | ```
43 |
44 | ## Prepare the Evolla model
45 |
46 | We provide the pre-trained Evolla-10B model in `huggingface hub`. You can download the model by running the following command:
47 | ```
48 | cd ckpt/huggingface
49 |
50 | git lfs install
51 |
52 | git clone https://huggingface.co/westlake-repl/Evolla-10B
53 |
54 | git clone https://huggingface.co/westlake-repl/SaProt_650M_AF2
55 |
56 | git clone https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
57 | ```
58 |
59 | ### Model checkpoints
60 |
61 | |**Name** |**Size** |
62 | |---------|---------|
63 | |[Evolla-10B](https://huggingface.co/westlake-repl/Evolla-10B) | 10B |
64 |
65 | ## Prepare input data
66 |
67 | We provide a sample input file `examples/inputs.tsv` for you to test the Evolla model. The input file should be a tab-separated file, where each line represents `(protein_id, aa_sequence, foldseek_sequence, question_in_json_string)`.
68 |
69 | Note: `protein_id` is the identifier of the line, `aa_sequence` is the amino acid sequence of the protein, `foldseek_sequence` is the sequence of the protein in FoldSeek format. `question_in_json_string` is the question which is dumped by `json.dumps` function.
70 |
71 |
72 | ## Run Evolla
73 |
74 | ### Use `inference.py`
75 |
76 | The following provides script to run inference based on tsv file.
77 |
78 | You should replace the `/your/path/to/Evolla` to your own path to `Evolla` directory.
79 |
80 | ```
81 | cd /your/path/to/Evolla
82 | python scripts/inference.py --config_path config/Evolla_10B.yaml --input_path examples/inputs.tsv
83 | ```
84 |
85 | ## Citation
86 |
87 | If you find this repository useful, please cite our paper:
88 |
89 | ```
90 | @article{zhou2025decoding,
91 | title={Decoding the Molecular Language of Proteins with Evolla},
92 | author={Zhou, Xibin and Han, Chenchen and Zhang, Yingqi and Su, Jin and Zhuang, Kai and Jiang, Shiyu and Yuan, Zichen and Zheng, Wei and Dai, Fengyuan and Zhou, Yuyang and others},
93 | journal={bioRxiv},
94 | pages={2025--01},
95 | year={2025},
96 | publisher={Cold Spring Harbor Laboratory}
97 | }
98 | ```
99 | ### Other resources
100 |
101 | - [ProTrek](https://www.biorxiv.org/content/10.1101/2024.05.30.596740v2) and its [online server](http://search-protrek.com/)
102 | - [Pinal](https://www.biorxiv.org/content/10.1101/2024.08.01.606258v2) and its [online server](http://www.denovo-pinal.com/)
103 | - [SaprotHub](https://www.biorxiv.org/content/10.1101/2024.05.24.595648v5) and its [online server](https://colab.research.google.com/github/westlake-repl/SaprotHub/blob/main/colab/SaprotHub_v2.ipynb?hl=en)
104 |
--------------------------------------------------------------------------------
/model/Evolla/fusion_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange, repeat
3 | from einops_exts import rearrange_many
4 | from torch import einsum, nn
5 |
6 |
7 | def FeedForward(dim, mult=4):
8 | inner_dim = int(dim * mult)
9 | return nn.Sequential(
10 | nn.LayerNorm(dim),
11 | nn.Linear(dim, inner_dim, bias=False),
12 | nn.GELU(),
13 | nn.Linear(inner_dim, dim, bias=False),
14 | )
15 |
16 |
17 | class SequenceCompressorAttention(nn.Module):
18 | def __init__(self, dim, dim_head=64, heads=8):
19 | super().__init__()
20 | self.scale = dim_head**-0.5
21 | self.heads = heads
22 | inner_dim = dim_head * heads
23 |
24 | self.norm_media = nn.LayerNorm(dim)
25 | self.norm_latents = nn.LayerNorm(dim)
26 |
27 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
28 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
29 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
30 |
31 | def forward(self, x, latents, mask):
32 | """
33 | Args:
34 | x (torch.Tensor): image features
35 | shape (b, n1, D)
36 | latent (torch.Tensor): latent features
37 | shape (b, n2, D); n2: num of latent tokens
38 | """
39 | x = self.norm_media(x)
40 | latents = self.norm_latents(latents)
41 |
42 | h = self.heads
43 |
44 | q = self.to_q(latents)
45 | kv_input = torch.cat((x, latents), dim=-2)
46 | k, v = self.to_kv(kv_input).chunk(
47 | 2, dim=-1
48 | ) # each: batch_size, max_protein_length+num_latents, dim_head*num_heads
49 |
50 | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
51 | q = q * self.scale # batch_size, num_heads, num_latents, dim_head
52 |
53 | # attention
54 | sim = einsum("... i d, ... j d -> ... i j", q, k)
55 |
56 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
57 |
58 | bs, nh, skd, okd = sim.shape
59 | mask = repeat(mask, "bs okd -> bs nh skd okd", nh=nh, skd=skd)
60 |
61 | sim = sim.masked_fill((1 - mask).bool(), -1e4)
62 | # sim = sim + (1 - mask) * torch.tensor(float('-inf'), dtype=sim.dtype) # 加上mask
63 | attn = sim.softmax(dim=-1)
64 |
65 | out = einsum("... i j, ... j d -> ... i d", attn, v)
66 |
67 | out = rearrange(out, "b h n d -> b n (h d)", h=h)
68 | return self.to_out(out)
69 |
70 |
71 | class SequenceCompressorResampler(nn.Module):
72 | def __init__(
73 | self,
74 | protein_repr_dim,
75 | output_repr_dim,
76 | depth=6,
77 | dim_head=64,
78 | heads=8,
79 | num_latents=64,
80 | ff_mult=4,
81 | ):
82 | super().__init__()
83 | self.latents = nn.Parameter(torch.randn(num_latents, protein_repr_dim))
84 |
85 | self.layers = nn.ModuleList([])
86 | for _ in range(depth):
87 | self.layers.append(
88 | nn.ModuleList(
89 | [
90 | SequenceCompressorAttention(
91 | dim=protein_repr_dim, dim_head=dim_head, heads=heads
92 | ),
93 | FeedForward(dim=protein_repr_dim, mult=ff_mult),
94 | ]
95 | )
96 | )
97 |
98 | self.norm = nn.LayerNorm(output_repr_dim)
99 |
100 | self.protein_projector = nn.Linear(protein_repr_dim, output_repr_dim)
101 |
102 | self.num_latents = num_latents
103 |
104 | @property
105 | def device(self):
106 | return self.latents.device
107 |
108 | @property
109 | def dtype(self):
110 | return self.latents.dtype
111 |
112 | def forward(self, embeds, mask):
113 |
114 | b = embeds.shape[0]
115 |
116 | bs, _ = mask.shape # bs, max_protein_length
117 | latent_mask = torch.ones(bs, self.num_latents).to(mask.device)
118 | mask = torch.cat(
119 | (mask, latent_mask), dim=1
120 | ) # bs, max_protein_length + num_latents
121 |
122 | # blocks
123 | latents = repeat(self.latents, "n d -> b n d", b=b)
124 | for attn, ff in self.layers:
125 | latents = attn(embeds, latents, mask) + latents
126 | latents = ff(latents) + latents
127 |
128 | transformed_feature = self.protein_projector(latents)
129 |
130 | return self.norm(transformed_feature)
131 |
132 | class MLPResampler(nn.Module):
133 | def __init__(
134 | self,
135 | protein_repr_dim,
136 | output_repr_dim,
137 | ):
138 | super().__init__()
139 | self.model = nn.Sequential(
140 | nn.Linear(protein_repr_dim, output_repr_dim),
141 | nn.ReLU(),
142 | nn.Linear(output_repr_dim, output_repr_dim),
143 | nn.LayerNorm(output_repr_dim),
144 | )
145 |
146 | def forward(self, embeds, mask):
147 | return self.model(embeds)
--------------------------------------------------------------------------------
/utils/others.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import re
4 | import time
5 |
6 | import numpy as np
7 | import torch
8 | # from Bio import SeqIO
9 | from torch.nn.utils.rnn import pad_sequence
10 | from tqdm import tqdm
11 | import yaml
12 | from utils.easydict import MyEasyDict
13 | from model.model_interface import ModelInterface
14 |
15 | structure_encoder_name_2_protein_dim = {
16 | "SaProt_35M_AF2": 480,
17 | "SaProt_650M_AF2": 1280,
18 | }
19 |
20 | protein_encoder_name_2_protein_dim = {
21 | "esm2_t12_35M_UR50D": 480,
22 | "esm2_t33_650M_UR50D": 1280,
23 | "SaProt_35M_AF2": 480,
24 | "SaProt_650M_AF2": 1280,
25 | "ProTrek_35M_seq": 480,
26 | "ProTrek_650M_seq": 1280,
27 | }
28 |
29 | llm_name_2_llm_embedding_dim = {
30 | "opt-350m": 512,
31 | "facebook-opt-350m": 512,
32 | "meta-llama_Meta-Llama-3-8B": 4096,
33 | "meta-llama_Meta-Llama-3-8B-Instruct": 4096,
34 | "opt-2.7b": 2560,
35 | "Qwen1.5-0.5B": 1024,
36 | "Qwen1.5-4B-Chat": 1024,
37 | "phi-1_5": 2048,
38 | "phi-2": 2560,
39 | "Llama2hf7b": 4096,
40 | }
41 |
42 | def setup_seed(seed):
43 | """set random seed for reproducibility.
44 | Args:
45 | seed (int): random seed to use.
46 | """
47 | torch.manual_seed(seed)
48 | torch.cuda.manual_seed_all(seed)
49 | np.random.seed(seed)
50 | random.seed(seed)
51 | # torch.backends.cudnn.deterministic = True
52 |
53 |
54 | def align_model_config(config: MyEasyDict):
55 | """Align model config. Different model sometimes should share the same dimension, but it's not easy to set them manually.
56 | Args:
57 | config (MyEasyDict): model config.
58 |
59 | Returns:
60 | config (MyEasyDict): aligned model config.
61 | """
62 |
63 | # if config.fusion_module.output_repr_dim is not set, it should be same as llm embedding dim
64 | llm_name = config.llm.hf_dir.split("/")[-1] # example: opt-350m
65 | llm_embedding_dim = llm_name_2_llm_embedding_dim[llm_name]
66 |
67 | if config.protein_encoder is not None:
68 | # get protein dim by protein_encoder.config_path
69 | protein_encoder_name = config.protein_encoder.config_path.split("/")[
70 | -1
71 | ] # example: esm2_t12_35M_UR50D
72 | protein_encoder_dim = protein_encoder_name_2_protein_dim[protein_encoder_name]
73 | # assign protein_encoder_dim to config.protein_encoder.fusion_module.protein_repr_dim
74 | # config.fusion_module.protein_repr_dim = protein_encoder_dim
75 | config.protein_encoder.fusion_module.protein_repr_dim = protein_encoder_dim
76 | # config.fusion_module.output_repr_dim = llm_embedding_dim
77 | if config.protein_encoder.fusion_module.output_repr_dim is None:
78 | config.protein_encoder.fusion_module.output_repr_dim = llm_embedding_dim
79 |
80 | # align config.llm.cross_attention_config.encoder_dim with config.fusion_module.output_repr_dim
81 | if config.llm.get("cross_attention_config", None) is not None:
82 | # config.llm.cross_attention_config.encoder_dim = config.fusion_module.output_repr_dim
83 | config.llm.cross_attention_config.protein_encoder_dim = (
84 | config.protein_encoder.fusion_module.output_repr_dim
85 | )
86 |
87 | if config.structure_encoder is not None:
88 | if "config_path" in config.structure_encoder: # for saprot
89 | structure_encoder_name = config.structure_encoder.config_path.split("/")[-1]
90 | elif "tokenizer_path" in config.structure_encoder: # for structure embedding
91 | structure_encoder_name = config.structure_encoder.tokenizer_path.split("/")[
92 | -1
93 | ]
94 | else: # for GNN
95 | structure_encoder_name = None
96 | if structure_encoder_name is not None:
97 | structure_encoder_dim = structure_encoder_name_2_protein_dim[
98 | structure_encoder_name
99 | ]
100 | else:
101 | structure_encoder_dim = 512 # TODO
102 |
103 | if "fusion_module" in config.structure_encoder:
104 | config.structure_encoder.fusion_module.protein_repr_dim = (
105 | structure_encoder_dim
106 | )
107 |
108 | if config.structure_encoder.fusion_module.output_repr_dim is None:
109 | config.structure_encoder.fusion_module.output_repr_dim = (
110 | llm_embedding_dim
111 | )
112 |
113 | # align config.llm.cross_attention_config.encoder_dim with config.fusion_module.output_repr_dim
114 | if config.llm.get("cross_attention_config", None) is not None:
115 | if "fusion_module" in config.structure_encoder:
116 | config.llm.cross_attention_config.structure_encoder_dim = (
117 | config.structure_encoder.fusion_module.output_repr_dim
118 | )
119 | else:
120 | config.llm.cross_attention_config.structure_encoder_dim = (
121 | structure_encoder_dim
122 | )
123 |
124 | if config.msa_encoder is not None:
125 | msa_encoder_dim = 768
126 |
127 | if "fusion_module" in config.msa_encoder:
128 | config.msa_encoder.fusion_module.protein_repr_dim = msa_encoder_dim
129 |
130 | if config.msa_encoder.fusion_module.output_repr_dim is None:
131 | config.msa_encoder.fusion_module.output_repr_dim = llm_embedding_dim
132 |
133 | # align config.llm.cross_attention_config.encoder_dim with config.fusion_module.output_repr_dim
134 | if config.llm.get("cross_attention_config", None) is not None:
135 | if "fusion_module" in config.msa_encoder:
136 | config.llm.cross_attention_config.msa_encoder_dim = (
137 | config.msa_encoder.fusion_module.output_repr_dim
138 | )
139 | else:
140 | config.llm.cross_attention_config.msa_encoder_dim = msa_encoder_dim
141 |
142 | return config
143 |
144 |
145 | def filter_llama_weights(state_dict):
146 | """Filter out llama weights from state_dict because of training issues. The llama weights have already been loaded while initializing the model."""
147 | llama_keys = []
148 | for k, v in state_dict.items():
149 | if k.startswith("llm.") and 'adapter' not in k:
150 | llama_keys.append(k)
151 | if k.startswith("model.3.") and 'adapter' not in k:
152 | llama_keys.append(k)
153 | for k in llama_keys:
154 | state_dict.pop(k)
155 | return state_dict
156 |
157 |
158 | def get_prompt(sequence, structure, question):
159 | """Generate prompt and SA sequence for SaProt.
160 |
161 | Args:
162 | sequence (str): amino acid sequence.
163 | structure (str): structure sequence represented by foldseek.
164 | question (str): question for the model.
165 |
166 | Returns:
167 | prompt (str): prompt for the model.
168 | sequence (str): sequence with structure information.
169 | """
170 | sequence_template = "Question: {Question} Answer: "
171 | structure_template = "Question: {Question} Answer: "
172 | saprot_template = "Question: {Question} Answer: "
173 | if sequence is not None and structure is not None:
174 | if len(sequence) != len(structure):
175 | raise ValueError(f"The length of sequence and structure are not equal. {len(sequence)}!= {len(structure)}")
176 | _sequence = sequence.upper()
177 | _structure = structure.lower()
178 | sequence = "".join([f"{_seq}{_struct}" for _seq, _struct in zip(_sequence, _structure)])
179 | print("all", sequence)
180 | prompt = saprot_template.format(Question=question)
181 | elif sequence is not None:
182 | _sequence = sequence.upper()
183 | _structure = "#" * len(_sequence)
184 | sequence = "".join([f"{_seq}{_struct}" for _seq, _struct in zip(_sequence, _structure)])
185 | print("seqonly", sequence)
186 | prompt = sequence_template.format(Question=question)
187 | elif structure is not None:
188 | _sequence = "#" * len(structure)
189 | _structure = structure.lower()
190 | sequence = "".join([f"{_seq}{_struct}" for _seq, _struct in zip(_sequence, _structure)])
191 | prompt = structure_template.format(Question=question)
192 | print("structonly", sequence)
193 | return prompt, sequence
194 |
195 |
196 |
197 | def load_config(config_path):
198 | with open(config_path, 'r', encoding='utf-8') as r:
199 | config = MyEasyDict(yaml.safe_load(r))
200 | config.model.config = align_model_config(config.model.config)
201 | return config
202 |
203 | def load_model_from_config(config, local_rank=0, dtype=None):
204 | """load model from config.
205 | Args:
206 | config (MyEasyDict): config of the model.
207 | local_rank (int): local rank of the current process.
208 | dtype (str): data type of the model. Default is None. Options are "fp32", "fp16", "bf16".
209 |
210 | Returns:
211 | model (nn.Module): loaded model.
212 | """
213 | model_py_path = config.model.pop("cls")
214 | model = ModelInterface.init_model(model_py_path, **config.model)
215 | model.eval()
216 |
217 | ckpt = torch.load(os.path.join(config.setting.from_checkpoint, "checkpoint", "mp_rank_00_model_states.pt"), map_location=f'cpu')
218 | state_dict = ckpt["module"]
219 | state_dict = filter_llama_weights(state_dict)
220 | model.load_state_dict(state_dict, strict=False)
221 | if dtype is None:
222 | pass
223 | elif dtype == "fp32":
224 | model.to(torch.float32)
225 | elif dtype == "bf16":
226 | model.to(torch.bfloat16)
227 | elif dtype == "fp16":
228 | model.to(torch.float16)
229 | else:
230 | raise ValueError(f"Unsupported data type: {dtype}, supported data types are 'fp32', 'fp16', 'bf16'")
231 | model.to(f'cuda:{local_rank}')
232 | return model
--------------------------------------------------------------------------------
/model/Evolla/injection_module.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 |
6 | class RMSNorm(torch.nn.Module):
7 | def __init__(self, dim: int, eps: float = 1e-6):
8 | super().__init__()
9 | self.eps = eps
10 | self.weight = nn.Parameter(torch.ones(dim))
11 |
12 | def _norm(self, x):
13 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
14 |
15 | def forward(self, x):
16 | output = self._norm(x.float()).type_as(x)
17 | return output * self.weight
18 |
19 | def FeedForward(dim, mult=None):
20 | if mult is None:
21 | mult = 4
22 | inner_dim = int(dim * mult)
23 | return nn.Sequential(
24 | nn.LayerNorm(dim),
25 | nn.Linear(dim, inner_dim, bias=False),
26 | nn.GELU(),
27 | nn.Linear(inner_dim, dim, bias=False),
28 | )
29 |
30 |
31 | class CrossAttention(nn.Module):
32 | def __init__(
33 | self,
34 | num_attention_heads,
35 | hidden_size,
36 | protein_encoder_dim=None, # protein dim in fusion module
37 | structure_encoder_dim=None, # structure dim in fusion module
38 | msa_encoder_dim=None, # msa dim in fusion module
39 | ffn_mult=None,
40 | attention_probs_dropout_prob=None,
41 | enable_bias=False,
42 | ):
43 | super().__init__()
44 | self.scale = num_attention_heads**-0.5
45 | self.num_attention_heads = num_attention_heads
46 | self.attention_head_size = int(hidden_size / num_attention_heads)
47 | self.all_head_size = self.num_attention_heads * self.attention_head_size
48 |
49 | self.query = nn.Linear(hidden_size, self.all_head_size)
50 | if protein_encoder_dim is not None:
51 | self.key_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
52 | self.value_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
53 | else:
54 | self.key_protein = None
55 | self.value_protein = None
56 |
57 | if structure_encoder_dim is not None:
58 | self.key_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
59 | self.value_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
60 | else:
61 | self.key_structure = None
62 | self.value_structure = None
63 |
64 | if msa_encoder_dim is not None:
65 | self.key_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
66 | self.value_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
67 | else:
68 | self.key_msa = None
69 | self.value_msa = None
70 |
71 | self.attention_norm = RMSNorm(hidden_size)
72 |
73 | self.dropout = nn.Dropout(attention_probs_dropout_prob)
74 |
75 | self.out_proj = nn.Linear(hidden_size, hidden_size, bias=enable_bias)
76 |
77 | self.ff = FeedForward(hidden_size, ffn_mult)
78 | self.gate_attention = nn.Parameter(torch.tensor([0.0]))
79 | self.gate_ffw = nn.Parameter(torch.tensor([0.0]))
80 |
81 | def cross_attention(
82 | self,
83 | query_states,
84 | protein_key_value_states,
85 | structure_key_value_states,
86 | msa_key_value_states,
87 | query_attn_mask,
88 | protein_kv_attn_mask,
89 | structure_kv_attn_mask,
90 | msa_kv_attn_mask,
91 | ):
92 | """
93 | query_states: text
94 | key_value_states: protein
95 | query_states: [bs, query_seq_len, dim]
96 | key_value_states: [bs, kv_seq_len, dim]
97 | query_attn_mask: [bs, query_seq_len]
98 | kv_attn_mask: [bs, kv_seq_len]
99 | """
100 |
101 | # Concatenate protein and structure
102 | kv_attn_mask = [protein_kv_attn_mask, structure_kv_attn_mask, msa_kv_attn_mask]
103 | kv_attn_mask = [_ for _ in kv_attn_mask if _ is not None]
104 | if not kv_attn_mask:
105 | raise ValueError(
106 | "At least one modality should be provided for cross attention."
107 | )
108 | kv_attn_mask = torch.cat(kv_attn_mask, dim=1)
109 |
110 | query_layer = self.attention_norm(query_states)
111 |
112 | # Warning: This place might cause issues, refers to
113 | # https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214/13
114 | # Solution: add `DISABLE_ADDMM_CUDA_LT=1` as environment variable
115 | # Apply linear transformation to input_query, input_key, and input_value
116 | query_layer = self.query(query_layer) # [bs, querylength, dim]
117 |
118 | if self.key_protein is not None and self.value_protein is not None:
119 | protein_key_value_states = protein_key_value_states.to(query_states)
120 | key_layer_protein = self.key_protein(
121 | protein_key_value_states
122 | ) # [bs, keylength, dim]
123 | value_layer_protein = self.value_protein(
124 | protein_key_value_states
125 | ) # [bs, keylength, dim]
126 | else:
127 | key_layer_protein = None
128 | value_layer_protein = None
129 |
130 | if self.key_structure is not None and self.value_structure is not None:
131 | structure_key_value_states = structure_key_value_states.to(query_states)
132 | key_layer_structure = self.key_structure(
133 | structure_key_value_states
134 | ) # [bs, keylength, dim]
135 | value_layer_structure = self.value_structure(
136 | structure_key_value_states
137 | ) # [bs, keylength, dim]
138 | else:
139 | key_layer_structure = None
140 | value_layer_structure = None
141 |
142 | if self.key_msa is not None and self.value_msa is not None:
143 | msa_key_value_states = msa_key_value_states.to(query_states)
144 | key_layer_msa = self.key_msa(msa_key_value_states) # [bs, keylength, dim]
145 | value_layer_msa = self.value_msa(
146 | msa_key_value_states
147 | ) # [bs, keylength, dim]
148 | else:
149 | key_layer_msa = None
150 | value_layer_msa = None
151 |
152 | key_layer = [key_layer_protein, key_layer_structure, key_layer_msa]
153 | key_layer = [_ for _ in key_layer if _ is not None]
154 | key_layer = torch.cat(key_layer, dim=1)
155 |
156 | value_layer = [value_layer_protein, value_layer_structure, value_layer_msa]
157 | value_layer = [_ for _ in value_layer if _ is not None]
158 | value_layer = torch.cat(value_layer, dim=1)
159 |
160 | query_layer = self.transpose_for_scores(
161 | query_layer
162 | ) # [bs, numheads, querylength, dim/numheads]
163 | key_layer = self.transpose_for_scores(
164 | key_layer
165 | ) # [bs, numheads, keylength, dim/numheads]
166 | value_layer = self.transpose_for_scores(
167 | value_layer
168 | ) # [bs, numheads, keylength, dim/numheads]
169 |
170 | query_layer = query_layer * self.scale
171 |
172 | # attention_mask: [bs, 1, querylength, keylength]
173 | attention_mask = (
174 | query_attn_mask[:, None, :, None] * kv_attn_mask[:, None, None, :]
175 | )
176 | # Compute the scaled dot-product attention scores
177 | attn_weights = torch.matmul(
178 | query_layer, key_layer.transpose(-1, -2)
179 | ) # [bs, numheads, querylength, keylength]
180 | attn_weights = (
181 | attn_weights - attn_weights.amax(dim=-1, keepdim=True).detach()
182 | ) # To stablize score
183 | attention_scores = attn_weights.masked_fill(
184 | (1 - attention_mask).bool(), torch.finfo(attn_weights.dtype).min
185 | ) # [bs, numheads, querylength, keylength]
186 |
187 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
188 |
189 | # attention_probs_dropped = self.dropout(attention_probs)
190 |
191 | context_layer = torch.matmul(
192 | attention_probs, value_layer
193 | ) # [bs, numheads, querylength, dim/numheads]
194 |
195 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
196 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
197 | context_layer = context_layer.view(*new_context_layer_shape)
198 |
199 | context_layer = self.out_proj(context_layer)
200 |
201 | return context_layer
202 |
203 | def forward(
204 | self,
205 | query_states,
206 | protein_kv_states,
207 | structure_kv_states,
208 | msa_kv_states,
209 | query_attn_mask,
210 | protein_kv_attn_mask=None,
211 | structure_kv_attn_mask=None,
212 | msa_kv_attn_mask=None,
213 | protein_batch_mask=None,
214 | structure_batch_mask=None,
215 | msa_batch_mask=None,
216 | past_key_value=None,
217 | ):
218 | """
219 | kv_states: protein
220 | query_states: text
221 |
222 | query_states: [bs, query_seq_len, dim]
223 | kv_states: [bs, kv_seq_len, dim]
224 | query_attn_mask: [bs, query_seq_len]
225 | kv_attn_mask: [bs, kv_seq_len], default None
226 | past_key_value: [bs, past_kv_seq_len, dim], default None
227 | """
228 | query_seq_len = query_states.shape[1]
229 | if protein_kv_states is not None:
230 | bs, protein_kv_seq_len, dim = protein_kv_states.shape
231 | if protein_kv_attn_mask is None:
232 | protein_kv_attn_mask = (
233 | torch.ones(bs, protein_kv_seq_len)
234 | * protein_batch_mask.expand(size=(protein_kv_seq_len, bs)).T
235 | ).to(protein_kv_states.device)
236 | else:
237 | protein_kv_attn_mask = None
238 |
239 | if structure_kv_states is not None:
240 | bs, structure_kv_seq_len, dim = structure_kv_states.shape
241 | if structure_kv_attn_mask is None:
242 | structure_kv_attn_mask = (
243 | torch.ones(bs, structure_kv_seq_len)
244 | * structure_batch_mask.expand(size=(structure_kv_seq_len, bs)).T
245 | ).to(structure_kv_states.device)
246 | else:
247 | structure_kv_attn_mask = None
248 |
249 | if msa_kv_states is not None:
250 | bs, msa_kv_seq_len, dim = msa_kv_states.shape
251 | if msa_kv_attn_mask is None:
252 | msa_kv_attn_mask = (
253 | torch.ones(bs, msa_kv_seq_len)
254 | * msa_batch_mask.expand(size=(msa_kv_seq_len, bs)).T
255 | ).to(msa_kv_states.device)
256 | else:
257 | msa_kv_attn_mask = None
258 | hidden_states = query_states
259 | # only when there's at least one valid modality, crossattention will be performed
260 | if (protein_kv_states is not None and protein_kv_attn_mask.any()) or (
261 | structure_kv_states is not None and structure_kv_attn_mask.any()
262 | ) or (
263 | msa_kv_states is not None and msa_kv_attn_mask.any()
264 | ):
265 | residual = hidden_states
266 | hidden_states = self.cross_attention(
267 | query_states=hidden_states,
268 | protein_key_value_states=protein_kv_states,
269 | structure_key_value_states=structure_kv_states,
270 | msa_key_value_states=msa_kv_states,
271 | query_attn_mask=query_attn_mask,
272 | protein_kv_attn_mask=protein_kv_attn_mask,
273 | structure_kv_attn_mask=structure_kv_attn_mask,
274 | msa_kv_attn_mask=msa_kv_attn_mask,
275 | ) # [bs, query_seq_len, dim]
276 | # tanh gate
277 | hidden_states = torch.tanh(self.gate_attention) * hidden_states
278 |
279 | hidden_states = residual + hidden_states # input_query
280 |
281 | residual = hidden_states
282 | hidden_states = self.ff(hidden_states) * torch.tanh(self.gate_ffw)
283 | hidden_states = residual + hidden_states
284 |
285 | return hidden_states
286 |
287 | def transpose_for_scores(self, x):
288 | new_x_shape = x.size()[:-1] + (
289 | self.num_attention_heads,
290 | self.attention_head_size,
291 | )
292 | x = x.view(*new_x_shape)
293 | return x.permute(0, 2, 1, 3)
294 |
--------------------------------------------------------------------------------
/model/Evolla/Evolla_model.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | from model.model_interface import register_model
3 | from utils.easydict import MyEasyDict
4 | import torch
5 |
6 | from .encoder_interface import EncoderInterface
7 | from .llm_interface import LLMInterface
8 |
9 | @register_model
10 | class EvollaModel(pl.LightningModule):
11 | def __init__(self,
12 | config: MyEasyDict,
13 | **kwargs):
14 | """
15 | Initialize the Evolla.
16 | Args:
17 | config (MyEasyDict): Configuration of the Evolla.
18 | """
19 | super().__init__()
20 | self.verbose = config.get('verbose', False)
21 | self.config = config
22 | self.generate_config = kwargs.pop('generate_config', {})
23 |
24 | if len(self.generate_config) == 0:
25 | print("Warning: No generate config is provided, the generate config now is \{\}")
26 | else:
27 | print("Generate config is provided, the generate config is: ", self.generate_config)
28 |
29 | self.initialize_model()
30 |
31 | self.special_pad_id = -100
32 |
33 | @staticmethod
34 | def init_protein_encoder(config: dict):
35 | """
36 | Initialize protein encoder
37 | Args:
38 | config: A dictionary containing the configuration of the protein encoder
39 |
40 | Returns:
41 | A protein encoder
42 | """
43 | encoder_py_path = config.pop("cls")
44 | model = EncoderInterface.init_encoder(encoder_py_path, **config)
45 | return model
46 |
47 | @staticmethod
48 | def init_structure_encoder(config: dict):
49 | """
50 | Initialize structure encoder
51 | Args:
52 | config: A dictionary containing the configuration of the structure encoder
53 | Returns:
54 | A structure encoder
55 | """
56 | encoder_py_path = config.pop("cls")
57 | model = EncoderInterface.init_encoder(encoder_py_path, **config)
58 | return model
59 |
60 | @staticmethod
61 | def init_msa_transformer_encoder(config: dict):
62 | """
63 | Initialize protein encoder
64 | Args:
65 | config: A dictionary containing the configuration of the protein encoder
66 |
67 | Returns:
68 | A protein evoformer encoder
69 | """
70 | msa_transformer_py_path = config.pop("cls")
71 | model = EncoderInterface.init_encoder(msa_transformer_py_path, **config)
72 | return model
73 |
74 | @staticmethod
75 | def init_llm(config: dict):
76 | """
77 | Initialize LLM
78 | Args:
79 | config: A dictionary containing the configuration of the LLM
80 |
81 | Returns:
82 | A LLM
83 | """
84 | llm_py_path = config.pop("cls")
85 | model = LLMInterface.init_llm(llm_py_path, **config)
86 | return model
87 |
88 | def initialize_model(self) -> None:
89 | """Initialize the Evolla model."""
90 | # torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
91 | if "protein_encoder" in self.config:
92 | if self.verbose:
93 | print("Loading Sequence Encoder...", flush=True)
94 | self.protein_encoder = self.init_protein_encoder(
95 | self.config["protein_encoder"]
96 | )
97 | else:
98 | self.protein_encoder = None
99 |
100 | if "msa_encoder" in self.config:
101 | if self.verbose:
102 | print("Loading MSA Tranformer Encoder...", flush=True)
103 | self.msa_encoder = self.init_msa_transformer_encoder(
104 | self.config["msa_encoder"]
105 | )
106 | else:
107 | self.msa_encoder = None
108 |
109 | if "structure_encoder" in self.config:
110 | if self.verbose:
111 | print("Loading Structure Encoder...", flush=True)
112 | self.structure_encoder = self.init_structure_encoder(
113 | self.config["structure_encoder"]
114 | )
115 | else:
116 | self.structure_encoder = None
117 | # print("Loading Fusion Module...", flush=True)
118 | # self.fusion_module = self.init_fusion_module(self.config["fusion_module"])
119 | if self.verbose:
120 | print("Loading LLM...", flush=True)
121 | self.llm = self.init_llm(self.config["llm"])
122 | self.llm_tokenizer = self.llm.tokenizer
123 |
124 | if self.protein_encoder is not None:
125 | self.freeze_protein_encoder_layers()
126 |
127 | if self.structure_encoder is not None:
128 | self.freeze_structure_encoder_layers()
129 |
130 | if self.msa_encoder is not None:
131 | self.freeze_msa_encoder_layers()
132 |
133 | self.freeze_llm_layers()
134 |
135 | def freeze_protein_encoder_layers(self):
136 | for name, param in self.protein_encoder.named_parameters():
137 | param.requires_grad = False
138 | if "resampler" in name:
139 | param.requires_grad = True
140 |
141 | def freeze_structure_encoder_layers(self):
142 | for name, param in self.structure_encoder.named_parameters():
143 | param.requires_grad = False
144 | if "resampler" in name:
145 | param.requires_grad = True
146 |
147 | def freeze_msa_encoder_layers(self):
148 | for name, param in self.msa_encoder.named_parameters():
149 | param.requires_grad = False
150 | if "resampler" in name:
151 | param.requires_grad = True
152 |
153 | def freeze_llm_layers(self):
154 | for name, param in self.llm.named_parameters():
155 | if "adapter" in name:
156 | param.requires_grad = True
157 | else:
158 | param.requires_grad = False
159 |
160 |
161 | def input_process(
162 | self,
163 | questions: list,
164 | answers: list = None,
165 | ):
166 | """
167 | Args:
168 | protein_embeds: encoded embedding of protein sequence
169 | templates: template used as container of question and answer pair
170 | questions: A list of prompts.
171 | answers: A list of answers.
172 | """
173 | return self.llm.input_process(
174 | questions=questions,
175 | answers=answers,
176 | max_length=self.config["text_length"],
177 | special_pad_id=self.special_pad_id,
178 | )
179 |
180 | def forward(
181 | self,
182 | seqs: tuple,
183 | foldseeks: tuple,
184 | questions: list,
185 | answers: list,
186 | msa_embeds: torch.Tensor = None,
187 | msa_atts: torch.Tensor = None,
188 | **kwargs,
189 | ):
190 | """Forward pass of the Evolla model.
191 | Args:
192 | seqs (tuple): Amino acid sequences of proteins.
193 | foldseeks (tuple): Foldseek sequences of proteins.
194 | questions (list): A list of prompts.
195 | answers (list): A list of answers.
196 | msa_embeds (torch.Tensor, Optional): MSA embeddings.
197 | msa_atts (torch.Tensor, Optional): MSA attention masks.
198 |
199 | Returns:
200 | return_dict (dict): A dictionary containing the predicted logits, prompts, answers, and raw text masks.
201 | labels (torch.Tensor): A tensor containing the labels.
202 | """
203 |
204 | if self.protein_encoder is not None:
205 | resampler_protein_repr, protein_repr, protein_attn, protein_batch_mask = self.protein_encoder(seqs)
206 | else:
207 | resampler_protein_repr = None
208 | protein_batch_mask = None
209 | protein_repr = None
210 | protein_attn = None
211 |
212 | if self.structure_encoder is not None:
213 | resampler_structure_repr, structure_repr, structure_attn, structure_batch_mask = self.structure_encoder(foldseeks)
214 | else:
215 | resampler_structure_repr = None
216 | structure_batch_mask = None
217 | structure_repr = None
218 | structure_attn = None
219 |
220 | if self.msa_encoder is not None:
221 | resampler_msa_repr, msa_repr, msa_attn, msa_batch_mask = self.msa_encoder(
222 | msa_embeds,
223 | msa_atts,
224 | )
225 | else:
226 | resampler_msa_repr = None
227 | msa_repr = None
228 | msa_attn = None
229 | msa_batch_mask = None
230 |
231 | input_ids, embeds, attns, labels, raw_text_masks = self.input_process(
232 | questions=questions,
233 | answers=answers,
234 | )
235 |
236 | outputs = self.llm.forward(
237 | input_ids=input_ids,
238 | inputs_embeds=embeds,
239 | inputs_mask=attns,
240 | protein_feats=resampler_protein_repr,
241 | structure_feats=resampler_structure_repr,
242 | msa_feats=resampler_msa_repr,
243 | protein_batch_mask=protein_batch_mask,
244 | structure_batch_mask=structure_batch_mask,
245 | msa_batch_mask=msa_batch_mask,
246 | )
247 | logits = outputs.logits
248 |
249 | return_dict = {
250 | "logits": logits,
251 | "prompts": questions,
252 | "answers": answers,
253 | "raw_text_masks": raw_text_masks,
254 | }
255 | if "comment_types" in kwargs:
256 | return_dict["comment_types"] = kwargs["comment_types"]
257 |
258 | return return_dict, labels
259 |
260 |
261 | def generate(
262 | self,
263 | seqs: tuple,
264 | foldseeks: tuple,
265 | questions: list,
266 | msa_embeds: torch.Tensor = None,
267 | msa_atts: torch.Tensor = None,
268 | **kwargs,
269 | ) -> str:
270 | """
271 | Generate answer for the question.
272 | Args:
273 | seqs (tuple): Amino acid sequences of proteins.
274 | foldseeks (tuple): Foldseek sequences of proteins.
275 | questions (list): A list of questions.
276 | msa_embeds (torch.Tensor, Optional): MSA embeddings.
277 | msa_atts (torch.Tensor, Optional): MSA attention masks.
278 |
279 | Returns:
280 | answers (list): A list of predicted answers.
281 | """
282 |
283 | with torch.no_grad():
284 | if self.protein_encoder is not None:
285 | (
286 | resampler_protein_repr,
287 | protein_repr,
288 | protein_attn,
289 | protein_batch_mask,
290 | ) = self.protein_encoder(seqs)
291 | else:
292 | resampler_protein_repr = None
293 | protein_batch_mask = None
294 | protein_repr = None
295 | protein_attn = None
296 |
297 | if self.structure_encoder is not None:
298 | (
299 | resampler_structure_repr,
300 | structure_repr,
301 | structure_attn,
302 | structure_batch_mask,
303 | ) = self.structure_encoder(foldseeks)
304 | else:
305 | resampler_structure_repr = None
306 | structure_batch_mask = None
307 | structure_repr = None
308 | structure_attn = None
309 |
310 | if self.msa_encoder is not None:
311 | resampler_msa_repr, msa_repr, msa_attn, msa_batch_mask = self.msa_encoder(
312 | msa_embeds,
313 | msa_atts,
314 | )
315 | else:
316 | resampler_msa_repr = None
317 | msa_batch_mask = None
318 | msa_repr = None
319 | msa_attn = None
320 |
321 | input_ids, embeds, attns, labels, raw_text_masks = self.input_process(
322 | questions=questions,
323 | )
324 |
325 | predicted_answer = self.llm.generate(
326 | input_ids=input_ids,
327 | inputs_mask=attns,
328 | protein_feats=resampler_protein_repr,
329 | structure_feats=resampler_structure_repr,
330 | msa_feats=resampler_msa_repr,
331 | protein_batch_mask=protein_batch_mask,
332 | structure_batch_mask=structure_batch_mask,
333 | msa_batch_mask=msa_batch_mask,
334 | **kwargs,
335 | )
336 |
337 | return self.llm.tokenizer.batch_decode(
338 | predicted_answer,
339 | skip_special_tokens=True,
340 | clean_up_tokenization_spaces=False,
341 | )
--------------------------------------------------------------------------------
/model/Evolla/llama_llm.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import json
3 | import os
4 | import random
5 | import types
6 | from pathlib import Path
7 | from typing import List, Optional, Tuple, Union
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
14 |
15 | from torch.optim.lr_scheduler import StepLR
16 | from transformers import (AutoTokenizer, BitsAndBytesConfig, LlamaConfig,
17 | LlamaForCausalLM)
18 | from transformers.cache_utils import Cache, DynamicCache
19 |
20 | from transformers.modeling_outputs import (BaseModelOutputWithPast,
21 | CausalLMOutputWithPast,
22 | QuestionAnsweringModelOutput,
23 | SequenceClassifierOutputWithPast)
24 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer
25 | from transformers.utils import (add_start_docstrings,
26 | add_start_docstrings_to_model_forward,
27 | is_flash_attn_2_available,
28 | is_flash_attn_greater_or_equal_2_10, logging,
29 | replace_return_docstrings)
30 |
31 | from .injection_module import CrossAttention
32 | # from .llama.modeling_llama import LlamaForCausalLM, LlamaModel
33 | from .llm_interface import register_llm
34 | from transformers import AutoConfig
35 |
36 | # Copyright (c) Meta Platforms, Inc. and affiliates.
37 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
38 |
39 |
40 | _CONFIG_FOR_DOC = "LlamaConfig"
41 | LLAMA_INPUTS_DOCSTRING = r"""
42 | Args:
43 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
44 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
45 | it.
46 |
47 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
48 | [`PreTrainedTokenizer.__call__`] for details.
49 |
50 | [What are input IDs?](../glossary#input-ids)
51 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
52 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
53 |
54 | - 1 for tokens that are **not masked**,
55 | - 0 for tokens that are **masked**.
56 |
57 | [What are attention masks?](../glossary#attention-mask)
58 |
59 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
60 | [`PreTrainedTokenizer.__call__`] for details.
61 |
62 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
63 | `past_key_values`).
64 |
65 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
66 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
67 | information on the default strategy.
68 |
69 | - 1 indicates the head is **not masked**,
70 | - 0 indicates the head is **masked**.
71 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
72 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
73 | config.n_positions - 1]`.
74 |
75 | [What are position IDs?](../glossary#position-ids)
76 | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
77 | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
78 | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
79 | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
80 |
81 | Two formats are allowed:
82 | - a [`~cache_utils.Cache`] instance;
83 | - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
84 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
85 | cache format.
86 |
87 | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
88 | legacy cache format will be returned.
89 |
90 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
91 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
92 | of shape `(batch_size, sequence_length)`.
93 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
94 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
95 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
96 | model's internal embedding lookup matrix.
97 | use_cache (`bool`, *optional*):
98 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
99 | `past_key_values`).
100 | output_attentions (`bool`, *optional*):
101 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
102 | tensors for more detail.
103 | output_hidden_states (`bool`, *optional*):
104 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
105 | more detail.
106 | return_dict (`bool`, *optional*):
107 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
108 | cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
109 | Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
110 | this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
111 | the complete sequence length.
112 | """
113 |
114 |
115 | def add_adapter_for_LlamaForCausalLM(llama_for_causalLM, cross_attention_config, num_add_layers=8):
116 | # follow the same config as the original model
117 | if num_add_layers < 1:
118 | return llama_for_causalLM
119 | if hasattr(llama_for_causalLM.model.layers[0].self_attn, "num_heads"):
120 | cross_attention_config["num_attention_heads"] = llama_for_causalLM.model.layers[0].self_attn.num_heads
121 | elif hasattr(llama_for_causalLM.model.layers[0].self_attn.config, "num_attention_heads"):
122 | cross_attention_config["num_attention_heads"] = llama_for_causalLM.model.layers[0].self_attn.config.num_attention_heads
123 | else:
124 | raise ValueError("Cannot find num_heads or num_attention_heads in self_attn of the first layer of the model.")
125 |
126 | if hasattr(llama_for_causalLM.model.layers[0].self_attn, "hidden_size"):
127 | cross_attention_config["hidden_size"] = llama_for_causalLM.model.layers[0].self_attn.hidden_size
128 | elif hasattr(llama_for_causalLM.model.layers[0].self_attn.config, "hidden_size"):
129 | cross_attention_config["hidden_size"] = llama_for_causalLM.model.layers[0].self_attn.config.hidden_size
130 | else:
131 | raise ValueError("Cannot find hidden_size in self_attn of the first layer of the model.")
132 |
133 | num_layers = len(llama_for_causalLM.model.layers)
134 | every_n_layers = max(num_layers // num_add_layers, 1)
135 | # add adapter for each decoder layer
136 | for i, layer in enumerate(llama_for_causalLM.model.layers):
137 | if (i + 1) % every_n_layers == 0:
138 | llama_for_causalLM.model.layers[i].adapter = CrossAttention(**cross_attention_config)
139 |
140 | return llama_for_causalLM
141 |
142 |
143 | def bind_forward_for_llama(llama_for_causalLM):
144 | """Bind `forward` function for llama models by `types.MethodType`"""
145 |
146 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
147 | def llama_model_forward(
148 | self,
149 | input_ids: torch.LongTensor = None,
150 | attention_mask: Optional[torch.Tensor] = None,
151 | position_ids: Optional[torch.LongTensor] = None,
152 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
153 | inputs_embeds: Optional[torch.FloatTensor] = None,
154 | use_cache: Optional[bool] = None,
155 | output_attentions: Optional[bool] = None,
156 | output_hidden_states: Optional[bool] = None,
157 | return_dict: Optional[bool] = None,
158 | cache_position: Optional[torch.LongTensor] = None,
159 | protein_feats: Optional[torch.FloatTensor] = None,
160 | structure_feats: Optional[torch.FloatTensor] = None,
161 | msa_feats: Optional[torch.FloatTensor] = None,
162 | protein_batch_mask: Optional[torch.Tensor] = None,
163 | structure_batch_mask: Optional[torch.Tensor] = None,
164 | msa_batch_mask: Optional[torch.Tensor] = None,
165 | **kwargs,
166 | ) -> Union[Tuple, BaseModelOutputWithPast]:
167 | output_attentions = (
168 | output_attentions
169 | if output_attentions is not None
170 | else self.config.output_attentions
171 | )
172 | output_hidden_states = (
173 | output_hidden_states
174 | if output_hidden_states is not None
175 | else self.config.output_hidden_states
176 | )
177 | use_cache = use_cache if use_cache is not None else self.config.use_cache
178 | return_dict = (
179 | return_dict if return_dict is not None else self.config.use_return_dict
180 | )
181 |
182 | if (input_ids is None) ^ (inputs_embeds is not None):
183 | raise ValueError(
184 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
185 | )
186 |
187 | if self.gradient_checkpointing and self.training and use_cache:
188 | use_cache = False
189 |
190 | if inputs_embeds is None:
191 | inputs_embeds = self.embed_tokens(input_ids)
192 |
193 | return_legacy_cache = False
194 | if use_cache and not isinstance(
195 | past_key_values, Cache
196 | ): # kept for BC (non `Cache` `past_key_values` inputs)
197 | return_legacy_cache = True
198 | past_key_values = DynamicCache.from_legacy_cache(past_key_values)
199 |
200 | if cache_position is None:
201 | past_seen_tokens = (
202 | past_key_values.get_seq_length() if past_key_values is not None else 0
203 | )
204 | cache_position = torch.arange(
205 | past_seen_tokens,
206 | past_seen_tokens + inputs_embeds.shape[1],
207 | device=inputs_embeds.device,
208 | )
209 | if position_ids is None:
210 | position_ids = cache_position.unsqueeze(0)
211 |
212 | causal_mask = self._update_causal_mask(
213 | attention_mask,
214 | inputs_embeds,
215 | cache_position,
216 | past_key_values,
217 | output_attentions,
218 | )
219 |
220 | # embed positions
221 | hidden_states = inputs_embeds
222 |
223 | # decoder layers
224 | all_hidden_states = () if output_hidden_states else None
225 | all_self_attns = () if output_attentions else None
226 | next_decoder_cache = None
227 |
228 | for decoder_layer in self.layers:
229 | if output_hidden_states:
230 | all_hidden_states += (hidden_states,)
231 |
232 | if self.gradient_checkpointing and self.training:
233 | if not hasattr(decoder_layer, 'adapter'):
234 | layer_outputs = self._gradient_checkpointing_func(
235 | decoder_layer.__call__,
236 | hidden_states,
237 | causal_mask,
238 | position_ids,
239 | past_key_values,
240 | output_attentions,
241 | use_cache,
242 | cache_position,
243 | )
244 | else:
245 | layer_outputs = self._gradient_checkpointing_func(
246 | decoder_layer.__call__,
247 | hidden_states,
248 | causal_mask,
249 | position_ids,
250 | past_key_values,
251 | output_attentions,
252 | use_cache,
253 | cache_position,
254 | )
255 | # keep the hidden_states only, cache other outputs
256 | hidden_states = layer_outputs[0]
257 | other_outputs = layer_outputs[1:]
258 | hidden_states = decoder_layer.adapter(
259 | query_states=hidden_states,
260 | protein_kv_states=protein_feats,
261 | structure_kv_states=structure_feats,
262 | msa_kv_states=msa_feats,
263 | protein_batch_mask=protein_batch_mask,
264 | structure_batch_mask=structure_batch_mask,
265 | msa_batch_mask=msa_batch_mask,
266 | query_attn_mask=attention_mask,
267 | )
268 | layer_outputs = (hidden_states,) + other_outputs
269 | else:
270 | if not hasattr(decoder_layer, 'adapter'):
271 | layer_outputs = decoder_layer(
272 | hidden_states,
273 | attention_mask=causal_mask,
274 | position_ids=position_ids,
275 | past_key_value=past_key_values,
276 | output_attentions=output_attentions,
277 | use_cache=use_cache,
278 | cache_position=cache_position,
279 | )
280 | else:
281 | layer_outputs = decoder_layer(
282 | hidden_states,
283 | attention_mask=causal_mask,
284 | position_ids=position_ids,
285 | past_key_value=past_key_values,
286 | output_attentions=output_attentions,
287 | use_cache=use_cache,
288 | cache_position=cache_position,
289 | )
290 |
291 | # keep the hidden_states only, cache other outputs
292 | hidden_states = layer_outputs[0]
293 | other_outputs = layer_outputs[1:]
294 | hidden_states = decoder_layer.adapter(
295 | query_states=hidden_states,
296 | protein_kv_states=protein_feats,
297 | structure_kv_states=structure_feats,
298 | msa_kv_states=msa_feats,
299 | protein_batch_mask=protein_batch_mask,
300 | structure_batch_mask=structure_batch_mask,
301 | msa_batch_mask=msa_batch_mask,
302 | query_attn_mask=attention_mask,
303 | )
304 | layer_outputs = (hidden_states,) + other_outputs
305 |
306 | hidden_states = layer_outputs[0]
307 |
308 | if use_cache:
309 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
310 |
311 | if output_attentions:
312 | all_self_attns += (layer_outputs[1],)
313 |
314 | hidden_states = self.norm(hidden_states)
315 |
316 | # add hidden states from the last decoder layer
317 | if output_hidden_states:
318 | all_hidden_states += (hidden_states,)
319 |
320 | next_cache = next_decoder_cache if use_cache else None
321 | if return_legacy_cache:
322 | next_cache = next_cache.to_legacy_cache()
323 |
324 | if not return_dict:
325 | return tuple(
326 | v
327 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
328 | if v is not None
329 | )
330 | return BaseModelOutputWithPast(
331 | last_hidden_state=hidden_states,
332 | past_key_values=next_cache,
333 | hidden_states=all_hidden_states,
334 | attentions=all_self_attns,
335 | )
336 |
337 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
338 | @replace_return_docstrings(
339 | output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
340 | )
341 | def llama_for_causalLM_forward(
342 | self,
343 | input_ids: torch.LongTensor = None,
344 | attention_mask: Optional[torch.Tensor] = None,
345 | position_ids: Optional[torch.LongTensor] = None,
346 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
347 | inputs_embeds: Optional[torch.FloatTensor] = None,
348 | labels: Optional[torch.LongTensor] = None,
349 | use_cache: Optional[bool] = None,
350 | output_attentions: Optional[bool] = None,
351 | output_hidden_states: Optional[bool] = None,
352 | return_dict: Optional[bool] = None,
353 | cache_position: Optional[torch.LongTensor] = None,
354 | protein_feats: Optional[torch.FloatTensor] = None,
355 | structure_feats: Optional[torch.FloatTensor] = None,
356 | msa_feats: Optional[torch.FloatTensor] = None,
357 | protein_batch_mask: Optional[torch.Tensor] = None,
358 | structure_batch_mask: Optional[torch.Tensor] = None,
359 | msa_batch_mask: Optional[torch.Tensor] = None,
360 | **kwargs
361 | ) -> Union[Tuple, CausalLMOutputWithPast]:
362 | r"""
363 | Args:
364 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
365 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
366 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
367 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
368 |
369 | Returns:
370 |
371 | Example:
372 |
373 | ```python
374 | >>> from transformers import AutoTokenizer, LlamaForCausalLM
375 |
376 | >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
377 | >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
378 |
379 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
380 | >>> inputs = tokenizer(prompt, return_tensors="pt")
381 |
382 | >>> # Generate
383 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
384 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
385 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
386 | ```"""
387 | output_attentions = (
388 | output_attentions
389 | if output_attentions is not None
390 | else self.config.output_attentions
391 | )
392 | output_hidden_states = (
393 | output_hidden_states
394 | if output_hidden_states is not None
395 | else self.config.output_hidden_states
396 | )
397 | return_dict = (
398 | return_dict if return_dict is not None else self.config.use_return_dict
399 | )
400 |
401 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
402 | outputs = self.model(
403 | input_ids=input_ids,
404 | attention_mask=attention_mask,
405 | position_ids=position_ids,
406 | past_key_values=past_key_values,
407 | inputs_embeds=inputs_embeds,
408 | use_cache=use_cache,
409 | output_attentions=output_attentions,
410 | output_hidden_states=output_hidden_states,
411 | return_dict=return_dict,
412 | cache_position=cache_position,
413 | protein_feats=protein_feats,
414 | structure_feats=structure_feats,
415 | msa_feats=msa_feats,
416 | protein_batch_mask=protein_batch_mask,
417 | structure_batch_mask=structure_batch_mask,
418 | msa_batch_mask=msa_batch_mask,
419 | )
420 |
421 | hidden_states = outputs[0]
422 | if self.config.pretraining_tp > 1:
423 | lm_head_slices = self.lm_head.weight.split(
424 | self.vocab_size // self.config.pretraining_tp, dim=0
425 | )
426 | logits = [
427 | F.linear(hidden_states, lm_head_slices[i])
428 | for i in range(self.config.pretraining_tp)
429 | ]
430 | logits = torch.cat(logits, dim=-1)
431 | else:
432 | logits = self.lm_head(hidden_states)
433 | logits = logits.float()
434 |
435 | loss = None
436 | if labels is not None:
437 | # Shift so that tokens < n predict n
438 | shift_logits = logits[..., :-1, :].contiguous()
439 | shift_labels = labels[..., 1:].contiguous()
440 | # Flatten the tokens
441 | loss_fct = CrossEntropyLoss()
442 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
443 | shift_labels = shift_labels.view(-1)
444 | # Enable model parallelism
445 | shift_labels = shift_labels.to(shift_logits.device)
446 | loss = loss_fct(shift_logits, shift_labels)
447 |
448 | if not return_dict:
449 | output = (logits,) + outputs[1:]
450 | return (loss,) + output if loss is not None else output
451 |
452 | return CausalLMOutputWithPast(
453 | loss=loss,
454 | logits=logits,
455 | past_key_values=outputs.past_key_values,
456 | hidden_states=outputs.hidden_states,
457 | attentions=outputs.attentions,
458 | )
459 |
460 | llama_for_causalLM.model.forward = types.MethodType(
461 | llama_model_forward, llama_for_causalLM.model
462 | )
463 | llama_for_causalLM.forward = types.MethodType(
464 | llama_for_causalLM_forward, llama_for_causalLM
465 | )
466 |
467 | return llama_for_causalLM
468 |
469 |
470 | def add_special_tokens_to_model_and_tokenizer(model, tokenizer, special_token):
471 | # add special tokens to tokenizer # 50265
472 | tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
473 | # add special tokens to model # 50272
474 | if len(tokenizer) <= model.model.embed_tokens.weight.shape[0]:
475 | return model, tokenizer
476 | else:
477 | embedding_layer = model.model.embed_tokens
478 | embedding_layer.weight.data = torch.cat(
479 | [
480 | embedding_layer.weight.data,
481 | torch.zeros(1, embedding_layer.weight.shape[1]).to(
482 | embedding_layer.weight.data
483 | ),
484 | ],
485 | dim=0,
486 | )
487 | return model, tokenizer
488 |
489 |
490 | def bind_function_for_llama(llama_for_causalLM):
491 | def llama_for_casualLM_prepare_inputs_for_generation(
492 | self,
493 | input_ids,
494 | past_key_values=None,
495 | attention_mask=None,
496 | inputs_embeds=None,
497 | cache_position=None,
498 | position_ids=None,
499 | use_cache=True,
500 | **kwargs,
501 | ):
502 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
503 | # Exception 1: when passing input_embeds, input_ids may be missing entries
504 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
505 | if past_key_values is not None:
506 | if inputs_embeds is not None: # Exception 1
507 | input_ids = input_ids[:, -cache_position.shape[0] :]
508 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
509 | input_ids = input_ids[:, cache_position]
510 |
511 | if attention_mask is not None and position_ids is None:
512 | # create position_ids on the fly for batch generation
513 | position_ids = attention_mask.long().cumsum(-1) - 1
514 | position_ids.masked_fill_(attention_mask == 0, 1)
515 | if past_key_values:
516 | position_ids = position_ids[:, -input_ids.shape[1] :]
517 |
518 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
519 | if inputs_embeds is not None and cache_position[0] == 0:
520 | model_inputs = {"inputs_embeds": inputs_embeds}
521 | else:
522 | model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
523 |
524 | model_inputs.update(
525 | {
526 | "position_ids": position_ids,
527 | "cache_position": cache_position,
528 | "past_key_values": past_key_values,
529 | "use_cache": use_cache,
530 | "attention_mask": attention_mask,
531 | }
532 | )
533 | model_inputs.update(kwargs)
534 | return model_inputs
535 |
536 | llama_for_causalLM.prepare_inputs_for_generation = types.MethodType(
537 | llama_for_casualLM_prepare_inputs_for_generation, llama_for_causalLM
538 | )
539 | return llama_for_causalLM
540 |
541 | from transformers import AutoConfig, AutoModelForCausalLM
542 | from accelerate import load_checkpoint_and_dispatch
543 | from accelerate import init_empty_weights
544 | from transformers.integrations import HfDeepSpeedConfig
545 |
546 | @register_llm
547 | class LlamaAdapterModel(nn.Module):
548 | def __init__(
549 | self,
550 | hf_dir,
551 | cross_attention_config,
552 | load_pretrained=True,
553 | quantization=False,
554 | attn_implementation="sdpa",
555 | num_add_layers=8,
556 | ):
557 | """Adapter model for Llama.
558 | Args:
559 | hf_dir (str): Directory of the Hugging Face model.
560 | cross_attention_config (dict): Configuration of the cross-attention layer.
561 | load_pretrained (bool): Whether to load the pretrained model. Defaults to True.
562 | quantization (bool or str): Whether to use quantization. Defaults to False. Acceptable values are True, False, '8bit', and '4bit'. True means 8-bit quantization. '8bit' means 8-bit quantization. '4bit' means 4-bit quantization.
563 | attn_implementation (str): Implementation of the attention layer. Defaults to "sdpa".
564 | num_add_layers (int): Number of additional layers to add. Defaults to 8.
565 | """
566 | super().__init__()
567 | if quantization is True or quantization == '8bit':
568 | assert load_pretrained, "load_pretrained should be True"
569 | quantization_config = BitsAndBytesConfig(load_in_8bit=True)
570 | print("8-bit Quantization is enabled")
571 | elif quantization == '4bit':
572 | quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
573 | print("4-bit Quantization is enabled")
574 | else:
575 | quantization_config = None
576 | print("Quantization is disabled")
577 |
578 | if load_pretrained:
579 | self.model = LlamaForCausalLM.from_pretrained(
580 | hf_dir,
581 | quantization_config=quantization_config,
582 | attn_implementation=attn_implementation,
583 | torch_dtype=torch.bfloat16,
584 | ).train()
585 | self.model = prepare_model_for_kbit_training(self.model)
586 | else:
587 | config = AutoConfig.from_pretrained(hf_dir)
588 | self.model = LlamaForCausalLM(config)
589 |
590 | self.model = add_adapter_for_LlamaForCausalLM(
591 | self.model, cross_attention_config, num_add_layers=num_add_layers
592 | )
593 | # bind `forward` function for llama models by `types.MethodType`
594 | self.model = bind_forward_for_llama(self.model)
595 | self.model = bind_function_for_llama(self.model)
596 |
597 | self.tokenizer = AutoTokenizer.from_pretrained(hf_dir, use_fast=False)
598 | self.tokenizer.pad_token = "<|reserved_special_token_0|>"
599 |
600 | def forward(
601 | self,
602 | input_ids,
603 | inputs_mask,
604 | protein_feats,
605 | structure_feats,
606 | msa_feats,
607 | protein_batch_mask,
608 | structure_batch_mask,
609 | msa_batch_mask,
610 | **kwargs
611 | ):
612 | output = self.model.forward(
613 | input_ids=input_ids,
614 | attention_mask=inputs_mask,
615 | protein_feats=protein_feats,
616 | structure_feats=structure_feats,
617 | msa_feats=msa_feats,
618 | protein_batch_mask=protein_batch_mask,
619 | structure_batch_mask=structure_batch_mask,
620 | msa_batch_mask=msa_batch_mask,
621 | output_hidden_states=True,
622 | )
623 | return output
624 |
625 | def generate(
626 | self,
627 | input_ids,
628 | inputs_mask,
629 | protein_feats,
630 | structure_feats,
631 | msa_feats,
632 | protein_batch_mask,
633 | structure_batch_mask,
634 | msa_batch_mask,
635 | **kwargs
636 | ):
637 | terminators = [
638 | self.tokenizer.eos_token_id,
639 | self.tokenizer.convert_tokens_to_ids("<|eot_id|>"),
640 | ]
641 | output = self.model.generate(
642 | input_ids,
643 | use_cache=False,
644 | attention_mask=inputs_mask,
645 | protein_feats=protein_feats,
646 | structure_feats=structure_feats,
647 | msa_feats=msa_feats,
648 | protein_batch_mask=protein_batch_mask,
649 | structure_batch_mask=structure_batch_mask,
650 | msa_batch_mask=msa_batch_mask,
651 | bos_token_id=self.tokenizer.bos_token_id,
652 | eos_token_id=terminators,
653 | **kwargs,
654 | )
655 | output = output[:, input_ids.shape[-1]:]
656 | return output
657 |
658 | def embed_tokens(self, tokens):
659 | return self.model.model.embed_tokens(tokens.to(self.model.device))
660 |
661 | def generate_prompt(self, question: str) -> str:
662 | """
663 | Generate QA prompt for the Llama3-instruct
664 |
665 | Returns: Formatted prompt
666 | """
667 | messages = [
668 | {"role": "system", "content": "You are an AI expert that can answer any questions about protein."},
669 | {"role": "user", "content": question},
670 | ]
671 |
672 | prompt = self.tokenizer.apply_chat_template(
673 | messages,
674 | tokenize=False,
675 | add_generation_prompt=True,
676 | )
677 | return prompt
678 |
679 | def input_process(self,
680 | questions: list,
681 | answers: list = None,
682 | max_length: int = 512,
683 | special_pad_id: int = -100):
684 |
685 | # Record original padding side
686 | original_padding_side = self.tokenizer.padding_side
687 |
688 | # Generate prompts for questions
689 | prompts = [self.generate_prompt(q) for q in questions]
690 |
691 | # Tokenize prompts and add left paddings
692 | self.tokenizer.padding_side = "left"
693 | prompt_inputs = self.tokenizer(
694 | prompts,
695 | add_special_tokens=False,
696 | return_tensors="pt",
697 | padding="longest",
698 | truncation=True,
699 | max_length=max_length,
700 | )
701 |
702 | input_ids = prompt_inputs["input_ids"]
703 | attns = prompt_inputs["attention_mask"]
704 | embeds = self.embed_tokens(input_ids)
705 |
706 | # Create labels
707 | labels = torch.full_like(input_ids, special_pad_id)
708 | # Create raw text mask
709 | raw_text_mask = torch.zeros_like(input_ids)
710 |
711 | if answers is not None:
712 | # Add eos token
713 | answers_eos = [a + self.tokenizer.eos_token for a in answers]
714 |
715 | # Tokenize answers and add right paddings
716 | self.tokenizer.padding_side = "right"
717 | answer_inputs = self.tokenizer(
718 | answers_eos,
719 | add_special_tokens=False,
720 | return_tensors="pt",
721 | padding="longest",
722 | truncation=True,
723 | max_length=max_length,
724 | )
725 |
726 | # Concatenate inputs ids
727 | answer_ids = answer_inputs["input_ids"]
728 | input_ids = torch.cat([input_ids, answer_ids], dim=-1)
729 |
730 | # Concatenate attention masks
731 | answer_mask = answer_inputs["attention_mask"]
732 | attns = torch.cat([attns, answer_mask], dim=-1)
733 |
734 | # Concatenate embeddings
735 | answer_embeds = self.embed_tokens(answer_ids)
736 | embeds = torch.cat([embeds, answer_embeds], dim=1)
737 |
738 | # Concatenate labels
739 | answer_labels = answer_ids.masked_fill(answer_ids == self.tokenizer.pad_token_id, special_pad_id)
740 | labels = torch.cat([labels, answer_labels], dim=-1)
741 |
742 | # Concatenate raw text mask
743 | raw_text_mask = torch.cat([raw_text_mask, torch.ones_like(answer_ids)], dim=-1)
744 | raw_text_mask = raw_text_mask.masked_fill(labels == special_pad_id, 0)
745 |
746 | labels = labels.masked_fill(labels == self.tokenizer.pad_token_id, special_pad_id)
747 | # Restore original padding side
748 | self.tokenizer.padding_side = original_padding_side
749 |
750 | # Convert to current device
751 | device = self.model.device
752 | input_ids = input_ids.to(device)
753 | embeds = embeds.to(device)
754 | attns = attns.to(device)
755 | labels = labels.to(device)
756 | raw_text_mask = raw_text_mask.to(device)
757 |
758 | return input_ids, embeds, attns, labels, raw_text_mask
--------------------------------------------------------------------------------