├── LICENSE ├── README.md ├── alpaca_finetuning.py ├── config └── deepspeed.json ├── misc ├── layer_similarity_animation.gif └── layer_similarity_animation.mp4 ├── models └── Llama2-7b-hf │ ├── config.json │ ├── configuration_llama.py │ ├── generation_config.json │ ├── modeling_llama.py │ ├── special_tokens_map.json │ ├── tokenizer.json │ ├── tokenizer.model │ └── tokenizer_config.json └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 METACARBON 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Beyond KV Caching: Shared Attention for Efficient LLMs 2 | [[paper](https://arxiv.org/abs/2407.12866)] 3 | 4 | ## Abstract 5 | The efficiency of large language models (LLMs) remains a critical challenge, particularly in contexts where computational resources are limited. Traditional attention mechanisms in these models, while powerful, require significant computational and memory resources due to the necessity of recalculating and storing attention weights across different layers. This paper introduces a novel Shared Attention (SA) mechanism, designed to enhance the efficiency of LLMs by directly sharing computed attention weights across multiple layers. Unlike previous methods that focus on sharing intermediate Key-Value (KV) caches, our approach utilizes the isotropic tendencies of attention distributions observed in advanced LLMs post-pretraining to reduce both the computational flops and the size of the KV cache required during inference. We empirically demonstrate that implementing SA across various LLMs results in minimal accuracy loss on standard benchmarks. Our findings suggest that SA not only conserves computational resources but also maintains robust model performance, thereby facilitating the deployment of more efficient LLMs in resource-constrained environments. 6 | 7 | ## Dynamic Animation 8 | ![](https://github.com/metacarbon/shareAtt/blob/main/misc/layer_similarity_animation.gif) 9 | 10 | ## Usage 11 | 12 | ### Environment Setup 13 | 14 | ```bash 15 | conda create -n shareAtt python=3.8 16 | conda activate shareAtt 17 | 18 | pip install torch torchvision torchaudio 19 | pip install transformers==4.33.0 accelerate datasets scipy sentencepiece 20 | ``` 21 | ### Prepare Weights 22 | 23 | Download the Llama-2-7B-hf weights (.safetensor files) into the `models/Llama2-7b-hf` folder. 24 | 25 | ### Direct Application of Shared Attention 26 | 27 | To apply Shared Attention, modify `modeling_llama.py` in `models/Llama2-7b-hf` at line 262. 28 | For instance, for SA from layers 27 to 30 (excluding layer_idx from the list): 29 | ```python 30 | self.share_attn = self.layer_idx not in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 31] 31 | ``` 32 | ### Reproduction of Evaluations 33 | 34 | Install `lm-evaluation-harness` from [EleutherAI's repository](https://github.com/EleutherAI/lm-evaluation-harness). 35 | 36 | Replace the `modeling_llama.py` file in the transformers library with the modified file in `models/Llama2-7b-hf`. 37 | 38 | Run the evaluation: 39 | 40 | ```bash 41 | CUDA_VISIBLE_DEVICES=0 lm_eval --model hf --model_args pretrained=./models/Llama2-7b-hf/ --tasks mmlu,glue,gsm8k,hellaswag --batch_size auto --output_path ./eval_out/llama2-7b-23_26 --use_cache ./eval_cache/llama2-7b-23_26 42 | ``` 43 | 44 | ### Fine-tuning 45 | 46 | Set up Accelerate with DeepSpeed: 47 | 48 | ```bash 49 | accelerate config 50 | ``` 51 | 52 | Download Llama-3-8b and modify corresponding files. 53 | 54 | Download Alpaca instruct dataset `alpaca_data_cleaned.json` from [gururise's repository](https://github.com/gururise/AlpacaDataCleaned). 55 | 56 | Train the model: 57 | ```bash 58 | ACCELERATE_USE_DEEPSPEED=true CUDA_VISIBLE_DEVICES="0,1" accelerate launch alpaca_finetuning.py 59 | ``` 60 | 61 | ## Citation 62 | 63 | If you find our works useful or relevant to your project and research, please kindly cite our paper: 64 | 65 | ```bibtex 66 | @article{liao2024shareAtt, 67 | title={Beyond KV Caching: Shared Attention for Efficient LLMs}, 68 | author={Bingli Liao and Danilo Vasconcellos Vargas}, 69 | journal={arXiv}, 70 | year={2024} 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /alpaca_finetuning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.nn.utils.rnn import pad_sequence 4 | from accelerate import Accelerator 5 | from accelerate.logging import get_logger 6 | from accelerate.utils import set_seed, DummyOptim 7 | from transformers import AutoModelForCausalLM, AutoTokenizer 8 | import argparse 9 | import logging 10 | import json 11 | from tqdm import tqdm 12 | 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | logger = get_logger(__name__) 16 | 17 | IGNORE_TOKEN_ID = -100 18 | DEFAULT_PAD_TOKEN = "[PAD]" 19 | DEFAULT_EOS_TOKEN = "" 20 | DEFAULT_BOS_TOKEN = "" 21 | DEFAULT_UNK_TOKEN = "" 22 | PROMPT_DICT = { 23 | "prompt_input": ( 24 | "Below is an instruction that describes a task, paired with an input that provides further context. " 25 | "Write a response that appropriately completes the request.\n\n" 26 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 27 | ), 28 | "prompt_no_input": ( 29 | "Below is an instruction that describes a task. " 30 | "Write a response that appropriately completes the request.\n\n" 31 | "### Instruction:\n{instruction}\n\n### Response:" 32 | ), 33 | } 34 | 35 | 36 | def safe_ids(ids, max_value, pad_id): 37 | return [i if i < max_value else pad_id for i in ids] 38 | 39 | 40 | def tokenize(messages, tokenizer): 41 | input_ids = [] 42 | labels = [] 43 | 44 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 45 | 46 | if messages.get("input", "") != "": 47 | prompt = prompt_input.format_map(messages) 48 | else: 49 | prompt = prompt_no_input.format_map(messages) 50 | 51 | response = f"{messages['output']}{DEFAULT_EOS_TOKEN}" 52 | 53 | 54 | prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) 55 | response_ids = tokenizer.encode(response, add_special_tokens=False) 56 | 57 | # Append all the sections to the input_ids 58 | input_ids += prompt_ids + response_ids 59 | 60 | # The labels should ignore the prompt (set to IGNORE_TOKEN_ID), 61 | # and should match the response_ids for the response part 62 | labels += [IGNORE_TOKEN_ID] * len(prompt_ids) + response_ids 63 | 64 | # Ensure lengths do not exceed model's max length 65 | input_ids = input_ids[:tokenizer.model_max_length] 66 | labels = labels[:tokenizer.model_max_length] 67 | 68 | input_ids = safe_ids(input_ids, tokenizer.vocab_size, tokenizer.eos_token_id) 69 | labels = safe_ids(labels, tokenizer.vocab_size, IGNORE_TOKEN_ID) 70 | return input_ids, labels 71 | 72 | 73 | class AlpacaData(Dataset): 74 | def __init__(self, data, tokenizer): 75 | self.data = data 76 | self.tokenizer = tokenizer 77 | 78 | def __len__(self): 79 | return len(self.data) 80 | 81 | def __getitem__(self, item): 82 | item = self.data[item] 83 | input_ids, labels = tokenize(item, self.tokenizer) 84 | return torch.tensor(input_ids), torch.tensor(labels) 85 | 86 | def collate_fn(self, data): 87 | input_ids, labels = zip(*data) 88 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id) 89 | labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_TOKEN_ID) 90 | attention_mask = input_ids.ne(self.tokenizer.eos_token_id) 91 | features = { 92 | 'input_ids': input_ids.long(), 93 | 'labels': labels.long(), 94 | 'attention_mask': attention_mask.long(), 95 | } 96 | return features 97 | 98 | 99 | def main(): 100 | parser = argparse.ArgumentParser(description='Fine-tuning LLM') 101 | parser.add_argument('--model_path', type=str, default='./models/Llama-3-8b/', help='Path to the pre-trained model') 102 | parser.add_argument('--save_path', type=str, default='./out/llama3_8b_alpaca/', help='Path to save the fine-tuned model') 103 | args = parser.parse_args() 104 | model_path = args.model_path 105 | save_path = args.save_path 106 | 107 | set_seed(42) 108 | accelerator = Accelerator() 109 | batch_size = 16 110 | 111 | logger.info('Initializing tokenizer...') 112 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side="left", model_max_length=4096, local_files_only=True, trust_remote_code=True) 113 | tokenizer.pad_token = tokenizer.unk_token 114 | 115 | logger.info('Initializing model...') 116 | model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True, trust_remote_code=True) 117 | model.config.use_cache = False 118 | model.gradient_checkpointing_enable() 119 | 120 | 121 | dataset = AlpacaData(json.load(open('data/alpaca_data_cleaned.json')), tokenizer) 122 | 123 | data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, 124 | batch_size=batch_size, num_workers=0, shuffle=True) 125 | 126 | dummy_optimizer = DummyOptim(model.parameters()) 127 | 128 | logger.info('accelerator preparing...') 129 | model, optimizer, data_loader = accelerator.prepare(model, dummy_optimizer, data_loader) 130 | 131 | for epoch in range(2): 132 | logger.info('=' * 10 + f'Start training {save_path} epoch {epoch + 1}' + '=' * 10) 133 | accelerator.wait_for_everyone() 134 | model.train() 135 | pbar = tqdm(enumerate(data_loader), total=len(data_loader), disable=(not accelerator.is_local_main_process)) 136 | loss_report = [] 137 | with accelerator.accumulate(model): 138 | for i, batch in pbar: 139 | out = model(**batch) 140 | loss = out.loss 141 | 142 | accelerator.backward(loss) 143 | accelerator.clip_grad_norm_(model.parameters(), 1.) 144 | optimizer.step() 145 | optimizer.zero_grad() 146 | 147 | loss_report.append(accelerator.gather(loss).mean().item()) 148 | pbar.set_description(f"epoch {epoch + 1} step {i}: train loss {sum(loss_report[-100:]) / len(loss_report[-100:]):.5f}.") 149 | 150 | accelerator.wait_for_everyone() 151 | # save model states 152 | model.save_checkpoint(f'{save_path}/{epoch}') 153 | logger.info(f'model for epoch {epoch + 1} is saved...') 154 | 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /config/deepspeed.json: -------------------------------------------------------------------------------- 1 | { 2 | "gradient_accumulation_steps": 1, 3 | "train_micro_batch_size_per_gpu": 1, 4 | "prescale_gradients": false, 5 | "zero_allow_untested_optimizer": true, 6 | "optimizer": { 7 | "type": "AdamW", 8 | "params": { 9 | "lr": 2e-5, 10 | "weight_decay": "auto", 11 | "torch_adam": true 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupCosineLR", 16 | "params": { 17 | "total_num_steps": 15000, 18 | "warmup_min_ratio": 0.04, 19 | "warmup_num_steps": 1000, 20 | "cos_min_ratio": 0.0001 21 | } 22 | }, 23 | "tensorboard": { 24 | "enabled": true, 25 | "output_path": "logs/", 26 | "job_name": "Llama-7b-pt" 27 | }, 28 | "zero_optimization": { 29 | "stage": 3, 30 | "offload_optimizer": { 31 | "device": "cpu", 32 | "pin_memory": true 33 | }, 34 | 35 | "contiguous_gradients": false, 36 | "allgather_bucket_size": 3e8, 37 | "reduce_bucket_size": 3e8, 38 | "overlap_comm": true, 39 | "reduce_scatter": true 40 | }, 41 | "steps_per_print": 16, 42 | "gradient_clipping": 1.0, 43 | "wall_clock_breakdown": true, 44 | "bf16": { 45 | "enabled": true 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /misc/layer_similarity_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metacarbon/shareAtt/2494690186802a829d351fc59753997bf9fbea2f/misc/layer_similarity_animation.gif -------------------------------------------------------------------------------- /misc/layer_similarity_animation.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metacarbon/shareAtt/2494690186802a829d351fc59753997bf9fbea2f/misc/layer_similarity_animation.mp4 -------------------------------------------------------------------------------- /models/Llama2-7b-hf/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "Llama-2-7b-hf", 3 | "architectures": [ 4 | "LlamaForCausalLM" 5 | ], 6 | "auto_map": { 7 | "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM" 8 | }, 9 | "bos_token_id": 1, 10 | "eos_token_id": 2, 11 | "hidden_act": "silu", 12 | "hidden_size": 4096, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 11008, 15 | "max_position_embeddings": 4096, 16 | "model_type": "llama", 17 | "num_attention_heads": 32, 18 | "num_hidden_layers": 32, 19 | "num_key_value_heads": 32, 20 | "pretraining_tp": 1, 21 | "rms_norm_eps": 1e-05, 22 | "rope_scaling": null, 23 | "tie_word_embeddings": false, 24 | "torch_dtype": "float16", 25 | "transformers_version": "4.31.0.dev0", 26 | "use_cache": true, 27 | "vocab_size": 32000 28 | } 29 | -------------------------------------------------------------------------------- /models/Llama2-7b-hf/configuration_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ LLaMA model configuration""" 21 | 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 29 | 30 | 31 | class LlamaConfig(PretrainedConfig): 32 | r""" 33 | This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA 34 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 35 | defaults will yield a similar configuration to that of the LLaMA-7B. 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information. 39 | 40 | 41 | Args: 42 | vocab_size (`int`, *optional*, defaults to 32000): 43 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the 44 | `inputs_ids` passed when calling [`LlamaModel`] 45 | hidden_size (`int`, *optional*, defaults to 4096): 46 | Dimension of the hidden representations. 47 | intermediate_size (`int`, *optional*, defaults to 11008): 48 | Dimension of the MLP representations. 49 | num_hidden_layers (`int`, *optional*, defaults to 32): 50 | Number of hidden layers in the Transformer encoder. 51 | num_attention_heads (`int`, *optional*, defaults to 32): 52 | Number of attention heads for each attention layer in the Transformer encoder. 53 | num_key_value_heads (`int`, *optional*): 54 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 55 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 56 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 57 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 58 | by meanpooling all the original heads within that group. For more details checkout [this 59 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 60 | `num_attention_heads`. 61 | pretraining_tp (`int`, *optional*, defaults to `1`): 62 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 63 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is 64 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 65 | issue](https://github.com/pytorch/pytorch/issues/76232). 66 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 67 | The non-linear activation function (function or string) in the decoder. 68 | max_position_embeddings (`int`, *optional*, defaults to 2048): 69 | The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, 70 | Llama 2 up to 4096, CodeLlama up to 16384. 71 | initializer_range (`float`, *optional*, defaults to 0.02): 72 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 73 | rms_norm_eps (`float`, *optional*, defaults to 1e-12): 74 | The epsilon used by the rms normalization layers. 75 | use_cache (`bool`, *optional*, defaults to `True`): 76 | Whether or not the model should return the last key/values attentions (not used by all models). Only 77 | relevant if `config.is_decoder=True`. 78 | tie_word_embeddings(`bool`, *optional*, defaults to `False`): 79 | Whether to tie weight embeddings 80 | rope_theta (`float`, *optional*, defaults to 10000.0): 81 | The base period of the RoPE embeddings. 82 | rope_scaling (`Dict`, *optional*): 83 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 84 | strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format 85 | is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 86 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 87 | these scaling strategies behave: 88 | https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an 89 | experimental feature, subject to breaking API changes in future versions. 90 | 91 | Example: 92 | 93 | ```python 94 | >>> from transformers import LlamaModel, LlamaConfig 95 | 96 | >>> # Initializing a LLaMA llama-7b style configuration 97 | >>> configuration = LlamaConfig() 98 | 99 | >>> # Initializing a model from the llama-7b style configuration 100 | >>> model = LlamaModel(configuration) 101 | 102 | >>> # Accessing the model configuration 103 | >>> configuration = model.config 104 | ```""" 105 | model_type = "llama" 106 | keys_to_ignore_at_inference = ["past_key_values"] 107 | 108 | def __init__( 109 | self, 110 | vocab_size=32000, 111 | hidden_size=4096, 112 | intermediate_size=11008, 113 | num_hidden_layers=32, 114 | num_attention_heads=32, 115 | num_key_value_heads=None, 116 | hidden_act="silu", 117 | max_position_embeddings=2048, 118 | initializer_range=0.02, 119 | rms_norm_eps=1e-6, 120 | use_cache=True, 121 | pad_token_id=None, 122 | bos_token_id=1, 123 | eos_token_id=2, 124 | pretraining_tp=1, 125 | tie_word_embeddings=False, 126 | rope_theta=10000.0, 127 | rope_scaling=None, 128 | **kwargs, 129 | ): 130 | self.vocab_size = vocab_size 131 | self.max_position_embeddings = max_position_embeddings 132 | self.hidden_size = hidden_size 133 | self.intermediate_size = intermediate_size 134 | self.num_hidden_layers = num_hidden_layers 135 | self.num_attention_heads = num_attention_heads 136 | 137 | # for backward compatibility 138 | if num_key_value_heads is None: 139 | num_key_value_heads = num_attention_heads 140 | 141 | self.num_key_value_heads = num_key_value_heads 142 | self.hidden_act = hidden_act 143 | self.initializer_range = initializer_range 144 | self.rms_norm_eps = rms_norm_eps 145 | self.pretraining_tp = pretraining_tp 146 | self.use_cache = use_cache 147 | self.rope_theta = rope_theta 148 | self.rope_scaling = rope_scaling 149 | self._rope_scaling_validation() 150 | 151 | super().__init__( 152 | pad_token_id=pad_token_id, 153 | bos_token_id=bos_token_id, 154 | eos_token_id=eos_token_id, 155 | tie_word_embeddings=tie_word_embeddings, 156 | **kwargs, 157 | ) 158 | 159 | def _rope_scaling_validation(self): 160 | """ 161 | Validate the `rope_scaling` configuration. 162 | """ 163 | if self.rope_scaling is None: 164 | return 165 | 166 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 167 | raise ValueError( 168 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 169 | f"got {self.rope_scaling}" 170 | ) 171 | rope_scaling_type = self.rope_scaling.get("type", None) 172 | rope_scaling_factor = self.rope_scaling.get("factor", None) 173 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 174 | raise ValueError( 175 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 176 | ) 177 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 178 | raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") 179 | -------------------------------------------------------------------------------- /models/Llama2-7b-hf/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token_id": 1, 3 | "do_sample": true, 4 | "eos_token_id": 2, 5 | "pad_token_id": 0, 6 | "temperature": 0.6, 7 | "max_length": 4096, 8 | "top_p": 0.9, 9 | "transformers_version": "4.31.0.dev0" 10 | } 11 | -------------------------------------------------------------------------------- /models/Llama2-7b-hf/modeling_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch LLaMA model.""" 21 | import math 22 | from typing import List, Optional, Tuple, Union 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | import torch.utils.checkpoint 27 | from torch import nn 28 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 29 | 30 | from transformers.activations import ACT2FN 31 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 32 | from transformers.modeling_utils import PreTrainedModel 33 | from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings 34 | from .configuration_llama import LlamaConfig 35 | import os 36 | 37 | 38 | logger = logging.get_logger(__name__) 39 | 40 | _CONFIG_FOR_DOC = "LlamaConfig" 41 | 42 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 43 | def _make_causal_mask( 44 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 45 | ): 46 | """ 47 | Make causal mask used for bi-directional self-attention. 48 | """ 49 | bsz, tgt_len = input_ids_shape 50 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) 51 | mask_cond = torch.arange(mask.size(-1), device=device) 52 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 53 | mask = mask.to(dtype) 54 | 55 | if past_key_values_length > 0: 56 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 57 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 58 | 59 | 60 | # Copied from transformers.models.bart.modeling_bart._expand_mask 61 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 62 | """ 63 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 64 | """ 65 | bsz, src_len = mask.size() 66 | tgt_len = tgt_len if tgt_len is not None else src_len 67 | 68 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 69 | 70 | inverted_mask = 1.0 - expanded_mask 71 | 72 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 73 | 74 | 75 | class LlamaRMSNorm(nn.Module): 76 | def __init__(self, hidden_size, eps=1e-6): 77 | """ 78 | LlamaRMSNorm is equivalent to T5LayerNorm 79 | """ 80 | super().__init__() 81 | self.weight = nn.Parameter(torch.ones(hidden_size)) 82 | self.variance_epsilon = eps 83 | 84 | def forward(self, hidden_states): 85 | input_dtype = hidden_states.dtype 86 | hidden_states = hidden_states.to(torch.float32) 87 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 88 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 89 | return self.weight * hidden_states.to(input_dtype) 90 | 91 | 92 | class LlamaRotaryEmbedding(torch.nn.Module): 93 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 94 | super().__init__() 95 | 96 | self.dim = dim 97 | self.max_position_embeddings = max_position_embeddings 98 | self.base = base 99 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 100 | self.register_buffer("inv_freq", inv_freq, persistent=False) 101 | 102 | # Build here to make `torch.jit.trace` work. 103 | self._set_cos_sin_cache( 104 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 105 | ) 106 | 107 | def _set_cos_sin_cache(self, seq_len, device, dtype): 108 | self.max_seq_len_cached = seq_len 109 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 110 | 111 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 112 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 113 | emb = torch.cat((freqs, freqs), dim=-1) 114 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 115 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 116 | 117 | def forward(self, x, seq_len=None): 118 | # x: [bs, num_attention_heads, seq_len, head_size] 119 | if seq_len > self.max_seq_len_cached: 120 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 121 | 122 | return ( 123 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 124 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 125 | ) 126 | 127 | 128 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): 129 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 130 | 131 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 132 | self.scaling_factor = scaling_factor 133 | super().__init__(dim, max_position_embeddings, base, device) 134 | 135 | def _set_cos_sin_cache(self, seq_len, device, dtype): 136 | self.max_seq_len_cached = seq_len 137 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 138 | t = t / self.scaling_factor 139 | 140 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 141 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 142 | emb = torch.cat((freqs, freqs), dim=-1) 143 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 144 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 145 | 146 | 147 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): 148 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 149 | 150 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 151 | self.scaling_factor = scaling_factor 152 | super().__init__(dim, max_position_embeddings, base, device) 153 | 154 | def _set_cos_sin_cache(self, seq_len, device, dtype): 155 | self.max_seq_len_cached = seq_len 156 | 157 | if seq_len > self.max_position_embeddings: 158 | base = self.base * ( 159 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 160 | ) ** (self.dim / (self.dim - 2)) 161 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 162 | self.register_buffer("inv_freq", inv_freq, persistent=False) 163 | 164 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 165 | 166 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 167 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 168 | emb = torch.cat((freqs, freqs), dim=-1) 169 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 170 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 171 | 172 | 173 | def rotate_half(x): 174 | """Rotates half the hidden dims of the input.""" 175 | x1 = x[..., : x.shape[-1] // 2] 176 | x2 = x[..., x.shape[-1] // 2 :] 177 | return torch.cat((-x2, x1), dim=-1) 178 | 179 | 180 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 181 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 182 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 183 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 184 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 185 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 186 | q_embed = (q * cos) + (rotate_half(q) * sin) 187 | k_embed = (k * cos) + (rotate_half(k) * sin) 188 | return q_embed, k_embed 189 | 190 | def apply_single_rotary_pos_emb(q, cos, sin, position_ids): 191 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 192 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 193 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 194 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 195 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 196 | q_embed = (q * cos) + (rotate_half(q) * sin) 197 | return q_embed 198 | 199 | 200 | class LlamaMLP(nn.Module): 201 | def __init__(self, config): 202 | super().__init__() 203 | self.config = config 204 | self.hidden_size = config.hidden_size 205 | self.intermediate_size = config.intermediate_size 206 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 207 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 208 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 209 | self.act_fn = ACT2FN[config.hidden_act] 210 | 211 | def forward(self, x): 212 | if self.config.pretraining_tp > 1: 213 | slice = self.intermediate_size // self.config.pretraining_tp 214 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) 215 | up_proj_slices = self.up_proj.weight.split(slice, dim=0) 216 | down_proj_slices = self.down_proj.weight.split(slice, dim=1) 217 | 218 | gate_proj = torch.cat( 219 | [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 220 | ) 221 | up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) 222 | 223 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) 224 | down_proj = [ 225 | F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) 226 | ] 227 | down_proj = sum(down_proj) 228 | else: 229 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 230 | 231 | return down_proj 232 | 233 | 234 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 235 | """ 236 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 237 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 238 | """ 239 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 240 | if n_rep == 1: 241 | return hidden_states 242 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 243 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 244 | 245 | 246 | class LlamaAttention(nn.Module): 247 | """Multi-headed attention from 'Attention Is All You Need' paper""" 248 | 249 | def __init__(self, config: LlamaConfig, layer_idx): 250 | super().__init__() 251 | self.config = config 252 | self.hidden_size = config.hidden_size 253 | self.num_heads = config.num_attention_heads 254 | self.head_dim = self.hidden_size // self.num_heads 255 | self.num_key_value_heads = config.num_key_value_heads 256 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 257 | self.max_position_embeddings = config.max_position_embeddings 258 | self.rope_theta = config.rope_theta 259 | self.alpha = 0.2 260 | self.layer_idx = layer_idx 261 | # share layers: [23, 24, 25, 26,] 262 | self.share_attn = self.layer_idx not in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 27, 28, 29, 30, 31] 263 | 264 | 265 | if (self.head_dim * self.num_heads) != self.hidden_size: 266 | raise ValueError( 267 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 268 | f" and `num_heads`: {self.num_heads})." 269 | ) 270 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 271 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) if not self.share_attn else None 272 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 273 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 274 | self._init_rope() 275 | 276 | def _init_rope(self): 277 | if self.config.rope_scaling is None: 278 | self.rotary_emb = LlamaRotaryEmbedding( 279 | self.head_dim, 280 | max_position_embeddings=self.max_position_embeddings, 281 | base=self.rope_theta, 282 | ) 283 | else: 284 | scaling_type = self.config.rope_scaling["type"] 285 | scaling_factor = self.config.rope_scaling["factor"] 286 | if scaling_type == "linear": 287 | self.rotary_emb = LlamaLinearScalingRotaryEmbedding( 288 | self.head_dim, 289 | max_position_embeddings=self.max_position_embeddings, 290 | scaling_factor=scaling_factor, 291 | base=self.rope_theta, 292 | ) 293 | elif scaling_type == "dynamic": 294 | self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( 295 | self.head_dim, 296 | max_position_embeddings=self.max_position_embeddings, 297 | scaling_factor=scaling_factor, 298 | base=self.rope_theta, 299 | ) 300 | else: 301 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}") 302 | 303 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 304 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 305 | 306 | def forward( 307 | self, 308 | hidden_states: torch.Tensor, 309 | attention_mask: Optional[torch.Tensor] = None, 310 | position_ids: Optional[torch.LongTensor] = None, 311 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 312 | output_attentions: bool = False, 313 | use_cache: bool = False, 314 | reuse_layer_attn: Optional[torch.Tensor] = None, 315 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 316 | bsz, q_len, _ = hidden_states.size() 317 | 318 | if self.config.pretraining_tp > 1: 319 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp 320 | query_slices = self.q_proj.weight.split( 321 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 322 | ) 323 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 324 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 325 | 326 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] 327 | query_states = torch.cat(query_states, dim=-1) 328 | 329 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] 330 | key_states = torch.cat(key_states, dim=-1) 331 | 332 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] 333 | value_states = torch.cat(value_states, dim=-1) 334 | 335 | else: 336 | query_states = self.q_proj(hidden_states) 337 | key_states = self.k_proj(hidden_states) if not self.share_attn else None 338 | value_states = self.v_proj(hidden_states) 339 | 340 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 341 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if not self.share_attn else None 342 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 343 | 344 | kv_seq_len = value_states.shape[-2] 345 | if past_key_value is not None: 346 | kv_seq_len += past_key_value[0].shape[-2] 347 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 348 | 349 | 350 | if not self.share_attn: 351 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 352 | else: 353 | query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids) 354 | 355 | 356 | if past_key_value is not None: 357 | if not self.share_attn: 358 | # reuse k, v, self_attention 359 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 360 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 361 | else: 362 | value_states = torch.cat([past_key_value[0], value_states], dim=2) 363 | 364 | 365 | if not self.share_attn: 366 | past_key_value = (key_states, value_states) if use_cache else None 367 | # repeat k/v heads if n_kv_heads < n_heads 368 | key_states = repeat_kv(key_states, self.num_key_value_groups) 369 | value_states = repeat_kv(value_states, self.num_key_value_groups) 370 | else: 371 | past_key_value = ([value_states]) if use_cache else None 372 | value_states = repeat_kv(value_states, self.num_key_value_groups) 373 | 374 | 375 | if not self.share_attn: 376 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 377 | 378 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 379 | raise ValueError( 380 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 381 | f" {attn_weights.size()}" 382 | ) 383 | 384 | if attention_mask is not None: 385 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 386 | raise ValueError( 387 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 388 | ) 389 | attn_weights = attn_weights + attention_mask 390 | 391 | # upcast attention to fp32 392 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 393 | reuse_layer_attn = attn_weights 394 | else: 395 | attn_weights = reuse_layer_attn 396 | 397 | attn_output = torch.matmul(attn_weights, value_states) 398 | 399 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 400 | raise ValueError( 401 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 402 | f" {attn_output.size()}" 403 | ) 404 | 405 | attn_output = attn_output.transpose(1, 2).contiguous() 406 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 407 | 408 | if self.config.pretraining_tp > 1: 409 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) 410 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) 411 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) 412 | else: 413 | attn_output = self.o_proj(attn_output) 414 | 415 | if not output_attentions: 416 | attn_weights = None 417 | 418 | return attn_output, attn_weights, past_key_value, reuse_layer_attn 419 | 420 | 421 | class LlamaDecoderLayer(nn.Module): 422 | def __init__(self, config: LlamaConfig, layer_idx): 423 | super().__init__() 424 | self.hidden_size = config.hidden_size 425 | self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) 426 | self.mlp = LlamaMLP(config) 427 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 428 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 429 | 430 | def forward( 431 | self, 432 | hidden_states: torch.Tensor, 433 | attention_mask: Optional[torch.Tensor] = None, 434 | position_ids: Optional[torch.LongTensor] = None, 435 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 436 | output_attentions: Optional[bool] = False, 437 | use_cache: Optional[bool] = False, 438 | reuse_layer_attn: Optional[List[torch.Tensor]] = None 439 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 440 | """ 441 | Args: 442 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 443 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 444 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 445 | output_attentions (`bool`, *optional*): 446 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 447 | returned tensors for more detail. 448 | use_cache (`bool`, *optional*): 449 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 450 | (see `past_key_values`). 451 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 452 | """ 453 | #if isinstance(hidden_states, tuple): 454 | # hidden_states = hidden_states[0] 455 | residual = hidden_states 456 | 457 | hidden_states = self.input_layernorm(hidden_states) 458 | 459 | # Self Attention 460 | hidden_states, self_attn_weights, present_key_value, reuse_layer_attn = self.self_attn( 461 | hidden_states=hidden_states, 462 | attention_mask=attention_mask, 463 | position_ids=position_ids, 464 | past_key_value=past_key_value, 465 | output_attentions=output_attentions, 466 | use_cache=use_cache, 467 | reuse_layer_attn=reuse_layer_attn, 468 | ) 469 | hidden_states = residual + hidden_states 470 | 471 | # Fully Connected 472 | residual = hidden_states 473 | hidden_states = self.post_attention_layernorm(hidden_states) 474 | hidden_states = self.mlp(hidden_states) 475 | hidden_states = residual + hidden_states 476 | 477 | outputs = (hidden_states,) 478 | 479 | if output_attentions: 480 | outputs += (self_attn_weights,) 481 | 482 | if use_cache: 483 | outputs += (present_key_value,) 484 | 485 | return outputs, reuse_layer_attn 486 | 487 | 488 | LLAMA_START_DOCSTRING = r""" 489 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 490 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 491 | etc.) 492 | 493 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 494 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 495 | and behavior. 496 | 497 | Parameters: 498 | config ([`LlamaConfig`]): 499 | Model configuration class with all the parameters of the model. Initializing with a config file does not 500 | load the weights associated with the model, only the configuration. Check out the 501 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 502 | """ 503 | 504 | 505 | @add_start_docstrings( 506 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 507 | LLAMA_START_DOCSTRING, 508 | ) 509 | class LlamaPreTrainedModel(PreTrainedModel): 510 | config_class = LlamaConfig 511 | base_model_prefix = "model" 512 | supports_gradient_checkpointing = True 513 | _no_split_modules = ["LlamaDecoderLayer"] 514 | _skip_keys_device_placement = "past_key_values" 515 | 516 | def _init_weights(self, module): 517 | std = self.config.initializer_range 518 | if isinstance(module, nn.Linear): 519 | module.weight.data.normal_(mean=0.0, std=std) 520 | if module.bias is not None: 521 | module.bias.data.zero_() 522 | elif isinstance(module, nn.Embedding): 523 | module.weight.data.normal_(mean=0.0, std=std) 524 | if module.padding_idx is not None: 525 | module.weight.data[module.padding_idx].zero_() 526 | 527 | def _set_gradient_checkpointing(self, module, value=False): 528 | if isinstance(module, LlamaModel): 529 | module.gradient_checkpointing = value 530 | 531 | 532 | LLAMA_INPUTS_DOCSTRING = r""" 533 | Args: 534 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 535 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 536 | it. 537 | 538 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 539 | [`PreTrainedTokenizer.__call__`] for details. 540 | 541 | [What are input IDs?](../glossary#input-ids) 542 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 543 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 544 | 545 | - 1 for tokens that are **not masked**, 546 | - 0 for tokens that are **masked**. 547 | 548 | [What are attention masks?](../glossary#attention-mask) 549 | 550 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 551 | [`PreTrainedTokenizer.__call__`] for details. 552 | 553 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 554 | `past_key_values`). 555 | 556 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 557 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 558 | information on the default strategy. 559 | 560 | - 1 indicates the head is **not masked**, 561 | - 0 indicates the head is **masked**. 562 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 563 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 564 | config.n_positions - 1]`. 565 | 566 | [What are position IDs?](../glossary#position-ids) 567 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 568 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 569 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 570 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 571 | 572 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 573 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 574 | 575 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 576 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 577 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 578 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 579 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 580 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 581 | model's internal embedding lookup matrix. 582 | use_cache (`bool`, *optional*): 583 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 584 | `past_key_values`). 585 | output_attentions (`bool`, *optional*): 586 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 587 | tensors for more detail. 588 | output_hidden_states (`bool`, *optional*): 589 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 590 | more detail. 591 | return_dict (`bool`, *optional*): 592 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 593 | """ 594 | 595 | 596 | @add_start_docstrings( 597 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 598 | LLAMA_START_DOCSTRING, 599 | ) 600 | class LlamaModel(LlamaPreTrainedModel): 601 | """ 602 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] 603 | 604 | Args: 605 | config: LlamaConfig 606 | """ 607 | 608 | def __init__(self, config: LlamaConfig): 609 | super().__init__(config) 610 | self.padding_idx = config.pad_token_id 611 | self.vocab_size = config.vocab_size 612 | 613 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 614 | self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) 615 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 616 | self.reuse_layer_attn = None 617 | 618 | self.gradient_checkpointing = False 619 | # Initialize weights and apply final processing 620 | self.post_init() 621 | 622 | def get_input_embeddings(self): 623 | return self.embed_tokens 624 | 625 | def set_input_embeddings(self, value): 626 | self.embed_tokens = value 627 | 628 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 629 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 630 | # create causal mask 631 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 632 | combined_attention_mask = None 633 | if input_shape[-1] > 1: 634 | combined_attention_mask = _make_causal_mask( 635 | input_shape, 636 | inputs_embeds.dtype, 637 | device=inputs_embeds.device, 638 | past_key_values_length=past_key_values_length, 639 | ) 640 | 641 | if attention_mask is not None: 642 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 643 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 644 | inputs_embeds.device 645 | ) 646 | combined_attention_mask = ( 647 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 648 | ) 649 | 650 | return combined_attention_mask 651 | 652 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 653 | def forward( 654 | self, 655 | input_ids: torch.LongTensor = None, 656 | attention_mask: Optional[torch.Tensor] = None, 657 | position_ids: Optional[torch.LongTensor] = None, 658 | past_key_values: Optional[List[torch.FloatTensor]] = None, 659 | inputs_embeds: Optional[torch.FloatTensor] = None, 660 | use_cache: Optional[bool] = None, 661 | output_attentions: Optional[bool] = None, 662 | output_hidden_states: Optional[bool] = None, 663 | return_dict: Optional[bool] = None, 664 | ) -> Union[Tuple, BaseModelOutputWithPast]: 665 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 666 | output_hidden_states = ( 667 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 668 | ) 669 | use_cache = use_cache if use_cache is not None else self.config.use_cache 670 | 671 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 672 | 673 | # retrieve input_ids and inputs_embeds 674 | if input_ids is not None and inputs_embeds is not None: 675 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 676 | elif input_ids is not None: 677 | batch_size, seq_length = input_ids.shape 678 | elif inputs_embeds is not None: 679 | batch_size, seq_length, _ = inputs_embeds.shape 680 | else: 681 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 682 | 683 | seq_length_with_past = seq_length 684 | past_key_values_length = 0 685 | 686 | if past_key_values is not None: 687 | past_key_values_length = past_key_values[0][0].shape[2] 688 | seq_length_with_past = seq_length_with_past + past_key_values_length 689 | 690 | if position_ids is None: 691 | device = input_ids.device if input_ids is not None else inputs_embeds.device 692 | position_ids = torch.arange( 693 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 694 | ) 695 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 696 | else: 697 | position_ids = position_ids.view(-1, seq_length).long() 698 | 699 | if inputs_embeds is None: 700 | inputs_embeds = self.embed_tokens(input_ids) 701 | # embed positions 702 | if attention_mask is None: 703 | attention_mask = torch.ones( 704 | (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device 705 | ) 706 | attention_mask = self._prepare_decoder_attention_mask( 707 | attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length 708 | ) 709 | 710 | hidden_states = inputs_embeds 711 | 712 | if self.gradient_checkpointing and self.training: 713 | if use_cache: 714 | logger.warning_once( 715 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 716 | ) 717 | use_cache = False 718 | 719 | # decoder layers 720 | all_hidden_states = () if output_hidden_states else None 721 | all_self_attns = () if output_attentions else None 722 | next_decoder_cache = () if use_cache else None 723 | self.reuse_layer_attn = None 724 | for idx, decoder_layer in enumerate(self.layers): 725 | if output_hidden_states: 726 | all_hidden_states += (hidden_states,) 727 | 728 | past_key_value = past_key_values[idx] if past_key_values is not None else None 729 | 730 | if self.gradient_checkpointing and self.training: 731 | def create_custom_forward(module): 732 | def custom_forward(*inputs): 733 | # None for past_key_value 734 | return module(*inputs, past_key_value, output_attentions, use_cache, self.reuse_layer_attn) 735 | 736 | return custom_forward 737 | 738 | layer_outputs, self.reuse_layer_attn = torch.utils.checkpoint.checkpoint( 739 | create_custom_forward(decoder_layer), 740 | hidden_states, 741 | attention_mask, 742 | position_ids, 743 | ) 744 | else: 745 | layer_outputs, self.reuse_layer_attn = decoder_layer( 746 | hidden_states, 747 | attention_mask=attention_mask, 748 | position_ids=position_ids, 749 | past_key_value=past_key_value, 750 | output_attentions=output_attentions, 751 | use_cache=use_cache, 752 | reuse_layer_attn=self.reuse_layer_attn 753 | ) 754 | 755 | hidden_states = layer_outputs[0] 756 | 757 | if use_cache: 758 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 759 | 760 | if output_attentions: 761 | all_self_attns += (layer_outputs[1],) 762 | del self.reuse_layer_attn 763 | hidden_states = self.norm(hidden_states) 764 | 765 | # add hidden states from the last decoder layer 766 | if output_hidden_states: 767 | all_hidden_states += (hidden_states,) 768 | 769 | next_cache = next_decoder_cache if use_cache else None 770 | if not return_dict: 771 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 772 | return BaseModelOutputWithPast( 773 | last_hidden_state=hidden_states, 774 | past_key_values=next_cache, 775 | hidden_states=all_hidden_states, 776 | attentions=all_self_attns, 777 | ) 778 | 779 | 780 | class LlamaForCausalLM(LlamaPreTrainedModel): 781 | _tied_weights_keys = ["lm_head.weight"] 782 | 783 | def __init__(self, config): 784 | super().__init__(config) 785 | self.model = LlamaModel(config) 786 | self.vocab_size = config.vocab_size 787 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 788 | 789 | # Initialize weights and apply final processing 790 | self.post_init() 791 | 792 | def get_input_embeddings(self): 793 | return self.model.embed_tokens 794 | 795 | def set_input_embeddings(self, value): 796 | self.model.embed_tokens = value 797 | 798 | def get_output_embeddings(self): 799 | return self.lm_head 800 | 801 | def set_output_embeddings(self, new_embeddings): 802 | self.lm_head = new_embeddings 803 | 804 | def set_decoder(self, decoder): 805 | self.model = decoder 806 | 807 | def get_decoder(self): 808 | return self.model 809 | 810 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 811 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 812 | def forward( 813 | self, 814 | input_ids: torch.LongTensor = None, 815 | attention_mask: Optional[torch.Tensor] = None, 816 | position_ids: Optional[torch.LongTensor] = None, 817 | past_key_values: Optional[List[torch.FloatTensor]] = None, 818 | inputs_embeds: Optional[torch.FloatTensor] = None, 819 | labels: Optional[torch.LongTensor] = None, 820 | use_cache: Optional[bool] = None, 821 | output_attentions: Optional[bool] = None, 822 | output_hidden_states: Optional[bool] = None, 823 | return_dict: Optional[bool] = None, 824 | ) -> Union[Tuple, CausalLMOutputWithPast]: 825 | r""" 826 | Args: 827 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 828 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 829 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 830 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 831 | 832 | Returns: 833 | 834 | Example: 835 | 836 | ```python 837 | >>> from transformers import AutoTokenizer, LlamaForCausalLM 838 | 839 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 840 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 841 | 842 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 843 | >>> inputs = tokenizer(prompt, return_tensors="pt") 844 | 845 | >>> # Generate 846 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 847 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 848 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 849 | ```""" 850 | 851 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 852 | output_hidden_states = ( 853 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 854 | ) 855 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 856 | 857 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 858 | outputs = self.model( 859 | input_ids=input_ids, 860 | attention_mask=attention_mask, 861 | position_ids=position_ids, 862 | past_key_values=past_key_values, 863 | inputs_embeds=inputs_embeds, 864 | use_cache=use_cache, 865 | output_attentions=output_attentions, 866 | output_hidden_states=output_hidden_states, 867 | return_dict=return_dict, 868 | ) 869 | 870 | hidden_states = outputs[0] 871 | if self.config.pretraining_tp > 1: 872 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 873 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 874 | logits = torch.cat(logits, dim=-1) 875 | else: 876 | logits = self.lm_head(hidden_states) 877 | logits = logits.float() 878 | 879 | loss = None 880 | if labels is not None: 881 | # Shift so that tokens < n predict n 882 | shift_logits = logits[..., :-1, :].contiguous() 883 | shift_labels = labels[..., 1:].contiguous() 884 | # Flatten the tokens 885 | loss_fct = CrossEntropyLoss() 886 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 887 | shift_labels = shift_labels.view(-1) 888 | # Enable model parallelism 889 | shift_labels = shift_labels.to(shift_logits.device) 890 | loss = loss_fct(shift_logits, shift_labels) 891 | 892 | if not return_dict: 893 | output = (logits,) + outputs[1:] 894 | return (loss,) + output if loss is not None else output 895 | 896 | return CausalLMOutputWithPast( 897 | loss=loss, 898 | logits=logits, 899 | past_key_values=outputs.past_key_values, 900 | hidden_states=outputs.hidden_states, 901 | attentions=outputs.attentions, 902 | ) 903 | 904 | def prepare_inputs_for_generation( 905 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 906 | ): 907 | if past_key_values: 908 | input_ids = input_ids[:, -1:] 909 | 910 | position_ids = kwargs.get("position_ids", None) 911 | if attention_mask is not None and position_ids is None: 912 | # create position_ids on the fly for batch generation 913 | position_ids = attention_mask.long().cumsum(-1) - 1 914 | position_ids.masked_fill_(attention_mask == 0, 1) 915 | if past_key_values: 916 | position_ids = position_ids[:, -1].unsqueeze(-1) 917 | 918 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 919 | if inputs_embeds is not None and past_key_values is None: 920 | model_inputs = {"inputs_embeds": inputs_embeds} 921 | else: 922 | model_inputs = {"input_ids": input_ids} 923 | 924 | model_inputs.update( 925 | { 926 | "position_ids": position_ids, 927 | "past_key_values": past_key_values, 928 | "use_cache": kwargs.get("use_cache"), 929 | "attention_mask": attention_mask, 930 | } 931 | ) 932 | return model_inputs 933 | 934 | @staticmethod 935 | def _reorder_cache(past_key_values, beam_idx): 936 | reordered_past = () 937 | for layer_past in past_key_values: 938 | reordered_past += ( 939 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 940 | ) 941 | return reordered_past 942 | 943 | 944 | @add_start_docstrings( 945 | """ 946 | The LLaMa Model transformer with a sequence classification head on top (linear layer). 947 | 948 | [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models 949 | (e.g. GPT-2) do. 950 | 951 | Since it does classification on the last token, it requires to know the position of the last token. If a 952 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 953 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 954 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 955 | each row of the batch). 956 | """, 957 | LLAMA_START_DOCSTRING, 958 | ) 959 | class LlamaForSequenceClassification(LlamaPreTrainedModel): 960 | def __init__(self, config): 961 | super().__init__(config) 962 | self.num_labels = config.num_labels 963 | self.model = LlamaModel(config) 964 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 965 | 966 | # Initialize weights and apply final processing 967 | self.post_init() 968 | 969 | def get_input_embeddings(self): 970 | return self.model.embed_tokens 971 | 972 | def set_input_embeddings(self, value): 973 | self.model.embed_tokens = value 974 | 975 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 976 | def forward( 977 | self, 978 | input_ids: torch.LongTensor = None, 979 | attention_mask: Optional[torch.Tensor] = None, 980 | position_ids: Optional[torch.LongTensor] = None, 981 | past_key_values: Optional[List[torch.FloatTensor]] = None, 982 | inputs_embeds: Optional[torch.FloatTensor] = None, 983 | labels: Optional[torch.LongTensor] = None, 984 | use_cache: Optional[bool] = None, 985 | output_attentions: Optional[bool] = None, 986 | output_hidden_states: Optional[bool] = None, 987 | return_dict: Optional[bool] = None, 988 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 989 | r""" 990 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 991 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 992 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 993 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 994 | """ 995 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 996 | 997 | transformer_outputs = self.model( 998 | input_ids, 999 | attention_mask=attention_mask, 1000 | position_ids=position_ids, 1001 | past_key_values=past_key_values, 1002 | inputs_embeds=inputs_embeds, 1003 | use_cache=use_cache, 1004 | output_attentions=output_attentions, 1005 | output_hidden_states=output_hidden_states, 1006 | return_dict=return_dict, 1007 | ) 1008 | hidden_states = transformer_outputs[0] 1009 | logits = self.score(hidden_states) 1010 | 1011 | if input_ids is not None: 1012 | batch_size = input_ids.shape[0] 1013 | else: 1014 | batch_size = inputs_embeds.shape[0] 1015 | 1016 | if self.config.pad_token_id is None and batch_size != 1: 1017 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 1018 | if self.config.pad_token_id is None: 1019 | sequence_lengths = -1 1020 | else: 1021 | if input_ids is not None: 1022 | sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( 1023 | logits.device 1024 | ) 1025 | else: 1026 | sequence_lengths = -1 1027 | 1028 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1029 | 1030 | loss = None 1031 | if labels is not None: 1032 | labels = labels.to(logits.device) 1033 | if self.config.problem_type is None: 1034 | if self.num_labels == 1: 1035 | self.config.problem_type = "regression" 1036 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1037 | self.config.problem_type = "single_label_classification" 1038 | else: 1039 | self.config.problem_type = "multi_label_classification" 1040 | 1041 | if self.config.problem_type == "regression": 1042 | loss_fct = MSELoss() 1043 | if self.num_labels == 1: 1044 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1045 | else: 1046 | loss = loss_fct(pooled_logits, labels) 1047 | elif self.config.problem_type == "single_label_classification": 1048 | loss_fct = CrossEntropyLoss() 1049 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1050 | elif self.config.problem_type == "multi_label_classification": 1051 | loss_fct = BCEWithLogitsLoss() 1052 | loss = loss_fct(pooled_logits, labels) 1053 | if not return_dict: 1054 | output = (pooled_logits,) + transformer_outputs[1:] 1055 | return ((loss,) + output) if loss is not None else output 1056 | 1057 | return SequenceClassifierOutputWithPast( 1058 | loss=loss, 1059 | logits=pooled_logits, 1060 | past_key_values=transformer_outputs.past_key_values, 1061 | hidden_states=transformer_outputs.hidden_states, 1062 | attentions=transformer_outputs.attentions, 1063 | ) 1064 | -------------------------------------------------------------------------------- /models/Llama2-7b-hf/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "", 4 | "lstrip": false, 5 | "normalized": false, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "", 11 | "lstrip": false, 12 | "normalized": false, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "unk_token": { 17 | "content": "", 18 | "lstrip": false, 19 | "normalized": false, 20 | "rstrip": false, 21 | "single_word": false 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /models/Llama2-7b-hf/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metacarbon/shareAtt/2494690186802a829d351fc59753997bf9fbea2f/models/Llama2-7b-hf/tokenizer.model -------------------------------------------------------------------------------- /models/Llama2-7b-hf/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": true, 3 | "add_eos_token": false, 4 | "bos_token": { 5 | "__type": "AddedToken", 6 | "content": "", 7 | "lstrip": false, 8 | "normalized": false, 9 | "rstrip": false, 10 | "single_word": false 11 | }, 12 | "clean_up_tokenization_spaces": false, 13 | "eos_token": { 14 | "__type": "AddedToken", 15 | "content": "", 16 | "lstrip": false, 17 | "normalized": false, 18 | "rstrip": false, 19 | "single_word": false 20 | }, 21 | "legacy": false, 22 | "model_max_length": 1000000000000000019884624838656, 23 | "pad_token": null, 24 | "padding_side": "right", 25 | "sp_model_kwargs": {}, 26 | "tokenizer_class": "LlamaTokenizer", 27 | "unk_token": { 28 | "__type": "AddedToken", 29 | "content": "", 30 | "lstrip": false, 31 | "normalized": false, 32 | "rstrip": false, 33 | "single_word": false 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | numpy==1.23.5 3 | sentencepiece==0.1.97 4 | transformers==4.29.1 5 | xformers==0.0.20 6 | deepspeed==0.13.5 7 | tensorboard==2.14.0 --------------------------------------------------------------------------------