├── ckpt ├── .placeholder └── huggingface │ └── .placeholder ├── .gitignore ├── figures └── overview.png ├── examples └── inputs.tsv ├── environment.sh ├── model ├── Evolla │ ├── llm_interface.py │ ├── encoder_interface.py │ ├── sequence_encoder_saprot.py │ ├── fusion_module.py │ ├── injection_module.py │ ├── Evolla_model.py │ └── llama_llm.py └── model_interface.py ├── config └── Evolla_10B.yaml ├── LICENSE ├── scripts └── inference.py ├── utils ├── easydict.py └── others.py └── README.md /ckpt/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ckpt/huggingface/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .idea/ 3 | __pycache__/ 4 | tests 5 | ckpt -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/westlake-repl/Evolla/HEAD/figures/overview.png -------------------------------------------------------------------------------- /examples/inputs.tsv: -------------------------------------------------------------------------------- 1 | C9RH78 MLLEETLKSCPIVKRGKYHYFIHPISDGVPLVEPKLLREVATRIIKIGNFEGVNKIVTAEAMGIPLVTTLSLYTDIPYVIMRKREYKLPGEVPVFQSTGYSKGQLYLNGIEKGDKVIIIDDVISTGGTMIAIINALERAGAEIKDIICVIERGDGKKIVEEKTGYKIKTLVKIDVVDGEVVIL dvvvvqqqpfawdddppdtdgcgclapvpdpddpvvlvvllvlcvvpadpvqaqeeeeeddscpsnvvsncvvpvhyydywylddppdppkdwqwf######gitidpdqaaaheyeyeeaeqdqlrvvlsvvvrcvvrnyhhrayeyaeyhycnqvvccvvpvghyhynwywdqdpsgidtd "What is the catalytic activity of this protein?" 2 | -------------------------------------------------------------------------------- /environment.sh: -------------------------------------------------------------------------------- 1 | # works on 2025/01/09 2 | pip install pyyaml 3 | pip3 install torch torchvision torchaudio 4 | # pip3 install torch torchvision torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple 5 | pip install tqdm 6 | pip install lightning 7 | # pip install lightning -i https://pypi.tuna.tsinghua.edu.cn/simple 8 | pip install transformers 9 | # pip install transformers -i https://pypi.tuna.tsinghua.edu.cn/simple 10 | pip install einops 11 | pip install einops_exts 12 | pip install peft 13 | # pip install peft -i https://pypi.tuna.tsinghua.edu.cn/simple 14 | pip install -U bitsandbytes 15 | # pip install -U bitsandbytes -i https://pypi.tuna.tsinghua.edu.cn/simple -------------------------------------------------------------------------------- /model/Evolla/llm_interface.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def register_llm(cls): 4 | global now_cls 5 | now_cls = cls 6 | return cls 7 | 8 | class LLMInterface: 9 | @classmethod 10 | def init_llm(cls, model_py_path, **kwargs): 11 | """ 12 | Initialize model from python file. 13 | Args: 14 | model_py_path: Path to model python file. e.g. model/transformer.py 15 | **kwargs: Kwargs for model initialization 16 | 17 | Returns: 18 | Initialized model 19 | """ 20 | sub_dirs = model_py_path.split(os.sep) 21 | cmd = f"from {'.'.join(sub_dirs[:-1])} import {sub_dirs[-1].split('.')[0]}" 22 | exec(cmd) 23 | return now_cls(**kwargs) 24 | -------------------------------------------------------------------------------- /model/model_interface.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def register_model(cls): 4 | global now_cls 5 | now_cls = cls 6 | return cls 7 | 8 | class ModelInterface: 9 | @classmethod 10 | def init_model(cls, model_py_path, **kwargs): 11 | """ 12 | Initialize model from python file. 13 | Args: 14 | model_py_path: Path to model python file. e.g. model/transformer.py 15 | **kwargs: Kwargs for model initialization 16 | 17 | Returns: 18 | Initialized model 19 | """ 20 | sub_dirs = model_py_path.split(os.sep) 21 | cmd = f"from {'.'.join(sub_dirs[:-1])} import {sub_dirs[-1].split('.')[0]}" 22 | exec(cmd) 23 | return now_cls(**kwargs) 24 | -------------------------------------------------------------------------------- /model/Evolla/encoder_interface.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def register_encoder(cls): 4 | global now_cls 5 | now_cls = cls 6 | return cls 7 | 8 | class EncoderInterface: 9 | @classmethod 10 | def init_encoder(cls, model_py_path, **kwargs): 11 | """ 12 | Initialize model from python file. 13 | Args: 14 | model_py_path: Path to model python file. e.g. model/transformer.py 15 | **kwargs: Kwargs for model initialization 16 | 17 | Returns: 18 | Initialized model 19 | """ 20 | sub_dirs = model_py_path.split(os.sep) 21 | cmd = f"from {'.'.join(sub_dirs[:-1])} import {sub_dirs[-1].split('.')[0]}" 22 | exec(cmd) 23 | return now_cls(**kwargs) 24 | -------------------------------------------------------------------------------- /config/Evolla_10B.yaml: -------------------------------------------------------------------------------- 1 | setting: 2 | seed: 42 3 | # from_checkpoint: ckpt/Evolla-10B 4 | from_checkpoint: ckpt/huggingface/Evolla-10B/Evolla-10B 5 | 6 | model: 7 | cls: model/Evolla/Evolla_model.py 8 | generate_config: 9 | max_new_tokens: 512 10 | do_sample: True 11 | temperature: 0.6 12 | top_p: 0.9 13 | config: 14 | text_length: 2048 15 | protein_encoder: 16 | cls: model/Evolla/sequence_encoder_saprot.py 17 | config_path: ckpt/huggingface/SaProt_650M_AF2 18 | fusion_module: 19 | cls: SequenceCompressorResampler 20 | depth: 6 21 | heads: 8 22 | num_latents: 64 23 | ff_mult: 4 24 | llm: 25 | cls: model/Evolla/llama_llm.py 26 | hf_dir: ckpt/huggingface/meta-llama_Meta-Llama-3-8B-Instruct 27 | cross_attention_config: 28 | ffn_mult: 4 29 | enable_bias: true 30 | attention_probs_dropout_prob: 0.1 31 | quantization: 8bit 32 | # quantization: false 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 westlake-repl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | HOME_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 4 | sys.path.append(HOME_PATH) 5 | 6 | import argparse 7 | import json 8 | 9 | import traceback 10 | from threading import Thread 11 | from utils.others import setup_seed, load_config, load_model_from_config 12 | 13 | from transformers import TextIteratorStreamer 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--config_path", type=str, required=True) 17 | parser.add_argument("--input_path", type=str, required=True, help="Path to the input file, each line is a tab-separated triplet of Uniprot ID (or any other identifier), sequence, foldseek, and question (in JSON format)") 18 | args = parser.parse_args() 19 | CONFIG_PATH = args.config_path 20 | input_path = args.input_path 21 | 22 | config = load_config(CONFIG_PATH) 23 | 24 | if config.setting.seed: 25 | setup_seed(config.setting.seed) 26 | 27 | 28 | model = load_model_from_config(config, local_rank=0, dtype="bf16") 29 | 30 | with open(input_path, "r") as f: 31 | for line in f: 32 | line = line.strip() 33 | uniprot_id, sequence, foldseek, question = line.split("\t") 34 | question = json.loads(question) 35 | streamer = TextIteratorStreamer( 36 | model.llm_tokenizer, 37 | # skip_prompt=True, 38 | skip_prompt=False, 39 | skip_special_tokens=True, 40 | ) 41 | 42 | mixed_sequence = "".join([s+f for s, f in zip(sequence, foldseek)]) 43 | print(f"{uniprot_id}") 44 | print(f"{question}") 45 | print(f"{mixed_sequence}") 46 | generation_kwargs = { 47 | "seqs": [mixed_sequence], 48 | "foldseeks": [None], 49 | "questions": [question], 50 | "streamer": streamer, 51 | } 52 | 53 | def generate_wrapper(): 54 | try: 55 | model.generate(**generation_kwargs, **model.generate_config) 56 | except Exception as e: 57 | # traceback the exception 58 | traceback.print_exc() 59 | print(f"Exception in generate_wrapper: {e}") 60 | 61 | thread = Thread(target=generate_wrapper) 62 | thread.start() 63 | for a in streamer: 64 | print(a, end="", flush=True) 65 | thread.join() 66 | print("=" * 50) 67 | -------------------------------------------------------------------------------- /model/Evolla/sequence_encoder_saprot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from transformers import EsmConfig, EsmForMaskedLM, EsmModel, EsmTokenizer 5 | 6 | from .encoder_interface import register_encoder 7 | from .fusion_module import SequenceCompressorResampler 8 | 9 | 10 | @register_encoder 11 | class SaProtSequenceEncoder(nn.Module): 12 | def __init__( 13 | self, 14 | config_path: str, 15 | load_pretrained: bool = True, 16 | fusion_module: dict = None, 17 | **kwargs, 18 | ): 19 | super().__init__() 20 | if load_pretrained: 21 | # self.model = EsmModel.from_pretrained(config_path) 22 | self.model = EsmForMaskedLM.from_pretrained(config_path) 23 | self.config = EsmConfig.from_pretrained(config_path) 24 | else: 25 | self.config = EsmConfig.from_pretrained(config_path) 26 | # self.model = EsmModel(self.config) 27 | self.model = EsmForMaskedLM(self.config) 28 | 29 | self.tokenizer = EsmTokenizer.from_pretrained(config_path) 30 | 31 | fusion_cls = fusion_module.pop("cls", None) 32 | if fusion_cls is None or fusion_cls == "SequenceCompressorResampler": 33 | self.resampler = SequenceCompressorResampler(**fusion_module) 34 | else: 35 | raise ValueError(f"Unknown fusion module class: {fusion_cls}") 36 | 37 | @property 38 | def num_layers(self): 39 | return len(self.model.encoder.layer) 40 | 41 | def sequence_encode(self, seqs): 42 | """ 43 | Encode protein sequence into protein representation 44 | """ 45 | seqs = [seq if seq is not None else "" for seq in seqs] 46 | protein_tokens = self.tokenizer.batch_encode_plus( 47 | seqs, return_tensors="pt", truncation=True, max_length=1026, padding=True 48 | ).to(self.model.device) 49 | 50 | protein_output = self.model( 51 | protein_tokens["input_ids"], 52 | protein_tokens["attention_mask"], 53 | return_dict=True, 54 | output_hidden_states=True, 55 | ) 56 | 57 | protein_embeds = protein_output.hidden_states[-1] 58 | 59 | mask = protein_tokens["attention_mask"] 60 | 61 | return protein_embeds, mask 62 | 63 | def forward(self, seqs): 64 | # create batch mask for seqs 65 | seqs_batch_mask = torch.tensor( 66 | [True if seq is not None else False for seq in seqs] 67 | ) 68 | # print("this is structure encoder", flush=True) 69 | sequence_embeds, mask = self.sequence_encode(seqs) 70 | 71 | sequence_repr = self.resampler(sequence_embeds, mask) 72 | 73 | return sequence_repr, sequence_embeds, mask, seqs_batch_mask 74 | 75 | -------------------------------------------------------------------------------- /utils/easydict.py: -------------------------------------------------------------------------------- 1 | # from easydict package 2 | # https://github.com/makinacorpus/easydict 3 | class MyEasyDict(dict): 4 | """ 5 | Get attributes 6 | 7 | >>> d = EasyDict({'foo':3}) 8 | >>> d['foo'] 9 | 3 10 | >>> d.foo 11 | 3 12 | >>> d.bar 13 | Traceback (most recent call last): 14 | ... 15 | AttributeError: 'EasyDict' object has no attribute 'bar' 16 | 17 | Works recursively 18 | 19 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) 20 | >>> isinstance(d.bar, dict) 21 | True 22 | >>> d.bar.x 23 | 1 24 | 25 | Bullet-proof 26 | 27 | >>> EasyDict({}) 28 | {} 29 | >>> EasyDict(d={}) 30 | {} 31 | >>> EasyDict(None) 32 | {} 33 | >>> d = {'a': 1} 34 | >>> EasyDict(**d) 35 | {'a': 1} 36 | 37 | Set attributes 38 | 39 | >>> d = EasyDict() 40 | >>> d.foo = 3 41 | >>> d.foo 42 | 3 43 | >>> d.bar = {'prop': 'value'} 44 | >>> d.bar.prop 45 | 'value' 46 | >>> d 47 | {'foo': 3, 'bar': {'prop': 'value'}} 48 | >>> d.bar.prop = 'newer' 49 | >>> d.bar.prop 50 | 'newer' 51 | 52 | 53 | Values extraction 54 | 55 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) 56 | >>> isinstance(d.bar, list) 57 | True 58 | >>> from operator import attrgetter 59 | >>> map(attrgetter('x'), d.bar) 60 | [1, 3] 61 | >>> map(attrgetter('y'), d.bar) 62 | [2, 4] 63 | >>> d = EasyDict() 64 | >>> d.keys() 65 | [] 66 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) 67 | >>> d.foo 68 | 3 69 | >>> d.bar.x 70 | 1 71 | 72 | Still like a dict though 73 | 74 | >>> o = EasyDict({'clean':True}) 75 | >>> o.items() 76 | [('clean', True)] 77 | 78 | And like a class 79 | 80 | >>> class Flower(EasyDict): 81 | ... power = 1 82 | ... 83 | >>> f = Flower() 84 | >>> f.power 85 | 1 86 | >>> f = Flower({'height': 12}) 87 | >>> f.height 88 | 12 89 | >>> f['power'] 90 | 1 91 | >>> sorted(f.keys()) 92 | ['height', 'power'] 93 | 94 | update and pop items 95 | >>> d = EasyDict(a=1, b='2') 96 | >>> e = EasyDict(c=3.0, a=9.0) 97 | >>> d.update(e) 98 | >>> d.c 99 | 3.0 100 | >>> d['c'] 101 | 3.0 102 | >>> d.get('c') 103 | 3.0 104 | >>> d.update(a=4, b=4) 105 | >>> d.b 106 | 4 107 | >>> d.pop('a') 108 | 4 109 | >>> d.a 110 | Traceback (most recent call last): 111 | ... 112 | AttributeError: 'EasyDict' object has no attribute 'a' 113 | """ 114 | 115 | def __init__(self, d=None, **kwargs): 116 | if d is None: 117 | d = {} 118 | if kwargs: 119 | d.update(**kwargs) 120 | for k, v in d.items(): 121 | setattr(self, k, v) 122 | # Class attributes 123 | for k in self.__class__.__dict__.keys(): 124 | if not (k.startswith("__") and k.endswith("__")) and not k in ( 125 | "update", 126 | "pop", 127 | ): 128 | setattr(self, k, getattr(self, k)) 129 | 130 | def __setattr__(self, name, value): 131 | if isinstance(value, (list, tuple)): 132 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value] 133 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 134 | value = self.__class__(value) 135 | super(MyEasyDict, self).__setattr__(name, value) 136 | super(MyEasyDict, self).__setitem__(name, value) 137 | 138 | __setitem__ = __setattr__ 139 | 140 | def update(self, e=None, **f): 141 | d = e or dict() 142 | d.update(f) 143 | for k in d: 144 | setattr(self, k, d[k]) 145 | 146 | def pop(self, k, d=None): 147 | if k not in self: 148 | return d 149 | delattr(self, k) 150 | return super(MyEasyDict, self).pop(k, d) 151 | 152 | def __getattr__(self, name): 153 | return self.__class__.__dict__.get(name, None) 154 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Evolla 2 | 3 | 4 | 5 | 6 | 7 | A frontier protein-language generative model designed to decode the molecular language of proteins. 8 | 9 | *Quickly try our online server (Evolla-10B) [here](http://www.chat-protein.com/).* 10 | 11 |
Table of contents 12 | 13 | - [News](#News) 14 | - [Overview](#Overview) 15 | - [Enviroment installation](#Enviroment-installation) 16 | - [Prepare the Evolla model](#Prepare-the-Evolla-model) 17 | - [Prepare input data](#Prepare-input-data) 18 | - [Run Evolla](#Run-Evolla) 19 | - [Citation](#Citation) 20 |
21 | 22 | > We have 2 PhD positions for international students at Westlake University, China! see [here](https://x.com/duguyuan/status/1897101692665258245). 23 | > 24 | ## News 25 | - **2025/01/06** We released our paper [Decoding the Molecular Language of Proteins with Evolla](https://doi.org/10.1101/2025.01.05.630192). 26 | - **2024/12/06** We uploaded the [Evolla-10B model](https://huggingface.co/westlake-repl/Evolla-10B) to `huggingface hub`. 27 | ## Overview 28 | 29 | ![](figures/overview.png) 30 | 31 | ## Enviroment installation 32 | 33 | ### Create a virtual environment 34 | ``` 35 | conda create -n Evolla python=3.10 36 | conda activate Evolla 37 | ``` 38 | 39 | ### Install packages 40 | ``` 41 | bash environment.sh 42 | ``` 43 | 44 | ## Prepare the Evolla model 45 | 46 | We provide the pre-trained Evolla-10B model in `huggingface hub`. You can download the model by running the following command: 47 | ``` 48 | cd ckpt/huggingface 49 | 50 | git lfs install 51 | 52 | git clone https://huggingface.co/westlake-repl/Evolla-10B 53 | 54 | git clone https://huggingface.co/westlake-repl/SaProt_650M_AF2 55 | 56 | git clone https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct 57 | ``` 58 | 59 | ### Model checkpoints 60 | 61 | |**Name** |**Size** | 62 | |---------|---------| 63 | |[Evolla-10B](https://huggingface.co/westlake-repl/Evolla-10B) | 10B | 64 | 65 | ## Prepare input data 66 | 67 | We provide a sample input file `examples/inputs.tsv` for you to test the Evolla model. The input file should be a tab-separated file, where each line represents `(protein_id, aa_sequence, foldseek_sequence, question_in_json_string)`. 68 | 69 | Note: `protein_id` is the identifier of the line, `aa_sequence` is the amino acid sequence of the protein, `foldseek_sequence` is the sequence of the protein in FoldSeek format. `question_in_json_string` is the question which is dumped by `json.dumps` function. 70 | 71 | 72 | ## Run Evolla 73 | 74 | ### Use `inference.py` 75 | 76 | The following provides script to run inference based on tsv file. 77 | 78 | You should replace the `/your/path/to/Evolla` to your own path to `Evolla` directory. 79 | 80 | ``` 81 | cd /your/path/to/Evolla 82 | python scripts/inference.py --config_path config/Evolla_10B.yaml --input_path examples/inputs.tsv 83 | ``` 84 | 85 | ## Citation 86 | 87 | If you find this repository useful, please cite our paper: 88 | 89 | ``` 90 | @article{zhou2025decoding, 91 | title={Decoding the Molecular Language of Proteins with Evolla}, 92 | author={Zhou, Xibin and Han, Chenchen and Zhang, Yingqi and Su, Jin and Zhuang, Kai and Jiang, Shiyu and Yuan, Zichen and Zheng, Wei and Dai, Fengyuan and Zhou, Yuyang and others}, 93 | journal={bioRxiv}, 94 | pages={2025--01}, 95 | year={2025}, 96 | publisher={Cold Spring Harbor Laboratory} 97 | } 98 | ``` 99 | ### Other resources 100 | 101 | - [ProTrek](https://www.biorxiv.org/content/10.1101/2024.05.30.596740v2) and its [online server](http://search-protrek.com/) 102 | - [Pinal](https://www.biorxiv.org/content/10.1101/2024.08.01.606258v2) and its [online server](http://www.denovo-pinal.com/) 103 | - [SaprotHub](https://www.biorxiv.org/content/10.1101/2024.05.24.595648v5) and its [online server](https://colab.research.google.com/github/westlake-repl/SaprotHub/blob/main/colab/SaprotHub_v2.ipynb?hl=en) 104 | -------------------------------------------------------------------------------- /model/Evolla/fusion_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange, repeat 3 | from einops_exts import rearrange_many 4 | from torch import einsum, nn 5 | 6 | 7 | def FeedForward(dim, mult=4): 8 | inner_dim = int(dim * mult) 9 | return nn.Sequential( 10 | nn.LayerNorm(dim), 11 | nn.Linear(dim, inner_dim, bias=False), 12 | nn.GELU(), 13 | nn.Linear(inner_dim, dim, bias=False), 14 | ) 15 | 16 | 17 | class SequenceCompressorAttention(nn.Module): 18 | def __init__(self, dim, dim_head=64, heads=8): 19 | super().__init__() 20 | self.scale = dim_head**-0.5 21 | self.heads = heads 22 | inner_dim = dim_head * heads 23 | 24 | self.norm_media = nn.LayerNorm(dim) 25 | self.norm_latents = nn.LayerNorm(dim) 26 | 27 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 28 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 29 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 30 | 31 | def forward(self, x, latents, mask): 32 | """ 33 | Args: 34 | x (torch.Tensor): image features 35 | shape (b, n1, D) 36 | latent (torch.Tensor): latent features 37 | shape (b, n2, D); n2: num of latent tokens 38 | """ 39 | x = self.norm_media(x) 40 | latents = self.norm_latents(latents) 41 | 42 | h = self.heads 43 | 44 | q = self.to_q(latents) 45 | kv_input = torch.cat((x, latents), dim=-2) 46 | k, v = self.to_kv(kv_input).chunk( 47 | 2, dim=-1 48 | ) # each: batch_size, max_protein_length+num_latents, dim_head*num_heads 49 | 50 | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) 51 | q = q * self.scale # batch_size, num_heads, num_latents, dim_head 52 | 53 | # attention 54 | sim = einsum("... i d, ... j d -> ... i j", q, k) 55 | 56 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 57 | 58 | bs, nh, skd, okd = sim.shape 59 | mask = repeat(mask, "bs okd -> bs nh skd okd", nh=nh, skd=skd) 60 | 61 | sim = sim.masked_fill((1 - mask).bool(), -1e4) 62 | # sim = sim + (1 - mask) * torch.tensor(float('-inf'), dtype=sim.dtype) # 加上mask 63 | attn = sim.softmax(dim=-1) 64 | 65 | out = einsum("... i j, ... j d -> ... i d", attn, v) 66 | 67 | out = rearrange(out, "b h n d -> b n (h d)", h=h) 68 | return self.to_out(out) 69 | 70 | 71 | class SequenceCompressorResampler(nn.Module): 72 | def __init__( 73 | self, 74 | protein_repr_dim, 75 | output_repr_dim, 76 | depth=6, 77 | dim_head=64, 78 | heads=8, 79 | num_latents=64, 80 | ff_mult=4, 81 | ): 82 | super().__init__() 83 | self.latents = nn.Parameter(torch.randn(num_latents, protein_repr_dim)) 84 | 85 | self.layers = nn.ModuleList([]) 86 | for _ in range(depth): 87 | self.layers.append( 88 | nn.ModuleList( 89 | [ 90 | SequenceCompressorAttention( 91 | dim=protein_repr_dim, dim_head=dim_head, heads=heads 92 | ), 93 | FeedForward(dim=protein_repr_dim, mult=ff_mult), 94 | ] 95 | ) 96 | ) 97 | 98 | self.norm = nn.LayerNorm(output_repr_dim) 99 | 100 | self.protein_projector = nn.Linear(protein_repr_dim, output_repr_dim) 101 | 102 | self.num_latents = num_latents 103 | 104 | @property 105 | def device(self): 106 | return self.latents.device 107 | 108 | @property 109 | def dtype(self): 110 | return self.latents.dtype 111 | 112 | def forward(self, embeds, mask): 113 | 114 | b = embeds.shape[0] 115 | 116 | bs, _ = mask.shape # bs, max_protein_length 117 | latent_mask = torch.ones(bs, self.num_latents).to(mask.device) 118 | mask = torch.cat( 119 | (mask, latent_mask), dim=1 120 | ) # bs, max_protein_length + num_latents 121 | 122 | # blocks 123 | latents = repeat(self.latents, "n d -> b n d", b=b) 124 | for attn, ff in self.layers: 125 | latents = attn(embeds, latents, mask) + latents 126 | latents = ff(latents) + latents 127 | 128 | transformed_feature = self.protein_projector(latents) 129 | 130 | return self.norm(transformed_feature) 131 | 132 | class MLPResampler(nn.Module): 133 | def __init__( 134 | self, 135 | protein_repr_dim, 136 | output_repr_dim, 137 | ): 138 | super().__init__() 139 | self.model = nn.Sequential( 140 | nn.Linear(protein_repr_dim, output_repr_dim), 141 | nn.ReLU(), 142 | nn.Linear(output_repr_dim, output_repr_dim), 143 | nn.LayerNorm(output_repr_dim), 144 | ) 145 | 146 | def forward(self, embeds, mask): 147 | return self.model(embeds) -------------------------------------------------------------------------------- /utils/others.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | # from Bio import SeqIO 9 | from torch.nn.utils.rnn import pad_sequence 10 | from tqdm import tqdm 11 | import yaml 12 | from utils.easydict import MyEasyDict 13 | from model.model_interface import ModelInterface 14 | 15 | structure_encoder_name_2_protein_dim = { 16 | "SaProt_35M_AF2": 480, 17 | "SaProt_650M_AF2": 1280, 18 | } 19 | 20 | protein_encoder_name_2_protein_dim = { 21 | "esm2_t12_35M_UR50D": 480, 22 | "esm2_t33_650M_UR50D": 1280, 23 | "SaProt_35M_AF2": 480, 24 | "SaProt_650M_AF2": 1280, 25 | "ProTrek_35M_seq": 480, 26 | "ProTrek_650M_seq": 1280, 27 | } 28 | 29 | llm_name_2_llm_embedding_dim = { 30 | "opt-350m": 512, 31 | "facebook-opt-350m": 512, 32 | "meta-llama_Meta-Llama-3-8B": 4096, 33 | "meta-llama_Meta-Llama-3-8B-Instruct": 4096, 34 | "opt-2.7b": 2560, 35 | "Qwen1.5-0.5B": 1024, 36 | "Qwen1.5-4B-Chat": 1024, 37 | "phi-1_5": 2048, 38 | "phi-2": 2560, 39 | "Llama2hf7b": 4096, 40 | } 41 | 42 | def setup_seed(seed): 43 | """set random seed for reproducibility. 44 | Args: 45 | seed (int): random seed to use. 46 | """ 47 | torch.manual_seed(seed) 48 | torch.cuda.manual_seed_all(seed) 49 | np.random.seed(seed) 50 | random.seed(seed) 51 | # torch.backends.cudnn.deterministic = True 52 | 53 | 54 | def align_model_config(config: MyEasyDict): 55 | """Align model config. Different model sometimes should share the same dimension, but it's not easy to set them manually. 56 | Args: 57 | config (MyEasyDict): model config. 58 | 59 | Returns: 60 | config (MyEasyDict): aligned model config. 61 | """ 62 | 63 | # if config.fusion_module.output_repr_dim is not set, it should be same as llm embedding dim 64 | llm_name = config.llm.hf_dir.split("/")[-1] # example: opt-350m 65 | llm_embedding_dim = llm_name_2_llm_embedding_dim[llm_name] 66 | 67 | if config.protein_encoder is not None: 68 | # get protein dim by protein_encoder.config_path 69 | protein_encoder_name = config.protein_encoder.config_path.split("/")[ 70 | -1 71 | ] # example: esm2_t12_35M_UR50D 72 | protein_encoder_dim = protein_encoder_name_2_protein_dim[protein_encoder_name] 73 | # assign protein_encoder_dim to config.protein_encoder.fusion_module.protein_repr_dim 74 | # config.fusion_module.protein_repr_dim = protein_encoder_dim 75 | config.protein_encoder.fusion_module.protein_repr_dim = protein_encoder_dim 76 | # config.fusion_module.output_repr_dim = llm_embedding_dim 77 | if config.protein_encoder.fusion_module.output_repr_dim is None: 78 | config.protein_encoder.fusion_module.output_repr_dim = llm_embedding_dim 79 | 80 | # align config.llm.cross_attention_config.encoder_dim with config.fusion_module.output_repr_dim 81 | if config.llm.get("cross_attention_config", None) is not None: 82 | # config.llm.cross_attention_config.encoder_dim = config.fusion_module.output_repr_dim 83 | config.llm.cross_attention_config.protein_encoder_dim = ( 84 | config.protein_encoder.fusion_module.output_repr_dim 85 | ) 86 | 87 | if config.structure_encoder is not None: 88 | if "config_path" in config.structure_encoder: # for saprot 89 | structure_encoder_name = config.structure_encoder.config_path.split("/")[-1] 90 | elif "tokenizer_path" in config.structure_encoder: # for structure embedding 91 | structure_encoder_name = config.structure_encoder.tokenizer_path.split("/")[ 92 | -1 93 | ] 94 | else: # for GNN 95 | structure_encoder_name = None 96 | if structure_encoder_name is not None: 97 | structure_encoder_dim = structure_encoder_name_2_protein_dim[ 98 | structure_encoder_name 99 | ] 100 | else: 101 | structure_encoder_dim = 512 # TODO 102 | 103 | if "fusion_module" in config.structure_encoder: 104 | config.structure_encoder.fusion_module.protein_repr_dim = ( 105 | structure_encoder_dim 106 | ) 107 | 108 | if config.structure_encoder.fusion_module.output_repr_dim is None: 109 | config.structure_encoder.fusion_module.output_repr_dim = ( 110 | llm_embedding_dim 111 | ) 112 | 113 | # align config.llm.cross_attention_config.encoder_dim with config.fusion_module.output_repr_dim 114 | if config.llm.get("cross_attention_config", None) is not None: 115 | if "fusion_module" in config.structure_encoder: 116 | config.llm.cross_attention_config.structure_encoder_dim = ( 117 | config.structure_encoder.fusion_module.output_repr_dim 118 | ) 119 | else: 120 | config.llm.cross_attention_config.structure_encoder_dim = ( 121 | structure_encoder_dim 122 | ) 123 | 124 | if config.msa_encoder is not None: 125 | msa_encoder_dim = 768 126 | 127 | if "fusion_module" in config.msa_encoder: 128 | config.msa_encoder.fusion_module.protein_repr_dim = msa_encoder_dim 129 | 130 | if config.msa_encoder.fusion_module.output_repr_dim is None: 131 | config.msa_encoder.fusion_module.output_repr_dim = llm_embedding_dim 132 | 133 | # align config.llm.cross_attention_config.encoder_dim with config.fusion_module.output_repr_dim 134 | if config.llm.get("cross_attention_config", None) is not None: 135 | if "fusion_module" in config.msa_encoder: 136 | config.llm.cross_attention_config.msa_encoder_dim = ( 137 | config.msa_encoder.fusion_module.output_repr_dim 138 | ) 139 | else: 140 | config.llm.cross_attention_config.msa_encoder_dim = msa_encoder_dim 141 | 142 | return config 143 | 144 | 145 | def filter_llama_weights(state_dict): 146 | """Filter out llama weights from state_dict because of training issues. The llama weights have already been loaded while initializing the model.""" 147 | llama_keys = [] 148 | for k, v in state_dict.items(): 149 | if k.startswith("llm.") and 'adapter' not in k: 150 | llama_keys.append(k) 151 | if k.startswith("model.3.") and 'adapter' not in k: 152 | llama_keys.append(k) 153 | for k in llama_keys: 154 | state_dict.pop(k) 155 | return state_dict 156 | 157 | 158 | def get_prompt(sequence, structure, question): 159 | """Generate prompt and SA sequence for SaProt. 160 | 161 | Args: 162 | sequence (str): amino acid sequence. 163 | structure (str): structure sequence represented by foldseek. 164 | question (str): question for the model. 165 | 166 | Returns: 167 | prompt (str): prompt for the model. 168 | sequence (str): sequence with structure information. 169 | """ 170 | sequence_template = "Question: {Question} Answer: " 171 | structure_template = "Question: {Question} Answer: " 172 | saprot_template = "Question: {Question} Answer: " 173 | if sequence is not None and structure is not None: 174 | if len(sequence) != len(structure): 175 | raise ValueError(f"The length of sequence and structure are not equal. {len(sequence)}!= {len(structure)}") 176 | _sequence = sequence.upper() 177 | _structure = structure.lower() 178 | sequence = "".join([f"{_seq}{_struct}" for _seq, _struct in zip(_sequence, _structure)]) 179 | print("all", sequence) 180 | prompt = saprot_template.format(Question=question) 181 | elif sequence is not None: 182 | _sequence = sequence.upper() 183 | _structure = "#" * len(_sequence) 184 | sequence = "".join([f"{_seq}{_struct}" for _seq, _struct in zip(_sequence, _structure)]) 185 | print("seqonly", sequence) 186 | prompt = sequence_template.format(Question=question) 187 | elif structure is not None: 188 | _sequence = "#" * len(structure) 189 | _structure = structure.lower() 190 | sequence = "".join([f"{_seq}{_struct}" for _seq, _struct in zip(_sequence, _structure)]) 191 | prompt = structure_template.format(Question=question) 192 | print("structonly", sequence) 193 | return prompt, sequence 194 | 195 | 196 | 197 | def load_config(config_path): 198 | with open(config_path, 'r', encoding='utf-8') as r: 199 | config = MyEasyDict(yaml.safe_load(r)) 200 | config.model.config = align_model_config(config.model.config) 201 | return config 202 | 203 | def load_model_from_config(config, local_rank=0, dtype=None): 204 | """load model from config. 205 | Args: 206 | config (MyEasyDict): config of the model. 207 | local_rank (int): local rank of the current process. 208 | dtype (str): data type of the model. Default is None. Options are "fp32", "fp16", "bf16". 209 | 210 | Returns: 211 | model (nn.Module): loaded model. 212 | """ 213 | model_py_path = config.model.pop("cls") 214 | model = ModelInterface.init_model(model_py_path, **config.model) 215 | model.eval() 216 | 217 | ckpt = torch.load(os.path.join(config.setting.from_checkpoint, "checkpoint", "mp_rank_00_model_states.pt"), map_location=f'cpu') 218 | state_dict = ckpt["module"] 219 | state_dict = filter_llama_weights(state_dict) 220 | model.load_state_dict(state_dict, strict=False) 221 | if dtype is None: 222 | pass 223 | elif dtype == "fp32": 224 | model.to(torch.float32) 225 | elif dtype == "bf16": 226 | model.to(torch.bfloat16) 227 | elif dtype == "fp16": 228 | model.to(torch.float16) 229 | else: 230 | raise ValueError(f"Unsupported data type: {dtype}, supported data types are 'fp32', 'fp16', 'bf16'") 231 | model.to(f'cuda:{local_rank}') 232 | return model -------------------------------------------------------------------------------- /model/Evolla/injection_module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | class RMSNorm(torch.nn.Module): 7 | def __init__(self, dim: int, eps: float = 1e-6): 8 | super().__init__() 9 | self.eps = eps 10 | self.weight = nn.Parameter(torch.ones(dim)) 11 | 12 | def _norm(self, x): 13 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 14 | 15 | def forward(self, x): 16 | output = self._norm(x.float()).type_as(x) 17 | return output * self.weight 18 | 19 | def FeedForward(dim, mult=None): 20 | if mult is None: 21 | mult = 4 22 | inner_dim = int(dim * mult) 23 | return nn.Sequential( 24 | nn.LayerNorm(dim), 25 | nn.Linear(dim, inner_dim, bias=False), 26 | nn.GELU(), 27 | nn.Linear(inner_dim, dim, bias=False), 28 | ) 29 | 30 | 31 | class CrossAttention(nn.Module): 32 | def __init__( 33 | self, 34 | num_attention_heads, 35 | hidden_size, 36 | protein_encoder_dim=None, # protein dim in fusion module 37 | structure_encoder_dim=None, # structure dim in fusion module 38 | msa_encoder_dim=None, # msa dim in fusion module 39 | ffn_mult=None, 40 | attention_probs_dropout_prob=None, 41 | enable_bias=False, 42 | ): 43 | super().__init__() 44 | self.scale = num_attention_heads**-0.5 45 | self.num_attention_heads = num_attention_heads 46 | self.attention_head_size = int(hidden_size / num_attention_heads) 47 | self.all_head_size = self.num_attention_heads * self.attention_head_size 48 | 49 | self.query = nn.Linear(hidden_size, self.all_head_size) 50 | if protein_encoder_dim is not None: 51 | self.key_protein = nn.Linear(protein_encoder_dim, self.all_head_size) 52 | self.value_protein = nn.Linear(protein_encoder_dim, self.all_head_size) 53 | else: 54 | self.key_protein = None 55 | self.value_protein = None 56 | 57 | if structure_encoder_dim is not None: 58 | self.key_structure = nn.Linear(structure_encoder_dim, self.all_head_size) 59 | self.value_structure = nn.Linear(structure_encoder_dim, self.all_head_size) 60 | else: 61 | self.key_structure = None 62 | self.value_structure = None 63 | 64 | if msa_encoder_dim is not None: 65 | self.key_msa = nn.Linear(msa_encoder_dim, self.all_head_size) 66 | self.value_msa = nn.Linear(msa_encoder_dim, self.all_head_size) 67 | else: 68 | self.key_msa = None 69 | self.value_msa = None 70 | 71 | self.attention_norm = RMSNorm(hidden_size) 72 | 73 | self.dropout = nn.Dropout(attention_probs_dropout_prob) 74 | 75 | self.out_proj = nn.Linear(hidden_size, hidden_size, bias=enable_bias) 76 | 77 | self.ff = FeedForward(hidden_size, ffn_mult) 78 | self.gate_attention = nn.Parameter(torch.tensor([0.0])) 79 | self.gate_ffw = nn.Parameter(torch.tensor([0.0])) 80 | 81 | def cross_attention( 82 | self, 83 | query_states, 84 | protein_key_value_states, 85 | structure_key_value_states, 86 | msa_key_value_states, 87 | query_attn_mask, 88 | protein_kv_attn_mask, 89 | structure_kv_attn_mask, 90 | msa_kv_attn_mask, 91 | ): 92 | """ 93 | query_states: text 94 | key_value_states: protein 95 | query_states: [bs, query_seq_len, dim] 96 | key_value_states: [bs, kv_seq_len, dim] 97 | query_attn_mask: [bs, query_seq_len] 98 | kv_attn_mask: [bs, kv_seq_len] 99 | """ 100 | 101 | # Concatenate protein and structure 102 | kv_attn_mask = [protein_kv_attn_mask, structure_kv_attn_mask, msa_kv_attn_mask] 103 | kv_attn_mask = [_ for _ in kv_attn_mask if _ is not None] 104 | if not kv_attn_mask: 105 | raise ValueError( 106 | "At least one modality should be provided for cross attention." 107 | ) 108 | kv_attn_mask = torch.cat(kv_attn_mask, dim=1) 109 | 110 | query_layer = self.attention_norm(query_states) 111 | 112 | # Warning: This place might cause issues, refers to 113 | # https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214/13 114 | # Solution: add `DISABLE_ADDMM_CUDA_LT=1` as environment variable 115 | # Apply linear transformation to input_query, input_key, and input_value 116 | query_layer = self.query(query_layer) # [bs, querylength, dim] 117 | 118 | if self.key_protein is not None and self.value_protein is not None: 119 | protein_key_value_states = protein_key_value_states.to(query_states) 120 | key_layer_protein = self.key_protein( 121 | protein_key_value_states 122 | ) # [bs, keylength, dim] 123 | value_layer_protein = self.value_protein( 124 | protein_key_value_states 125 | ) # [bs, keylength, dim] 126 | else: 127 | key_layer_protein = None 128 | value_layer_protein = None 129 | 130 | if self.key_structure is not None and self.value_structure is not None: 131 | structure_key_value_states = structure_key_value_states.to(query_states) 132 | key_layer_structure = self.key_structure( 133 | structure_key_value_states 134 | ) # [bs, keylength, dim] 135 | value_layer_structure = self.value_structure( 136 | structure_key_value_states 137 | ) # [bs, keylength, dim] 138 | else: 139 | key_layer_structure = None 140 | value_layer_structure = None 141 | 142 | if self.key_msa is not None and self.value_msa is not None: 143 | msa_key_value_states = msa_key_value_states.to(query_states) 144 | key_layer_msa = self.key_msa(msa_key_value_states) # [bs, keylength, dim] 145 | value_layer_msa = self.value_msa( 146 | msa_key_value_states 147 | ) # [bs, keylength, dim] 148 | else: 149 | key_layer_msa = None 150 | value_layer_msa = None 151 | 152 | key_layer = [key_layer_protein, key_layer_structure, key_layer_msa] 153 | key_layer = [_ for _ in key_layer if _ is not None] 154 | key_layer = torch.cat(key_layer, dim=1) 155 | 156 | value_layer = [value_layer_protein, value_layer_structure, value_layer_msa] 157 | value_layer = [_ for _ in value_layer if _ is not None] 158 | value_layer = torch.cat(value_layer, dim=1) 159 | 160 | query_layer = self.transpose_for_scores( 161 | query_layer 162 | ) # [bs, numheads, querylength, dim/numheads] 163 | key_layer = self.transpose_for_scores( 164 | key_layer 165 | ) # [bs, numheads, keylength, dim/numheads] 166 | value_layer = self.transpose_for_scores( 167 | value_layer 168 | ) # [bs, numheads, keylength, dim/numheads] 169 | 170 | query_layer = query_layer * self.scale 171 | 172 | # attention_mask: [bs, 1, querylength, keylength] 173 | attention_mask = ( 174 | query_attn_mask[:, None, :, None] * kv_attn_mask[:, None, None, :] 175 | ) 176 | # Compute the scaled dot-product attention scores 177 | attn_weights = torch.matmul( 178 | query_layer, key_layer.transpose(-1, -2) 179 | ) # [bs, numheads, querylength, keylength] 180 | attn_weights = ( 181 | attn_weights - attn_weights.amax(dim=-1, keepdim=True).detach() 182 | ) # To stablize score 183 | attention_scores = attn_weights.masked_fill( 184 | (1 - attention_mask).bool(), torch.finfo(attn_weights.dtype).min 185 | ) # [bs, numheads, querylength, keylength] 186 | 187 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 188 | 189 | # attention_probs_dropped = self.dropout(attention_probs) 190 | 191 | context_layer = torch.matmul( 192 | attention_probs, value_layer 193 | ) # [bs, numheads, querylength, dim/numheads] 194 | 195 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 196 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 197 | context_layer = context_layer.view(*new_context_layer_shape) 198 | 199 | context_layer = self.out_proj(context_layer) 200 | 201 | return context_layer 202 | 203 | def forward( 204 | self, 205 | query_states, 206 | protein_kv_states, 207 | structure_kv_states, 208 | msa_kv_states, 209 | query_attn_mask, 210 | protein_kv_attn_mask=None, 211 | structure_kv_attn_mask=None, 212 | msa_kv_attn_mask=None, 213 | protein_batch_mask=None, 214 | structure_batch_mask=None, 215 | msa_batch_mask=None, 216 | past_key_value=None, 217 | ): 218 | """ 219 | kv_states: protein 220 | query_states: text 221 | 222 | query_states: [bs, query_seq_len, dim] 223 | kv_states: [bs, kv_seq_len, dim] 224 | query_attn_mask: [bs, query_seq_len] 225 | kv_attn_mask: [bs, kv_seq_len], default None 226 | past_key_value: [bs, past_kv_seq_len, dim], default None 227 | """ 228 | query_seq_len = query_states.shape[1] 229 | if protein_kv_states is not None: 230 | bs, protein_kv_seq_len, dim = protein_kv_states.shape 231 | if protein_kv_attn_mask is None: 232 | protein_kv_attn_mask = ( 233 | torch.ones(bs, protein_kv_seq_len) 234 | * protein_batch_mask.expand(size=(protein_kv_seq_len, bs)).T 235 | ).to(protein_kv_states.device) 236 | else: 237 | protein_kv_attn_mask = None 238 | 239 | if structure_kv_states is not None: 240 | bs, structure_kv_seq_len, dim = structure_kv_states.shape 241 | if structure_kv_attn_mask is None: 242 | structure_kv_attn_mask = ( 243 | torch.ones(bs, structure_kv_seq_len) 244 | * structure_batch_mask.expand(size=(structure_kv_seq_len, bs)).T 245 | ).to(structure_kv_states.device) 246 | else: 247 | structure_kv_attn_mask = None 248 | 249 | if msa_kv_states is not None: 250 | bs, msa_kv_seq_len, dim = msa_kv_states.shape 251 | if msa_kv_attn_mask is None: 252 | msa_kv_attn_mask = ( 253 | torch.ones(bs, msa_kv_seq_len) 254 | * msa_batch_mask.expand(size=(msa_kv_seq_len, bs)).T 255 | ).to(msa_kv_states.device) 256 | else: 257 | msa_kv_attn_mask = None 258 | hidden_states = query_states 259 | # only when there's at least one valid modality, crossattention will be performed 260 | if (protein_kv_states is not None and protein_kv_attn_mask.any()) or ( 261 | structure_kv_states is not None and structure_kv_attn_mask.any() 262 | ) or ( 263 | msa_kv_states is not None and msa_kv_attn_mask.any() 264 | ): 265 | residual = hidden_states 266 | hidden_states = self.cross_attention( 267 | query_states=hidden_states, 268 | protein_key_value_states=protein_kv_states, 269 | structure_key_value_states=structure_kv_states, 270 | msa_key_value_states=msa_kv_states, 271 | query_attn_mask=query_attn_mask, 272 | protein_kv_attn_mask=protein_kv_attn_mask, 273 | structure_kv_attn_mask=structure_kv_attn_mask, 274 | msa_kv_attn_mask=msa_kv_attn_mask, 275 | ) # [bs, query_seq_len, dim] 276 | # tanh gate 277 | hidden_states = torch.tanh(self.gate_attention) * hidden_states 278 | 279 | hidden_states = residual + hidden_states # input_query 280 | 281 | residual = hidden_states 282 | hidden_states = self.ff(hidden_states) * torch.tanh(self.gate_ffw) 283 | hidden_states = residual + hidden_states 284 | 285 | return hidden_states 286 | 287 | def transpose_for_scores(self, x): 288 | new_x_shape = x.size()[:-1] + ( 289 | self.num_attention_heads, 290 | self.attention_head_size, 291 | ) 292 | x = x.view(*new_x_shape) 293 | return x.permute(0, 2, 1, 3) 294 | -------------------------------------------------------------------------------- /model/Evolla/Evolla_model.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from model.model_interface import register_model 3 | from utils.easydict import MyEasyDict 4 | import torch 5 | 6 | from .encoder_interface import EncoderInterface 7 | from .llm_interface import LLMInterface 8 | 9 | @register_model 10 | class EvollaModel(pl.LightningModule): 11 | def __init__(self, 12 | config: MyEasyDict, 13 | **kwargs): 14 | """ 15 | Initialize the Evolla. 16 | Args: 17 | config (MyEasyDict): Configuration of the Evolla. 18 | """ 19 | super().__init__() 20 | self.verbose = config.get('verbose', False) 21 | self.config = config 22 | self.generate_config = kwargs.pop('generate_config', {}) 23 | 24 | if len(self.generate_config) == 0: 25 | print("Warning: No generate config is provided, the generate config now is \{\}") 26 | else: 27 | print("Generate config is provided, the generate config is: ", self.generate_config) 28 | 29 | self.initialize_model() 30 | 31 | self.special_pad_id = -100 32 | 33 | @staticmethod 34 | def init_protein_encoder(config: dict): 35 | """ 36 | Initialize protein encoder 37 | Args: 38 | config: A dictionary containing the configuration of the protein encoder 39 | 40 | Returns: 41 | A protein encoder 42 | """ 43 | encoder_py_path = config.pop("cls") 44 | model = EncoderInterface.init_encoder(encoder_py_path, **config) 45 | return model 46 | 47 | @staticmethod 48 | def init_structure_encoder(config: dict): 49 | """ 50 | Initialize structure encoder 51 | Args: 52 | config: A dictionary containing the configuration of the structure encoder 53 | Returns: 54 | A structure encoder 55 | """ 56 | encoder_py_path = config.pop("cls") 57 | model = EncoderInterface.init_encoder(encoder_py_path, **config) 58 | return model 59 | 60 | @staticmethod 61 | def init_msa_transformer_encoder(config: dict): 62 | """ 63 | Initialize protein encoder 64 | Args: 65 | config: A dictionary containing the configuration of the protein encoder 66 | 67 | Returns: 68 | A protein evoformer encoder 69 | """ 70 | msa_transformer_py_path = config.pop("cls") 71 | model = EncoderInterface.init_encoder(msa_transformer_py_path, **config) 72 | return model 73 | 74 | @staticmethod 75 | def init_llm(config: dict): 76 | """ 77 | Initialize LLM 78 | Args: 79 | config: A dictionary containing the configuration of the LLM 80 | 81 | Returns: 82 | A LLM 83 | """ 84 | llm_py_path = config.pop("cls") 85 | model = LLMInterface.init_llm(llm_py_path, **config) 86 | return model 87 | 88 | def initialize_model(self) -> None: 89 | """Initialize the Evolla model.""" 90 | # torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) 91 | if "protein_encoder" in self.config: 92 | if self.verbose: 93 | print("Loading Sequence Encoder...", flush=True) 94 | self.protein_encoder = self.init_protein_encoder( 95 | self.config["protein_encoder"] 96 | ) 97 | else: 98 | self.protein_encoder = None 99 | 100 | if "msa_encoder" in self.config: 101 | if self.verbose: 102 | print("Loading MSA Tranformer Encoder...", flush=True) 103 | self.msa_encoder = self.init_msa_transformer_encoder( 104 | self.config["msa_encoder"] 105 | ) 106 | else: 107 | self.msa_encoder = None 108 | 109 | if "structure_encoder" in self.config: 110 | if self.verbose: 111 | print("Loading Structure Encoder...", flush=True) 112 | self.structure_encoder = self.init_structure_encoder( 113 | self.config["structure_encoder"] 114 | ) 115 | else: 116 | self.structure_encoder = None 117 | # print("Loading Fusion Module...", flush=True) 118 | # self.fusion_module = self.init_fusion_module(self.config["fusion_module"]) 119 | if self.verbose: 120 | print("Loading LLM...", flush=True) 121 | self.llm = self.init_llm(self.config["llm"]) 122 | self.llm_tokenizer = self.llm.tokenizer 123 | 124 | if self.protein_encoder is not None: 125 | self.freeze_protein_encoder_layers() 126 | 127 | if self.structure_encoder is not None: 128 | self.freeze_structure_encoder_layers() 129 | 130 | if self.msa_encoder is not None: 131 | self.freeze_msa_encoder_layers() 132 | 133 | self.freeze_llm_layers() 134 | 135 | def freeze_protein_encoder_layers(self): 136 | for name, param in self.protein_encoder.named_parameters(): 137 | param.requires_grad = False 138 | if "resampler" in name: 139 | param.requires_grad = True 140 | 141 | def freeze_structure_encoder_layers(self): 142 | for name, param in self.structure_encoder.named_parameters(): 143 | param.requires_grad = False 144 | if "resampler" in name: 145 | param.requires_grad = True 146 | 147 | def freeze_msa_encoder_layers(self): 148 | for name, param in self.msa_encoder.named_parameters(): 149 | param.requires_grad = False 150 | if "resampler" in name: 151 | param.requires_grad = True 152 | 153 | def freeze_llm_layers(self): 154 | for name, param in self.llm.named_parameters(): 155 | if "adapter" in name: 156 | param.requires_grad = True 157 | else: 158 | param.requires_grad = False 159 | 160 | 161 | def input_process( 162 | self, 163 | questions: list, 164 | answers: list = None, 165 | ): 166 | """ 167 | Args: 168 | protein_embeds: encoded embedding of protein sequence 169 | templates: template used as container of question and answer pair 170 | questions: A list of prompts. 171 | answers: A list of answers. 172 | """ 173 | return self.llm.input_process( 174 | questions=questions, 175 | answers=answers, 176 | max_length=self.config["text_length"], 177 | special_pad_id=self.special_pad_id, 178 | ) 179 | 180 | def forward( 181 | self, 182 | seqs: tuple, 183 | foldseeks: tuple, 184 | questions: list, 185 | answers: list, 186 | msa_embeds: torch.Tensor = None, 187 | msa_atts: torch.Tensor = None, 188 | **kwargs, 189 | ): 190 | """Forward pass of the Evolla model. 191 | Args: 192 | seqs (tuple): Amino acid sequences of proteins. 193 | foldseeks (tuple): Foldseek sequences of proteins. 194 | questions (list): A list of prompts. 195 | answers (list): A list of answers. 196 | msa_embeds (torch.Tensor, Optional): MSA embeddings. 197 | msa_atts (torch.Tensor, Optional): MSA attention masks. 198 | 199 | Returns: 200 | return_dict (dict): A dictionary containing the predicted logits, prompts, answers, and raw text masks. 201 | labels (torch.Tensor): A tensor containing the labels. 202 | """ 203 | 204 | if self.protein_encoder is not None: 205 | resampler_protein_repr, protein_repr, protein_attn, protein_batch_mask = self.protein_encoder(seqs) 206 | else: 207 | resampler_protein_repr = None 208 | protein_batch_mask = None 209 | protein_repr = None 210 | protein_attn = None 211 | 212 | if self.structure_encoder is not None: 213 | resampler_structure_repr, structure_repr, structure_attn, structure_batch_mask = self.structure_encoder(foldseeks) 214 | else: 215 | resampler_structure_repr = None 216 | structure_batch_mask = None 217 | structure_repr = None 218 | structure_attn = None 219 | 220 | if self.msa_encoder is not None: 221 | resampler_msa_repr, msa_repr, msa_attn, msa_batch_mask = self.msa_encoder( 222 | msa_embeds, 223 | msa_atts, 224 | ) 225 | else: 226 | resampler_msa_repr = None 227 | msa_repr = None 228 | msa_attn = None 229 | msa_batch_mask = None 230 | 231 | input_ids, embeds, attns, labels, raw_text_masks = self.input_process( 232 | questions=questions, 233 | answers=answers, 234 | ) 235 | 236 | outputs = self.llm.forward( 237 | input_ids=input_ids, 238 | inputs_embeds=embeds, 239 | inputs_mask=attns, 240 | protein_feats=resampler_protein_repr, 241 | structure_feats=resampler_structure_repr, 242 | msa_feats=resampler_msa_repr, 243 | protein_batch_mask=protein_batch_mask, 244 | structure_batch_mask=structure_batch_mask, 245 | msa_batch_mask=msa_batch_mask, 246 | ) 247 | logits = outputs.logits 248 | 249 | return_dict = { 250 | "logits": logits, 251 | "prompts": questions, 252 | "answers": answers, 253 | "raw_text_masks": raw_text_masks, 254 | } 255 | if "comment_types" in kwargs: 256 | return_dict["comment_types"] = kwargs["comment_types"] 257 | 258 | return return_dict, labels 259 | 260 | 261 | def generate( 262 | self, 263 | seqs: tuple, 264 | foldseeks: tuple, 265 | questions: list, 266 | msa_embeds: torch.Tensor = None, 267 | msa_atts: torch.Tensor = None, 268 | **kwargs, 269 | ) -> str: 270 | """ 271 | Generate answer for the question. 272 | Args: 273 | seqs (tuple): Amino acid sequences of proteins. 274 | foldseeks (tuple): Foldseek sequences of proteins. 275 | questions (list): A list of questions. 276 | msa_embeds (torch.Tensor, Optional): MSA embeddings. 277 | msa_atts (torch.Tensor, Optional): MSA attention masks. 278 | 279 | Returns: 280 | answers (list): A list of predicted answers. 281 | """ 282 | 283 | with torch.no_grad(): 284 | if self.protein_encoder is not None: 285 | ( 286 | resampler_protein_repr, 287 | protein_repr, 288 | protein_attn, 289 | protein_batch_mask, 290 | ) = self.protein_encoder(seqs) 291 | else: 292 | resampler_protein_repr = None 293 | protein_batch_mask = None 294 | protein_repr = None 295 | protein_attn = None 296 | 297 | if self.structure_encoder is not None: 298 | ( 299 | resampler_structure_repr, 300 | structure_repr, 301 | structure_attn, 302 | structure_batch_mask, 303 | ) = self.structure_encoder(foldseeks) 304 | else: 305 | resampler_structure_repr = None 306 | structure_batch_mask = None 307 | structure_repr = None 308 | structure_attn = None 309 | 310 | if self.msa_encoder is not None: 311 | resampler_msa_repr, msa_repr, msa_attn, msa_batch_mask = self.msa_encoder( 312 | msa_embeds, 313 | msa_atts, 314 | ) 315 | else: 316 | resampler_msa_repr = None 317 | msa_batch_mask = None 318 | msa_repr = None 319 | msa_attn = None 320 | 321 | input_ids, embeds, attns, labels, raw_text_masks = self.input_process( 322 | questions=questions, 323 | ) 324 | 325 | predicted_answer = self.llm.generate( 326 | input_ids=input_ids, 327 | inputs_mask=attns, 328 | protein_feats=resampler_protein_repr, 329 | structure_feats=resampler_structure_repr, 330 | msa_feats=resampler_msa_repr, 331 | protein_batch_mask=protein_batch_mask, 332 | structure_batch_mask=structure_batch_mask, 333 | msa_batch_mask=msa_batch_mask, 334 | **kwargs, 335 | ) 336 | 337 | return self.llm.tokenizer.batch_decode( 338 | predicted_answer, 339 | skip_special_tokens=True, 340 | clean_up_tokenization_spaces=False, 341 | ) -------------------------------------------------------------------------------- /model/Evolla/llama_llm.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | import os 4 | import random 5 | import types 6 | from pathlib import Path 7 | from typing import List, Optional, Tuple, Union 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training 14 | 15 | from torch.optim.lr_scheduler import StepLR 16 | from transformers import (AutoTokenizer, BitsAndBytesConfig, LlamaConfig, 17 | LlamaForCausalLM) 18 | from transformers.cache_utils import Cache, DynamicCache 19 | 20 | from transformers.modeling_outputs import (BaseModelOutputWithPast, 21 | CausalLMOutputWithPast, 22 | QuestionAnsweringModelOutput, 23 | SequenceClassifierOutputWithPast) 24 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 25 | from transformers.utils import (add_start_docstrings, 26 | add_start_docstrings_to_model_forward, 27 | is_flash_attn_2_available, 28 | is_flash_attn_greater_or_equal_2_10, logging, 29 | replace_return_docstrings) 30 | 31 | from .injection_module import CrossAttention 32 | # from .llama.modeling_llama import LlamaForCausalLM, LlamaModel 33 | from .llm_interface import register_llm 34 | from transformers import AutoConfig 35 | 36 | # Copyright (c) Meta Platforms, Inc. and affiliates. 37 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 38 | 39 | 40 | _CONFIG_FOR_DOC = "LlamaConfig" 41 | LLAMA_INPUTS_DOCSTRING = r""" 42 | Args: 43 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 44 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 45 | it. 46 | 47 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 48 | [`PreTrainedTokenizer.__call__`] for details. 49 | 50 | [What are input IDs?](../glossary#input-ids) 51 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 52 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 53 | 54 | - 1 for tokens that are **not masked**, 55 | - 0 for tokens that are **masked**. 56 | 57 | [What are attention masks?](../glossary#attention-mask) 58 | 59 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 60 | [`PreTrainedTokenizer.__call__`] for details. 61 | 62 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see 63 | `past_key_values`). 64 | 65 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 66 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 67 | information on the default strategy. 68 | 69 | - 1 indicates the head is **not masked**, 70 | - 0 indicates the head is **masked**. 71 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 72 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 73 | config.n_positions - 1]`. 74 | 75 | [What are position IDs?](../glossary#position-ids) 76 | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): 77 | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 78 | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` 79 | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. 80 | 81 | Two formats are allowed: 82 | - a [`~cache_utils.Cache`] instance; 83 | - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 84 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy 85 | cache format. 86 | 87 | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the 88 | legacy cache format will be returned. 89 | 90 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 91 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 92 | of shape `(batch_size, sequence_length)`. 93 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 94 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 95 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 96 | model's internal embedding lookup matrix. 97 | use_cache (`bool`, *optional*): 98 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 99 | `past_key_values`). 100 | output_attentions (`bool`, *optional*): 101 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 102 | tensors for more detail. 103 | output_hidden_states (`bool`, *optional*): 104 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 105 | more detail. 106 | return_dict (`bool`, *optional*): 107 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 108 | cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): 109 | Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, 110 | this tensor is not affected by padding. It is used to update the cache in the correct position and to infer 111 | the complete sequence length. 112 | """ 113 | 114 | 115 | def add_adapter_for_LlamaForCausalLM(llama_for_causalLM, cross_attention_config, num_add_layers=8): 116 | # follow the same config as the original model 117 | if num_add_layers < 1: 118 | return llama_for_causalLM 119 | if hasattr(llama_for_causalLM.model.layers[0].self_attn, "num_heads"): 120 | cross_attention_config["num_attention_heads"] = llama_for_causalLM.model.layers[0].self_attn.num_heads 121 | elif hasattr(llama_for_causalLM.model.layers[0].self_attn.config, "num_attention_heads"): 122 | cross_attention_config["num_attention_heads"] = llama_for_causalLM.model.layers[0].self_attn.config.num_attention_heads 123 | else: 124 | raise ValueError("Cannot find num_heads or num_attention_heads in self_attn of the first layer of the model.") 125 | 126 | if hasattr(llama_for_causalLM.model.layers[0].self_attn, "hidden_size"): 127 | cross_attention_config["hidden_size"] = llama_for_causalLM.model.layers[0].self_attn.hidden_size 128 | elif hasattr(llama_for_causalLM.model.layers[0].self_attn.config, "hidden_size"): 129 | cross_attention_config["hidden_size"] = llama_for_causalLM.model.layers[0].self_attn.config.hidden_size 130 | else: 131 | raise ValueError("Cannot find hidden_size in self_attn of the first layer of the model.") 132 | 133 | num_layers = len(llama_for_causalLM.model.layers) 134 | every_n_layers = max(num_layers // num_add_layers, 1) 135 | # add adapter for each decoder layer 136 | for i, layer in enumerate(llama_for_causalLM.model.layers): 137 | if (i + 1) % every_n_layers == 0: 138 | llama_for_causalLM.model.layers[i].adapter = CrossAttention(**cross_attention_config) 139 | 140 | return llama_for_causalLM 141 | 142 | 143 | def bind_forward_for_llama(llama_for_causalLM): 144 | """Bind `forward` function for llama models by `types.MethodType`""" 145 | 146 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 147 | def llama_model_forward( 148 | self, 149 | input_ids: torch.LongTensor = None, 150 | attention_mask: Optional[torch.Tensor] = None, 151 | position_ids: Optional[torch.LongTensor] = None, 152 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 153 | inputs_embeds: Optional[torch.FloatTensor] = None, 154 | use_cache: Optional[bool] = None, 155 | output_attentions: Optional[bool] = None, 156 | output_hidden_states: Optional[bool] = None, 157 | return_dict: Optional[bool] = None, 158 | cache_position: Optional[torch.LongTensor] = None, 159 | protein_feats: Optional[torch.FloatTensor] = None, 160 | structure_feats: Optional[torch.FloatTensor] = None, 161 | msa_feats: Optional[torch.FloatTensor] = None, 162 | protein_batch_mask: Optional[torch.Tensor] = None, 163 | structure_batch_mask: Optional[torch.Tensor] = None, 164 | msa_batch_mask: Optional[torch.Tensor] = None, 165 | **kwargs, 166 | ) -> Union[Tuple, BaseModelOutputWithPast]: 167 | output_attentions = ( 168 | output_attentions 169 | if output_attentions is not None 170 | else self.config.output_attentions 171 | ) 172 | output_hidden_states = ( 173 | output_hidden_states 174 | if output_hidden_states is not None 175 | else self.config.output_hidden_states 176 | ) 177 | use_cache = use_cache if use_cache is not None else self.config.use_cache 178 | return_dict = ( 179 | return_dict if return_dict is not None else self.config.use_return_dict 180 | ) 181 | 182 | if (input_ids is None) ^ (inputs_embeds is not None): 183 | raise ValueError( 184 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" 185 | ) 186 | 187 | if self.gradient_checkpointing and self.training and use_cache: 188 | use_cache = False 189 | 190 | if inputs_embeds is None: 191 | inputs_embeds = self.embed_tokens(input_ids) 192 | 193 | return_legacy_cache = False 194 | if use_cache and not isinstance( 195 | past_key_values, Cache 196 | ): # kept for BC (non `Cache` `past_key_values` inputs) 197 | return_legacy_cache = True 198 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 199 | 200 | if cache_position is None: 201 | past_seen_tokens = ( 202 | past_key_values.get_seq_length() if past_key_values is not None else 0 203 | ) 204 | cache_position = torch.arange( 205 | past_seen_tokens, 206 | past_seen_tokens + inputs_embeds.shape[1], 207 | device=inputs_embeds.device, 208 | ) 209 | if position_ids is None: 210 | position_ids = cache_position.unsqueeze(0) 211 | 212 | causal_mask = self._update_causal_mask( 213 | attention_mask, 214 | inputs_embeds, 215 | cache_position, 216 | past_key_values, 217 | output_attentions, 218 | ) 219 | 220 | # embed positions 221 | hidden_states = inputs_embeds 222 | 223 | # decoder layers 224 | all_hidden_states = () if output_hidden_states else None 225 | all_self_attns = () if output_attentions else None 226 | next_decoder_cache = None 227 | 228 | for decoder_layer in self.layers: 229 | if output_hidden_states: 230 | all_hidden_states += (hidden_states,) 231 | 232 | if self.gradient_checkpointing and self.training: 233 | if not hasattr(decoder_layer, 'adapter'): 234 | layer_outputs = self._gradient_checkpointing_func( 235 | decoder_layer.__call__, 236 | hidden_states, 237 | causal_mask, 238 | position_ids, 239 | past_key_values, 240 | output_attentions, 241 | use_cache, 242 | cache_position, 243 | ) 244 | else: 245 | layer_outputs = self._gradient_checkpointing_func( 246 | decoder_layer.__call__, 247 | hidden_states, 248 | causal_mask, 249 | position_ids, 250 | past_key_values, 251 | output_attentions, 252 | use_cache, 253 | cache_position, 254 | ) 255 | # keep the hidden_states only, cache other outputs 256 | hidden_states = layer_outputs[0] 257 | other_outputs = layer_outputs[1:] 258 | hidden_states = decoder_layer.adapter( 259 | query_states=hidden_states, 260 | protein_kv_states=protein_feats, 261 | structure_kv_states=structure_feats, 262 | msa_kv_states=msa_feats, 263 | protein_batch_mask=protein_batch_mask, 264 | structure_batch_mask=structure_batch_mask, 265 | msa_batch_mask=msa_batch_mask, 266 | query_attn_mask=attention_mask, 267 | ) 268 | layer_outputs = (hidden_states,) + other_outputs 269 | else: 270 | if not hasattr(decoder_layer, 'adapter'): 271 | layer_outputs = decoder_layer( 272 | hidden_states, 273 | attention_mask=causal_mask, 274 | position_ids=position_ids, 275 | past_key_value=past_key_values, 276 | output_attentions=output_attentions, 277 | use_cache=use_cache, 278 | cache_position=cache_position, 279 | ) 280 | else: 281 | layer_outputs = decoder_layer( 282 | hidden_states, 283 | attention_mask=causal_mask, 284 | position_ids=position_ids, 285 | past_key_value=past_key_values, 286 | output_attentions=output_attentions, 287 | use_cache=use_cache, 288 | cache_position=cache_position, 289 | ) 290 | 291 | # keep the hidden_states only, cache other outputs 292 | hidden_states = layer_outputs[0] 293 | other_outputs = layer_outputs[1:] 294 | hidden_states = decoder_layer.adapter( 295 | query_states=hidden_states, 296 | protein_kv_states=protein_feats, 297 | structure_kv_states=structure_feats, 298 | msa_kv_states=msa_feats, 299 | protein_batch_mask=protein_batch_mask, 300 | structure_batch_mask=structure_batch_mask, 301 | msa_batch_mask=msa_batch_mask, 302 | query_attn_mask=attention_mask, 303 | ) 304 | layer_outputs = (hidden_states,) + other_outputs 305 | 306 | hidden_states = layer_outputs[0] 307 | 308 | if use_cache: 309 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 310 | 311 | if output_attentions: 312 | all_self_attns += (layer_outputs[1],) 313 | 314 | hidden_states = self.norm(hidden_states) 315 | 316 | # add hidden states from the last decoder layer 317 | if output_hidden_states: 318 | all_hidden_states += (hidden_states,) 319 | 320 | next_cache = next_decoder_cache if use_cache else None 321 | if return_legacy_cache: 322 | next_cache = next_cache.to_legacy_cache() 323 | 324 | if not return_dict: 325 | return tuple( 326 | v 327 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] 328 | if v is not None 329 | ) 330 | return BaseModelOutputWithPast( 331 | last_hidden_state=hidden_states, 332 | past_key_values=next_cache, 333 | hidden_states=all_hidden_states, 334 | attentions=all_self_attns, 335 | ) 336 | 337 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 338 | @replace_return_docstrings( 339 | output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 340 | ) 341 | def llama_for_causalLM_forward( 342 | self, 343 | input_ids: torch.LongTensor = None, 344 | attention_mask: Optional[torch.Tensor] = None, 345 | position_ids: Optional[torch.LongTensor] = None, 346 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 347 | inputs_embeds: Optional[torch.FloatTensor] = None, 348 | labels: Optional[torch.LongTensor] = None, 349 | use_cache: Optional[bool] = None, 350 | output_attentions: Optional[bool] = None, 351 | output_hidden_states: Optional[bool] = None, 352 | return_dict: Optional[bool] = None, 353 | cache_position: Optional[torch.LongTensor] = None, 354 | protein_feats: Optional[torch.FloatTensor] = None, 355 | structure_feats: Optional[torch.FloatTensor] = None, 356 | msa_feats: Optional[torch.FloatTensor] = None, 357 | protein_batch_mask: Optional[torch.Tensor] = None, 358 | structure_batch_mask: Optional[torch.Tensor] = None, 359 | msa_batch_mask: Optional[torch.Tensor] = None, 360 | **kwargs 361 | ) -> Union[Tuple, CausalLMOutputWithPast]: 362 | r""" 363 | Args: 364 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 365 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 366 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 367 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 368 | 369 | Returns: 370 | 371 | Example: 372 | 373 | ```python 374 | >>> from transformers import AutoTokenizer, LlamaForCausalLM 375 | 376 | >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") 377 | >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 378 | 379 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 380 | >>> inputs = tokenizer(prompt, return_tensors="pt") 381 | 382 | >>> # Generate 383 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 384 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 385 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 386 | ```""" 387 | output_attentions = ( 388 | output_attentions 389 | if output_attentions is not None 390 | else self.config.output_attentions 391 | ) 392 | output_hidden_states = ( 393 | output_hidden_states 394 | if output_hidden_states is not None 395 | else self.config.output_hidden_states 396 | ) 397 | return_dict = ( 398 | return_dict if return_dict is not None else self.config.use_return_dict 399 | ) 400 | 401 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 402 | outputs = self.model( 403 | input_ids=input_ids, 404 | attention_mask=attention_mask, 405 | position_ids=position_ids, 406 | past_key_values=past_key_values, 407 | inputs_embeds=inputs_embeds, 408 | use_cache=use_cache, 409 | output_attentions=output_attentions, 410 | output_hidden_states=output_hidden_states, 411 | return_dict=return_dict, 412 | cache_position=cache_position, 413 | protein_feats=protein_feats, 414 | structure_feats=structure_feats, 415 | msa_feats=msa_feats, 416 | protein_batch_mask=protein_batch_mask, 417 | structure_batch_mask=structure_batch_mask, 418 | msa_batch_mask=msa_batch_mask, 419 | ) 420 | 421 | hidden_states = outputs[0] 422 | if self.config.pretraining_tp > 1: 423 | lm_head_slices = self.lm_head.weight.split( 424 | self.vocab_size // self.config.pretraining_tp, dim=0 425 | ) 426 | logits = [ 427 | F.linear(hidden_states, lm_head_slices[i]) 428 | for i in range(self.config.pretraining_tp) 429 | ] 430 | logits = torch.cat(logits, dim=-1) 431 | else: 432 | logits = self.lm_head(hidden_states) 433 | logits = logits.float() 434 | 435 | loss = None 436 | if labels is not None: 437 | # Shift so that tokens < n predict n 438 | shift_logits = logits[..., :-1, :].contiguous() 439 | shift_labels = labels[..., 1:].contiguous() 440 | # Flatten the tokens 441 | loss_fct = CrossEntropyLoss() 442 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 443 | shift_labels = shift_labels.view(-1) 444 | # Enable model parallelism 445 | shift_labels = shift_labels.to(shift_logits.device) 446 | loss = loss_fct(shift_logits, shift_labels) 447 | 448 | if not return_dict: 449 | output = (logits,) + outputs[1:] 450 | return (loss,) + output if loss is not None else output 451 | 452 | return CausalLMOutputWithPast( 453 | loss=loss, 454 | logits=logits, 455 | past_key_values=outputs.past_key_values, 456 | hidden_states=outputs.hidden_states, 457 | attentions=outputs.attentions, 458 | ) 459 | 460 | llama_for_causalLM.model.forward = types.MethodType( 461 | llama_model_forward, llama_for_causalLM.model 462 | ) 463 | llama_for_causalLM.forward = types.MethodType( 464 | llama_for_causalLM_forward, llama_for_causalLM 465 | ) 466 | 467 | return llama_for_causalLM 468 | 469 | 470 | def add_special_tokens_to_model_and_tokenizer(model, tokenizer, special_token): 471 | # add special tokens to tokenizer # 50265 472 | tokenizer.add_special_tokens({"additional_special_tokens": [special_token]}) 473 | # add special tokens to model # 50272 474 | if len(tokenizer) <= model.model.embed_tokens.weight.shape[0]: 475 | return model, tokenizer 476 | else: 477 | embedding_layer = model.model.embed_tokens 478 | embedding_layer.weight.data = torch.cat( 479 | [ 480 | embedding_layer.weight.data, 481 | torch.zeros(1, embedding_layer.weight.shape[1]).to( 482 | embedding_layer.weight.data 483 | ), 484 | ], 485 | dim=0, 486 | ) 487 | return model, tokenizer 488 | 489 | 490 | def bind_function_for_llama(llama_for_causalLM): 491 | def llama_for_casualLM_prepare_inputs_for_generation( 492 | self, 493 | input_ids, 494 | past_key_values=None, 495 | attention_mask=None, 496 | inputs_embeds=None, 497 | cache_position=None, 498 | position_ids=None, 499 | use_cache=True, 500 | **kwargs, 501 | ): 502 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens 503 | # Exception 1: when passing input_embeds, input_ids may be missing entries 504 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here 505 | if past_key_values is not None: 506 | if inputs_embeds is not None: # Exception 1 507 | input_ids = input_ids[:, -cache_position.shape[0] :] 508 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) 509 | input_ids = input_ids[:, cache_position] 510 | 511 | if attention_mask is not None and position_ids is None: 512 | # create position_ids on the fly for batch generation 513 | position_ids = attention_mask.long().cumsum(-1) - 1 514 | position_ids.masked_fill_(attention_mask == 0, 1) 515 | if past_key_values: 516 | position_ids = position_ids[:, -input_ids.shape[1] :] 517 | 518 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 519 | if inputs_embeds is not None and cache_position[0] == 0: 520 | model_inputs = {"inputs_embeds": inputs_embeds} 521 | else: 522 | model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases 523 | 524 | model_inputs.update( 525 | { 526 | "position_ids": position_ids, 527 | "cache_position": cache_position, 528 | "past_key_values": past_key_values, 529 | "use_cache": use_cache, 530 | "attention_mask": attention_mask, 531 | } 532 | ) 533 | model_inputs.update(kwargs) 534 | return model_inputs 535 | 536 | llama_for_causalLM.prepare_inputs_for_generation = types.MethodType( 537 | llama_for_casualLM_prepare_inputs_for_generation, llama_for_causalLM 538 | ) 539 | return llama_for_causalLM 540 | 541 | from transformers import AutoConfig, AutoModelForCausalLM 542 | from accelerate import load_checkpoint_and_dispatch 543 | from accelerate import init_empty_weights 544 | from transformers.integrations import HfDeepSpeedConfig 545 | 546 | @register_llm 547 | class LlamaAdapterModel(nn.Module): 548 | def __init__( 549 | self, 550 | hf_dir, 551 | cross_attention_config, 552 | load_pretrained=True, 553 | quantization=False, 554 | attn_implementation="sdpa", 555 | num_add_layers=8, 556 | ): 557 | """Adapter model for Llama. 558 | Args: 559 | hf_dir (str): Directory of the Hugging Face model. 560 | cross_attention_config (dict): Configuration of the cross-attention layer. 561 | load_pretrained (bool): Whether to load the pretrained model. Defaults to True. 562 | quantization (bool or str): Whether to use quantization. Defaults to False. Acceptable values are True, False, '8bit', and '4bit'. True means 8-bit quantization. '8bit' means 8-bit quantization. '4bit' means 4-bit quantization. 563 | attn_implementation (str): Implementation of the attention layer. Defaults to "sdpa". 564 | num_add_layers (int): Number of additional layers to add. Defaults to 8. 565 | """ 566 | super().__init__() 567 | if quantization is True or quantization == '8bit': 568 | assert load_pretrained, "load_pretrained should be True" 569 | quantization_config = BitsAndBytesConfig(load_in_8bit=True) 570 | print("8-bit Quantization is enabled") 571 | elif quantization == '4bit': 572 | quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) 573 | print("4-bit Quantization is enabled") 574 | else: 575 | quantization_config = None 576 | print("Quantization is disabled") 577 | 578 | if load_pretrained: 579 | self.model = LlamaForCausalLM.from_pretrained( 580 | hf_dir, 581 | quantization_config=quantization_config, 582 | attn_implementation=attn_implementation, 583 | torch_dtype=torch.bfloat16, 584 | ).train() 585 | self.model = prepare_model_for_kbit_training(self.model) 586 | else: 587 | config = AutoConfig.from_pretrained(hf_dir) 588 | self.model = LlamaForCausalLM(config) 589 | 590 | self.model = add_adapter_for_LlamaForCausalLM( 591 | self.model, cross_attention_config, num_add_layers=num_add_layers 592 | ) 593 | # bind `forward` function for llama models by `types.MethodType` 594 | self.model = bind_forward_for_llama(self.model) 595 | self.model = bind_function_for_llama(self.model) 596 | 597 | self.tokenizer = AutoTokenizer.from_pretrained(hf_dir, use_fast=False) 598 | self.tokenizer.pad_token = "<|reserved_special_token_0|>" 599 | 600 | def forward( 601 | self, 602 | input_ids, 603 | inputs_mask, 604 | protein_feats, 605 | structure_feats, 606 | msa_feats, 607 | protein_batch_mask, 608 | structure_batch_mask, 609 | msa_batch_mask, 610 | **kwargs 611 | ): 612 | output = self.model.forward( 613 | input_ids=input_ids, 614 | attention_mask=inputs_mask, 615 | protein_feats=protein_feats, 616 | structure_feats=structure_feats, 617 | msa_feats=msa_feats, 618 | protein_batch_mask=protein_batch_mask, 619 | structure_batch_mask=structure_batch_mask, 620 | msa_batch_mask=msa_batch_mask, 621 | output_hidden_states=True, 622 | ) 623 | return output 624 | 625 | def generate( 626 | self, 627 | input_ids, 628 | inputs_mask, 629 | protein_feats, 630 | structure_feats, 631 | msa_feats, 632 | protein_batch_mask, 633 | structure_batch_mask, 634 | msa_batch_mask, 635 | **kwargs 636 | ): 637 | terminators = [ 638 | self.tokenizer.eos_token_id, 639 | self.tokenizer.convert_tokens_to_ids("<|eot_id|>"), 640 | ] 641 | output = self.model.generate( 642 | input_ids, 643 | use_cache=False, 644 | attention_mask=inputs_mask, 645 | protein_feats=protein_feats, 646 | structure_feats=structure_feats, 647 | msa_feats=msa_feats, 648 | protein_batch_mask=protein_batch_mask, 649 | structure_batch_mask=structure_batch_mask, 650 | msa_batch_mask=msa_batch_mask, 651 | bos_token_id=self.tokenizer.bos_token_id, 652 | eos_token_id=terminators, 653 | **kwargs, 654 | ) 655 | output = output[:, input_ids.shape[-1]:] 656 | return output 657 | 658 | def embed_tokens(self, tokens): 659 | return self.model.model.embed_tokens(tokens.to(self.model.device)) 660 | 661 | def generate_prompt(self, question: str) -> str: 662 | """ 663 | Generate QA prompt for the Llama3-instruct 664 | 665 | Returns: Formatted prompt 666 | """ 667 | messages = [ 668 | {"role": "system", "content": "You are an AI expert that can answer any questions about protein."}, 669 | {"role": "user", "content": question}, 670 | ] 671 | 672 | prompt = self.tokenizer.apply_chat_template( 673 | messages, 674 | tokenize=False, 675 | add_generation_prompt=True, 676 | ) 677 | return prompt 678 | 679 | def input_process(self, 680 | questions: list, 681 | answers: list = None, 682 | max_length: int = 512, 683 | special_pad_id: int = -100): 684 | 685 | # Record original padding side 686 | original_padding_side = self.tokenizer.padding_side 687 | 688 | # Generate prompts for questions 689 | prompts = [self.generate_prompt(q) for q in questions] 690 | 691 | # Tokenize prompts and add left paddings 692 | self.tokenizer.padding_side = "left" 693 | prompt_inputs = self.tokenizer( 694 | prompts, 695 | add_special_tokens=False, 696 | return_tensors="pt", 697 | padding="longest", 698 | truncation=True, 699 | max_length=max_length, 700 | ) 701 | 702 | input_ids = prompt_inputs["input_ids"] 703 | attns = prompt_inputs["attention_mask"] 704 | embeds = self.embed_tokens(input_ids) 705 | 706 | # Create labels 707 | labels = torch.full_like(input_ids, special_pad_id) 708 | # Create raw text mask 709 | raw_text_mask = torch.zeros_like(input_ids) 710 | 711 | if answers is not None: 712 | # Add eos token 713 | answers_eos = [a + self.tokenizer.eos_token for a in answers] 714 | 715 | # Tokenize answers and add right paddings 716 | self.tokenizer.padding_side = "right" 717 | answer_inputs = self.tokenizer( 718 | answers_eos, 719 | add_special_tokens=False, 720 | return_tensors="pt", 721 | padding="longest", 722 | truncation=True, 723 | max_length=max_length, 724 | ) 725 | 726 | # Concatenate inputs ids 727 | answer_ids = answer_inputs["input_ids"] 728 | input_ids = torch.cat([input_ids, answer_ids], dim=-1) 729 | 730 | # Concatenate attention masks 731 | answer_mask = answer_inputs["attention_mask"] 732 | attns = torch.cat([attns, answer_mask], dim=-1) 733 | 734 | # Concatenate embeddings 735 | answer_embeds = self.embed_tokens(answer_ids) 736 | embeds = torch.cat([embeds, answer_embeds], dim=1) 737 | 738 | # Concatenate labels 739 | answer_labels = answer_ids.masked_fill(answer_ids == self.tokenizer.pad_token_id, special_pad_id) 740 | labels = torch.cat([labels, answer_labels], dim=-1) 741 | 742 | # Concatenate raw text mask 743 | raw_text_mask = torch.cat([raw_text_mask, torch.ones_like(answer_ids)], dim=-1) 744 | raw_text_mask = raw_text_mask.masked_fill(labels == special_pad_id, 0) 745 | 746 | labels = labels.masked_fill(labels == self.tokenizer.pad_token_id, special_pad_id) 747 | # Restore original padding side 748 | self.tokenizer.padding_side = original_padding_side 749 | 750 | # Convert to current device 751 | device = self.model.device 752 | input_ids = input_ids.to(device) 753 | embeds = embeds.to(device) 754 | attns = attns.to(device) 755 | labels = labels.to(device) 756 | raw_text_mask = raw_text_mask.to(device) 757 | 758 | return input_ids, embeds, attns, labels, raw_text_mask --------------------------------------------------------------------------------