├── README.md ├── configs ├── llama_100m.json ├── llama_130m.json ├── llama_1b.json ├── llama_20m.json ├── llama_250m.json ├── llama_350m.json ├── llama_35m.json ├── llama_3b.json ├── llama_40m.json ├── llama_60m.json ├── llama_71m.json ├── llama_7b.json └── llama_9m.json ├── exp_requirements.txt ├── layer_remove.py ├── peft_pretraining ├── __pycache__ │ ├── args_utils.cpython-39.pyc │ ├── dataloader.cpython-39.pyc │ ├── modeling_llama.cpython-39.pyc │ └── training_utils.cpython-39.pyc ├── args_utils.py ├── dataloader.py ├── modeling_llama.py └── training_utils.py ├── requirements.txt ├── run_130m.sh ├── run_1b.sh ├── run_250m.sh ├── run_350m.sh ├── scaled_init.png ├── scaling.png ├── torchrun_main.py └── utils ├── angular_distance.py ├── metrics.py └── short_hf.py /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # The Curse of Depth in Large Language Models 4 | [[`Arxiv`](https://arxiv.org/abs/2502.05795)] 5 | [[X(Twitter)](https://x.com/Shiwei_Liu66/status/1889257901346152844)] 6 | [[Model](https://huggingface.co/pengxiang/LNS_1B)] 7 | [[Talk](https://www.youtube.com/watch?v=sVN7wgmmNms)] 8 | 9 | We present the Curse of Depth, a phenomenon in Large Language Models (LLMs) where deeper layers contribute less effectively to training due to the widespread use of Pre-Layer Normalization (Pre-LN). Our analysis identifies this issue as a key bottleneck in LLM optimization and proposes LayerNorm Scaling as a solution to mitigate its impact. 10 | 11 |
12 | Image 2 13 |
14 |
15 | 16 | ## Abstract 17 | 18 | In this paper, we introduce the Curse of Depth, a concept that highlights, explains, and addresses the recent observation in modern Large Language Models (LLMs) where nearly half of the layers are less effective than expected. We first confirm the wide existence of this phenomenon across the most popular families of LLMs such as Llama, Mistral, DeepSeek, and Qwen. Our analysis, theoretically and empirically, identifies that the underlying reason for the ineffectiveness of deep layers in LLMs is the widespread usage of Pre-Layer Normalization (Pre-LN). While Pre-LN stabilizes the training of Transformer LLMs, its output variance exponentially grows with the model depth, which undesirably causes the derivative of the deep Transformer blocks to be an identity matrix, and therefore barely contributes to the training. To resolve this training pitfall, we propose LayerNorm Scaling, which scales the variance of output of the layer normalization inversely by the square root of its depth. This simple modification mitigates the output variance explosion of deeper Transformer layers, improving their contribution. Our experimental results, spanning model sizes from 130M to 1B, demonstrate that LayerNorm Scaling significantly enhances LLM pre-training performance compared to Pre-LN. Moreover, this improvement seamlessly carries over to supervised fine-tuning. All these gains can be attributed to the fact that LayerNorm Scaling enables deeper layers to contribute more effectively during training. 19 | 20 | ## Caveat 21 | 22 | Combining LNS with Scaled Initialization (which scales the initialization of W0 and W2 by the overall depth $1/\sqrt{2L}$) undermines the effectiveness of LNS, performing worse than using LNS alone. This highlights the importance of eliminating conflicting initialization strategies before adopting LNS. 23 | 24 |
25 | Image 2 26 |
27 | 28 | 29 | ## Hugging Face 30 | We have uploaded the trained weights for the 1B model using LayerNorm Scaling (LNS). 31 | You can download them from [https://huggingface.co/pengxiang/LNS_1B]. 32 | 33 | ## Quick Start 34 | 35 | ### Install experiment dependencies 36 | 37 | You can configure the environment using the following command lines: 38 | 39 | ```bash 40 | conda create -n LNS python=3.9 -y 41 | conda activate LNS 42 | pip install -r exp_requirements.txt 43 | ``` 44 | 45 | ### Training Examples 46 | We provide scripts to train models of different sizes using Pre-LN, Post-LN, Mix-LN, and LayerNorm Scaling (LNS). 47 | 48 | Train a 130M Model: 49 | ```bash 50 | bash run_130m.sh pre 3 # Pre-LN 51 | bash run_130m.sh post 3 # Post-LN 52 | bash run_130m.sh post_pre 3 # Mix-LN 53 | bash run_130m.sh LNS 3 # LayerNorm Scaling (LNS) 54 | 55 | (Note: 3 represents the number of Post-LN layers in Mix-LN.) 56 | ``` 57 | 58 | 59 | Train a 250M Mode: 60 | ```bash 61 | bash run_250m.sh pre 6 # Pre-LN 62 | bash run_250m.sh post 6 # Post-LN 63 | bash run_250m.sh post_pre 6 # Mix-LN 64 | bash run_250m.sh LNS 6 # LayerNorm Scaling (LNS) 65 | 66 | (Note: 6 represents the number of Post-LN layers in Mix-LN.) 67 | ``` 68 | 69 | Train a 350M Mode: 70 | ```bash 71 | bash run_350m.sh pre 6 # Pre-LN 72 | bash run_350m.sh post 6 # Post-LN 73 | bash run_350m.sh post_pre 6 # Mix-LN 74 | bash run_350m.sh LNS 6 # LayerNorm Scaling (LNS) 75 | ``` 76 | 77 | Train a 1B Mode: 78 | ```bash 79 | bash run_1b.sh pre 6 # Pre-LN 80 | bash run_1b.sh post 6 # Post-LN 81 | bash run_1b.sh post_pre 6 # Mix-LN 82 | bash run_1b.sh LNS 6 # LayerNorm Scaling (LNS) 83 | ``` 84 | 85 | ### Performance Drop 86 | Calculate the performance drop after removing different layers. We use [lm_eval](https://github.com/EleutherAI/lm-evaluation-harness) to obtain evaluation results. Please refer to its installation instructions to configure `lm_eval``. 87 | ```bash 88 | git clone https://github.com/EleutherAI/lm-evaluation-harness 89 | cd lm-evaluation-harness 90 | pip install -e . 91 | ``` 92 | 93 | Then, you can run the following command to remove different layers and save the weights to a new model. The performance drop will be calculated based on the new model: 94 | ```bash 95 | # LLaMA2-7B, Remove Layer 1 96 | python layer_remove.py \ 97 | --model_path meta-llama/Llama-2-7b-hf \ 98 | --layer_index 1 \ 99 | --save_path ./llama_7b_removed_1 100 | ``` 101 | 102 | ### 📚Citation 103 | 104 | ```bibtex 105 | @article{sun2025curse, 106 | title={The Curse of Depth in Large Language Models}, 107 | author={Sun, Wenfang and Song, Xinyuan and Li, Pengxiang and Yin, Lu and Zheng, Yefeng and Liu, Shiwei}, 108 | journal={arXiv preprint arXiv:2502.05795}, 109 | year={2025} 110 | } 111 | ``` 112 | 113 | 114 | ### Acknowledgement 115 | This repository is built upon the [Mix-LN](https://github.com/pixeli99/MixLN/tree/main) repositories. Thanks for their great work! 116 | -------------------------------------------------------------------------------- /configs/llama_100m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 640, 9 | "intermediate_size": 1708, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 10, 14 | "num_hidden_layers": 12, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_130m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 768, 9 | "intermediate_size": 2048, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 12, 14 | "num_hidden_layers": 12, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_1b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 2048, 9 | "intermediate_size": 5461, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 24, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_20m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 256, 9 | "intermediate_size": 688, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 4, 14 | "num_hidden_layers": 4, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_250m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 768, 9 | "intermediate_size": 2560, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 16, 14 | "num_hidden_layers": 24, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_350m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 1024, 9 | "intermediate_size": 2736, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 16, 14 | "num_hidden_layers": 24, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_35m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 384, 9 | "intermediate_size": 1024, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 8, 14 | "num_hidden_layers": 6, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_3b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 2560, 9 | "intermediate_size": 6848, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 32, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_40m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 416, 9 | "intermediate_size": 1024, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 8, 14 | "num_hidden_layers": 8, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_60m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 512, 9 | "intermediate_size": 1376, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 8, 14 | "num_hidden_layers": 8, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_71m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 512, 9 | "intermediate_size": 1368, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 8, 14 | "num_hidden_layers": 12, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 4096, 9 | "intermediate_size": 11008, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 2048, 12 | "model_type": "llama", 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 32, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_9m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 128, 9 | "intermediate_size": 352, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 4, 14 | "num_hidden_layers": 4, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /exp_requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers==4.31.0 3 | tensorly 4 | tokenizers 5 | datasets 6 | peft 7 | wandb 8 | loguru 9 | nvitop 10 | lion-pytorch 11 | matplotlib 12 | bitsandbytes 13 | scipy 14 | scikit-learn 15 | evaluate 16 | -------------------------------------------------------------------------------- /layer_remove.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from transformers import LlamaForCausalLM 4 | import subprocess 5 | import json 6 | import argparse 7 | 8 | def remove_layers_and_save(model_path, output_dir, layers_to_remove): 9 | # Load the pre-trained LLaMA model 10 | model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16) 11 | 12 | # Ensure the output directory exists 13 | if not os.path.exists(output_dir): 14 | os.makedirs(output_dir) 15 | 16 | # Remove the specified layers 17 | for layer_idx in layers_to_remove: 18 | if 0 <= layer_idx < len(model.model.layers): 19 | del model.model.layers[layer_idx] 20 | 21 | # Renumber the remaining layers' indices 22 | for layer_idx, module in enumerate(model.model.layers): 23 | module.self_attn.layer_idx = layer_idx 24 | 25 | # Save the modified model 26 | model.save_pretrained(output_dir) 27 | print(f"Model saved to {output_dir}") 28 | 29 | # Update the config.json with the new number of hidden layers 30 | config_path = os.path.join(output_dir, "config.json") 31 | with open(config_path, "r", encoding="utf-8") as file: 32 | config = json.load(file) 33 | 34 | config['num_hidden_layers'] = len(model.model.layers) 35 | 36 | with open(config_path, "w", encoding="utf-8") as file: 37 | json.dump(config, file, indent=4, ensure_ascii=False) 38 | print(f"Updated config saved to {config_path}") 39 | 40 | 41 | def run_bash_script(script_path, working_directory): 42 | # Switch to the target directory 43 | os.chdir(working_directory) 44 | 45 | # Execute the bash script 46 | subprocess.run(["bash", script_path]) 47 | 48 | # Switch back to the original directory 49 | os.chdir("..") 50 | print(f"Executed script {script_path} in {working_directory}") 51 | 52 | 53 | if __name__ == "__main__": 54 | # Argument parsing 55 | parser = argparse.ArgumentParser(description="Remove layers from LLaMA model and save the modified version.") 56 | parser.add_argument("--model_path", type=str, required=True, help="Path to the pre-trained LLaMA model") 57 | parser.add_argument("--layer_index", type=int, required=True, help="Index of the layer to remove") 58 | parser.add_argument("--save_path", type=str, required=True, help="Path to save the modified model") 59 | 60 | args = parser.parse_args() 61 | 62 | # Remove layers and save the model 63 | remove_layers_and_save(args.model_path, args.save_path, [args.layer_index]) 64 | 65 | # Path to the evaluation script and directory 66 | # target_directory = "/lm-evaluation-harness" 67 | # script_name = "run_task.sh" 68 | 69 | # Run the evaluation script 70 | # run_bash_script(script_name, target_directory) 71 | -------------------------------------------------------------------------------- /peft_pretraining/__pycache__/args_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmsdss/LayerNorm-Scaling/da34a90d7392fe01a7dcdae6715ea6e6706bdd81/peft_pretraining/__pycache__/args_utils.cpython-39.pyc -------------------------------------------------------------------------------- /peft_pretraining/__pycache__/dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmsdss/LayerNorm-Scaling/da34a90d7392fe01a7dcdae6715ea6e6706bdd81/peft_pretraining/__pycache__/dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /peft_pretraining/__pycache__/modeling_llama.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmsdss/LayerNorm-Scaling/da34a90d7392fe01a7dcdae6715ea6e6706bdd81/peft_pretraining/__pycache__/modeling_llama.cpython-39.pyc -------------------------------------------------------------------------------- /peft_pretraining/__pycache__/training_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmsdss/LayerNorm-Scaling/da34a90d7392fe01a7dcdae6715ea6e6706bdd81/peft_pretraining/__pycache__/training_utils.cpython-39.pyc -------------------------------------------------------------------------------- /peft_pretraining/args_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | from loguru import logger 5 | 6 | 7 | def check_args_torchrun_main(args): 8 | 9 | if args.save_dir is None: 10 | # use checkpoints / model name, date and time as save directory 11 | args.save_dir = f"checkpoints/{args.model_config.split('/')[-1].rstrip('.json')}-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" 12 | 13 | if args.tags is not None: 14 | args.tags = args.tags.split(",") 15 | 16 | if args.total_batch_size is None: 17 | args.gradient_accumulation = args.gradient_accumulation or 1 18 | args.total_batch_size = args.batch_size * args.gradient_accumulation 19 | 20 | assert args.total_batch_size % args.batch_size == 0, "total_batch_size must be divisible by batch_size" 21 | 22 | if args.max_train_tokens is not None: 23 | args.num_training_steps = args.max_train_tokens // args.total_batch_size 24 | logger.info(f"Training for {args.num_training_steps} update steps") 25 | 26 | if args.continue_from is not None: 27 | assert os.path.exists(args.continue_from), f"--continue_from={args.continue_from} does not exist" 28 | 29 | if args.dtype in ["fp16", "float16"]: 30 | raise NotImplementedError("fp16 is not supported in torchrun_main.py. Use deepspeed_main.py instead (but it seems to have bugs)") 31 | 32 | return args 33 | -------------------------------------------------------------------------------- /peft_pretraining/dataloader.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | from torch.utils.data import IterableDataset, get_worker_info 5 | 6 | 7 | class PreprocessedIterableDataset(IterableDataset): 8 | def __init__(self, data, tokenizer, batch_size, max_length): 9 | super().__init__() 10 | self.data = data 11 | self.tokenizer = tokenizer 12 | self.batch_size = batch_size 13 | self.max_length = max_length 14 | 15 | def __iter__(self): 16 | worker_info = get_worker_info() 17 | if worker_info is None: 18 | # If no worker_info is provided, we are not using DataLoader workers, so yield all data 19 | iter_data = iter(self.data) 20 | else: 21 | # If using DataLoader workers, yield a subset of the data for this worker 22 | worker_id = worker_info.id 23 | num_workers = worker_info.num_workers 24 | iter_data = itertools.islice(self.data, worker_id, None, num_workers) 25 | 26 | batch = [] 27 | for example in iter_data: 28 | tokenized_example = self.tokenizer( 29 | example["text"], 30 | max_length=self.max_length, 31 | truncation=True, 32 | padding="max_length", 33 | return_tensors="pt", 34 | ) 35 | batch.append(tokenized_example) 36 | 37 | if len(batch) == self.batch_size: 38 | yield self._format_batch(batch) 39 | batch = [] 40 | 41 | if batch: 42 | yield self._format_batch(batch) 43 | 44 | def _format_batch(self, batch): 45 | input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch]) 46 | attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch]) 47 | 48 | return {"input_ids": input_ids, "attention_mask": attention_mask} 49 | -------------------------------------------------------------------------------- /peft_pretraining/modeling_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch LLaMA model.""" 21 | import math 22 | import os 23 | from typing import List, Optional, Tuple, Union 24 | 25 | import torch 26 | import torch.utils.checkpoint 27 | from torch import nn 28 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 29 | 30 | from transformers.activations import ACT2FN 31 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 32 | from transformers.modeling_utils import PreTrainedModel 33 | from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings 34 | from transformers.models.llama.configuration_llama import LlamaConfig 35 | 36 | 37 | logger = logging.get_logger(__name__) 38 | 39 | _CONFIG_FOR_DOC = "LlamaConfig" 40 | 41 | 42 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 43 | def _make_causal_mask( 44 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 45 | ): 46 | """ 47 | Make causal mask used for bi-directional self-attention. 48 | """ 49 | bsz, tgt_len = input_ids_shape 50 | mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) 51 | mask_cond = torch.arange(mask.size(-1), device=device) 52 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 53 | mask = mask.to(dtype) 54 | 55 | if past_key_values_length > 0: 56 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 57 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 58 | 59 | 60 | # Copied from transformers.models.bart.modeling_bart._expand_mask 61 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 62 | """ 63 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 64 | """ 65 | bsz, src_len = mask.size() 66 | tgt_len = tgt_len if tgt_len is not None else src_len 67 | 68 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 69 | 70 | inverted_mask = 1.0 - expanded_mask 71 | 72 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 73 | 74 | 75 | class LlamaRMSNorm(nn.Module): 76 | def __init__(self, hidden_size, eps=1e-6): 77 | """ 78 | LlamaRMSNorm is equivalent to T5LayerNorm 79 | """ 80 | super().__init__() 81 | self.weight = nn.Parameter(torch.ones(hidden_size)) 82 | self.variance_epsilon = eps 83 | 84 | def forward(self, hidden_states): 85 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 86 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 87 | 88 | # convert into half-precision if necessary 89 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 90 | hidden_states = hidden_states.to(self.weight.dtype) 91 | 92 | return self.weight * hidden_states 93 | 94 | 95 | class LlamaRotaryEmbedding(torch.nn.Module): 96 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 97 | super().__init__() 98 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 99 | self.register_buffer("inv_freq", inv_freq) 100 | 101 | # Build here to make `torch.jit.trace` work. 102 | self.max_seq_len_cached = max_position_embeddings 103 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 104 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 105 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 106 | emb = torch.cat((freqs, freqs), dim=-1) 107 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) 108 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) 109 | 110 | def forward(self, x, seq_len=None): 111 | # x: [bs, num_attention_heads, seq_len, head_size] 112 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 113 | if seq_len > self.max_seq_len_cached: 114 | self.max_seq_len_cached = seq_len 115 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 116 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 117 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 118 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 119 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) 120 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) 121 | return ( 122 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 123 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 124 | ) 125 | 126 | 127 | def rotate_half(x): 128 | """Rotates half the hidden dims of the input.""" 129 | x1 = x[..., : x.shape[-1] // 2] 130 | x2 = x[..., x.shape[-1] // 2 :] 131 | return torch.cat((-x2, x1), dim=-1) 132 | 133 | 134 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 135 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 136 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 137 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 138 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 139 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 140 | q_embed = (q * cos) + (rotate_half(q) * sin) 141 | k_embed = (k * cos) + (rotate_half(k) * sin) 142 | return q_embed, k_embed 143 | 144 | 145 | class LlamaMLP(nn.Module): 146 | def __init__( 147 | self, 148 | hidden_size: int, 149 | intermediate_size: int, 150 | hidden_act: str, 151 | scale_mlp_output: bool = False, 152 | ): 153 | super().__init__() 154 | self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) 155 | self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) 156 | self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) 157 | self.act_fn = ACT2FN[hidden_act] 158 | 159 | if scale_mlp_output: 160 | self.down_proj.is_scaled_layer = True 161 | if os.getenv('NORM_TYPE').lower() == 'deeppost': 162 | self.gate_proj.is_deeppost_layer = True 163 | self.up_proj.is_deeppost_layer = True 164 | self.down_proj.is_deeppost_layer = True 165 | def forward(self, x): 166 | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 167 | 168 | 169 | class LlamaAttention(nn.Module): 170 | """Multi-headed attention from 'Attention Is All You Need' paper""" 171 | 172 | def __init__(self, config: LlamaConfig, scale_attn_weights: bool = False): 173 | super().__init__() 174 | self.config = config 175 | self.hidden_size = config.hidden_size 176 | self.num_heads = config.num_attention_heads 177 | self.head_dim = self.hidden_size // self.num_heads 178 | self.max_position_embeddings = config.max_position_embeddings 179 | 180 | if (self.head_dim * self.num_heads) != self.hidden_size: 181 | raise ValueError( 182 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 183 | f" and `num_heads`: {self.num_heads})." 184 | ) 185 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 186 | self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 187 | self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 188 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 189 | self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) 190 | 191 | if scale_attn_weights: 192 | self.o_proj.is_scaled_layer = True 193 | if os.getenv('NORM_TYPE', 'pre').lower() == 'deeppost': 194 | self.v_proj.is_deeppost_layer = True 195 | self.o_proj.is_deeppost_layer = True 196 | # ---- - - - - - 197 | self.q_proj.is_deeppost_layer_qk = True 198 | self.k_proj.is_deeppost_layer_qk = True 199 | 200 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 201 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 202 | 203 | def forward( 204 | self, 205 | hidden_states: torch.Tensor, 206 | attention_mask: Optional[torch.Tensor] = None, 207 | position_ids: Optional[torch.LongTensor] = None, 208 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 209 | output_attentions: bool = False, 210 | use_cache: bool = False, 211 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 212 | bsz, q_len, _ = hidden_states.size() 213 | 214 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 215 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 216 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 217 | 218 | kv_seq_len = key_states.shape[-2] 219 | if past_key_value is not None: 220 | kv_seq_len += past_key_value[0].shape[-2] 221 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 222 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 223 | # [bsz, nh, t, hd] 224 | 225 | if past_key_value is not None: 226 | # reuse k, v, self_attention 227 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 228 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 229 | 230 | past_key_value = (key_states, value_states) if use_cache else None 231 | 232 | if attention_mask is not None: 233 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 234 | raise ValueError( 235 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 236 | ) 237 | 238 | # WARNING: padding mask is ignored, causal is always applied 239 | attn_output = torch.nn.functional.scaled_dot_product_attention( 240 | query_states, key_states, value_states, dropout_p=0.0, is_causal=True, 241 | ) 242 | 243 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 244 | raise ValueError( 245 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 246 | f" {attn_output.size()}" 247 | ) 248 | 249 | attn_output = attn_output.transpose(1, 2) 250 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 251 | 252 | attn_output = self.o_proj(attn_output) 253 | 254 | if not output_attentions: 255 | attn_weights = None 256 | 257 | return attn_output, attn_weights, past_key_value 258 | 259 | 260 | class LlamaDecoderLayer(nn.Module): 261 | def __init__(self, config: LlamaConfig, layer_index: int): 262 | super().__init__() 263 | self.hidden_size = config.hidden_size 264 | 265 | norm_type = os.getenv('NORM_TYPE', 'pre').lower() 266 | self.layer_nums = config.num_hidden_layers 267 | self.layer_index = layer_index 268 | self.max_post_norm_layer = int(os.getenv('POST_NUM', '1')) 269 | print(f'Initializing LlamaDecoderLayer {self.layer_index + 1}/{self.layer_nums} with norm type: {norm_type}') 270 | scale_attn_weights = False 271 | scale_mlp_output = False 272 | 273 | if norm_type == 'scale_post_pre': 274 | if self.layer_index < self.max_post_norm_layer: 275 | scale_attn_weights = False 276 | scale_mlp_output = False 277 | else: 278 | scale_attn_weights = True 279 | scale_mlp_output = True 280 | if norm_type == 'scale_pre': 281 | scale_attn_weights = True 282 | scale_mlp_output = True 283 | 284 | self.self_attn = LlamaAttention(config=config, scale_attn_weights=scale_attn_weights) 285 | self.mlp = LlamaMLP( 286 | hidden_size=self.hidden_size, 287 | intermediate_size=config.intermediate_size, 288 | hidden_act=config.hidden_act, 289 | scale_mlp_output=scale_mlp_output, 290 | ) 291 | 292 | if norm_type == 'pre' or norm_type == 'scale_pre' or norm_type == 'LNS': 293 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 294 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 295 | elif norm_type== 'post' or norm_type == 'deeppost': 296 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 297 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 298 | elif norm_type == 'sandwich': 299 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 300 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 301 | self.pre_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 302 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 303 | elif norm_type == 'pre_post': 304 | print("cur layer index is:", layer_index) 305 | self.max_pre_norm_layer = 7 306 | if self.layer_index < self.max_pre_norm_layer: 307 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 308 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 309 | else: 310 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 311 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 312 | elif norm_type == 'pre_sandwich': 313 | print("cur layer index is:", layer_index) 314 | self.max_pre_norm_layer = 7 315 | if self.layer_index < self.max_pre_norm_layer: 316 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 317 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 318 | else: 319 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 320 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 321 | self.pre_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 322 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 323 | elif norm_type == 'post_pre' or norm_type == 'scale_post_pre': 324 | print("cur layer index is:", layer_index) 325 | if self.layer_index < self.max_post_norm_layer: 326 | print(f"cur layer is {layer_index}, and you're using the post norm!") 327 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 328 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 329 | else: 330 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 331 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 332 | elif norm_type == 'scale_res_pre_norm': 333 | self.raw_scaling_factor_attn = nn.Parameter(torch.tensor(0.001)) 334 | self.raw_scaling_factor_mlp = nn.Parameter(torch.tensor(0.001)) 335 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 336 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 337 | elif norm_type == 'pre_post_pre_post': 338 | print("cur layer index is:", layer_index) 339 | # For even-numbered layers, use pre-norm 340 | if (self.layer_index + 1) % 4 != 0: 341 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 342 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 343 | else: 344 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 345 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 346 | elif norm_type == 'post_pre_post_pre': 347 | print("cur layer index is:", layer_index) 348 | # For even-numbered layers, use pre-norm 349 | if (self.layer_index + 1) % 4 != 0: 350 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 351 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 352 | else: 353 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 354 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 355 | elif norm_type == 'mono': 356 | if (self.layer_index + 1) % 4 != 0: 357 | # Pre-LayerNorm Only 358 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 359 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 360 | else: 361 | # Post-LayerNorm & Pre-LayerNorm 362 | self.router = nn.Linear(config.hidden_size, 2, bias=False) 363 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 364 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 365 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 366 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 367 | elif norm_type == 'mono_reverse': 368 | if (self.layer_index + 1) % 4 != 0: 369 | # Post-LayerNorm Only 370 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 371 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 372 | else: 373 | # Post-LayerNorm & Pre-LayerNorm 374 | self.router = nn.Linear(config.hidden_size, 2, bias=False) 375 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 376 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 377 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 378 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 379 | 380 | def forward( 381 | self, 382 | hidden_states: torch.Tensor, 383 | attention_mask: Optional[torch.Tensor] = None, 384 | position_ids: Optional[torch.LongTensor] = None, 385 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 386 | output_attentions: Optional[bool] = False, 387 | use_cache: Optional[bool] = False, 388 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 389 | """ 390 | Args: 391 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 392 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 393 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 394 | output_attentions (`bool`, *optional*): 395 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 396 | returned tensors for more detail. 397 | use_cache (`bool`, *optional*): 398 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 399 | (see `past_key_values`). 400 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 401 | """ 402 | 403 | norm_type = os.getenv('NORM_TYPE', 'pre').lower() 404 | 405 | if norm_type == 'pre' or norm_type == 'scale_pre': 406 | # # Layer 1: Self-Attention 407 | residual = hidden_states 408 | hidden_states = self.input_layernorm(hidden_states) 409 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 410 | hidden_states=hidden_states, 411 | attention_mask=attention_mask, 412 | position_ids=position_ids, 413 | past_key_value=past_key_value, 414 | output_attentions=output_attentions, 415 | use_cache=use_cache, 416 | ) 417 | hidden_states = residual + hidden_states 418 | 419 | # Layer 2: Feed-Forward Network (FFN) 420 | 421 | residual = hidden_states 422 | hidden_states = self.post_attention_layernorm(hidden_states) 423 | hidden_states = self.mlp(hidden_states) 424 | hidden_states = residual + hidden_states 425 | 426 | 427 | elif norm_type == 'LNS': 428 | # Layer 1: Self-Attention 429 | residual = hidden_states 430 | hidden_states = self.input_layernorm(hidden_states) 431 | 432 | scale_factor = 1 / math.sqrt(self.layer_index + 1) # scale 433 | hidden_states = scale_factor * hidden_states 434 | 435 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 436 | hidden_states=hidden_states, 437 | attention_mask=attention_mask, 438 | position_ids=position_ids, 439 | past_key_value=past_key_value, 440 | output_attentions=output_attentions, 441 | use_cache=use_cache, 442 | ) 443 | hidden_states = residual + hidden_states 444 | 445 | # Layer 2: Feed-Forward Network (FFN) 446 | residual = hidden_states 447 | hidden_states = self.post_attention_layernorm(hidden_states) 448 | 449 | scale_factor = 1 / math.sqrt(self.layer_index + 1) # scale 450 | hidden_states = scale_factor * hidden_states 451 | 452 | hidden_states = self.mlp(hidden_states) 453 | hidden_states = residual + hidden_states 454 | 455 | 456 | elif norm_type == 'scale_res_pre_norm': 457 | # Pre-LayerNorm Only 458 | residual = hidden_states 459 | hidden_states = self.input_layernorm(hidden_states) 460 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 461 | hidden_states=hidden_states, 462 | attention_mask=attention_mask, 463 | position_ids=position_ids, 464 | past_key_value=past_key_value, 465 | output_attentions=output_attentions, 466 | use_cache=use_cache, 467 | ) 468 | hidden_states = residual + hidden_states * self.raw_scaling_factor_attn 469 | 470 | residual = hidden_states 471 | hidden_states = self.post_attention_layernorm(hidden_states) 472 | hidden_states = self.mlp(hidden_states) 473 | hidden_states = residual + hidden_states * self.raw_scaling_factor_mlp 474 | elif norm_type == 'post' or norm_type == 'deeppost': 475 | # Post-LayerNorm Only 476 | residual = hidden_states 477 | 478 | 479 | 480 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 481 | hidden_states=hidden_states, 482 | attention_mask=attention_mask, 483 | position_ids=position_ids, 484 | past_key_value=past_key_value, 485 | output_attentions=output_attentions, 486 | use_cache=use_cache, 487 | ) 488 | if norm_type == 'deeppost': 489 | residual = 2.8284271247461903 * residual 490 | hidden_states = residual + hidden_states 491 | 492 | 493 | 494 | 495 | hidden_states = self.post_attention_layernorm(hidden_states) 496 | 497 | residual = hidden_states 498 | 499 | hidden_states = self.mlp(hidden_states) 500 | if norm_type == 'deeppost': 501 | residual = 2.8284271247461903 * residual 502 | hidden_states = residual + hidden_states 503 | hidden_states = self.post_feedforward_layernorm(hidden_states) 504 | 505 | elif norm_type == 'sandwich': 506 | # Pre + Post LayerNorm (default) 507 | residual = hidden_states 508 | hidden_states = self.input_layernorm(hidden_states) 509 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 510 | hidden_states=hidden_states, 511 | attention_mask=attention_mask, 512 | position_ids=position_ids, 513 | past_key_value=past_key_value, 514 | output_attentions=output_attentions, 515 | use_cache=use_cache, 516 | ) 517 | hidden_states = self.post_attention_layernorm(hidden_states) 518 | hidden_states = residual + hidden_states 519 | 520 | residual = hidden_states 521 | hidden_states = self.pre_feedforward_layernorm(hidden_states) 522 | hidden_states = self.mlp(hidden_states) 523 | hidden_states = self.post_feedforward_layernorm(hidden_states) 524 | hidden_states = residual + hidden_states 525 | 526 | elif norm_type == 'pre_post': 527 | if self.layer_index < self.max_pre_norm_layer: 528 | # Pre-LayerNorm Only 529 | residual = hidden_states 530 | hidden_states = self.input_layernorm(hidden_states) 531 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 532 | hidden_states=hidden_states, 533 | attention_mask=attention_mask, 534 | position_ids=position_ids, 535 | past_key_value=past_key_value, 536 | output_attentions=output_attentions, 537 | use_cache=use_cache, 538 | ) 539 | hidden_states = residual + hidden_states 540 | 541 | residual = hidden_states 542 | hidden_states = self.post_attention_layernorm(hidden_states) 543 | hidden_states = self.mlp(hidden_states) 544 | hidden_states = residual + hidden_states 545 | else: 546 | # Post-LayerNorm Only 547 | residual = hidden_states 548 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 549 | hidden_states=hidden_states, 550 | attention_mask=attention_mask, 551 | position_ids=position_ids, 552 | past_key_value=past_key_value, 553 | output_attentions=output_attentions, 554 | use_cache=use_cache, 555 | ) 556 | hidden_states = residual + hidden_states 557 | hidden_states = self.post_attention_layernorm(hidden_states) 558 | 559 | residual = hidden_states 560 | hidden_states = self.mlp(hidden_states) 561 | hidden_states = residual + hidden_states 562 | hidden_states = self.post_feedforward_layernorm(hidden_states) 563 | 564 | elif norm_type == 'pre_sandwich': 565 | if self.layer_index < self.max_pre_norm_layer: 566 | # Pre-LayerNorm Only 567 | residual = hidden_states 568 | hidden_states = self.input_layernorm(hidden_states) 569 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 570 | hidden_states=hidden_states, 571 | attention_mask=attention_mask, 572 | position_ids=position_ids, 573 | past_key_value=past_key_value, 574 | output_attentions=output_attentions, 575 | use_cache=use_cache, 576 | ) 577 | hidden_states = residual + hidden_states 578 | 579 | residual = hidden_states 580 | hidden_states = self.post_attention_layernorm(hidden_states) 581 | hidden_states = self.mlp(hidden_states) 582 | hidden_states = residual + hidden_states 583 | else: 584 | # Pre + Post LayerNorm (default) 585 | residual = hidden_states 586 | hidden_states = self.input_layernorm(hidden_states) 587 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 588 | hidden_states=hidden_states, 589 | attention_mask=attention_mask, 590 | position_ids=position_ids, 591 | past_key_value=past_key_value, 592 | output_attentions=output_attentions, 593 | use_cache=use_cache, 594 | ) 595 | hidden_states = self.post_attention_layernorm(hidden_states) 596 | hidden_states = residual + hidden_states 597 | 598 | residual = hidden_states 599 | hidden_states = self.pre_feedforward_layernorm(hidden_states) 600 | hidden_states = self.mlp(hidden_states) 601 | hidden_states = self.post_feedforward_layernorm(hidden_states) 602 | hidden_states = residual + hidden_states 603 | 604 | elif norm_type == 'post_pre' or norm_type == 'scale_post_pre': 605 | if self.layer_index < self.max_post_norm_layer: 606 | # Post-LayerNorm Only 607 | residual = hidden_states 608 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 609 | hidden_states=hidden_states, 610 | attention_mask=attention_mask, 611 | position_ids=position_ids, 612 | past_key_value=past_key_value, 613 | output_attentions=output_attentions, 614 | use_cache=use_cache, 615 | ) 616 | hidden_states = residual + hidden_states 617 | hidden_states = self.post_attention_layernorm(hidden_states) 618 | 619 | residual = hidden_states 620 | hidden_states = self.mlp(hidden_states) 621 | hidden_states = residual + hidden_states 622 | hidden_states = self.post_feedforward_layernorm(hidden_states) 623 | else: 624 | # Pre-LayerNorm Only 625 | residual = hidden_states 626 | hidden_states = self.input_layernorm(hidden_states) 627 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 628 | hidden_states=hidden_states, 629 | attention_mask=attention_mask, 630 | position_ids=position_ids, 631 | past_key_value=past_key_value, 632 | output_attentions=output_attentions, 633 | use_cache=use_cache, 634 | ) 635 | hidden_states = residual + hidden_states 636 | 637 | residual = hidden_states 638 | hidden_states = self.post_attention_layernorm(hidden_states) 639 | hidden_states = self.mlp(hidden_states) 640 | hidden_states = residual + hidden_states 641 | elif norm_type == 'mono': 642 | if (self.layer_index + 1) % 4 != 0: 643 | # Pre-LayerNorm Only 644 | residual = hidden_states 645 | hidden_states = self.input_layernorm(hidden_states) 646 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 647 | hidden_states=hidden_states, 648 | attention_mask=attention_mask, 649 | position_ids=position_ids, 650 | past_key_value=past_key_value, 651 | output_attentions=output_attentions, 652 | use_cache=use_cache, 653 | ) 654 | hidden_states = residual + hidden_states 655 | 656 | residual = hidden_states 657 | hidden_states = self.post_attention_layernorm(hidden_states) 658 | hidden_states = self.mlp(hidden_states) 659 | hidden_states = residual + hidden_states 660 | else: 661 | # b, n, c -> b, 2 662 | router_logits = self.router(hidden_states).mean(dim=1) 663 | router_output = torch.nn.functional.gumbel_softmax(router_logits, tau=1, hard=True) 664 | 665 | ### pre branch 666 | hidden_states_pre = hidden_states.clone() 667 | residual = hidden_states_pre 668 | hidden_states_pre = self.input_layernorm(hidden_states_pre) 669 | hidden_states_pre, self_attn_weights, present_key_value = self.self_attn( 670 | hidden_states=hidden_states_pre, 671 | attention_mask=attention_mask, 672 | position_ids=position_ids, 673 | past_key_value=past_key_value, 674 | output_attentions=output_attentions, 675 | use_cache=use_cache, 676 | ) 677 | hidden_states_pre = residual + hidden_states_pre 678 | 679 | residual = hidden_states_pre 680 | hidden_states_pre = self.post_attention_layernorm(hidden_states_pre) 681 | hidden_states_pre = self.mlp(hidden_states_pre) 682 | hidden_states_pre = residual + hidden_states_pre 683 | ### 684 | ### post norm 685 | hidden_states_post = hidden_states.clone() 686 | residual = hidden_states_post 687 | hidden_states_post, self_attn_weights, present_key_value = self.self_attn( 688 | hidden_states=hidden_states_post, 689 | attention_mask=attention_mask, 690 | position_ids=position_ids, 691 | past_key_value=past_key_value, 692 | output_attentions=output_attentions, 693 | use_cache=use_cache, 694 | ) 695 | hidden_states_post = residual + hidden_states_post 696 | hidden_states_post = self.post_attention_layernorm(hidden_states_post) 697 | 698 | residual = hidden_states_post 699 | hidden_states_post = self.mlp(hidden_states_post) 700 | hidden_states_post = residual + hidden_states_post 701 | hidden_states_post = self.post_feedforward_layernorm(hidden_states_post) 702 | 703 | hidden_states = router_output[:, 0:1] * hidden_states_pre + router_output[:, 1:] * hidden_states_post 704 | 705 | elif norm_type == 'mono_reverse': 706 | if (self.layer_index + 1) % 4 != 0: 707 | # Post-LayerNorm Only 708 | residual = hidden_states 709 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 710 | hidden_states=hidden_states, 711 | attention_mask=attention_mask, 712 | position_ids=position_ids, 713 | past_key_value=past_key_value, 714 | output_attentions=output_attentions, 715 | use_cache=use_cache, 716 | ) 717 | hidden_states = residual + hidden_states 718 | hidden_states = self.post_attention_layernorm(hidden_states) 719 | 720 | residual = hidden_states 721 | hidden_states = self.post_attention_layernorm(hidden_states) 722 | hidden_states = self.mlp(hidden_states) 723 | hidden_states = residual + hidden_states 724 | hidden_states = self.post_feedforward_layernorm(hidden_states) 725 | else: 726 | # b, n, c -> b, 2 727 | router_logits = self.router(hidden_states).mean(dim=1) 728 | router_output = torch.nn.functional.gumbel_softmax(router_logits, tau=1, hard=True) 729 | 730 | ### pre branch 731 | hidden_states_pre = hidden_states.clone() 732 | residual = hidden_states_pre 733 | hidden_states_pre = self.input_layernorm(hidden_states_pre) 734 | hidden_states_pre, self_attn_weights, present_key_value = self.self_attn( 735 | hidden_states=hidden_states_pre, 736 | attention_mask=attention_mask, 737 | position_ids=position_ids, 738 | past_key_value=past_key_value, 739 | output_attentions=output_attentions, 740 | use_cache=use_cache, 741 | ) 742 | hidden_states_pre = residual + hidden_states_pre 743 | 744 | residual = hidden_states_pre 745 | hidden_states_pre = self.post_attention_layernorm(hidden_states_pre) 746 | hidden_states_pre = self.mlp(hidden_states_pre) 747 | hidden_states_pre = residual + hidden_states_pre 748 | ### 749 | ### post norm 750 | hidden_states_post = hidden_states.clone() 751 | residual = hidden_states_post 752 | hidden_states_post, self_attn_weights, present_key_value = self.self_attn( 753 | hidden_states=hidden_states_post, 754 | attention_mask=attention_mask, 755 | position_ids=position_ids, 756 | past_key_value=past_key_value, 757 | output_attentions=output_attentions, 758 | use_cache=use_cache, 759 | ) 760 | hidden_states_post = residual + hidden_states_post 761 | hidden_states_post = self.post_attention_layernorm(hidden_states_post) 762 | 763 | residual = hidden_states_post 764 | hidden_states_post = self.mlp(hidden_states_post) 765 | hidden_states_post = residual + hidden_states_post 766 | hidden_states_post = self.post_feedforward_layernorm(hidden_states_post) 767 | 768 | hidden_states = router_output[:, 0:1] * hidden_states_pre + router_output[:, 1:] * hidden_states_post 769 | 770 | elif norm_type == 'post_pre_post_pre': 771 | if (self.layer_index + 1) % 4 == 0: 772 | # Pre-LayerNorm Only 773 | residual = hidden_states 774 | hidden_states = self.input_layernorm(hidden_states) 775 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 776 | hidden_states=hidden_states, 777 | attention_mask=attention_mask, 778 | position_ids=position_ids, 779 | past_key_value=past_key_value, 780 | output_attentions=output_attentions, 781 | use_cache=use_cache, 782 | ) 783 | hidden_states = residual + hidden_states 784 | 785 | residual = hidden_states 786 | hidden_states = self.post_attention_layernorm(hidden_states) 787 | hidden_states = self.mlp(hidden_states) 788 | hidden_states = residual + hidden_states 789 | else: 790 | # Post-LayerNorm Only 791 | residual = hidden_states 792 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 793 | hidden_states=hidden_states, 794 | attention_mask=attention_mask, 795 | position_ids=position_ids, 796 | past_key_value=past_key_value, 797 | output_attentions=output_attentions, 798 | use_cache=use_cache, 799 | ) 800 | hidden_states = residual + hidden_states 801 | hidden_states = self.post_attention_layernorm(hidden_states) 802 | 803 | residual = hidden_states 804 | hidden_states = self.mlp(hidden_states) 805 | hidden_states = residual + hidden_states 806 | hidden_states = self.post_feedforward_layernorm(hidden_states) 807 | 808 | outputs = (hidden_states,) 809 | 810 | if output_attentions: 811 | outputs += (self_attn_weights,) 812 | 813 | if use_cache: 814 | outputs += (present_key_value,) 815 | 816 | return outputs 817 | 818 | 819 | LLAMA_START_DOCSTRING = r""" 820 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 821 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 822 | etc.) 823 | 824 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 825 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 826 | and behavior. 827 | 828 | Parameters: 829 | config ([`LlamaConfig`]): 830 | Model configuration class with all the parameters of the model. Initializing with a config file does not 831 | load the weights associated with the model, only the configuration. Check out the 832 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 833 | """ 834 | 835 | 836 | @add_start_docstrings( 837 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 838 | LLAMA_START_DOCSTRING, 839 | ) 840 | class LlamaPreTrainedModel(PreTrainedModel): 841 | config_class = LlamaConfig 842 | base_model_prefix = "model" 843 | supports_gradient_checkpointing = True 844 | _no_split_modules = ["LlamaDecoderLayer"] 845 | _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] 846 | 847 | # def _init_weights(self, module): 848 | # std = self.config.initializer_range 849 | # if isinstance(module, nn.Linear): 850 | # module.weight.data.normal_(mean=0.0, std=std) 851 | # if module.bias is not None: 852 | # module.bias.data.zero_() 853 | # elif isinstance(module, nn.Embedding): 854 | # module.weight.data.normal_(mean=0.0, std=std) 855 | # if module.padding_idx is not None: 856 | # module.weight.data[module.padding_idx].zero_() 857 | def _init_weights(self, module): 858 | std = self.config.initializer_range 859 | num_layers = self.config.num_hidden_layers 860 | scaled_std = std / (2 * num_layers) ** 0.5 861 | if isinstance(module, nn.Linear): 862 | if hasattr(module, "is_deeppost_layer") and module.is_deeppost_layer: 863 | torch.nn.init.xavier_normal_(module.weight, gain=(num_layers * 8) ** 0.25) 864 | elif hasattr(module, "is_deeppost_layer_qk") and module.is_deeppost_layer_qk: 865 | torch.nn.init.xavier_normal_(module.weight, gain=1) 866 | elif hasattr(module, "is_scaled_layer") and module.is_scaled_layer: 867 | module.weight.data.normal_(mean=0.0, std=scaled_std) 868 | print('-'*50) 869 | print('Warning: scaled init for layer:', module) 870 | print('-'*50) 871 | else: 872 | module.weight.data.normal_(mean=0.0, std=std) 873 | if module.bias is not None: 874 | module.bias.data.zero_() 875 | elif isinstance(module, nn.Embedding): 876 | module.weight.data.normal_(mean=0.0, std=(2 / 5) ** 0.5) 877 | if module.padding_idx is not None: 878 | module.weight.data[module.padding_idx].zero_() 879 | 880 | def _set_gradient_checkpointing(self, module, value=False): 881 | if isinstance(module, LlamaModel): 882 | module.gradient_checkpointing = value 883 | 884 | 885 | LLAMA_INPUTS_DOCSTRING = r""" 886 | Args: 887 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 888 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 889 | it. 890 | 891 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 892 | [`PreTrainedTokenizer.__call__`] for details. 893 | 894 | [What are input IDs?](../glossary#input-ids) 895 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 896 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 897 | 898 | - 1 for tokens that are **not masked**, 899 | - 0 for tokens that are **masked**. 900 | 901 | [What are attention masks?](../glossary#attention-mask) 902 | 903 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 904 | [`PreTrainedTokenizer.__call__`] for details. 905 | 906 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 907 | `past_key_values`). 908 | 909 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 910 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 911 | information on the default strategy. 912 | 913 | - 1 indicates the head is **not masked**, 914 | - 0 indicates the head is **masked**. 915 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 916 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 917 | config.n_positions - 1]`. 918 | 919 | [What are position IDs?](../glossary#position-ids) 920 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 921 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 922 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 923 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 924 | 925 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 926 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 927 | 928 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 929 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 930 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 931 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 932 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 933 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 934 | model's internal embedding lookup matrix. 935 | use_cache (`bool`, *optional*): 936 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 937 | `past_key_values`). 938 | output_attentions (`bool`, *optional*): 939 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 940 | tensors for more detail. 941 | output_hidden_states (`bool`, *optional*): 942 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 943 | more detail. 944 | return_dict (`bool`, *optional*): 945 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 946 | """ 947 | 948 | 949 | @add_start_docstrings( 950 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 951 | LLAMA_START_DOCSTRING, 952 | ) 953 | class LlamaModel(LlamaPreTrainedModel): 954 | """ 955 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] 956 | 957 | Args: 958 | config: LlamaConfig 959 | """ 960 | 961 | def __init__(self, config: LlamaConfig): 962 | super().__init__(config) 963 | self.padding_idx = config.pad_token_id 964 | self.vocab_size = config.vocab_size 965 | 966 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 967 | self.layers = nn.ModuleList([LlamaDecoderLayer(config, _idx) for _idx in range(config.num_hidden_layers)]) 968 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 969 | 970 | self.gradient_checkpointing = False 971 | # Initialize weights and apply final processing 972 | self.post_init() 973 | 974 | def get_input_embeddings(self): 975 | return self.embed_tokens 976 | 977 | def set_input_embeddings(self, value): 978 | self.embed_tokens = value 979 | 980 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 981 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 982 | # create causal mask 983 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 984 | combined_attention_mask = None 985 | if input_shape[-1] > 1: 986 | combined_attention_mask = _make_causal_mask( 987 | input_shape, 988 | inputs_embeds.dtype, 989 | device=inputs_embeds.device, 990 | past_key_values_length=past_key_values_length, 991 | ) 992 | 993 | if attention_mask is not None: 994 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 995 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 996 | inputs_embeds.device 997 | ) 998 | combined_attention_mask = ( 999 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 1000 | ) 1001 | 1002 | return combined_attention_mask 1003 | 1004 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 1005 | def forward( 1006 | self, 1007 | input_ids: torch.LongTensor = None, 1008 | attention_mask: Optional[torch.Tensor] = None, 1009 | position_ids: Optional[torch.LongTensor] = None, 1010 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1011 | inputs_embeds: Optional[torch.FloatTensor] = None, 1012 | use_cache: Optional[bool] = None, 1013 | output_attentions: Optional[bool] = None, 1014 | output_hidden_states: Optional[bool] = None, 1015 | return_dict: Optional[bool] = None, 1016 | ) -> Union[Tuple, BaseModelOutputWithPast]: 1017 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1018 | output_hidden_states = ( 1019 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1020 | ) 1021 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1022 | 1023 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1024 | 1025 | # retrieve input_ids and inputs_embeds 1026 | if input_ids is not None and inputs_embeds is not None: 1027 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 1028 | elif input_ids is not None: 1029 | batch_size, seq_length = input_ids.shape 1030 | elif inputs_embeds is not None: 1031 | batch_size, seq_length, _ = inputs_embeds.shape 1032 | else: 1033 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 1034 | 1035 | seq_length_with_past = seq_length 1036 | past_key_values_length = 0 1037 | 1038 | if past_key_values is not None: 1039 | past_key_values_length = past_key_values[0][0].shape[2] 1040 | seq_length_with_past = seq_length_with_past + past_key_values_length 1041 | 1042 | if position_ids is None: 1043 | device = input_ids.device if input_ids is not None else inputs_embeds.device 1044 | position_ids = torch.arange( 1045 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 1046 | ) 1047 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 1048 | else: 1049 | position_ids = position_ids.view(-1, seq_length).long() 1050 | 1051 | if inputs_embeds is None: 1052 | inputs_embeds = self.embed_tokens(input_ids) 1053 | # embed positions 1054 | if attention_mask is None: 1055 | attention_mask = torch.ones( 1056 | (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device 1057 | ) 1058 | attention_mask = self._prepare_decoder_attention_mask( 1059 | attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length 1060 | ) 1061 | 1062 | hidden_states = inputs_embeds 1063 | 1064 | if self.gradient_checkpointing and self.training: 1065 | if use_cache: 1066 | logger.warning_once( 1067 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 1068 | ) 1069 | use_cache = False 1070 | 1071 | # decoder layers 1072 | all_hidden_states = () if output_hidden_states else None 1073 | all_self_attns = () if output_attentions else None 1074 | next_decoder_cache = () if use_cache else None 1075 | 1076 | for idx, decoder_layer in enumerate(self.layers): 1077 | if output_hidden_states: 1078 | all_hidden_states += (hidden_states,) 1079 | 1080 | past_key_value = past_key_values[idx] if past_key_values is not None else None 1081 | 1082 | if self.gradient_checkpointing and self.training: 1083 | 1084 | def create_custom_forward(module): 1085 | def custom_forward(*inputs): 1086 | # None for past_key_value 1087 | return module(*inputs, output_attentions, None) 1088 | 1089 | return custom_forward 1090 | 1091 | layer_outputs = torch.utils.checkpoint.checkpoint( 1092 | create_custom_forward(decoder_layer), 1093 | hidden_states, 1094 | attention_mask, 1095 | position_ids, 1096 | None, 1097 | ) 1098 | else: 1099 | layer_outputs = decoder_layer( 1100 | hidden_states, 1101 | attention_mask=attention_mask, 1102 | position_ids=position_ids, 1103 | past_key_value=past_key_value, 1104 | output_attentions=output_attentions, 1105 | use_cache=use_cache, 1106 | ) 1107 | 1108 | hidden_states = layer_outputs[0] 1109 | 1110 | if use_cache: 1111 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 1112 | 1113 | if output_attentions: 1114 | all_self_attns += (layer_outputs[1],) 1115 | 1116 | hidden_states = self.norm(hidden_states) 1117 | 1118 | # add hidden states from the last decoder layer 1119 | if output_hidden_states: 1120 | all_hidden_states += (hidden_states,) 1121 | 1122 | next_cache = next_decoder_cache if use_cache else None 1123 | if not return_dict: 1124 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 1125 | return BaseModelOutputWithPast( 1126 | last_hidden_state=hidden_states, 1127 | past_key_values=next_cache, 1128 | hidden_states=all_hidden_states, 1129 | attentions=all_self_attns, 1130 | ) 1131 | 1132 | 1133 | class LlamaForCausalLM(LlamaPreTrainedModel): 1134 | def __init__(self, config): 1135 | super().__init__(config) 1136 | self.model = LlamaModel(config) 1137 | 1138 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1139 | 1140 | # Initialize weights and apply final processing 1141 | self.post_init() 1142 | 1143 | def get_input_embeddings(self): 1144 | return self.model.embed_tokens 1145 | 1146 | def set_input_embeddings(self, value): 1147 | self.model.embed_tokens = value 1148 | 1149 | def get_output_embeddings(self): 1150 | return self.lm_head 1151 | 1152 | def set_output_embeddings(self, new_embeddings): 1153 | self.lm_head = new_embeddings 1154 | 1155 | def set_decoder(self, decoder): 1156 | self.model = decoder 1157 | 1158 | def get_decoder(self): 1159 | return self.model 1160 | 1161 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 1162 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 1163 | def forward( 1164 | self, 1165 | input_ids: torch.LongTensor = None, 1166 | attention_mask: Optional[torch.Tensor] = None, 1167 | position_ids: Optional[torch.LongTensor] = None, 1168 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1169 | inputs_embeds: Optional[torch.FloatTensor] = None, 1170 | labels: Optional[torch.LongTensor] = None, 1171 | use_cache: Optional[bool] = None, 1172 | output_attentions: Optional[bool] = None, 1173 | output_hidden_states: Optional[bool] = None, 1174 | return_dict: Optional[bool] = None, 1175 | ) -> Union[Tuple, CausalLMOutputWithPast]: 1176 | r""" 1177 | Args: 1178 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1179 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1180 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1181 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1182 | 1183 | Returns: 1184 | 1185 | Example: 1186 | 1187 | ```python 1188 | >>> from transformers import AutoTokenizer, LlamaForCausalLM 1189 | 1190 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 1191 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 1192 | 1193 | >>> prompt = "Hey, are you consciours? Can you talk to me?" 1194 | >>> inputs = tokenizer(prompt, return_tensors="pt") 1195 | 1196 | >>> # Generate 1197 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1198 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1199 | "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." 1200 | ```""" 1201 | 1202 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1203 | output_hidden_states = ( 1204 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1205 | ) 1206 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1207 | 1208 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1209 | outputs = self.model( 1210 | input_ids=input_ids, 1211 | attention_mask=attention_mask, 1212 | position_ids=position_ids, 1213 | past_key_values=past_key_values, 1214 | inputs_embeds=inputs_embeds, 1215 | use_cache=use_cache, 1216 | output_attentions=output_attentions, 1217 | output_hidden_states=output_hidden_states, 1218 | return_dict=return_dict, 1219 | ) 1220 | 1221 | hidden_states = outputs[0] 1222 | logits = self.lm_head(hidden_states) 1223 | 1224 | loss = None 1225 | if labels is not None: 1226 | # NOTE: big optimization could be done here (?) 1227 | # maybe the copy operation that you saw in the debugger was happening here 1228 | 1229 | # Shift so that tokens < n predict n 1230 | shift_logits = logits[..., :-1, :].contiguous() 1231 | shift_labels = labels[..., 1:].contiguous() 1232 | # Flatten the tokens 1233 | loss_fct = CrossEntropyLoss() 1234 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1235 | shift_labels = shift_labels.view(-1) 1236 | # Enable model parallelism 1237 | shift_labels = shift_labels.to(shift_logits.device) 1238 | loss = loss_fct(shift_logits, shift_labels) 1239 | 1240 | if not return_dict: 1241 | output = (logits,) + outputs[1:] 1242 | return (loss,) + output if loss is not None else output 1243 | 1244 | return CausalLMOutputWithPast( 1245 | loss=loss, 1246 | logits=logits, 1247 | past_key_values=outputs.past_key_values, 1248 | hidden_states=outputs.hidden_states, 1249 | attentions=outputs.attentions, 1250 | ) 1251 | 1252 | def prepare_inputs_for_generation( 1253 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 1254 | ): 1255 | if past_key_values: 1256 | input_ids = input_ids[:, -1:] 1257 | 1258 | position_ids = kwargs.get("position_ids", None) 1259 | if attention_mask is not None and position_ids is None: 1260 | # create position_ids on the fly for batch generation 1261 | position_ids = attention_mask.long().cumsum(-1) - 1 1262 | position_ids.masked_fill_(attention_mask == 0, 1) 1263 | if past_key_values: 1264 | position_ids = position_ids[:, -1].unsqueeze(-1) 1265 | 1266 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1267 | if inputs_embeds is not None and past_key_values is None: 1268 | model_inputs = {"inputs_embeds": inputs_embeds} 1269 | else: 1270 | model_inputs = {"input_ids": input_ids} 1271 | 1272 | model_inputs.update( 1273 | { 1274 | "position_ids": position_ids, 1275 | "past_key_values": past_key_values, 1276 | "use_cache": kwargs.get("use_cache"), 1277 | "attention_mask": attention_mask, 1278 | } 1279 | ) 1280 | return model_inputs 1281 | 1282 | @staticmethod 1283 | def _reorder_cache(past_key_values, beam_idx): 1284 | reordered_past = () 1285 | for layer_past in past_key_values: 1286 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 1287 | return reordered_past 1288 | 1289 | 1290 | @add_start_docstrings( 1291 | """ 1292 | The LLaMa Model transformer with a sequence classification head on top (linear layer). 1293 | 1294 | [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1295 | (e.g. GPT-2) do. 1296 | 1297 | Since it does classification on the last token, it requires to know the position of the last token. If a 1298 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1299 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1300 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1301 | each row of the batch). 1302 | """, 1303 | LLAMA_START_DOCSTRING, 1304 | ) 1305 | class LlamaForSequenceClassification(LlamaPreTrainedModel): 1306 | _keys_to_ignore_on_load_missing = [r"lm_head.weight"] 1307 | 1308 | def __init__(self, config): 1309 | super().__init__(config) 1310 | self.num_labels = config.num_labels 1311 | self.model = LlamaModel(config) 1312 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 1313 | 1314 | # Initialize weights and apply final processing 1315 | self.post_init() 1316 | 1317 | def get_input_embeddings(self): 1318 | return self.model.embed_tokens 1319 | 1320 | def set_input_embeddings(self, value): 1321 | self.model.embed_tokens = value 1322 | 1323 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 1324 | def forward( 1325 | self, 1326 | input_ids: torch.LongTensor = None, 1327 | attention_mask: Optional[torch.Tensor] = None, 1328 | position_ids: Optional[torch.LongTensor] = None, 1329 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1330 | inputs_embeds: Optional[torch.FloatTensor] = None, 1331 | labels: Optional[torch.LongTensor] = None, 1332 | use_cache: Optional[bool] = None, 1333 | output_attentions: Optional[bool] = None, 1334 | output_hidden_states: Optional[bool] = None, 1335 | return_dict: Optional[bool] = None, 1336 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1337 | r""" 1338 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1339 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1340 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1341 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1342 | """ 1343 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1344 | 1345 | transformer_outputs = self.model( 1346 | input_ids, 1347 | attention_mask=attention_mask, 1348 | position_ids=position_ids, 1349 | past_key_values=past_key_values, 1350 | inputs_embeds=inputs_embeds, 1351 | use_cache=use_cache, 1352 | output_attentions=output_attentions, 1353 | output_hidden_states=output_hidden_states, 1354 | return_dict=return_dict, 1355 | ) 1356 | hidden_states = transformer_outputs[0] 1357 | logits = self.score(hidden_states) 1358 | 1359 | if input_ids is not None: 1360 | batch_size = input_ids.shape[0] 1361 | else: 1362 | batch_size = inputs_embeds.shape[0] 1363 | 1364 | if self.config.pad_token_id is None and batch_size != 1: 1365 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 1366 | if self.config.pad_token_id is None: 1367 | sequence_lengths = -1 1368 | else: 1369 | if input_ids is not None: 1370 | sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) 1371 | else: 1372 | sequence_lengths = -1 1373 | 1374 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1375 | 1376 | loss = None 1377 | if labels is not None: 1378 | labels = labels.to(logits.device) 1379 | if self.config.problem_type is None: 1380 | if self.num_labels == 1: 1381 | self.config.problem_type = "regression" 1382 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1383 | self.config.problem_type = "single_label_classification" 1384 | else: 1385 | self.config.problem_type = "multi_label_classification" 1386 | 1387 | if self.config.problem_type == "regression": 1388 | loss_fct = MSELoss() 1389 | if self.num_labels == 1: 1390 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1391 | else: 1392 | loss = loss_fct(pooled_logits, labels) 1393 | elif self.config.problem_type == "single_label_classification": 1394 | loss_fct = CrossEntropyLoss() 1395 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1396 | elif self.config.problem_type == "multi_label_classification": 1397 | loss_fct = BCEWithLogitsLoss() 1398 | loss = loss_fct(pooled_logits, labels) 1399 | if not return_dict: 1400 | output = (pooled_logits,) + transformer_outputs[1:] 1401 | return ((loss,) + output) if loss is not None else output 1402 | 1403 | return SequenceClassifierOutputWithPast( 1404 | loss=loss, 1405 | logits=pooled_logits, 1406 | past_key_values=transformer_outputs.past_key_values, 1407 | hidden_states=transformer_outputs.hidden_states, 1408 | attentions=transformer_outputs.attentions, 1409 | ) 1410 | -------------------------------------------------------------------------------- /peft_pretraining/training_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | from torch.optim.lr_scheduler import LambdaLR 6 | import transformers 7 | 8 | 9 | def get_scheculer( 10 | optimizer, 11 | *, 12 | scheduler_type, 13 | num_training_steps, 14 | warmup_steps, 15 | min_lr_ratio, 16 | cycle_length=None, 17 | restart_warmup_steps=None, 18 | adjust_step=0, 19 | last_epoch=-1, 20 | ): 21 | if adjust_step != 0 and scheduler_type != "cosine_restarts": 22 | raise ValueError("adjust_step is only supported for cosine_restarts scheduler") 23 | 24 | if scheduler_type == "linear": 25 | return transformers.get_linear_schedule_with_warmup( 26 | optimizer, 27 | num_warmup_steps=warmup_steps, 28 | num_training_steps=num_training_steps, 29 | last_epoch=last_epoch, 30 | ) 31 | if scheduler_type == "cosine": 32 | return get_cyclical_cosine_schedule_with_min_lr( 33 | optimizer, 34 | num_warmup_steps=warmup_steps, 35 | num_training_steps=num_training_steps, 36 | cycle_length=cycle_length, 37 | min_lr_ratio=min_lr_ratio, 38 | last_epoch=last_epoch, 39 | ) 40 | if scheduler_type == "cosine_restarts": 41 | assert restart_warmup_steps is not None, "restart_warmup_steps must be specified for cosine_restarts scheduler" 42 | return get_cosine_schedule_with_multiple_warmups( 43 | optimizer, 44 | num_training_steps=num_training_steps, 45 | first_warmup_steps=warmup_steps, 46 | restart_warmup_steps=restart_warmup_steps, 47 | restart_every=cycle_length, 48 | min_lr_ratio=min_lr_ratio, 49 | last_epoch=last_epoch, 50 | adjust_step=adjust_step, 51 | ) 52 | 53 | raise NotImplementedError(f"Scheduler {scheduler_type} is not implemented") 54 | 55 | 56 | def get_cyclical_cosine_schedule_with_min_lr(optimizer, num_warmup_steps, num_training_steps, cycle_length, min_lr_ratio=0.1, last_epoch=-1): 57 | assert cycle_length is not None or num_training_steps is not None, "You must specify either cycle_length or num_training_steps" 58 | 59 | if cycle_length is None: 60 | cycle_length = num_training_steps 61 | 62 | if num_training_steps % cycle_length != 0: 63 | raise ValueError(f"num_training_steps ({num_training_steps}) must be divisible by cycle_length ({cycle_length})") 64 | 65 | lr_lambda = partial( 66 | _get_cyclical_cosine_schedule_with_min_lr_lambda, 67 | num_warmup_steps=num_warmup_steps, 68 | cycle_length=cycle_length, 69 | min_lr_ratio=min_lr_ratio, 70 | ) 71 | return LambdaLR(optimizer, lr_lambda, last_epoch) 72 | 73 | 74 | def get_cosine_schedule_with_multiple_warmups( 75 | optimizer, 76 | *, 77 | num_training_steps, 78 | first_warmup_steps, 79 | restart_warmup_steps, 80 | restart_every, 81 | min_lr_ratio=0.1, 82 | adjust_step=0, 83 | last_epoch=-1, 84 | ): 85 | if restart_every is None: 86 | raise ValueError("restart_every must be specified for cosine_restarts scheduler") 87 | 88 | if num_training_steps % restart_every != 0: 89 | raise ValueError(f"num_training_steps ({num_training_steps}) must be divisible by restart_every ({restart_every})") 90 | 91 | lr_lambda = partial( 92 | _get_cosine_schedule_with_multiple_warmups_lambda, 93 | num_training_steps=num_training_steps, 94 | first_warmup_steps=first_warmup_steps, 95 | restart_warmup_steps=restart_warmup_steps, 96 | restart_every=restart_every, 97 | min_lr_ratio=min_lr_ratio, 98 | adjust_step=adjust_step, 99 | ) 100 | return LambdaLR(optimizer, lr_lambda, last_epoch) 101 | 102 | 103 | @torch.no_grad() 104 | def random_pruning(tensor, prune_ratio): 105 | """ 106 | Performs random pruning dimensionality reduction. 107 | Only reduces the inner dimensionality, does not affect the shape of the tensor 108 | """ 109 | random_pruning_mask = torch.rand_like(tensor) > prune_ratio 110 | tensor = tensor * random_pruning_mask 111 | return tensor 112 | 113 | 114 | @torch.no_grad() 115 | def magnitude_pruning(tensor, prune_ratio): 116 | """ 117 | Performs magnitude pruning dimensionality reduction. 118 | Only reduces the inner dimensionality, does not affect the shape of the tensor 119 | """ 120 | tensor_magnitude = torch.abs(tensor) 121 | threshold = torch.quantile(tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio).to(dtype=tensor.dtype) 122 | 123 | mask = tensor_magnitude > threshold 124 | tensor = tensor * mask.to(dtype=tensor.dtype) 125 | return tensor 126 | 127 | 128 | def _get_cyclical_cosine_schedule_with_min_lr_lambda(current_step, *, num_warmup_steps, cycle_length, min_lr_ratio): 129 | assert 0 < min_lr_ratio <= 1.0, "min_lr_ratio must be in (0,1]" 130 | 131 | # compute where we are in the current cycle 132 | cycle_step = current_step % cycle_length 133 | 134 | if cycle_step < num_warmup_steps: 135 | if current_step != cycle_step: 136 | if cycle_step < 2: 137 | return 1e-7 138 | return float(cycle_step) / float(max(1, num_warmup_steps)) 139 | 140 | progress = float(cycle_step - num_warmup_steps) / float(max(1, cycle_length - num_warmup_steps)) 141 | cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) 142 | 143 | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay 144 | 145 | 146 | def _get_cosine_schedule_with_multiple_warmups_lambda( 147 | current_step, 148 | *, 149 | num_training_steps, 150 | first_warmup_steps, 151 | restart_warmup_steps, 152 | restart_every, 153 | min_lr_ratio, 154 | adjust_step, 155 | ): 156 | """ 157 | Args: 158 | adjust_step: useful when continuing training from a warmed up checkpoint, 159 | it allows to sync the resets by reducing the number of steps 160 | after the first warmup and before the first reset. 161 | Thus, your ReLoRA resets can be synced with the optimizer resets. 162 | """ 163 | assert 0 < min_lr_ratio <= 1.0, "min_lr_ratio must be in (0,1]" 164 | assert restart_every > 0, "restart_every must be positive" 165 | assert adjust_step + first_warmup_steps < num_training_steps, "warmup + adjust_step is more than full training steps" 166 | assert adjust_step + first_warmup_steps < restart_every, "the first reset will happen before the warmup is done" 167 | 168 | if current_step < first_warmup_steps: 169 | return float(current_step) / float(max(1, first_warmup_steps)) 170 | 171 | _current_step = current_step + adjust_step 172 | 173 | restart_step = _current_step % restart_every 174 | restart_number = _current_step // restart_every 175 | 176 | if restart_step < restart_warmup_steps: 177 | # get expected lr multipler at the end of the warmup 178 | end_of_warmup_progress = ( 179 | float(restart_number * restart_every) / 180 | float(max(1, num_training_steps - first_warmup_steps)) 181 | ) 182 | 183 | _cosine_decay = 0.5 * (1.0 + math.cos(math.pi * end_of_warmup_progress)) 184 | warmup_lr_multiplier = min_lr_ratio + (1.0 - min_lr_ratio) * _cosine_decay 185 | 186 | return float(restart_step) / float(max(1, restart_warmup_steps)) * warmup_lr_multiplier 187 | 188 | progress = float(_current_step - first_warmup_steps) / float(max(1, num_training_steps - first_warmup_steps)) 189 | cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) 190 | 191 | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay 192 | 193 | 194 | def collate_fn(batch_list): 195 | batch = { 196 | "input_ids": torch.stack([torch.Tensor(example["input_ids"]).long() for example in batch_list]), 197 | "attention_mask": torch.stack([torch.Tensor(example["attention_mask"]).long() for example in batch_list]), 198 | } 199 | return batch 200 | 201 | 202 | def batch_fn(dataset, batch_size): 203 | batch = [] 204 | for example in dataset: 205 | batch.append(example) 206 | if len(batch) == batch_size: 207 | batch = collate_fn(batch) 208 | yield batch 209 | batch = [] 210 | if len(batch) > 0: 211 | yield batch 212 | 213 | 214 | def max_train_tokens_to_number(max_train_tokens): 215 | if max_train_tokens.endswith("M"): 216 | return int(max_train_tokens.rstrip("M")) * 1_000_000 217 | elif max_train_tokens.endswith("B"): 218 | return int(max_train_tokens.rstrip("B")) * 1_000_000_000 219 | else: 220 | return int(max_train_tokens) 221 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | bitsandbytes -------------------------------------------------------------------------------- /run_130m.sh: -------------------------------------------------------------------------------- 1 | # Define the set of learning rates and normalization types 2 | norm_type=$1 3 | learning_rates=1e-3 4 | export NORM_TYPE=$norm_type 5 | export POST_NUM=$2 6 | 7 | # Function to run a single training task 8 | 9 | echo "Training with learning rate: $learning_rates, norm type: $norm_type on GPU $gpu" 10 | 11 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node 1 --master_port=29510 torchrun_main.py \ 12 | --model_config configs/llama_130m.json \ 13 | --lr $learning_rates \ 14 | --batch_size 64 \ 15 | --total_batch_size 512 \ 16 | --num_training_steps 20000 \ 17 | --warmup_steps 2000 \ 18 | --weight_decay 0 \ 19 | --dtype bfloat16 \ 20 | --eval_every 1000 \ 21 | --optimizer adam \ 22 | --grad_clipping 0.0 \ 23 | --run_name "130m_res_${norm_type}_lr${learning_rates}_layer_scale" \ 24 | --save_dir "130m_res_${norm_type}_lr${learning_rates}" -------------------------------------------------------------------------------- /run_1b.sh: -------------------------------------------------------------------------------- 1 | # Define the set of learning rates and normalization types 2 | norm_type=$1 3 | learning_rates=5e-4 4 | export NORM_TYPE=$norm_type 5 | export POST_NUM=$2 6 | 7 | # Function to run a single training task 8 | 9 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node 8 --master_port=29500 torchrun_main.py \ 10 | --model_config configs/llama_1b.json \ 11 | --lr $learning_rates \ 12 | --batch_size 32 \ 13 | --total_batch_size 512 \ 14 | --num_training_steps 100000 \ 15 | --warmup_steps 1000 \ 16 | --weight_decay 0 \ 17 | --dtype bfloat16 \ 18 | --eval_every 1000 \ 19 | --optimizer adam \ 20 | --grad_clipping 0.0 \ 21 | --run_name "1b_res_${norm_type}_lr${learning_rates}" \ 22 | --save_dir "1b_res_${norm_type}_lr${learning_rates}" -------------------------------------------------------------------------------- /run_250m.sh: -------------------------------------------------------------------------------- 1 | # Define the set of learning rates and normalization types 2 | norm_type=$1 3 | learning_rates=1e-3 4 | export NORM_TYPE=$norm_type 5 | export POST_NUM=$2 6 | 7 | # Function to run a single training task 8 | 9 | echo "Training with learning rate: $learning_rates, norm type: $norm_type on GPU $gpu" 10 | 11 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 --master_port=29500 torchrun_main.py \ 12 | --model_config configs/llama_250m.json \ 13 | --lr $learning_rates \ 14 | --batch_size 128 \ 15 | --total_batch_size 512 \ 16 | --num_training_steps 40000 \ 17 | --warmup_steps 4000 \ 18 | --weight_decay 0 \ 19 | --dtype bfloat16 \ 20 | --eval_every 1000 \ 21 | --optimizer adam \ 22 | --grad_clipping 0.0 \ 23 | --run_name "250m_res_${norm_type}_lr${learning_rates}" \ 24 | --save_dir "250m_res_${norm_type}_lr${learning_rates}" -------------------------------------------------------------------------------- /run_350m.sh: -------------------------------------------------------------------------------- 1 | # Define the set of learning rates and normalization types 2 | norm_type=$1 3 | learning_rates=5e-4 4 | export NORM_TYPE=$norm_type 5 | export POST_NUM=$2 6 | 7 | # Function to run a single training task 8 | 9 | echo "Training with learning rate: $learning_rates, norm type: $norm_type on GPU $gpu" 10 | 11 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 --master_port=29503 torchrun_main.py \ 12 | --model_config configs/llama_350m.json \ 13 | --lr $learning_rates \ 14 | --batch_size 128 \ 15 | --total_batch_size 512 \ 16 | --num_training_steps 60000 \ 17 | --warmup_steps 6000 \ 18 | --weight_decay 0 \ 19 | --dtype bfloat16 \ 20 | --eval_every 1000 \ 21 | --optimizer adam \ 22 | --grad_clipping 0.0 \ 23 | --run_name "350m_res_${norm_type}_lr${learning_rates}" \ 24 | --save_dir "350m_res_${norm_type}_lr${learning_rates}" -------------------------------------------------------------------------------- /scaled_init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmsdss/LayerNorm-Scaling/da34a90d7392fe01a7dcdae6715ea6e6706bdd81/scaled_init.png -------------------------------------------------------------------------------- /scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmsdss/LayerNorm-Scaling/da34a90d7392fe01a7dcdae6715ea6e6706bdd81/scaling.png -------------------------------------------------------------------------------- /torchrun_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import random 5 | import argparse 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.data 11 | import torch.distributed as dist 12 | 13 | import transformers 14 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM 15 | from transformers import LlamaForCausalLM as HF_LlamaForCausalLM 16 | 17 | import datasets 18 | import datasets.distributed 19 | import wandb 20 | 21 | from tqdm import tqdm 22 | from loguru import logger 23 | 24 | from peft_pretraining import training_utils, args_utils 25 | from peft_pretraining.dataloader import PreprocessedIterableDataset 26 | from peft_pretraining.modeling_llama import LlamaForCausalLM 27 | 28 | import bitsandbytes as bnb 29 | 30 | import matplotlib.pyplot as plt 31 | transformers.logging.set_verbosity_error() 32 | 33 | def parse_args(args): 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument("--model_config", type=str, required=True) 37 | parser.add_argument("--use_hf_model", default=False, action="store_true") 38 | parser.add_argument("--continue_from", type=str, default=None) 39 | parser.add_argument("--batch_size", type=int, required=True) 40 | parser.add_argument("--gradient_accumulation", type=int, default=None) 41 | parser.add_argument("--total_batch_size", type=int, default=None) 42 | parser.add_argument("--max_length", type=int, default=256) 43 | parser.add_argument("--optimizer", default="Adam") 44 | parser.add_argument("--lr", type=float, default=1e-4) 45 | parser.add_argument("--scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_restarts"]) 46 | parser.add_argument("--min_lr_ratio", type=float, default=0.1) 47 | parser.add_argument("--activation_checkpointing", action="store_true") 48 | parser.add_argument("--weight_decay", type=float, default=0.0) 49 | parser.add_argument("--warmup_steps", type=int, default=1_000) 50 | parser.add_argument("--eval_every", type=int, default=2_000) 51 | parser.add_argument("--num_training_steps", type=int, default=10_000, 52 | help="Number of **update steps** to train for. " 53 | "Notice that gradient accumulation is taken into account.") 54 | parser.add_argument("--max_train_tokens", type=training_utils.max_train_tokens_to_number, default=None, 55 | help="Number of tokens to train on. Overwrites num_training_steps. " 56 | "You can use M and B suffixes, e.g. 100M or 1B.") 57 | parser.add_argument("--save_every", type=int, default=10000) 58 | parser.add_argument("--save_dir", type=str, default=None) 59 | parser.add_argument("--tags", type=str, default=None) 60 | parser.add_argument("--dtype", type=str, default="bfloat16" if torch.cuda.is_bf16_supported() else "float32") 61 | parser.add_argument("--workers", type=int, default=8) 62 | parser.add_argument("--seed", type=int, default=1) 63 | parser.add_argument("--name", type=str, default="test") 64 | parser.add_argument("--grad_clipping", type=float, default=1.0) 65 | parser.add_argument("--run_name", type=str, default="default") 66 | # beta1 for adafactor 67 | parser.add_argument("--beta1", type=float, default=0.0) 68 | 69 | # GaLore parameters 70 | parser.add_argument("--rank", type=int, default=128) 71 | parser.add_argument("--update_proj_gap", type=int, default=50) 72 | parser.add_argument("--galore_scale", type=float, default=1.0) 73 | parser.add_argument("--proj_type", type=str, default="std") 74 | 75 | # disable ddp, single_gpu 76 | parser.add_argument("--single_gpu", default=False, action="store_true") 77 | 78 | args = parser.parse_args(args) 79 | 80 | args = args_utils.check_args_torchrun_main(args) 81 | return args 82 | 83 | 84 | @torch.no_grad() 85 | def evaluate_model(model, preprocess_batched, pad_idx, global_rank, world_size, device, batch_size): 86 | _time = time.time() 87 | val_data = datasets.load_dataset("c4", "en", split="validation", streaming=True, trust_remote_code=True) #DGX 88 | val_data = val_data.shuffle(seed=42) 89 | logger.info(f"Loaded validation dataset in {time.time() - _time:.2f} seconds") 90 | 91 | if not args.single_gpu: 92 | val_data = datasets.distributed.split_dataset_by_node(val_data, rank=global_rank, world_size=world_size) 93 | 94 | val_data_mapped = val_data.map( 95 | preprocess_batched, 96 | batched=True, 97 | remove_columns=["text", "timestamp", "url"], 98 | ) 99 | val_data_mapped.batch = lambda batch_size: training_utils.batch_fn(val_data_mapped, batch_size) 100 | 101 | target_eval_tokens = 10_000_000 102 | evaluated_on_tokens = 0 103 | total_loss = torch.tensor(0.0).to(device) 104 | total_batches = 1 105 | logger.info(f"Eval set prepared in {time.time() - _time:.2f} seconds") 106 | 107 | for batch in val_data_mapped.batch(batch_size=batch_size): 108 | if evaluated_on_tokens > target_eval_tokens: 109 | break 110 | total_batches += 1 111 | 112 | batch = {k: v.to(device) for k, v in batch.items()} 113 | labels = batch["input_ids"].clone() 114 | labels[labels == pad_idx] = -100 115 | loss = model(**batch, labels=labels).loss 116 | total_loss += loss.detach() 117 | 118 | evaluated_on_tokens += (batch["input_ids"] != pad_idx).sum().item() * world_size 119 | 120 | total_loss = total_loss / total_batches 121 | 122 | # Gather losses across all GPUs 123 | gathered_losses = [torch.zeros_like(total_loss) for _ in range(world_size)] 124 | dist.all_gather(gathered_losses, total_loss) 125 | total_loss = sum([t.item() for t in gathered_losses]) / world_size 126 | 127 | return total_loss, evaluated_on_tokens 128 | 129 | 130 | def main(args): 131 | torch.manual_seed(args.seed) 132 | np.random.seed(args.seed) 133 | random.seed(args.seed) 134 | 135 | assert "LOCAL_RANK" in os.environ, "torchrun should set LOCAL_RANK" 136 | global_rank = int(os.environ['RANK']) 137 | local_rank = int(os.environ["LOCAL_RANK"]) 138 | world_size = int(os.environ["WORLD_SIZE"]) 139 | torch.cuda.set_device(local_rank) 140 | 141 | logger.info(f"Global rank {global_rank}, local rank {local_rank}, device: {torch.cuda.current_device()}") 142 | 143 | dist.init_process_group(backend="nccl", rank=global_rank, world_size=world_size) 144 | 145 | logger.info("Process group initialized") 146 | device = f"cuda:{local_rank}" 147 | 148 | if args.total_batch_size is not None: 149 | if args.gradient_accumulation is None: 150 | assert args.total_batch_size % world_size == 0, "total_batch_size must be divisible by world_size" 151 | args.gradient_accumulation = args.total_batch_size // (args.batch_size * world_size) 152 | assert args.gradient_accumulation > 0, "gradient_accumulation must be greater than 0" 153 | 154 | assert args.gradient_accumulation * args.batch_size * world_size == args.total_batch_size, \ 155 | "gradient_accumulation * batch_size * world_size must be equal to total_batch_size" 156 | 157 | # turn off logger 158 | if global_rank != 0: logger.remove() 159 | 160 | # initialize wandb without config (it is passed later) 161 | if global_rank == 0: 162 | wandb.init(project="cod", name=args.run_name) 163 | 164 | logger.info(f"Using dist with rank {global_rank} (only rank 0 will log)") 165 | logger.info("*" * 40) 166 | logger.info(f"Starting training with the arguments") 167 | for k, v in vars(args).items(): 168 | logger.info(f"{k:30} {v}") 169 | logger.info("*" * 40) 170 | 171 | data = datasets.load_dataset("allenai/c4", "en", split="train", streaming=True) 172 | 173 | 174 | 175 | seed_for_shuffle = 32 176 | 177 | logger.info(f"Shuffling data with seed {seed_for_shuffle}") 178 | data: datasets.Dataset = data.shuffle(seed=seed_for_shuffle) 179 | if not args.single_gpu: 180 | data = datasets.distributed.split_dataset_by_node( 181 | data, rank=global_rank, world_size=world_size, 182 | ) 183 | 184 | # it doesn't matter which tokenizer we use, because we train from scratch 185 | # T5 tokenizer was trained on C4 and we are also training on C4, so it's a good choice 186 | tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=args.max_length) 187 | 188 | def preprocess_batched(batch): 189 | batch = tokenizer( 190 | batch["text"], 191 | max_length=args.max_length, 192 | truncation=True, 193 | padding="max_length", 194 | return_tensors="pt", 195 | ) 196 | return batch 197 | 198 | dataset = PreprocessedIterableDataset(data, tokenizer, batch_size=args.batch_size, max_length=args.max_length) 199 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=args.workers) 200 | 201 | model_config = AutoConfig.from_pretrained(args.model_config) 202 | if args.use_hf_model: 203 | model: HF_LlamaForCausalLM = AutoModelForCausalLM.from_config(model_config) 204 | else: 205 | model = LlamaForCausalLM(model_config) 206 | 207 | if args.activation_checkpointing: 208 | model.gradient_checkpointing_enable() 209 | 210 | global_step = 0 211 | update_step = 0 212 | beginning_step = 0 213 | tokens_seen = 0 214 | tokens_seen_before = 0 215 | 216 | if args.continue_from is not None: 217 | logger.info("*" * 40) 218 | logger.info(f"Loading model from {args.continue_from}") 219 | 220 | from safetensors.torch import load_file 221 | state_dict = load_file(f"{args.continue_from}/model.safetensors") 222 | model.load_state_dict(state_dict) 223 | 224 | # checkpoint_path = os.path.join(args.continue_from, "pytorch_model.bin") 225 | # model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"), strict=True) 226 | logger.info(f"Model successfully loaded (strict=True policy)") 227 | 228 | if os.path.exists(os.path.join(args.continue_from, "training_state.json")): 229 | logger.info(f"Loading training state like global_step, update_step, and tokens_seen from {args.continue_from}") 230 | with open(os.path.join(args.continue_from, "training_state.json")) as f: 231 | _old_state = json.load(f) 232 | global_step = _old_state["global_step"] 233 | update_step = _old_state["update_step"] 234 | tokens_seen = _old_state["tokens_seen"] 235 | tokens_seen_before = _old_state["tokens_seen_before"] 236 | logger.info(f"global_step : {global_step}") 237 | logger.info(f"update_step : {update_step}") 238 | logger.info(f"tokens_seen : {tokens_seen}") 239 | logger.info(f"tokens_seen_before: {tokens_seen_before}") 240 | logger.info(f"Will train for {args.num_training_steps - update_step} update steps") 241 | else: 242 | logger.warning(f"Did not find training state in {args.continue_from}, global step will start from zero") 243 | logger.info("*" * 40) 244 | 245 | 246 | if args.dtype in ["bf16", "bfloat16"]: 247 | model = model.to(device=device, dtype=torch.bfloat16) 248 | else: 249 | model = model.to(device=device) 250 | 251 | n_total_params = sum(p.numel() for p in model.parameters()) 252 | trainable_params = [p for p in model.parameters() if p.requires_grad] 253 | # Initialize wandb 254 | run_config = dict(vars(args)) 255 | run_config.update({ 256 | "max_lr": run_config.pop("lr"), # rename lr to max_lr to avoid conflicts with scheduler 257 | "total_params_M": n_total_params / 1_000_000, 258 | "dataset": 'c4', 259 | "model": model_config.to_dict(), 260 | "world_size": world_size, 261 | "device": str(device), 262 | }) 263 | 264 | if global_rank == 0: 265 | wandb.config.update(run_config, allow_val_change=True) 266 | wandb.save(os.path.abspath(__file__), policy="now") # save current script 267 | # fix tqdm visual length to 80 so that the progress bar 268 | # doesn't jump around when changing from external display to laptop 269 | pbar = tqdm(total=args.num_training_steps - update_step, desc="Update steps", ncols=80) 270 | 271 | if 'galore' in args.optimizer.lower(): 272 | # make parameters with "rank" to a single group, if param_name has "mlp" or "attn" 273 | galore_params = [] 274 | target_modules_list = ["attn", "mlp"] 275 | for module_name, module in model.named_modules(): 276 | if not isinstance(module, nn.Linear): 277 | continue 278 | 279 | if not any(target_key in module_name for target_key in target_modules_list): 280 | continue 281 | 282 | print('enable GaLore for weights in module: ', module_name) 283 | galore_params.append(module.weight) 284 | id_galore_params = [id(p) for p in galore_params] 285 | # make parameters without "rank" to another group 286 | regular_params = [p for p in model.parameters() if id(p) not in id_galore_params] 287 | # then call galore_adamw 288 | param_groups = [{'params': regular_params}, 289 | {'params': galore_params, 'rank': args.rank, 'update_proj_gap': args.update_proj_gap, 'scale': args.galore_scale, 'proj_type': args.proj_type}] 290 | 291 | # print params and trainable params 292 | logger.info(f"\n{model}\n") 293 | logger.info(f"Total params: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M") 294 | logger.info(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000:.2f}M") 295 | if 'galore' in args.optimizer.lower(): 296 | logger.info(f"Total params with GaLore enabled: {sum(p.numel() for p in galore_params) / 1_000_000:.2f}M") 297 | logger.info(f"Saving model to {args.save_dir} every {args.save_every} update steps") 298 | 299 | layer_wise_flag = False 300 | if args.optimizer.lower() == "adam": 301 | optimizer = torch.optim.Adam(trainable_params, lr=args.lr, weight_decay=args.weight_decay) 302 | else: 303 | raise ValueError(f"Optimizer {args.optimizer} not supported") 304 | 305 | if not layer_wise_flag: 306 | scheduler = training_utils.get_scheculer( 307 | optimizer=optimizer, 308 | scheduler_type=args.scheduler, 309 | num_training_steps=args.num_training_steps, 310 | warmup_steps=args.warmup_steps, 311 | min_lr_ratio=args.min_lr_ratio, 312 | ) 313 | 314 | if not args.single_gpu: 315 | model: LlamaForCausalLM = torch.nn.parallel.DistributedDataParallel( 316 | model, 317 | device_ids=[local_rank], 318 | output_device=local_rank, 319 | broadcast_buffers=False, 320 | ) 321 | 322 | # global steps and others are defined above 323 | pad_idx = tokenizer.pad_token_id 324 | update_time = time.time() 325 | local_step = 0 # when continue_from is used, local_step != global_step 326 | 327 | # ############################## 328 | # TRAINING LOOP 329 | # we'll never go through all the data, so no need for epochs 330 | # ############################## 331 | 332 | for batch_idx, batch in enumerate(dataloader): 333 | 334 | global_step += 1 335 | local_step += 1 336 | if update_step > args.num_training_steps: 337 | logger.info(f"Reached max number of update steps (f{args.num_training_steps}). Stopping training.") 338 | print(f"Rank {global_rank} stopping training.") 339 | break 340 | 341 | batch = {k: v.to(device) for k, v in batch.items()} 342 | labels = batch["input_ids"].clone() 343 | labels[labels == pad_idx] = -100 344 | tokens_seen += (batch["input_ids"] != pad_idx).sum().item() * world_size 345 | 346 | loss = model(**batch, labels=labels).loss 347 | scaled_loss = loss / args.gradient_accumulation 348 | scaled_loss.backward() 349 | 350 | if global_step % args.gradient_accumulation != 0: 351 | continue 352 | 353 | # The below code is only executed during the update step 354 | 355 | # add grad clipping 356 | if args.grad_clipping != 0.0: torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clipping) 357 | 358 | if global_rank == 0: pbar.update(1) 359 | 360 | if not layer_wise_flag: 361 | optimizer.step() 362 | scheduler.step() 363 | optimizer.zero_grad() 364 | 365 | update_step += 1 366 | update_time = time.time() - update_time 367 | 368 | # save checkpoint by save_every 369 | if local_step > args.gradient_accumulation and update_step % args.save_every == 0 and global_rank == 0: 370 | current_model_directory = f"{args.save_dir}/model_{update_step}" 371 | logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}") 372 | tokenizer.save_pretrained(current_model_directory) 373 | os.makedirs(args.save_dir, exist_ok=True) 374 | model.module.save_pretrained(current_model_directory, max_shard_size='100GB') 375 | optimizer_checkpoint = { 376 | "optimizer": optimizer.state_dict(), 377 | "scheduler": scheduler.state_dict(), 378 | "update_step": update_step, 379 | "global_step": global_step, 380 | "config": run_config, 381 | "wandb": wandb.run.dir, 382 | "dtype": args.dtype, 383 | } 384 | torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt") 385 | 386 | training_state_checkpoint = { 387 | "global_step": global_step, 388 | "update_step": update_step, 389 | "tokens_seen": tokens_seen, 390 | "tokens_seen_before": tokens_seen_before, 391 | "update_time": update_time, 392 | } 393 | with open(f"{current_model_directory}/training_state.json", "w") as f: 394 | json.dump(training_state_checkpoint, f, indent=4) 395 | 396 | # save wandb related info 397 | wandb_info = { 398 | "wandb_id": wandb.run.id, 399 | } 400 | with open(f"{args.save_dir}/wandb.json", "w") as f: 401 | json.dump(wandb_info, f, indent=4) 402 | 403 | # evaluation 404 | if update_step % args.eval_every == 0: 405 | logger.info(f"Performing evaluation at step {update_step}") 406 | total_loss, evaluated_on_tokens = evaluate_model( 407 | model, preprocess_batched, pad_idx, global_rank, world_size, device, args.batch_size 408 | ) 409 | if global_rank == 0: 410 | wandb.log({ 411 | "final_eval_loss": total_loss, 412 | "final_eval_tokens": evaluated_on_tokens, 413 | }, 414 | step=global_step, 415 | ) 416 | logger.info(f"Eval loss at step {update_step}: {total_loss}") 417 | 418 | if not layer_wise_flag: 419 | lr = optimizer.param_groups[0]["lr"] 420 | else: 421 | pass 422 | tokens_in_update = tokens_seen - tokens_seen_before 423 | tokens_seen_before = tokens_seen 424 | batches_in_update = args.gradient_accumulation * world_size 425 | 426 | if global_rank == 0: 427 | wandb.log({ 428 | "loss": loss.item(), 429 | "lr": lr, 430 | "update_step": update_step, 431 | "tokens_seen": tokens_seen, 432 | "throughput_tokens": tokens_in_update / update_time, 433 | "throughput_examples": args.total_batch_size / update_time, 434 | "throughput_batches": batches_in_update / update_time, 435 | }, 436 | step=global_step, 437 | ) 438 | update_time = time.time() 439 | 440 | # ############################## 441 | # END of training loop 442 | # ############################## 443 | logger.info("Training finished") 444 | if global_rank == 0: pbar.close() 445 | 446 | current_model_directory = f"{args.save_dir}/model_{update_step}" 447 | if global_rank == 0 and not os.path.exists(current_model_directory): 448 | logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}") 449 | os.makedirs(args.save_dir, exist_ok=True) 450 | model.module.save_pretrained(current_model_directory) 451 | 452 | optimizer_checkpoint = { 453 | "optimizer": optimizer.state_dict(), 454 | "scheduler": scheduler.state_dict(), 455 | "update_step": update_step, 456 | "global_step": global_step, 457 | "config": run_config, 458 | "wandb": wandb.run.dir, 459 | "dtype": args.dtype, 460 | } 461 | torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt") 462 | 463 | training_state_checkpoint = { 464 | "global_step": global_step, 465 | "update_step": update_step, 466 | "tokens_seen": tokens_seen, 467 | "tokens_seen_before": tokens_seen_before, 468 | "update_time": update_time, 469 | } 470 | with open(f"{current_model_directory}/training_state.json", "w") as f: 471 | json.dump(training_state_checkpoint, f, indent=4) 472 | 473 | # Final evaluation 474 | logger.info("Running final evaluation") 475 | model.eval() 476 | del loss, optimizer, scheduler 477 | import gc; gc.collect() 478 | torch.cuda.empty_cache() 479 | 480 | total_loss, evaluated_on_tokens = evaluate_model( 481 | model, preprocess_batched, pad_idx, global_rank, world_size, device, args.batch_size 482 | ) 483 | 484 | if global_rank == 0: 485 | wandb.log({ 486 | "final_eval_loss": total_loss, 487 | "final_eval_tokens": evaluated_on_tokens, 488 | }, 489 | step=global_step, 490 | ) 491 | logger.info(f"Final eval loss: {total_loss}") 492 | 493 | logger.info("Script finished successfully") 494 | print(f"Rank {global_rank} finished successfully") 495 | 496 | 497 | if __name__ == "__main__": 498 | print("Starting script") 499 | args = parse_args(None) 500 | main(args) 501 | -------------------------------------------------------------------------------- /utils/angular_distance.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | from datasets import load_dataset 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | 7 | from short_hf import ShortHFModel 8 | 9 | def compute_angular_distance(model, data, max_seq_len=1024, stride=256, n_samples=50): 10 | angular_distances = [] 11 | for i, batch in enumerate(tqdm(data)): 12 | if i >= n_samples: 13 | break 14 | prompts = batch['text'] 15 | 16 | model.eval_importance( 17 | prompts=prompts, 18 | max_seq_len=max_seq_len, 19 | stride=stride, 20 | max_gen_len=0, 21 | angular=True 22 | ) 23 | angular_distances.extend(model.importances) 24 | return angular_distances 25 | 26 | def plot_angular_distances(distances, output_path): 27 | plt.figure(figsize=(10, 6)) 28 | plt.plot(distances, label='Angular Distances') 29 | plt.xlabel('Sample Index') 30 | plt.ylabel('Angular Distance') 31 | plt.title('Angular Distances Across Samples') 32 | plt.legend() 33 | plt.savefig(output_path) 34 | plt.show() 35 | 36 | def main(args): 37 | model = ShortHFModel( 38 | model_name=args.model_path, 39 | layers_path="model.layers", 40 | n_prune_layers=1, # this is a dummy value, don't worry about it 41 | ) 42 | 43 | data = load_dataset("allenai/c4", "en", split="train", streaming=True) 44 | 45 | angular_distances = compute_angular_distance(model, data, args.max_seq_len, args.stride, args.n_samples) 46 | 47 | plot_angular_distances(angular_distances, args.output_path) 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser(description="Compute and plot angular distances") 51 | parser.add_argument("--model_path", type=str, required=True, help="Path to the model") 52 | parser.add_argument("--n_samples", type=int, default=50, help="Number of samples to process") 53 | parser.add_argument("--max_seq_len", type=int, default=1024, help="Maximum sequence length") 54 | parser.add_argument("--stride", type=int, default=256, help="Stride for processing sequences") 55 | parser.add_argument("--output_path", type=str, default="angular_distances.png", help="Output path for the plot") 56 | 57 | args = parser.parse_args() 58 | 59 | main(args) -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def block_influence( 5 | input_hidden_state: torch.Tensor, 6 | output_hidden_state: torch.Tensor, 7 | angular=False, 8 | ): 9 | """ 10 | input_hidden_state: B, S, D 11 | output_hidden_state: B, S, D 12 | """ 13 | _, _, d = input_hidden_state.shape 14 | input_hidden_state = input_hidden_state.reshape(-1, d) 15 | output_hidden_state = output_hidden_state.reshape(-1, d) 16 | 17 | norm_input = input_hidden_state.norm(dim=-1, keepdim=True) 18 | norm_output = output_hidden_state.norm(dim=-1, keepdim=True) 19 | 20 | sim = (input_hidden_state @ output_hidden_state.T) / (norm_input * norm_output) 21 | sim = sim.diagonal().nan_to_num(nan=0.5) 22 | 23 | if angular: 24 | return (torch.arccos(sim) / torch.pi) 25 | 26 | return 1 - sim -------------------------------------------------------------------------------- /utils/short_hf.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from transformers import AutoTokenizer, AutoModelForCausalLM 7 | 8 | from metrics import * 9 | 10 | 11 | class ShortHFModel(): 12 | 13 | def __init__(self, model_name: str, layers_path: str, n_prune_layers: Optional[int] = None): 14 | """ 15 | HuggingFace Model Wrapper 16 | 17 | Args: 18 | model_name (str): HuggingFace model name 19 | layers_path (str): String in dot notation demonstrating how to access layers of the model. Ex: "model.layers" 20 | (Optional) n_prune_layers (int): Number of layers to prune. Defaults to None. 21 | """ 22 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 23 | self.tokenizer.pad_token = self.tokenizer.eos_token 24 | self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) 25 | # self.model.params = self.model.to_fp16(self.model.params) 26 | self.model.to("cuda") 27 | 28 | modules = layers_path.split(".") 29 | mod = self.model 30 | for m in modules: 31 | mod = getattr(mod, m) 32 | self.layers = mod 33 | 34 | self.n_prune_layers = n_prune_layers 35 | self.importances = [0 for _ in self.layers] # layer-wise importance scores 36 | 37 | def remove_layers( 38 | self, 39 | layers_to_remove: Optional[List[int]] = [], 40 | angular: Optional[bool] = False 41 | ): 42 | if angular: 43 | assert self.importances, "Need to compute importances with eval_importance()" 44 | assert self.n_prune_layers, "Need number of layers to prune, set `n_prune_layers`" 45 | start_layer = np.argsort(np.array(self.importances[:-self.n_prune_layers+1]))[0] 46 | layers_to_remove = list(range(start_layer, start_layer + self.n_prune_layers)) 47 | elif not layers_to_remove and self.n_prune_layers: 48 | assert self.importances, "Need to compute importances with eval_importance()" 49 | layers_to_remove = np.argsort(np.array(self.importances))[:self.n_prune_layers].tolist() 50 | 51 | # remove layers in reverse to avoid indexing errors 52 | for layer_idx in sorted(layers_to_remove, reverse=True): 53 | try: 54 | del self.layers[layer_idx] 55 | except IndexError: 56 | print(f"layer {layer_idx} does not exist, function may have already been called") 57 | return [] 58 | 59 | return layers_to_remove 60 | 61 | def compute_bi(self, hiddens: List[torch.Tensor], angular: bool): 62 | n = 1 63 | if angular: 64 | assert self.n_prune_layers is not None, "Set number of layers to prune to use angular importance" 65 | n = self.n_prune_layers 66 | 67 | for i in range(len(hiddens) - n): 68 | in_hidden = hiddens[i] 69 | out_hidden = hiddens[i+n] 70 | if angular: 71 | # use only last token for angular distance as described in section 3.2 72 | # https://arxiv.org/pdf/2403.17887.pdf 73 | in_hidden = in_hidden[:,-1:] 74 | out_hidden = out_hidden[:,-1:] 75 | 76 | self.importances[i] += block_influence( 77 | in_hidden, 78 | out_hidden, 79 | angular=angular 80 | ).sum().cpu().item() 81 | 82 | @torch.inference_mode() 83 | def eval_importance( 84 | self, 85 | prompts: List[str], 86 | max_seq_len: int, 87 | stride: int = 256, 88 | max_gen_len: int = 0, 89 | temperature: float = 0.6, 90 | top_p: float = 0.9, 91 | angular: Optional[bool] = False 92 | ): 93 | """ 94 | Computes layer-wise importances over input texts. 95 | 96 | NOTE: ShortGPT paper performs no generation during importance computation, which suggests a `max_gen_len`= 0. 97 | 98 | Args: 99 | prompts (List[str]): List of prompts. 100 | max_seq_len (int): Maximum sequence length for model input, the sliding window size. 101 | (Optional) stride (int): Number of tokens to skip/shift between each window inference. 102 | (Optional) max_gen_len (int): Maximum length of the generated text sequence. 103 | (Optional) temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6. 104 | (Optional) top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9. 105 | (Optional) angular (bool): Whether to ues angular distance. Defaults to False. 106 | 107 | Returns: 108 | None 109 | """ 110 | prompt_tokens = self.tokenizer( 111 | prompts, 112 | padding=True, 113 | return_attention_mask=True, 114 | return_tensors='pt' 115 | ) 116 | input_ids = prompt_tokens.input_ids 117 | attn_mask = prompt_tokens.attention_mask 118 | 119 | max_prompt_len = max(len(t) for t in input_ids) 120 | 121 | # authors use a sliding window of size 1024 with a shift of 256 122 | for start in range(0, max_prompt_len, stride): 123 | seq_ids = (attn_mask.sum(dim=-1) > start).nonzero().squeeze() 124 | seq_ids = seq_ids.unsqueeze(0) if seq_ids.dim() == 0 else seq_ids # ensure 2d 125 | inputs = input_ids[seq_ids, start:start+max_seq_len] 126 | attn = attn_mask[seq_ids, start:start+max_seq_len] 127 | 128 | if max_gen_len == 0: 129 | outputs = self.model( 130 | input_ids=inputs.to("cuda"), 131 | attention_mask=attn.to("cuda"), 132 | output_hidden_states=True, 133 | ) 134 | else: 135 | outputs = self.model.generate( 136 | input_ids=inputs.to("cuda"), 137 | attention_mask=attn.to("cuda"), 138 | max_new_tokens=max_gen_len, 139 | do_sample=True, 140 | temperature=temperature, 141 | top_p=top_p, 142 | output_hidden_states=True, 143 | return_dict_in_generate=True, 144 | ) 145 | 146 | self.compute_bi(outputs.hidden_states, angular=angular) 147 | 148 | return --------------------------------------------------------------------------------