├── README.md
├── configs
├── llama_100m.json
├── llama_130m.json
├── llama_1b.json
├── llama_20m.json
├── llama_250m.json
├── llama_350m.json
├── llama_35m.json
├── llama_3b.json
├── llama_40m.json
├── llama_60m.json
├── llama_71m.json
├── llama_7b.json
└── llama_9m.json
├── exp_requirements.txt
├── layer_remove.py
├── peft_pretraining
├── __pycache__
│ ├── args_utils.cpython-39.pyc
│ ├── dataloader.cpython-39.pyc
│ ├── modeling_llama.cpython-39.pyc
│ └── training_utils.cpython-39.pyc
├── args_utils.py
├── dataloader.py
├── modeling_llama.py
└── training_utils.py
├── requirements.txt
├── run_130m.sh
├── run_1b.sh
├── run_250m.sh
├── run_350m.sh
├── scaled_init.png
├── scaling.png
├── torchrun_main.py
└── utils
├── angular_distance.py
├── metrics.py
└── short_hf.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # The Curse of Depth in Large Language Models
4 | [[`Arxiv`](https://arxiv.org/abs/2502.05795)]
5 | [[X(Twitter)](https://x.com/Shiwei_Liu66/status/1889257901346152844)]
6 | [[Model](https://huggingface.co/pengxiang/LNS_1B)]
7 | [[Talk](https://www.youtube.com/watch?v=sVN7wgmmNms)]
8 |
9 | We present the Curse of Depth, a phenomenon in Large Language Models (LLMs) where deeper layers contribute less effectively to training due to the widespread use of Pre-Layer Normalization (Pre-LN). Our analysis identifies this issue as a key bottleneck in LLM optimization and proposes LayerNorm Scaling as a solution to mitigate its impact.
10 |
11 |
12 |

13 |
14 |
15 |
16 | ## Abstract
17 |
18 | In this paper, we introduce the Curse of Depth, a concept that highlights, explains, and addresses the recent observation in modern Large Language Models (LLMs) where nearly half of the layers are less effective than expected. We first confirm the wide existence of this phenomenon across the most popular families of LLMs such as Llama, Mistral, DeepSeek, and Qwen. Our analysis, theoretically and empirically, identifies that the underlying reason for the ineffectiveness of deep layers in LLMs is the widespread usage of Pre-Layer Normalization (Pre-LN). While Pre-LN stabilizes the training of Transformer LLMs, its output variance exponentially grows with the model depth, which undesirably causes the derivative of the deep Transformer blocks to be an identity matrix, and therefore barely contributes to the training. To resolve this training pitfall, we propose LayerNorm Scaling, which scales the variance of output of the layer normalization inversely by the square root of its depth. This simple modification mitigates the output variance explosion of deeper Transformer layers, improving their contribution. Our experimental results, spanning model sizes from 130M to 1B, demonstrate that LayerNorm Scaling significantly enhances LLM pre-training performance compared to Pre-LN. Moreover, this improvement seamlessly carries over to supervised fine-tuning. All these gains can be attributed to the fact that LayerNorm Scaling enables deeper layers to contribute more effectively during training.
19 |
20 | ## Caveat
21 |
22 | Combining LNS with Scaled Initialization (which scales the initialization of W0 and W2 by the overall depth $1/\sqrt{2L}$) undermines the effectiveness of LNS, performing worse than using LNS alone. This highlights the importance of eliminating conflicting initialization strategies before adopting LNS.
23 |
24 |
25 |

26 |
27 |
28 |
29 | ## Hugging Face
30 | We have uploaded the trained weights for the 1B model using LayerNorm Scaling (LNS).
31 | You can download them from [https://huggingface.co/pengxiang/LNS_1B].
32 |
33 | ## Quick Start
34 |
35 | ### Install experiment dependencies
36 |
37 | You can configure the environment using the following command lines:
38 |
39 | ```bash
40 | conda create -n LNS python=3.9 -y
41 | conda activate LNS
42 | pip install -r exp_requirements.txt
43 | ```
44 |
45 | ### Training Examples
46 | We provide scripts to train models of different sizes using Pre-LN, Post-LN, Mix-LN, and LayerNorm Scaling (LNS).
47 |
48 | Train a 130M Model:
49 | ```bash
50 | bash run_130m.sh pre 3 # Pre-LN
51 | bash run_130m.sh post 3 # Post-LN
52 | bash run_130m.sh post_pre 3 # Mix-LN
53 | bash run_130m.sh LNS 3 # LayerNorm Scaling (LNS)
54 |
55 | (Note: 3 represents the number of Post-LN layers in Mix-LN.)
56 | ```
57 |
58 |
59 | Train a 250M Mode:
60 | ```bash
61 | bash run_250m.sh pre 6 # Pre-LN
62 | bash run_250m.sh post 6 # Post-LN
63 | bash run_250m.sh post_pre 6 # Mix-LN
64 | bash run_250m.sh LNS 6 # LayerNorm Scaling (LNS)
65 |
66 | (Note: 6 represents the number of Post-LN layers in Mix-LN.)
67 | ```
68 |
69 | Train a 350M Mode:
70 | ```bash
71 | bash run_350m.sh pre 6 # Pre-LN
72 | bash run_350m.sh post 6 # Post-LN
73 | bash run_350m.sh post_pre 6 # Mix-LN
74 | bash run_350m.sh LNS 6 # LayerNorm Scaling (LNS)
75 | ```
76 |
77 | Train a 1B Mode:
78 | ```bash
79 | bash run_1b.sh pre 6 # Pre-LN
80 | bash run_1b.sh post 6 # Post-LN
81 | bash run_1b.sh post_pre 6 # Mix-LN
82 | bash run_1b.sh LNS 6 # LayerNorm Scaling (LNS)
83 | ```
84 |
85 | ### Performance Drop
86 | Calculate the performance drop after removing different layers. We use [lm_eval](https://github.com/EleutherAI/lm-evaluation-harness) to obtain evaluation results. Please refer to its installation instructions to configure `lm_eval``.
87 | ```bash
88 | git clone https://github.com/EleutherAI/lm-evaluation-harness
89 | cd lm-evaluation-harness
90 | pip install -e .
91 | ```
92 |
93 | Then, you can run the following command to remove different layers and save the weights to a new model. The performance drop will be calculated based on the new model:
94 | ```bash
95 | # LLaMA2-7B, Remove Layer 1
96 | python layer_remove.py \
97 | --model_path meta-llama/Llama-2-7b-hf \
98 | --layer_index 1 \
99 | --save_path ./llama_7b_removed_1
100 | ```
101 |
102 | ### 📚Citation
103 |
104 | ```bibtex
105 | @article{sun2025curse,
106 | title={The Curse of Depth in Large Language Models},
107 | author={Sun, Wenfang and Song, Xinyuan and Li, Pengxiang and Yin, Lu and Zheng, Yefeng and Liu, Shiwei},
108 | journal={arXiv preprint arXiv:2502.05795},
109 | year={2025}
110 | }
111 | ```
112 |
113 |
114 | ### Acknowledgement
115 | This repository is built upon the [Mix-LN](https://github.com/pixeli99/MixLN/tree/main) repositories. Thanks for their great work!
116 |
--------------------------------------------------------------------------------
/configs/llama_100m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 640,
9 | "intermediate_size": 1708,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 10,
14 | "num_hidden_layers": 12,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32100
20 | }
--------------------------------------------------------------------------------
/configs/llama_130m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 768,
9 | "intermediate_size": 2048,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 12,
14 | "num_hidden_layers": 12,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_1b.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 2048,
9 | "intermediate_size": 5461,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 32,
14 | "num_hidden_layers": 24,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_20m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 256,
9 | "intermediate_size": 688,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 4,
14 | "num_hidden_layers": 4,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_250m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 768,
9 | "intermediate_size": 2560,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 16,
14 | "num_hidden_layers": 24,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_350m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 1024,
9 | "intermediate_size": 2736,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 16,
14 | "num_hidden_layers": 24,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_35m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 384,
9 | "intermediate_size": 1024,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 8,
14 | "num_hidden_layers": 6,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_3b.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 2560,
9 | "intermediate_size": 6848,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 32,
14 | "num_hidden_layers": 32,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_40m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 416,
9 | "intermediate_size": 1024,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 8,
14 | "num_hidden_layers": 8,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_60m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 512,
9 | "intermediate_size": 1376,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 8,
14 | "num_hidden_layers": 8,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_71m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 512,
9 | "intermediate_size": 1368,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 8,
14 | "num_hidden_layers": 12,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_7b.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 4096,
9 | "intermediate_size": 11008,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 2048,
12 | "model_type": "llama",
13 | "num_attention_heads": 32,
14 | "num_hidden_layers": 32,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_9m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 128,
9 | "intermediate_size": 352,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 4,
14 | "num_hidden_layers": 4,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/exp_requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | transformers==4.31.0
3 | tensorly
4 | tokenizers
5 | datasets
6 | peft
7 | wandb
8 | loguru
9 | nvitop
10 | lion-pytorch
11 | matplotlib
12 | bitsandbytes
13 | scipy
14 | scikit-learn
15 | evaluate
16 |
--------------------------------------------------------------------------------
/layer_remove.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from transformers import LlamaForCausalLM
4 | import subprocess
5 | import json
6 | import argparse
7 |
8 | def remove_layers_and_save(model_path, output_dir, layers_to_remove):
9 | # Load the pre-trained LLaMA model
10 | model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
11 |
12 | # Ensure the output directory exists
13 | if not os.path.exists(output_dir):
14 | os.makedirs(output_dir)
15 |
16 | # Remove the specified layers
17 | for layer_idx in layers_to_remove:
18 | if 0 <= layer_idx < len(model.model.layers):
19 | del model.model.layers[layer_idx]
20 |
21 | # Renumber the remaining layers' indices
22 | for layer_idx, module in enumerate(model.model.layers):
23 | module.self_attn.layer_idx = layer_idx
24 |
25 | # Save the modified model
26 | model.save_pretrained(output_dir)
27 | print(f"Model saved to {output_dir}")
28 |
29 | # Update the config.json with the new number of hidden layers
30 | config_path = os.path.join(output_dir, "config.json")
31 | with open(config_path, "r", encoding="utf-8") as file:
32 | config = json.load(file)
33 |
34 | config['num_hidden_layers'] = len(model.model.layers)
35 |
36 | with open(config_path, "w", encoding="utf-8") as file:
37 | json.dump(config, file, indent=4, ensure_ascii=False)
38 | print(f"Updated config saved to {config_path}")
39 |
40 |
41 | def run_bash_script(script_path, working_directory):
42 | # Switch to the target directory
43 | os.chdir(working_directory)
44 |
45 | # Execute the bash script
46 | subprocess.run(["bash", script_path])
47 |
48 | # Switch back to the original directory
49 | os.chdir("..")
50 | print(f"Executed script {script_path} in {working_directory}")
51 |
52 |
53 | if __name__ == "__main__":
54 | # Argument parsing
55 | parser = argparse.ArgumentParser(description="Remove layers from LLaMA model and save the modified version.")
56 | parser.add_argument("--model_path", type=str, required=True, help="Path to the pre-trained LLaMA model")
57 | parser.add_argument("--layer_index", type=int, required=True, help="Index of the layer to remove")
58 | parser.add_argument("--save_path", type=str, required=True, help="Path to save the modified model")
59 |
60 | args = parser.parse_args()
61 |
62 | # Remove layers and save the model
63 | remove_layers_and_save(args.model_path, args.save_path, [args.layer_index])
64 |
65 | # Path to the evaluation script and directory
66 | # target_directory = "/lm-evaluation-harness"
67 | # script_name = "run_task.sh"
68 |
69 | # Run the evaluation script
70 | # run_bash_script(script_name, target_directory)
71 |
--------------------------------------------------------------------------------
/peft_pretraining/__pycache__/args_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmsdss/LayerNorm-Scaling/da34a90d7392fe01a7dcdae6715ea6e6706bdd81/peft_pretraining/__pycache__/args_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/peft_pretraining/__pycache__/dataloader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmsdss/LayerNorm-Scaling/da34a90d7392fe01a7dcdae6715ea6e6706bdd81/peft_pretraining/__pycache__/dataloader.cpython-39.pyc
--------------------------------------------------------------------------------
/peft_pretraining/__pycache__/modeling_llama.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmsdss/LayerNorm-Scaling/da34a90d7392fe01a7dcdae6715ea6e6706bdd81/peft_pretraining/__pycache__/modeling_llama.cpython-39.pyc
--------------------------------------------------------------------------------
/peft_pretraining/__pycache__/training_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmsdss/LayerNorm-Scaling/da34a90d7392fe01a7dcdae6715ea6e6706bdd81/peft_pretraining/__pycache__/training_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/peft_pretraining/args_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 |
4 | from loguru import logger
5 |
6 |
7 | def check_args_torchrun_main(args):
8 |
9 | if args.save_dir is None:
10 | # use checkpoints / model name, date and time as save directory
11 | args.save_dir = f"checkpoints/{args.model_config.split('/')[-1].rstrip('.json')}-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
12 |
13 | if args.tags is not None:
14 | args.tags = args.tags.split(",")
15 |
16 | if args.total_batch_size is None:
17 | args.gradient_accumulation = args.gradient_accumulation or 1
18 | args.total_batch_size = args.batch_size * args.gradient_accumulation
19 |
20 | assert args.total_batch_size % args.batch_size == 0, "total_batch_size must be divisible by batch_size"
21 |
22 | if args.max_train_tokens is not None:
23 | args.num_training_steps = args.max_train_tokens // args.total_batch_size
24 | logger.info(f"Training for {args.num_training_steps} update steps")
25 |
26 | if args.continue_from is not None:
27 | assert os.path.exists(args.continue_from), f"--continue_from={args.continue_from} does not exist"
28 |
29 | if args.dtype in ["fp16", "float16"]:
30 | raise NotImplementedError("fp16 is not supported in torchrun_main.py. Use deepspeed_main.py instead (but it seems to have bugs)")
31 |
32 | return args
33 |
--------------------------------------------------------------------------------
/peft_pretraining/dataloader.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import torch
4 | from torch.utils.data import IterableDataset, get_worker_info
5 |
6 |
7 | class PreprocessedIterableDataset(IterableDataset):
8 | def __init__(self, data, tokenizer, batch_size, max_length):
9 | super().__init__()
10 | self.data = data
11 | self.tokenizer = tokenizer
12 | self.batch_size = batch_size
13 | self.max_length = max_length
14 |
15 | def __iter__(self):
16 | worker_info = get_worker_info()
17 | if worker_info is None:
18 | # If no worker_info is provided, we are not using DataLoader workers, so yield all data
19 | iter_data = iter(self.data)
20 | else:
21 | # If using DataLoader workers, yield a subset of the data for this worker
22 | worker_id = worker_info.id
23 | num_workers = worker_info.num_workers
24 | iter_data = itertools.islice(self.data, worker_id, None, num_workers)
25 |
26 | batch = []
27 | for example in iter_data:
28 | tokenized_example = self.tokenizer(
29 | example["text"],
30 | max_length=self.max_length,
31 | truncation=True,
32 | padding="max_length",
33 | return_tensors="pt",
34 | )
35 | batch.append(tokenized_example)
36 |
37 | if len(batch) == self.batch_size:
38 | yield self._format_batch(batch)
39 | batch = []
40 |
41 | if batch:
42 | yield self._format_batch(batch)
43 |
44 | def _format_batch(self, batch):
45 | input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
46 | attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
47 |
48 | return {"input_ids": input_ids, "attention_mask": attention_mask}
49 |
--------------------------------------------------------------------------------
/peft_pretraining/modeling_llama.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | """ PyTorch LLaMA model."""
21 | import math
22 | import os
23 | from typing import List, Optional, Tuple, Union
24 |
25 | import torch
26 | import torch.utils.checkpoint
27 | from torch import nn
28 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29 |
30 | from transformers.activations import ACT2FN
31 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
32 | from transformers.modeling_utils import PreTrainedModel
33 | from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
34 | from transformers.models.llama.configuration_llama import LlamaConfig
35 |
36 |
37 | logger = logging.get_logger(__name__)
38 |
39 | _CONFIG_FOR_DOC = "LlamaConfig"
40 |
41 |
42 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask
43 | def _make_causal_mask(
44 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
45 | ):
46 | """
47 | Make causal mask used for bi-directional self-attention.
48 | """
49 | bsz, tgt_len = input_ids_shape
50 | mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
51 | mask_cond = torch.arange(mask.size(-1), device=device)
52 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
53 | mask = mask.to(dtype)
54 |
55 | if past_key_values_length > 0:
56 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
57 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
58 |
59 |
60 | # Copied from transformers.models.bart.modeling_bart._expand_mask
61 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
62 | """
63 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
64 | """
65 | bsz, src_len = mask.size()
66 | tgt_len = tgt_len if tgt_len is not None else src_len
67 |
68 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
69 |
70 | inverted_mask = 1.0 - expanded_mask
71 |
72 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
73 |
74 |
75 | class LlamaRMSNorm(nn.Module):
76 | def __init__(self, hidden_size, eps=1e-6):
77 | """
78 | LlamaRMSNorm is equivalent to T5LayerNorm
79 | """
80 | super().__init__()
81 | self.weight = nn.Parameter(torch.ones(hidden_size))
82 | self.variance_epsilon = eps
83 |
84 | def forward(self, hidden_states):
85 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
86 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
87 |
88 | # convert into half-precision if necessary
89 | if self.weight.dtype in [torch.float16, torch.bfloat16]:
90 | hidden_states = hidden_states.to(self.weight.dtype)
91 |
92 | return self.weight * hidden_states
93 |
94 |
95 | class LlamaRotaryEmbedding(torch.nn.Module):
96 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
97 | super().__init__()
98 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
99 | self.register_buffer("inv_freq", inv_freq)
100 |
101 | # Build here to make `torch.jit.trace` work.
102 | self.max_seq_len_cached = max_position_embeddings
103 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
104 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
105 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
106 | emb = torch.cat((freqs, freqs), dim=-1)
107 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
108 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
109 |
110 | def forward(self, x, seq_len=None):
111 | # x: [bs, num_attention_heads, seq_len, head_size]
112 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
113 | if seq_len > self.max_seq_len_cached:
114 | self.max_seq_len_cached = seq_len
115 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
116 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
117 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
118 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
119 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
120 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
121 | return (
122 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
123 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
124 | )
125 |
126 |
127 | def rotate_half(x):
128 | """Rotates half the hidden dims of the input."""
129 | x1 = x[..., : x.shape[-1] // 2]
130 | x2 = x[..., x.shape[-1] // 2 :]
131 | return torch.cat((-x2, x1), dim=-1)
132 |
133 |
134 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
135 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
136 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
137 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
138 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
139 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
140 | q_embed = (q * cos) + (rotate_half(q) * sin)
141 | k_embed = (k * cos) + (rotate_half(k) * sin)
142 | return q_embed, k_embed
143 |
144 |
145 | class LlamaMLP(nn.Module):
146 | def __init__(
147 | self,
148 | hidden_size: int,
149 | intermediate_size: int,
150 | hidden_act: str,
151 | scale_mlp_output: bool = False,
152 | ):
153 | super().__init__()
154 | self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
155 | self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
156 | self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
157 | self.act_fn = ACT2FN[hidden_act]
158 |
159 | if scale_mlp_output:
160 | self.down_proj.is_scaled_layer = True
161 | if os.getenv('NORM_TYPE').lower() == 'deeppost':
162 | self.gate_proj.is_deeppost_layer = True
163 | self.up_proj.is_deeppost_layer = True
164 | self.down_proj.is_deeppost_layer = True
165 | def forward(self, x):
166 | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
167 |
168 |
169 | class LlamaAttention(nn.Module):
170 | """Multi-headed attention from 'Attention Is All You Need' paper"""
171 |
172 | def __init__(self, config: LlamaConfig, scale_attn_weights: bool = False):
173 | super().__init__()
174 | self.config = config
175 | self.hidden_size = config.hidden_size
176 | self.num_heads = config.num_attention_heads
177 | self.head_dim = self.hidden_size // self.num_heads
178 | self.max_position_embeddings = config.max_position_embeddings
179 |
180 | if (self.head_dim * self.num_heads) != self.hidden_size:
181 | raise ValueError(
182 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
183 | f" and `num_heads`: {self.num_heads})."
184 | )
185 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
186 | self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
187 | self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
188 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
189 | self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
190 |
191 | if scale_attn_weights:
192 | self.o_proj.is_scaled_layer = True
193 | if os.getenv('NORM_TYPE', 'pre').lower() == 'deeppost':
194 | self.v_proj.is_deeppost_layer = True
195 | self.o_proj.is_deeppost_layer = True
196 | # ---- - - - - -
197 | self.q_proj.is_deeppost_layer_qk = True
198 | self.k_proj.is_deeppost_layer_qk = True
199 |
200 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
201 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
202 |
203 | def forward(
204 | self,
205 | hidden_states: torch.Tensor,
206 | attention_mask: Optional[torch.Tensor] = None,
207 | position_ids: Optional[torch.LongTensor] = None,
208 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
209 | output_attentions: bool = False,
210 | use_cache: bool = False,
211 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
212 | bsz, q_len, _ = hidden_states.size()
213 |
214 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
215 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
216 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
217 |
218 | kv_seq_len = key_states.shape[-2]
219 | if past_key_value is not None:
220 | kv_seq_len += past_key_value[0].shape[-2]
221 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
222 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
223 | # [bsz, nh, t, hd]
224 |
225 | if past_key_value is not None:
226 | # reuse k, v, self_attention
227 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
228 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
229 |
230 | past_key_value = (key_states, value_states) if use_cache else None
231 |
232 | if attention_mask is not None:
233 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
234 | raise ValueError(
235 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
236 | )
237 |
238 | # WARNING: padding mask is ignored, causal is always applied
239 | attn_output = torch.nn.functional.scaled_dot_product_attention(
240 | query_states, key_states, value_states, dropout_p=0.0, is_causal=True,
241 | )
242 |
243 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
244 | raise ValueError(
245 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
246 | f" {attn_output.size()}"
247 | )
248 |
249 | attn_output = attn_output.transpose(1, 2)
250 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
251 |
252 | attn_output = self.o_proj(attn_output)
253 |
254 | if not output_attentions:
255 | attn_weights = None
256 |
257 | return attn_output, attn_weights, past_key_value
258 |
259 |
260 | class LlamaDecoderLayer(nn.Module):
261 | def __init__(self, config: LlamaConfig, layer_index: int):
262 | super().__init__()
263 | self.hidden_size = config.hidden_size
264 |
265 | norm_type = os.getenv('NORM_TYPE', 'pre').lower()
266 | self.layer_nums = config.num_hidden_layers
267 | self.layer_index = layer_index
268 | self.max_post_norm_layer = int(os.getenv('POST_NUM', '1'))
269 | print(f'Initializing LlamaDecoderLayer {self.layer_index + 1}/{self.layer_nums} with norm type: {norm_type}')
270 | scale_attn_weights = False
271 | scale_mlp_output = False
272 |
273 | if norm_type == 'scale_post_pre':
274 | if self.layer_index < self.max_post_norm_layer:
275 | scale_attn_weights = False
276 | scale_mlp_output = False
277 | else:
278 | scale_attn_weights = True
279 | scale_mlp_output = True
280 | if norm_type == 'scale_pre':
281 | scale_attn_weights = True
282 | scale_mlp_output = True
283 |
284 | self.self_attn = LlamaAttention(config=config, scale_attn_weights=scale_attn_weights)
285 | self.mlp = LlamaMLP(
286 | hidden_size=self.hidden_size,
287 | intermediate_size=config.intermediate_size,
288 | hidden_act=config.hidden_act,
289 | scale_mlp_output=scale_mlp_output,
290 | )
291 |
292 | if norm_type == 'pre' or norm_type == 'scale_pre' or norm_type == 'LNS':
293 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
294 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
295 | elif norm_type== 'post' or norm_type == 'deeppost':
296 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
297 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
298 | elif norm_type == 'sandwich':
299 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
300 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
301 | self.pre_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
302 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
303 | elif norm_type == 'pre_post':
304 | print("cur layer index is:", layer_index)
305 | self.max_pre_norm_layer = 7
306 | if self.layer_index < self.max_pre_norm_layer:
307 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
308 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
309 | else:
310 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
311 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
312 | elif norm_type == 'pre_sandwich':
313 | print("cur layer index is:", layer_index)
314 | self.max_pre_norm_layer = 7
315 | if self.layer_index < self.max_pre_norm_layer:
316 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
317 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
318 | else:
319 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
320 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
321 | self.pre_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
322 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
323 | elif norm_type == 'post_pre' or norm_type == 'scale_post_pre':
324 | print("cur layer index is:", layer_index)
325 | if self.layer_index < self.max_post_norm_layer:
326 | print(f"cur layer is {layer_index}, and you're using the post norm!")
327 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
328 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
329 | else:
330 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
331 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
332 | elif norm_type == 'scale_res_pre_norm':
333 | self.raw_scaling_factor_attn = nn.Parameter(torch.tensor(0.001))
334 | self.raw_scaling_factor_mlp = nn.Parameter(torch.tensor(0.001))
335 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
336 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
337 | elif norm_type == 'pre_post_pre_post':
338 | print("cur layer index is:", layer_index)
339 | # For even-numbered layers, use pre-norm
340 | if (self.layer_index + 1) % 4 != 0:
341 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
342 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
343 | else:
344 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
345 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
346 | elif norm_type == 'post_pre_post_pre':
347 | print("cur layer index is:", layer_index)
348 | # For even-numbered layers, use pre-norm
349 | if (self.layer_index + 1) % 4 != 0:
350 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
351 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
352 | else:
353 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
354 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
355 | elif norm_type == 'mono':
356 | if (self.layer_index + 1) % 4 != 0:
357 | # Pre-LayerNorm Only
358 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
359 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
360 | else:
361 | # Post-LayerNorm & Pre-LayerNorm
362 | self.router = nn.Linear(config.hidden_size, 2, bias=False)
363 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
364 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
365 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
366 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
367 | elif norm_type == 'mono_reverse':
368 | if (self.layer_index + 1) % 4 != 0:
369 | # Post-LayerNorm Only
370 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
371 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
372 | else:
373 | # Post-LayerNorm & Pre-LayerNorm
374 | self.router = nn.Linear(config.hidden_size, 2, bias=False)
375 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
376 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
377 | self.post_feedforward_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
378 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
379 |
380 | def forward(
381 | self,
382 | hidden_states: torch.Tensor,
383 | attention_mask: Optional[torch.Tensor] = None,
384 | position_ids: Optional[torch.LongTensor] = None,
385 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
386 | output_attentions: Optional[bool] = False,
387 | use_cache: Optional[bool] = False,
388 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
389 | """
390 | Args:
391 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
392 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
393 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
394 | output_attentions (`bool`, *optional*):
395 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
396 | returned tensors for more detail.
397 | use_cache (`bool`, *optional*):
398 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
399 | (see `past_key_values`).
400 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
401 | """
402 |
403 | norm_type = os.getenv('NORM_TYPE', 'pre').lower()
404 |
405 | if norm_type == 'pre' or norm_type == 'scale_pre':
406 | # # Layer 1: Self-Attention
407 | residual = hidden_states
408 | hidden_states = self.input_layernorm(hidden_states)
409 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
410 | hidden_states=hidden_states,
411 | attention_mask=attention_mask,
412 | position_ids=position_ids,
413 | past_key_value=past_key_value,
414 | output_attentions=output_attentions,
415 | use_cache=use_cache,
416 | )
417 | hidden_states = residual + hidden_states
418 |
419 | # Layer 2: Feed-Forward Network (FFN)
420 |
421 | residual = hidden_states
422 | hidden_states = self.post_attention_layernorm(hidden_states)
423 | hidden_states = self.mlp(hidden_states)
424 | hidden_states = residual + hidden_states
425 |
426 |
427 | elif norm_type == 'LNS':
428 | # Layer 1: Self-Attention
429 | residual = hidden_states
430 | hidden_states = self.input_layernorm(hidden_states)
431 |
432 | scale_factor = 1 / math.sqrt(self.layer_index + 1) # scale
433 | hidden_states = scale_factor * hidden_states
434 |
435 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
436 | hidden_states=hidden_states,
437 | attention_mask=attention_mask,
438 | position_ids=position_ids,
439 | past_key_value=past_key_value,
440 | output_attentions=output_attentions,
441 | use_cache=use_cache,
442 | )
443 | hidden_states = residual + hidden_states
444 |
445 | # Layer 2: Feed-Forward Network (FFN)
446 | residual = hidden_states
447 | hidden_states = self.post_attention_layernorm(hidden_states)
448 |
449 | scale_factor = 1 / math.sqrt(self.layer_index + 1) # scale
450 | hidden_states = scale_factor * hidden_states
451 |
452 | hidden_states = self.mlp(hidden_states)
453 | hidden_states = residual + hidden_states
454 |
455 |
456 | elif norm_type == 'scale_res_pre_norm':
457 | # Pre-LayerNorm Only
458 | residual = hidden_states
459 | hidden_states = self.input_layernorm(hidden_states)
460 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
461 | hidden_states=hidden_states,
462 | attention_mask=attention_mask,
463 | position_ids=position_ids,
464 | past_key_value=past_key_value,
465 | output_attentions=output_attentions,
466 | use_cache=use_cache,
467 | )
468 | hidden_states = residual + hidden_states * self.raw_scaling_factor_attn
469 |
470 | residual = hidden_states
471 | hidden_states = self.post_attention_layernorm(hidden_states)
472 | hidden_states = self.mlp(hidden_states)
473 | hidden_states = residual + hidden_states * self.raw_scaling_factor_mlp
474 | elif norm_type == 'post' or norm_type == 'deeppost':
475 | # Post-LayerNorm Only
476 | residual = hidden_states
477 |
478 |
479 |
480 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
481 | hidden_states=hidden_states,
482 | attention_mask=attention_mask,
483 | position_ids=position_ids,
484 | past_key_value=past_key_value,
485 | output_attentions=output_attentions,
486 | use_cache=use_cache,
487 | )
488 | if norm_type == 'deeppost':
489 | residual = 2.8284271247461903 * residual
490 | hidden_states = residual + hidden_states
491 |
492 |
493 |
494 |
495 | hidden_states = self.post_attention_layernorm(hidden_states)
496 |
497 | residual = hidden_states
498 |
499 | hidden_states = self.mlp(hidden_states)
500 | if norm_type == 'deeppost':
501 | residual = 2.8284271247461903 * residual
502 | hidden_states = residual + hidden_states
503 | hidden_states = self.post_feedforward_layernorm(hidden_states)
504 |
505 | elif norm_type == 'sandwich':
506 | # Pre + Post LayerNorm (default)
507 | residual = hidden_states
508 | hidden_states = self.input_layernorm(hidden_states)
509 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
510 | hidden_states=hidden_states,
511 | attention_mask=attention_mask,
512 | position_ids=position_ids,
513 | past_key_value=past_key_value,
514 | output_attentions=output_attentions,
515 | use_cache=use_cache,
516 | )
517 | hidden_states = self.post_attention_layernorm(hidden_states)
518 | hidden_states = residual + hidden_states
519 |
520 | residual = hidden_states
521 | hidden_states = self.pre_feedforward_layernorm(hidden_states)
522 | hidden_states = self.mlp(hidden_states)
523 | hidden_states = self.post_feedforward_layernorm(hidden_states)
524 | hidden_states = residual + hidden_states
525 |
526 | elif norm_type == 'pre_post':
527 | if self.layer_index < self.max_pre_norm_layer:
528 | # Pre-LayerNorm Only
529 | residual = hidden_states
530 | hidden_states = self.input_layernorm(hidden_states)
531 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
532 | hidden_states=hidden_states,
533 | attention_mask=attention_mask,
534 | position_ids=position_ids,
535 | past_key_value=past_key_value,
536 | output_attentions=output_attentions,
537 | use_cache=use_cache,
538 | )
539 | hidden_states = residual + hidden_states
540 |
541 | residual = hidden_states
542 | hidden_states = self.post_attention_layernorm(hidden_states)
543 | hidden_states = self.mlp(hidden_states)
544 | hidden_states = residual + hidden_states
545 | else:
546 | # Post-LayerNorm Only
547 | residual = hidden_states
548 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
549 | hidden_states=hidden_states,
550 | attention_mask=attention_mask,
551 | position_ids=position_ids,
552 | past_key_value=past_key_value,
553 | output_attentions=output_attentions,
554 | use_cache=use_cache,
555 | )
556 | hidden_states = residual + hidden_states
557 | hidden_states = self.post_attention_layernorm(hidden_states)
558 |
559 | residual = hidden_states
560 | hidden_states = self.mlp(hidden_states)
561 | hidden_states = residual + hidden_states
562 | hidden_states = self.post_feedforward_layernorm(hidden_states)
563 |
564 | elif norm_type == 'pre_sandwich':
565 | if self.layer_index < self.max_pre_norm_layer:
566 | # Pre-LayerNorm Only
567 | residual = hidden_states
568 | hidden_states = self.input_layernorm(hidden_states)
569 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
570 | hidden_states=hidden_states,
571 | attention_mask=attention_mask,
572 | position_ids=position_ids,
573 | past_key_value=past_key_value,
574 | output_attentions=output_attentions,
575 | use_cache=use_cache,
576 | )
577 | hidden_states = residual + hidden_states
578 |
579 | residual = hidden_states
580 | hidden_states = self.post_attention_layernorm(hidden_states)
581 | hidden_states = self.mlp(hidden_states)
582 | hidden_states = residual + hidden_states
583 | else:
584 | # Pre + Post LayerNorm (default)
585 | residual = hidden_states
586 | hidden_states = self.input_layernorm(hidden_states)
587 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
588 | hidden_states=hidden_states,
589 | attention_mask=attention_mask,
590 | position_ids=position_ids,
591 | past_key_value=past_key_value,
592 | output_attentions=output_attentions,
593 | use_cache=use_cache,
594 | )
595 | hidden_states = self.post_attention_layernorm(hidden_states)
596 | hidden_states = residual + hidden_states
597 |
598 | residual = hidden_states
599 | hidden_states = self.pre_feedforward_layernorm(hidden_states)
600 | hidden_states = self.mlp(hidden_states)
601 | hidden_states = self.post_feedforward_layernorm(hidden_states)
602 | hidden_states = residual + hidden_states
603 |
604 | elif norm_type == 'post_pre' or norm_type == 'scale_post_pre':
605 | if self.layer_index < self.max_post_norm_layer:
606 | # Post-LayerNorm Only
607 | residual = hidden_states
608 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
609 | hidden_states=hidden_states,
610 | attention_mask=attention_mask,
611 | position_ids=position_ids,
612 | past_key_value=past_key_value,
613 | output_attentions=output_attentions,
614 | use_cache=use_cache,
615 | )
616 | hidden_states = residual + hidden_states
617 | hidden_states = self.post_attention_layernorm(hidden_states)
618 |
619 | residual = hidden_states
620 | hidden_states = self.mlp(hidden_states)
621 | hidden_states = residual + hidden_states
622 | hidden_states = self.post_feedforward_layernorm(hidden_states)
623 | else:
624 | # Pre-LayerNorm Only
625 | residual = hidden_states
626 | hidden_states = self.input_layernorm(hidden_states)
627 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
628 | hidden_states=hidden_states,
629 | attention_mask=attention_mask,
630 | position_ids=position_ids,
631 | past_key_value=past_key_value,
632 | output_attentions=output_attentions,
633 | use_cache=use_cache,
634 | )
635 | hidden_states = residual + hidden_states
636 |
637 | residual = hidden_states
638 | hidden_states = self.post_attention_layernorm(hidden_states)
639 | hidden_states = self.mlp(hidden_states)
640 | hidden_states = residual + hidden_states
641 | elif norm_type == 'mono':
642 | if (self.layer_index + 1) % 4 != 0:
643 | # Pre-LayerNorm Only
644 | residual = hidden_states
645 | hidden_states = self.input_layernorm(hidden_states)
646 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
647 | hidden_states=hidden_states,
648 | attention_mask=attention_mask,
649 | position_ids=position_ids,
650 | past_key_value=past_key_value,
651 | output_attentions=output_attentions,
652 | use_cache=use_cache,
653 | )
654 | hidden_states = residual + hidden_states
655 |
656 | residual = hidden_states
657 | hidden_states = self.post_attention_layernorm(hidden_states)
658 | hidden_states = self.mlp(hidden_states)
659 | hidden_states = residual + hidden_states
660 | else:
661 | # b, n, c -> b, 2
662 | router_logits = self.router(hidden_states).mean(dim=1)
663 | router_output = torch.nn.functional.gumbel_softmax(router_logits, tau=1, hard=True)
664 |
665 | ### pre branch
666 | hidden_states_pre = hidden_states.clone()
667 | residual = hidden_states_pre
668 | hidden_states_pre = self.input_layernorm(hidden_states_pre)
669 | hidden_states_pre, self_attn_weights, present_key_value = self.self_attn(
670 | hidden_states=hidden_states_pre,
671 | attention_mask=attention_mask,
672 | position_ids=position_ids,
673 | past_key_value=past_key_value,
674 | output_attentions=output_attentions,
675 | use_cache=use_cache,
676 | )
677 | hidden_states_pre = residual + hidden_states_pre
678 |
679 | residual = hidden_states_pre
680 | hidden_states_pre = self.post_attention_layernorm(hidden_states_pre)
681 | hidden_states_pre = self.mlp(hidden_states_pre)
682 | hidden_states_pre = residual + hidden_states_pre
683 | ###
684 | ### post norm
685 | hidden_states_post = hidden_states.clone()
686 | residual = hidden_states_post
687 | hidden_states_post, self_attn_weights, present_key_value = self.self_attn(
688 | hidden_states=hidden_states_post,
689 | attention_mask=attention_mask,
690 | position_ids=position_ids,
691 | past_key_value=past_key_value,
692 | output_attentions=output_attentions,
693 | use_cache=use_cache,
694 | )
695 | hidden_states_post = residual + hidden_states_post
696 | hidden_states_post = self.post_attention_layernorm(hidden_states_post)
697 |
698 | residual = hidden_states_post
699 | hidden_states_post = self.mlp(hidden_states_post)
700 | hidden_states_post = residual + hidden_states_post
701 | hidden_states_post = self.post_feedforward_layernorm(hidden_states_post)
702 |
703 | hidden_states = router_output[:, 0:1] * hidden_states_pre + router_output[:, 1:] * hidden_states_post
704 |
705 | elif norm_type == 'mono_reverse':
706 | if (self.layer_index + 1) % 4 != 0:
707 | # Post-LayerNorm Only
708 | residual = hidden_states
709 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
710 | hidden_states=hidden_states,
711 | attention_mask=attention_mask,
712 | position_ids=position_ids,
713 | past_key_value=past_key_value,
714 | output_attentions=output_attentions,
715 | use_cache=use_cache,
716 | )
717 | hidden_states = residual + hidden_states
718 | hidden_states = self.post_attention_layernorm(hidden_states)
719 |
720 | residual = hidden_states
721 | hidden_states = self.post_attention_layernorm(hidden_states)
722 | hidden_states = self.mlp(hidden_states)
723 | hidden_states = residual + hidden_states
724 | hidden_states = self.post_feedforward_layernorm(hidden_states)
725 | else:
726 | # b, n, c -> b, 2
727 | router_logits = self.router(hidden_states).mean(dim=1)
728 | router_output = torch.nn.functional.gumbel_softmax(router_logits, tau=1, hard=True)
729 |
730 | ### pre branch
731 | hidden_states_pre = hidden_states.clone()
732 | residual = hidden_states_pre
733 | hidden_states_pre = self.input_layernorm(hidden_states_pre)
734 | hidden_states_pre, self_attn_weights, present_key_value = self.self_attn(
735 | hidden_states=hidden_states_pre,
736 | attention_mask=attention_mask,
737 | position_ids=position_ids,
738 | past_key_value=past_key_value,
739 | output_attentions=output_attentions,
740 | use_cache=use_cache,
741 | )
742 | hidden_states_pre = residual + hidden_states_pre
743 |
744 | residual = hidden_states_pre
745 | hidden_states_pre = self.post_attention_layernorm(hidden_states_pre)
746 | hidden_states_pre = self.mlp(hidden_states_pre)
747 | hidden_states_pre = residual + hidden_states_pre
748 | ###
749 | ### post norm
750 | hidden_states_post = hidden_states.clone()
751 | residual = hidden_states_post
752 | hidden_states_post, self_attn_weights, present_key_value = self.self_attn(
753 | hidden_states=hidden_states_post,
754 | attention_mask=attention_mask,
755 | position_ids=position_ids,
756 | past_key_value=past_key_value,
757 | output_attentions=output_attentions,
758 | use_cache=use_cache,
759 | )
760 | hidden_states_post = residual + hidden_states_post
761 | hidden_states_post = self.post_attention_layernorm(hidden_states_post)
762 |
763 | residual = hidden_states_post
764 | hidden_states_post = self.mlp(hidden_states_post)
765 | hidden_states_post = residual + hidden_states_post
766 | hidden_states_post = self.post_feedforward_layernorm(hidden_states_post)
767 |
768 | hidden_states = router_output[:, 0:1] * hidden_states_pre + router_output[:, 1:] * hidden_states_post
769 |
770 | elif norm_type == 'post_pre_post_pre':
771 | if (self.layer_index + 1) % 4 == 0:
772 | # Pre-LayerNorm Only
773 | residual = hidden_states
774 | hidden_states = self.input_layernorm(hidden_states)
775 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
776 | hidden_states=hidden_states,
777 | attention_mask=attention_mask,
778 | position_ids=position_ids,
779 | past_key_value=past_key_value,
780 | output_attentions=output_attentions,
781 | use_cache=use_cache,
782 | )
783 | hidden_states = residual + hidden_states
784 |
785 | residual = hidden_states
786 | hidden_states = self.post_attention_layernorm(hidden_states)
787 | hidden_states = self.mlp(hidden_states)
788 | hidden_states = residual + hidden_states
789 | else:
790 | # Post-LayerNorm Only
791 | residual = hidden_states
792 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
793 | hidden_states=hidden_states,
794 | attention_mask=attention_mask,
795 | position_ids=position_ids,
796 | past_key_value=past_key_value,
797 | output_attentions=output_attentions,
798 | use_cache=use_cache,
799 | )
800 | hidden_states = residual + hidden_states
801 | hidden_states = self.post_attention_layernorm(hidden_states)
802 |
803 | residual = hidden_states
804 | hidden_states = self.mlp(hidden_states)
805 | hidden_states = residual + hidden_states
806 | hidden_states = self.post_feedforward_layernorm(hidden_states)
807 |
808 | outputs = (hidden_states,)
809 |
810 | if output_attentions:
811 | outputs += (self_attn_weights,)
812 |
813 | if use_cache:
814 | outputs += (present_key_value,)
815 |
816 | return outputs
817 |
818 |
819 | LLAMA_START_DOCSTRING = r"""
820 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
821 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
822 | etc.)
823 |
824 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
825 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
826 | and behavior.
827 |
828 | Parameters:
829 | config ([`LlamaConfig`]):
830 | Model configuration class with all the parameters of the model. Initializing with a config file does not
831 | load the weights associated with the model, only the configuration. Check out the
832 | [`~PreTrainedModel.from_pretrained`] method to load the model weights.
833 | """
834 |
835 |
836 | @add_start_docstrings(
837 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
838 | LLAMA_START_DOCSTRING,
839 | )
840 | class LlamaPreTrainedModel(PreTrainedModel):
841 | config_class = LlamaConfig
842 | base_model_prefix = "model"
843 | supports_gradient_checkpointing = True
844 | _no_split_modules = ["LlamaDecoderLayer"]
845 | _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
846 |
847 | # def _init_weights(self, module):
848 | # std = self.config.initializer_range
849 | # if isinstance(module, nn.Linear):
850 | # module.weight.data.normal_(mean=0.0, std=std)
851 | # if module.bias is not None:
852 | # module.bias.data.zero_()
853 | # elif isinstance(module, nn.Embedding):
854 | # module.weight.data.normal_(mean=0.0, std=std)
855 | # if module.padding_idx is not None:
856 | # module.weight.data[module.padding_idx].zero_()
857 | def _init_weights(self, module):
858 | std = self.config.initializer_range
859 | num_layers = self.config.num_hidden_layers
860 | scaled_std = std / (2 * num_layers) ** 0.5
861 | if isinstance(module, nn.Linear):
862 | if hasattr(module, "is_deeppost_layer") and module.is_deeppost_layer:
863 | torch.nn.init.xavier_normal_(module.weight, gain=(num_layers * 8) ** 0.25)
864 | elif hasattr(module, "is_deeppost_layer_qk") and module.is_deeppost_layer_qk:
865 | torch.nn.init.xavier_normal_(module.weight, gain=1)
866 | elif hasattr(module, "is_scaled_layer") and module.is_scaled_layer:
867 | module.weight.data.normal_(mean=0.0, std=scaled_std)
868 | print('-'*50)
869 | print('Warning: scaled init for layer:', module)
870 | print('-'*50)
871 | else:
872 | module.weight.data.normal_(mean=0.0, std=std)
873 | if module.bias is not None:
874 | module.bias.data.zero_()
875 | elif isinstance(module, nn.Embedding):
876 | module.weight.data.normal_(mean=0.0, std=(2 / 5) ** 0.5)
877 | if module.padding_idx is not None:
878 | module.weight.data[module.padding_idx].zero_()
879 |
880 | def _set_gradient_checkpointing(self, module, value=False):
881 | if isinstance(module, LlamaModel):
882 | module.gradient_checkpointing = value
883 |
884 |
885 | LLAMA_INPUTS_DOCSTRING = r"""
886 | Args:
887 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
888 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
889 | it.
890 |
891 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
892 | [`PreTrainedTokenizer.__call__`] for details.
893 |
894 | [What are input IDs?](../glossary#input-ids)
895 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
896 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
897 |
898 | - 1 for tokens that are **not masked**,
899 | - 0 for tokens that are **masked**.
900 |
901 | [What are attention masks?](../glossary#attention-mask)
902 |
903 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
904 | [`PreTrainedTokenizer.__call__`] for details.
905 |
906 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
907 | `past_key_values`).
908 |
909 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
910 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
911 | information on the default strategy.
912 |
913 | - 1 indicates the head is **not masked**,
914 | - 0 indicates the head is **masked**.
915 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
916 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
917 | config.n_positions - 1]`.
918 |
919 | [What are position IDs?](../glossary#position-ids)
920 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
921 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
922 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
923 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
924 |
925 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
926 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
927 |
928 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
929 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
930 | `decoder_input_ids` of shape `(batch_size, sequence_length)`.
931 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
932 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
933 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
934 | model's internal embedding lookup matrix.
935 | use_cache (`bool`, *optional*):
936 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
937 | `past_key_values`).
938 | output_attentions (`bool`, *optional*):
939 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
940 | tensors for more detail.
941 | output_hidden_states (`bool`, *optional*):
942 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
943 | more detail.
944 | return_dict (`bool`, *optional*):
945 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
946 | """
947 |
948 |
949 | @add_start_docstrings(
950 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
951 | LLAMA_START_DOCSTRING,
952 | )
953 | class LlamaModel(LlamaPreTrainedModel):
954 | """
955 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
956 |
957 | Args:
958 | config: LlamaConfig
959 | """
960 |
961 | def __init__(self, config: LlamaConfig):
962 | super().__init__(config)
963 | self.padding_idx = config.pad_token_id
964 | self.vocab_size = config.vocab_size
965 |
966 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
967 | self.layers = nn.ModuleList([LlamaDecoderLayer(config, _idx) for _idx in range(config.num_hidden_layers)])
968 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
969 |
970 | self.gradient_checkpointing = False
971 | # Initialize weights and apply final processing
972 | self.post_init()
973 |
974 | def get_input_embeddings(self):
975 | return self.embed_tokens
976 |
977 | def set_input_embeddings(self, value):
978 | self.embed_tokens = value
979 |
980 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
981 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
982 | # create causal mask
983 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
984 | combined_attention_mask = None
985 | if input_shape[-1] > 1:
986 | combined_attention_mask = _make_causal_mask(
987 | input_shape,
988 | inputs_embeds.dtype,
989 | device=inputs_embeds.device,
990 | past_key_values_length=past_key_values_length,
991 | )
992 |
993 | if attention_mask is not None:
994 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
995 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
996 | inputs_embeds.device
997 | )
998 | combined_attention_mask = (
999 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
1000 | )
1001 |
1002 | return combined_attention_mask
1003 |
1004 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1005 | def forward(
1006 | self,
1007 | input_ids: torch.LongTensor = None,
1008 | attention_mask: Optional[torch.Tensor] = None,
1009 | position_ids: Optional[torch.LongTensor] = None,
1010 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1011 | inputs_embeds: Optional[torch.FloatTensor] = None,
1012 | use_cache: Optional[bool] = None,
1013 | output_attentions: Optional[bool] = None,
1014 | output_hidden_states: Optional[bool] = None,
1015 | return_dict: Optional[bool] = None,
1016 | ) -> Union[Tuple, BaseModelOutputWithPast]:
1017 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1018 | output_hidden_states = (
1019 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1020 | )
1021 | use_cache = use_cache if use_cache is not None else self.config.use_cache
1022 |
1023 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1024 |
1025 | # retrieve input_ids and inputs_embeds
1026 | if input_ids is not None and inputs_embeds is not None:
1027 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1028 | elif input_ids is not None:
1029 | batch_size, seq_length = input_ids.shape
1030 | elif inputs_embeds is not None:
1031 | batch_size, seq_length, _ = inputs_embeds.shape
1032 | else:
1033 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1034 |
1035 | seq_length_with_past = seq_length
1036 | past_key_values_length = 0
1037 |
1038 | if past_key_values is not None:
1039 | past_key_values_length = past_key_values[0][0].shape[2]
1040 | seq_length_with_past = seq_length_with_past + past_key_values_length
1041 |
1042 | if position_ids is None:
1043 | device = input_ids.device if input_ids is not None else inputs_embeds.device
1044 | position_ids = torch.arange(
1045 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1046 | )
1047 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1048 | else:
1049 | position_ids = position_ids.view(-1, seq_length).long()
1050 |
1051 | if inputs_embeds is None:
1052 | inputs_embeds = self.embed_tokens(input_ids)
1053 | # embed positions
1054 | if attention_mask is None:
1055 | attention_mask = torch.ones(
1056 | (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
1057 | )
1058 | attention_mask = self._prepare_decoder_attention_mask(
1059 | attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1060 | )
1061 |
1062 | hidden_states = inputs_embeds
1063 |
1064 | if self.gradient_checkpointing and self.training:
1065 | if use_cache:
1066 | logger.warning_once(
1067 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1068 | )
1069 | use_cache = False
1070 |
1071 | # decoder layers
1072 | all_hidden_states = () if output_hidden_states else None
1073 | all_self_attns = () if output_attentions else None
1074 | next_decoder_cache = () if use_cache else None
1075 |
1076 | for idx, decoder_layer in enumerate(self.layers):
1077 | if output_hidden_states:
1078 | all_hidden_states += (hidden_states,)
1079 |
1080 | past_key_value = past_key_values[idx] if past_key_values is not None else None
1081 |
1082 | if self.gradient_checkpointing and self.training:
1083 |
1084 | def create_custom_forward(module):
1085 | def custom_forward(*inputs):
1086 | # None for past_key_value
1087 | return module(*inputs, output_attentions, None)
1088 |
1089 | return custom_forward
1090 |
1091 | layer_outputs = torch.utils.checkpoint.checkpoint(
1092 | create_custom_forward(decoder_layer),
1093 | hidden_states,
1094 | attention_mask,
1095 | position_ids,
1096 | None,
1097 | )
1098 | else:
1099 | layer_outputs = decoder_layer(
1100 | hidden_states,
1101 | attention_mask=attention_mask,
1102 | position_ids=position_ids,
1103 | past_key_value=past_key_value,
1104 | output_attentions=output_attentions,
1105 | use_cache=use_cache,
1106 | )
1107 |
1108 | hidden_states = layer_outputs[0]
1109 |
1110 | if use_cache:
1111 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1112 |
1113 | if output_attentions:
1114 | all_self_attns += (layer_outputs[1],)
1115 |
1116 | hidden_states = self.norm(hidden_states)
1117 |
1118 | # add hidden states from the last decoder layer
1119 | if output_hidden_states:
1120 | all_hidden_states += (hidden_states,)
1121 |
1122 | next_cache = next_decoder_cache if use_cache else None
1123 | if not return_dict:
1124 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1125 | return BaseModelOutputWithPast(
1126 | last_hidden_state=hidden_states,
1127 | past_key_values=next_cache,
1128 | hidden_states=all_hidden_states,
1129 | attentions=all_self_attns,
1130 | )
1131 |
1132 |
1133 | class LlamaForCausalLM(LlamaPreTrainedModel):
1134 | def __init__(self, config):
1135 | super().__init__(config)
1136 | self.model = LlamaModel(config)
1137 |
1138 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1139 |
1140 | # Initialize weights and apply final processing
1141 | self.post_init()
1142 |
1143 | def get_input_embeddings(self):
1144 | return self.model.embed_tokens
1145 |
1146 | def set_input_embeddings(self, value):
1147 | self.model.embed_tokens = value
1148 |
1149 | def get_output_embeddings(self):
1150 | return self.lm_head
1151 |
1152 | def set_output_embeddings(self, new_embeddings):
1153 | self.lm_head = new_embeddings
1154 |
1155 | def set_decoder(self, decoder):
1156 | self.model = decoder
1157 |
1158 | def get_decoder(self):
1159 | return self.model
1160 |
1161 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1162 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1163 | def forward(
1164 | self,
1165 | input_ids: torch.LongTensor = None,
1166 | attention_mask: Optional[torch.Tensor] = None,
1167 | position_ids: Optional[torch.LongTensor] = None,
1168 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1169 | inputs_embeds: Optional[torch.FloatTensor] = None,
1170 | labels: Optional[torch.LongTensor] = None,
1171 | use_cache: Optional[bool] = None,
1172 | output_attentions: Optional[bool] = None,
1173 | output_hidden_states: Optional[bool] = None,
1174 | return_dict: Optional[bool] = None,
1175 | ) -> Union[Tuple, CausalLMOutputWithPast]:
1176 | r"""
1177 | Args:
1178 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1179 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1180 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1181 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1182 |
1183 | Returns:
1184 |
1185 | Example:
1186 |
1187 | ```python
1188 | >>> from transformers import AutoTokenizer, LlamaForCausalLM
1189 |
1190 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1191 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1192 |
1193 | >>> prompt = "Hey, are you consciours? Can you talk to me?"
1194 | >>> inputs = tokenizer(prompt, return_tensors="pt")
1195 |
1196 | >>> # Generate
1197 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1198 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1199 | "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
1200 | ```"""
1201 |
1202 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1203 | output_hidden_states = (
1204 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1205 | )
1206 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1207 |
1208 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1209 | outputs = self.model(
1210 | input_ids=input_ids,
1211 | attention_mask=attention_mask,
1212 | position_ids=position_ids,
1213 | past_key_values=past_key_values,
1214 | inputs_embeds=inputs_embeds,
1215 | use_cache=use_cache,
1216 | output_attentions=output_attentions,
1217 | output_hidden_states=output_hidden_states,
1218 | return_dict=return_dict,
1219 | )
1220 |
1221 | hidden_states = outputs[0]
1222 | logits = self.lm_head(hidden_states)
1223 |
1224 | loss = None
1225 | if labels is not None:
1226 | # NOTE: big optimization could be done here (?)
1227 | # maybe the copy operation that you saw in the debugger was happening here
1228 |
1229 | # Shift so that tokens < n predict n
1230 | shift_logits = logits[..., :-1, :].contiguous()
1231 | shift_labels = labels[..., 1:].contiguous()
1232 | # Flatten the tokens
1233 | loss_fct = CrossEntropyLoss()
1234 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
1235 | shift_labels = shift_labels.view(-1)
1236 | # Enable model parallelism
1237 | shift_labels = shift_labels.to(shift_logits.device)
1238 | loss = loss_fct(shift_logits, shift_labels)
1239 |
1240 | if not return_dict:
1241 | output = (logits,) + outputs[1:]
1242 | return (loss,) + output if loss is not None else output
1243 |
1244 | return CausalLMOutputWithPast(
1245 | loss=loss,
1246 | logits=logits,
1247 | past_key_values=outputs.past_key_values,
1248 | hidden_states=outputs.hidden_states,
1249 | attentions=outputs.attentions,
1250 | )
1251 |
1252 | def prepare_inputs_for_generation(
1253 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1254 | ):
1255 | if past_key_values:
1256 | input_ids = input_ids[:, -1:]
1257 |
1258 | position_ids = kwargs.get("position_ids", None)
1259 | if attention_mask is not None and position_ids is None:
1260 | # create position_ids on the fly for batch generation
1261 | position_ids = attention_mask.long().cumsum(-1) - 1
1262 | position_ids.masked_fill_(attention_mask == 0, 1)
1263 | if past_key_values:
1264 | position_ids = position_ids[:, -1].unsqueeze(-1)
1265 |
1266 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1267 | if inputs_embeds is not None and past_key_values is None:
1268 | model_inputs = {"inputs_embeds": inputs_embeds}
1269 | else:
1270 | model_inputs = {"input_ids": input_ids}
1271 |
1272 | model_inputs.update(
1273 | {
1274 | "position_ids": position_ids,
1275 | "past_key_values": past_key_values,
1276 | "use_cache": kwargs.get("use_cache"),
1277 | "attention_mask": attention_mask,
1278 | }
1279 | )
1280 | return model_inputs
1281 |
1282 | @staticmethod
1283 | def _reorder_cache(past_key_values, beam_idx):
1284 | reordered_past = ()
1285 | for layer_past in past_key_values:
1286 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1287 | return reordered_past
1288 |
1289 |
1290 | @add_start_docstrings(
1291 | """
1292 | The LLaMa Model transformer with a sequence classification head on top (linear layer).
1293 |
1294 | [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1295 | (e.g. GPT-2) do.
1296 |
1297 | Since it does classification on the last token, it requires to know the position of the last token. If a
1298 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1299 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1300 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1301 | each row of the batch).
1302 | """,
1303 | LLAMA_START_DOCSTRING,
1304 | )
1305 | class LlamaForSequenceClassification(LlamaPreTrainedModel):
1306 | _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
1307 |
1308 | def __init__(self, config):
1309 | super().__init__(config)
1310 | self.num_labels = config.num_labels
1311 | self.model = LlamaModel(config)
1312 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1313 |
1314 | # Initialize weights and apply final processing
1315 | self.post_init()
1316 |
1317 | def get_input_embeddings(self):
1318 | return self.model.embed_tokens
1319 |
1320 | def set_input_embeddings(self, value):
1321 | self.model.embed_tokens = value
1322 |
1323 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1324 | def forward(
1325 | self,
1326 | input_ids: torch.LongTensor = None,
1327 | attention_mask: Optional[torch.Tensor] = None,
1328 | position_ids: Optional[torch.LongTensor] = None,
1329 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1330 | inputs_embeds: Optional[torch.FloatTensor] = None,
1331 | labels: Optional[torch.LongTensor] = None,
1332 | use_cache: Optional[bool] = None,
1333 | output_attentions: Optional[bool] = None,
1334 | output_hidden_states: Optional[bool] = None,
1335 | return_dict: Optional[bool] = None,
1336 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1337 | r"""
1338 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1339 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1340 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1341 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1342 | """
1343 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1344 |
1345 | transformer_outputs = self.model(
1346 | input_ids,
1347 | attention_mask=attention_mask,
1348 | position_ids=position_ids,
1349 | past_key_values=past_key_values,
1350 | inputs_embeds=inputs_embeds,
1351 | use_cache=use_cache,
1352 | output_attentions=output_attentions,
1353 | output_hidden_states=output_hidden_states,
1354 | return_dict=return_dict,
1355 | )
1356 | hidden_states = transformer_outputs[0]
1357 | logits = self.score(hidden_states)
1358 |
1359 | if input_ids is not None:
1360 | batch_size = input_ids.shape[0]
1361 | else:
1362 | batch_size = inputs_embeds.shape[0]
1363 |
1364 | if self.config.pad_token_id is None and batch_size != 1:
1365 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1366 | if self.config.pad_token_id is None:
1367 | sequence_lengths = -1
1368 | else:
1369 | if input_ids is not None:
1370 | sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
1371 | else:
1372 | sequence_lengths = -1
1373 |
1374 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1375 |
1376 | loss = None
1377 | if labels is not None:
1378 | labels = labels.to(logits.device)
1379 | if self.config.problem_type is None:
1380 | if self.num_labels == 1:
1381 | self.config.problem_type = "regression"
1382 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1383 | self.config.problem_type = "single_label_classification"
1384 | else:
1385 | self.config.problem_type = "multi_label_classification"
1386 |
1387 | if self.config.problem_type == "regression":
1388 | loss_fct = MSELoss()
1389 | if self.num_labels == 1:
1390 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1391 | else:
1392 | loss = loss_fct(pooled_logits, labels)
1393 | elif self.config.problem_type == "single_label_classification":
1394 | loss_fct = CrossEntropyLoss()
1395 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1396 | elif self.config.problem_type == "multi_label_classification":
1397 | loss_fct = BCEWithLogitsLoss()
1398 | loss = loss_fct(pooled_logits, labels)
1399 | if not return_dict:
1400 | output = (pooled_logits,) + transformer_outputs[1:]
1401 | return ((loss,) + output) if loss is not None else output
1402 |
1403 | return SequenceClassifierOutputWithPast(
1404 | loss=loss,
1405 | logits=pooled_logits,
1406 | past_key_values=transformer_outputs.past_key_values,
1407 | hidden_states=transformer_outputs.hidden_states,
1408 | attentions=transformer_outputs.attentions,
1409 | )
1410 |
--------------------------------------------------------------------------------
/peft_pretraining/training_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | import torch
5 | from torch.optim.lr_scheduler import LambdaLR
6 | import transformers
7 |
8 |
9 | def get_scheculer(
10 | optimizer,
11 | *,
12 | scheduler_type,
13 | num_training_steps,
14 | warmup_steps,
15 | min_lr_ratio,
16 | cycle_length=None,
17 | restart_warmup_steps=None,
18 | adjust_step=0,
19 | last_epoch=-1,
20 | ):
21 | if adjust_step != 0 and scheduler_type != "cosine_restarts":
22 | raise ValueError("adjust_step is only supported for cosine_restarts scheduler")
23 |
24 | if scheduler_type == "linear":
25 | return transformers.get_linear_schedule_with_warmup(
26 | optimizer,
27 | num_warmup_steps=warmup_steps,
28 | num_training_steps=num_training_steps,
29 | last_epoch=last_epoch,
30 | )
31 | if scheduler_type == "cosine":
32 | return get_cyclical_cosine_schedule_with_min_lr(
33 | optimizer,
34 | num_warmup_steps=warmup_steps,
35 | num_training_steps=num_training_steps,
36 | cycle_length=cycle_length,
37 | min_lr_ratio=min_lr_ratio,
38 | last_epoch=last_epoch,
39 | )
40 | if scheduler_type == "cosine_restarts":
41 | assert restart_warmup_steps is not None, "restart_warmup_steps must be specified for cosine_restarts scheduler"
42 | return get_cosine_schedule_with_multiple_warmups(
43 | optimizer,
44 | num_training_steps=num_training_steps,
45 | first_warmup_steps=warmup_steps,
46 | restart_warmup_steps=restart_warmup_steps,
47 | restart_every=cycle_length,
48 | min_lr_ratio=min_lr_ratio,
49 | last_epoch=last_epoch,
50 | adjust_step=adjust_step,
51 | )
52 |
53 | raise NotImplementedError(f"Scheduler {scheduler_type} is not implemented")
54 |
55 |
56 | def get_cyclical_cosine_schedule_with_min_lr(optimizer, num_warmup_steps, num_training_steps, cycle_length, min_lr_ratio=0.1, last_epoch=-1):
57 | assert cycle_length is not None or num_training_steps is not None, "You must specify either cycle_length or num_training_steps"
58 |
59 | if cycle_length is None:
60 | cycle_length = num_training_steps
61 |
62 | if num_training_steps % cycle_length != 0:
63 | raise ValueError(f"num_training_steps ({num_training_steps}) must be divisible by cycle_length ({cycle_length})")
64 |
65 | lr_lambda = partial(
66 | _get_cyclical_cosine_schedule_with_min_lr_lambda,
67 | num_warmup_steps=num_warmup_steps,
68 | cycle_length=cycle_length,
69 | min_lr_ratio=min_lr_ratio,
70 | )
71 | return LambdaLR(optimizer, lr_lambda, last_epoch)
72 |
73 |
74 | def get_cosine_schedule_with_multiple_warmups(
75 | optimizer,
76 | *,
77 | num_training_steps,
78 | first_warmup_steps,
79 | restart_warmup_steps,
80 | restart_every,
81 | min_lr_ratio=0.1,
82 | adjust_step=0,
83 | last_epoch=-1,
84 | ):
85 | if restart_every is None:
86 | raise ValueError("restart_every must be specified for cosine_restarts scheduler")
87 |
88 | if num_training_steps % restart_every != 0:
89 | raise ValueError(f"num_training_steps ({num_training_steps}) must be divisible by restart_every ({restart_every})")
90 |
91 | lr_lambda = partial(
92 | _get_cosine_schedule_with_multiple_warmups_lambda,
93 | num_training_steps=num_training_steps,
94 | first_warmup_steps=first_warmup_steps,
95 | restart_warmup_steps=restart_warmup_steps,
96 | restart_every=restart_every,
97 | min_lr_ratio=min_lr_ratio,
98 | adjust_step=adjust_step,
99 | )
100 | return LambdaLR(optimizer, lr_lambda, last_epoch)
101 |
102 |
103 | @torch.no_grad()
104 | def random_pruning(tensor, prune_ratio):
105 | """
106 | Performs random pruning dimensionality reduction.
107 | Only reduces the inner dimensionality, does not affect the shape of the tensor
108 | """
109 | random_pruning_mask = torch.rand_like(tensor) > prune_ratio
110 | tensor = tensor * random_pruning_mask
111 | return tensor
112 |
113 |
114 | @torch.no_grad()
115 | def magnitude_pruning(tensor, prune_ratio):
116 | """
117 | Performs magnitude pruning dimensionality reduction.
118 | Only reduces the inner dimensionality, does not affect the shape of the tensor
119 | """
120 | tensor_magnitude = torch.abs(tensor)
121 | threshold = torch.quantile(tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio).to(dtype=tensor.dtype)
122 |
123 | mask = tensor_magnitude > threshold
124 | tensor = tensor * mask.to(dtype=tensor.dtype)
125 | return tensor
126 |
127 |
128 | def _get_cyclical_cosine_schedule_with_min_lr_lambda(current_step, *, num_warmup_steps, cycle_length, min_lr_ratio):
129 | assert 0 < min_lr_ratio <= 1.0, "min_lr_ratio must be in (0,1]"
130 |
131 | # compute where we are in the current cycle
132 | cycle_step = current_step % cycle_length
133 |
134 | if cycle_step < num_warmup_steps:
135 | if current_step != cycle_step:
136 | if cycle_step < 2:
137 | return 1e-7
138 | return float(cycle_step) / float(max(1, num_warmup_steps))
139 |
140 | progress = float(cycle_step - num_warmup_steps) / float(max(1, cycle_length - num_warmup_steps))
141 | cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
142 |
143 | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay
144 |
145 |
146 | def _get_cosine_schedule_with_multiple_warmups_lambda(
147 | current_step,
148 | *,
149 | num_training_steps,
150 | first_warmup_steps,
151 | restart_warmup_steps,
152 | restart_every,
153 | min_lr_ratio,
154 | adjust_step,
155 | ):
156 | """
157 | Args:
158 | adjust_step: useful when continuing training from a warmed up checkpoint,
159 | it allows to sync the resets by reducing the number of steps
160 | after the first warmup and before the first reset.
161 | Thus, your ReLoRA resets can be synced with the optimizer resets.
162 | """
163 | assert 0 < min_lr_ratio <= 1.0, "min_lr_ratio must be in (0,1]"
164 | assert restart_every > 0, "restart_every must be positive"
165 | assert adjust_step + first_warmup_steps < num_training_steps, "warmup + adjust_step is more than full training steps"
166 | assert adjust_step + first_warmup_steps < restart_every, "the first reset will happen before the warmup is done"
167 |
168 | if current_step < first_warmup_steps:
169 | return float(current_step) / float(max(1, first_warmup_steps))
170 |
171 | _current_step = current_step + adjust_step
172 |
173 | restart_step = _current_step % restart_every
174 | restart_number = _current_step // restart_every
175 |
176 | if restart_step < restart_warmup_steps:
177 | # get expected lr multipler at the end of the warmup
178 | end_of_warmup_progress = (
179 | float(restart_number * restart_every) /
180 | float(max(1, num_training_steps - first_warmup_steps))
181 | )
182 |
183 | _cosine_decay = 0.5 * (1.0 + math.cos(math.pi * end_of_warmup_progress))
184 | warmup_lr_multiplier = min_lr_ratio + (1.0 - min_lr_ratio) * _cosine_decay
185 |
186 | return float(restart_step) / float(max(1, restart_warmup_steps)) * warmup_lr_multiplier
187 |
188 | progress = float(_current_step - first_warmup_steps) / float(max(1, num_training_steps - first_warmup_steps))
189 | cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
190 |
191 | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay
192 |
193 |
194 | def collate_fn(batch_list):
195 | batch = {
196 | "input_ids": torch.stack([torch.Tensor(example["input_ids"]).long() for example in batch_list]),
197 | "attention_mask": torch.stack([torch.Tensor(example["attention_mask"]).long() for example in batch_list]),
198 | }
199 | return batch
200 |
201 |
202 | def batch_fn(dataset, batch_size):
203 | batch = []
204 | for example in dataset:
205 | batch.append(example)
206 | if len(batch) == batch_size:
207 | batch = collate_fn(batch)
208 | yield batch
209 | batch = []
210 | if len(batch) > 0:
211 | yield batch
212 |
213 |
214 | def max_train_tokens_to_number(max_train_tokens):
215 | if max_train_tokens.endswith("M"):
216 | return int(max_train_tokens.rstrip("M")) * 1_000_000
217 | elif max_train_tokens.endswith("B"):
218 | return int(max_train_tokens.rstrip("B")) * 1_000_000_000
219 | else:
220 | return int(max_train_tokens)
221 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | transformers
3 | bitsandbytes
--------------------------------------------------------------------------------
/run_130m.sh:
--------------------------------------------------------------------------------
1 | # Define the set of learning rates and normalization types
2 | norm_type=$1
3 | learning_rates=1e-3
4 | export NORM_TYPE=$norm_type
5 | export POST_NUM=$2
6 |
7 | # Function to run a single training task
8 |
9 | echo "Training with learning rate: $learning_rates, norm type: $norm_type on GPU $gpu"
10 |
11 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node 1 --master_port=29510 torchrun_main.py \
12 | --model_config configs/llama_130m.json \
13 | --lr $learning_rates \
14 | --batch_size 64 \
15 | --total_batch_size 512 \
16 | --num_training_steps 20000 \
17 | --warmup_steps 2000 \
18 | --weight_decay 0 \
19 | --dtype bfloat16 \
20 | --eval_every 1000 \
21 | --optimizer adam \
22 | --grad_clipping 0.0 \
23 | --run_name "130m_res_${norm_type}_lr${learning_rates}_layer_scale" \
24 | --save_dir "130m_res_${norm_type}_lr${learning_rates}"
--------------------------------------------------------------------------------
/run_1b.sh:
--------------------------------------------------------------------------------
1 | # Define the set of learning rates and normalization types
2 | norm_type=$1
3 | learning_rates=5e-4
4 | export NORM_TYPE=$norm_type
5 | export POST_NUM=$2
6 |
7 | # Function to run a single training task
8 |
9 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node 8 --master_port=29500 torchrun_main.py \
10 | --model_config configs/llama_1b.json \
11 | --lr $learning_rates \
12 | --batch_size 32 \
13 | --total_batch_size 512 \
14 | --num_training_steps 100000 \
15 | --warmup_steps 1000 \
16 | --weight_decay 0 \
17 | --dtype bfloat16 \
18 | --eval_every 1000 \
19 | --optimizer adam \
20 | --grad_clipping 0.0 \
21 | --run_name "1b_res_${norm_type}_lr${learning_rates}" \
22 | --save_dir "1b_res_${norm_type}_lr${learning_rates}"
--------------------------------------------------------------------------------
/run_250m.sh:
--------------------------------------------------------------------------------
1 | # Define the set of learning rates and normalization types
2 | norm_type=$1
3 | learning_rates=1e-3
4 | export NORM_TYPE=$norm_type
5 | export POST_NUM=$2
6 |
7 | # Function to run a single training task
8 |
9 | echo "Training with learning rate: $learning_rates, norm type: $norm_type on GPU $gpu"
10 |
11 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 --master_port=29500 torchrun_main.py \
12 | --model_config configs/llama_250m.json \
13 | --lr $learning_rates \
14 | --batch_size 128 \
15 | --total_batch_size 512 \
16 | --num_training_steps 40000 \
17 | --warmup_steps 4000 \
18 | --weight_decay 0 \
19 | --dtype bfloat16 \
20 | --eval_every 1000 \
21 | --optimizer adam \
22 | --grad_clipping 0.0 \
23 | --run_name "250m_res_${norm_type}_lr${learning_rates}" \
24 | --save_dir "250m_res_${norm_type}_lr${learning_rates}"
--------------------------------------------------------------------------------
/run_350m.sh:
--------------------------------------------------------------------------------
1 | # Define the set of learning rates and normalization types
2 | norm_type=$1
3 | learning_rates=5e-4
4 | export NORM_TYPE=$norm_type
5 | export POST_NUM=$2
6 |
7 | # Function to run a single training task
8 |
9 | echo "Training with learning rate: $learning_rates, norm type: $norm_type on GPU $gpu"
10 |
11 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 --master_port=29503 torchrun_main.py \
12 | --model_config configs/llama_350m.json \
13 | --lr $learning_rates \
14 | --batch_size 128 \
15 | --total_batch_size 512 \
16 | --num_training_steps 60000 \
17 | --warmup_steps 6000 \
18 | --weight_decay 0 \
19 | --dtype bfloat16 \
20 | --eval_every 1000 \
21 | --optimizer adam \
22 | --grad_clipping 0.0 \
23 | --run_name "350m_res_${norm_type}_lr${learning_rates}" \
24 | --save_dir "350m_res_${norm_type}_lr${learning_rates}"
--------------------------------------------------------------------------------
/scaled_init.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmsdss/LayerNorm-Scaling/da34a90d7392fe01a7dcdae6715ea6e6706bdd81/scaled_init.png
--------------------------------------------------------------------------------
/scaling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmsdss/LayerNorm-Scaling/da34a90d7392fe01a7dcdae6715ea6e6706bdd81/scaling.png
--------------------------------------------------------------------------------
/torchrun_main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import json
4 | import random
5 | import argparse
6 | import numpy as np
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.utils.data
11 | import torch.distributed as dist
12 |
13 | import transformers
14 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
15 | from transformers import LlamaForCausalLM as HF_LlamaForCausalLM
16 |
17 | import datasets
18 | import datasets.distributed
19 | import wandb
20 |
21 | from tqdm import tqdm
22 | from loguru import logger
23 |
24 | from peft_pretraining import training_utils, args_utils
25 | from peft_pretraining.dataloader import PreprocessedIterableDataset
26 | from peft_pretraining.modeling_llama import LlamaForCausalLM
27 |
28 | import bitsandbytes as bnb
29 |
30 | import matplotlib.pyplot as plt
31 | transformers.logging.set_verbosity_error()
32 |
33 | def parse_args(args):
34 | parser = argparse.ArgumentParser()
35 |
36 | parser.add_argument("--model_config", type=str, required=True)
37 | parser.add_argument("--use_hf_model", default=False, action="store_true")
38 | parser.add_argument("--continue_from", type=str, default=None)
39 | parser.add_argument("--batch_size", type=int, required=True)
40 | parser.add_argument("--gradient_accumulation", type=int, default=None)
41 | parser.add_argument("--total_batch_size", type=int, default=None)
42 | parser.add_argument("--max_length", type=int, default=256)
43 | parser.add_argument("--optimizer", default="Adam")
44 | parser.add_argument("--lr", type=float, default=1e-4)
45 | parser.add_argument("--scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_restarts"])
46 | parser.add_argument("--min_lr_ratio", type=float, default=0.1)
47 | parser.add_argument("--activation_checkpointing", action="store_true")
48 | parser.add_argument("--weight_decay", type=float, default=0.0)
49 | parser.add_argument("--warmup_steps", type=int, default=1_000)
50 | parser.add_argument("--eval_every", type=int, default=2_000)
51 | parser.add_argument("--num_training_steps", type=int, default=10_000,
52 | help="Number of **update steps** to train for. "
53 | "Notice that gradient accumulation is taken into account.")
54 | parser.add_argument("--max_train_tokens", type=training_utils.max_train_tokens_to_number, default=None,
55 | help="Number of tokens to train on. Overwrites num_training_steps. "
56 | "You can use M and B suffixes, e.g. 100M or 1B.")
57 | parser.add_argument("--save_every", type=int, default=10000)
58 | parser.add_argument("--save_dir", type=str, default=None)
59 | parser.add_argument("--tags", type=str, default=None)
60 | parser.add_argument("--dtype", type=str, default="bfloat16" if torch.cuda.is_bf16_supported() else "float32")
61 | parser.add_argument("--workers", type=int, default=8)
62 | parser.add_argument("--seed", type=int, default=1)
63 | parser.add_argument("--name", type=str, default="test")
64 | parser.add_argument("--grad_clipping", type=float, default=1.0)
65 | parser.add_argument("--run_name", type=str, default="default")
66 | # beta1 for adafactor
67 | parser.add_argument("--beta1", type=float, default=0.0)
68 |
69 | # GaLore parameters
70 | parser.add_argument("--rank", type=int, default=128)
71 | parser.add_argument("--update_proj_gap", type=int, default=50)
72 | parser.add_argument("--galore_scale", type=float, default=1.0)
73 | parser.add_argument("--proj_type", type=str, default="std")
74 |
75 | # disable ddp, single_gpu
76 | parser.add_argument("--single_gpu", default=False, action="store_true")
77 |
78 | args = parser.parse_args(args)
79 |
80 | args = args_utils.check_args_torchrun_main(args)
81 | return args
82 |
83 |
84 | @torch.no_grad()
85 | def evaluate_model(model, preprocess_batched, pad_idx, global_rank, world_size, device, batch_size):
86 | _time = time.time()
87 | val_data = datasets.load_dataset("c4", "en", split="validation", streaming=True, trust_remote_code=True) #DGX
88 | val_data = val_data.shuffle(seed=42)
89 | logger.info(f"Loaded validation dataset in {time.time() - _time:.2f} seconds")
90 |
91 | if not args.single_gpu:
92 | val_data = datasets.distributed.split_dataset_by_node(val_data, rank=global_rank, world_size=world_size)
93 |
94 | val_data_mapped = val_data.map(
95 | preprocess_batched,
96 | batched=True,
97 | remove_columns=["text", "timestamp", "url"],
98 | )
99 | val_data_mapped.batch = lambda batch_size: training_utils.batch_fn(val_data_mapped, batch_size)
100 |
101 | target_eval_tokens = 10_000_000
102 | evaluated_on_tokens = 0
103 | total_loss = torch.tensor(0.0).to(device)
104 | total_batches = 1
105 | logger.info(f"Eval set prepared in {time.time() - _time:.2f} seconds")
106 |
107 | for batch in val_data_mapped.batch(batch_size=batch_size):
108 | if evaluated_on_tokens > target_eval_tokens:
109 | break
110 | total_batches += 1
111 |
112 | batch = {k: v.to(device) for k, v in batch.items()}
113 | labels = batch["input_ids"].clone()
114 | labels[labels == pad_idx] = -100
115 | loss = model(**batch, labels=labels).loss
116 | total_loss += loss.detach()
117 |
118 | evaluated_on_tokens += (batch["input_ids"] != pad_idx).sum().item() * world_size
119 |
120 | total_loss = total_loss / total_batches
121 |
122 | # Gather losses across all GPUs
123 | gathered_losses = [torch.zeros_like(total_loss) for _ in range(world_size)]
124 | dist.all_gather(gathered_losses, total_loss)
125 | total_loss = sum([t.item() for t in gathered_losses]) / world_size
126 |
127 | return total_loss, evaluated_on_tokens
128 |
129 |
130 | def main(args):
131 | torch.manual_seed(args.seed)
132 | np.random.seed(args.seed)
133 | random.seed(args.seed)
134 |
135 | assert "LOCAL_RANK" in os.environ, "torchrun should set LOCAL_RANK"
136 | global_rank = int(os.environ['RANK'])
137 | local_rank = int(os.environ["LOCAL_RANK"])
138 | world_size = int(os.environ["WORLD_SIZE"])
139 | torch.cuda.set_device(local_rank)
140 |
141 | logger.info(f"Global rank {global_rank}, local rank {local_rank}, device: {torch.cuda.current_device()}")
142 |
143 | dist.init_process_group(backend="nccl", rank=global_rank, world_size=world_size)
144 |
145 | logger.info("Process group initialized")
146 | device = f"cuda:{local_rank}"
147 |
148 | if args.total_batch_size is not None:
149 | if args.gradient_accumulation is None:
150 | assert args.total_batch_size % world_size == 0, "total_batch_size must be divisible by world_size"
151 | args.gradient_accumulation = args.total_batch_size // (args.batch_size * world_size)
152 | assert args.gradient_accumulation > 0, "gradient_accumulation must be greater than 0"
153 |
154 | assert args.gradient_accumulation * args.batch_size * world_size == args.total_batch_size, \
155 | "gradient_accumulation * batch_size * world_size must be equal to total_batch_size"
156 |
157 | # turn off logger
158 | if global_rank != 0: logger.remove()
159 |
160 | # initialize wandb without config (it is passed later)
161 | if global_rank == 0:
162 | wandb.init(project="cod", name=args.run_name)
163 |
164 | logger.info(f"Using dist with rank {global_rank} (only rank 0 will log)")
165 | logger.info("*" * 40)
166 | logger.info(f"Starting training with the arguments")
167 | for k, v in vars(args).items():
168 | logger.info(f"{k:30} {v}")
169 | logger.info("*" * 40)
170 |
171 | data = datasets.load_dataset("allenai/c4", "en", split="train", streaming=True)
172 |
173 |
174 |
175 | seed_for_shuffle = 32
176 |
177 | logger.info(f"Shuffling data with seed {seed_for_shuffle}")
178 | data: datasets.Dataset = data.shuffle(seed=seed_for_shuffle)
179 | if not args.single_gpu:
180 | data = datasets.distributed.split_dataset_by_node(
181 | data, rank=global_rank, world_size=world_size,
182 | )
183 |
184 | # it doesn't matter which tokenizer we use, because we train from scratch
185 | # T5 tokenizer was trained on C4 and we are also training on C4, so it's a good choice
186 | tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=args.max_length)
187 |
188 | def preprocess_batched(batch):
189 | batch = tokenizer(
190 | batch["text"],
191 | max_length=args.max_length,
192 | truncation=True,
193 | padding="max_length",
194 | return_tensors="pt",
195 | )
196 | return batch
197 |
198 | dataset = PreprocessedIterableDataset(data, tokenizer, batch_size=args.batch_size, max_length=args.max_length)
199 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=args.workers)
200 |
201 | model_config = AutoConfig.from_pretrained(args.model_config)
202 | if args.use_hf_model:
203 | model: HF_LlamaForCausalLM = AutoModelForCausalLM.from_config(model_config)
204 | else:
205 | model = LlamaForCausalLM(model_config)
206 |
207 | if args.activation_checkpointing:
208 | model.gradient_checkpointing_enable()
209 |
210 | global_step = 0
211 | update_step = 0
212 | beginning_step = 0
213 | tokens_seen = 0
214 | tokens_seen_before = 0
215 |
216 | if args.continue_from is not None:
217 | logger.info("*" * 40)
218 | logger.info(f"Loading model from {args.continue_from}")
219 |
220 | from safetensors.torch import load_file
221 | state_dict = load_file(f"{args.continue_from}/model.safetensors")
222 | model.load_state_dict(state_dict)
223 |
224 | # checkpoint_path = os.path.join(args.continue_from, "pytorch_model.bin")
225 | # model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"), strict=True)
226 | logger.info(f"Model successfully loaded (strict=True policy)")
227 |
228 | if os.path.exists(os.path.join(args.continue_from, "training_state.json")):
229 | logger.info(f"Loading training state like global_step, update_step, and tokens_seen from {args.continue_from}")
230 | with open(os.path.join(args.continue_from, "training_state.json")) as f:
231 | _old_state = json.load(f)
232 | global_step = _old_state["global_step"]
233 | update_step = _old_state["update_step"]
234 | tokens_seen = _old_state["tokens_seen"]
235 | tokens_seen_before = _old_state["tokens_seen_before"]
236 | logger.info(f"global_step : {global_step}")
237 | logger.info(f"update_step : {update_step}")
238 | logger.info(f"tokens_seen : {tokens_seen}")
239 | logger.info(f"tokens_seen_before: {tokens_seen_before}")
240 | logger.info(f"Will train for {args.num_training_steps - update_step} update steps")
241 | else:
242 | logger.warning(f"Did not find training state in {args.continue_from}, global step will start from zero")
243 | logger.info("*" * 40)
244 |
245 |
246 | if args.dtype in ["bf16", "bfloat16"]:
247 | model = model.to(device=device, dtype=torch.bfloat16)
248 | else:
249 | model = model.to(device=device)
250 |
251 | n_total_params = sum(p.numel() for p in model.parameters())
252 | trainable_params = [p for p in model.parameters() if p.requires_grad]
253 | # Initialize wandb
254 | run_config = dict(vars(args))
255 | run_config.update({
256 | "max_lr": run_config.pop("lr"), # rename lr to max_lr to avoid conflicts with scheduler
257 | "total_params_M": n_total_params / 1_000_000,
258 | "dataset": 'c4',
259 | "model": model_config.to_dict(),
260 | "world_size": world_size,
261 | "device": str(device),
262 | })
263 |
264 | if global_rank == 0:
265 | wandb.config.update(run_config, allow_val_change=True)
266 | wandb.save(os.path.abspath(__file__), policy="now") # save current script
267 | # fix tqdm visual length to 80 so that the progress bar
268 | # doesn't jump around when changing from external display to laptop
269 | pbar = tqdm(total=args.num_training_steps - update_step, desc="Update steps", ncols=80)
270 |
271 | if 'galore' in args.optimizer.lower():
272 | # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
273 | galore_params = []
274 | target_modules_list = ["attn", "mlp"]
275 | for module_name, module in model.named_modules():
276 | if not isinstance(module, nn.Linear):
277 | continue
278 |
279 | if not any(target_key in module_name for target_key in target_modules_list):
280 | continue
281 |
282 | print('enable GaLore for weights in module: ', module_name)
283 | galore_params.append(module.weight)
284 | id_galore_params = [id(p) for p in galore_params]
285 | # make parameters without "rank" to another group
286 | regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]
287 | # then call galore_adamw
288 | param_groups = [{'params': regular_params},
289 | {'params': galore_params, 'rank': args.rank, 'update_proj_gap': args.update_proj_gap, 'scale': args.galore_scale, 'proj_type': args.proj_type}]
290 |
291 | # print params and trainable params
292 | logger.info(f"\n{model}\n")
293 | logger.info(f"Total params: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M")
294 | logger.info(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000:.2f}M")
295 | if 'galore' in args.optimizer.lower():
296 | logger.info(f"Total params with GaLore enabled: {sum(p.numel() for p in galore_params) / 1_000_000:.2f}M")
297 | logger.info(f"Saving model to {args.save_dir} every {args.save_every} update steps")
298 |
299 | layer_wise_flag = False
300 | if args.optimizer.lower() == "adam":
301 | optimizer = torch.optim.Adam(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
302 | else:
303 | raise ValueError(f"Optimizer {args.optimizer} not supported")
304 |
305 | if not layer_wise_flag:
306 | scheduler = training_utils.get_scheculer(
307 | optimizer=optimizer,
308 | scheduler_type=args.scheduler,
309 | num_training_steps=args.num_training_steps,
310 | warmup_steps=args.warmup_steps,
311 | min_lr_ratio=args.min_lr_ratio,
312 | )
313 |
314 | if not args.single_gpu:
315 | model: LlamaForCausalLM = torch.nn.parallel.DistributedDataParallel(
316 | model,
317 | device_ids=[local_rank],
318 | output_device=local_rank,
319 | broadcast_buffers=False,
320 | )
321 |
322 | # global steps and others are defined above
323 | pad_idx = tokenizer.pad_token_id
324 | update_time = time.time()
325 | local_step = 0 # when continue_from is used, local_step != global_step
326 |
327 | # ##############################
328 | # TRAINING LOOP
329 | # we'll never go through all the data, so no need for epochs
330 | # ##############################
331 |
332 | for batch_idx, batch in enumerate(dataloader):
333 |
334 | global_step += 1
335 | local_step += 1
336 | if update_step > args.num_training_steps:
337 | logger.info(f"Reached max number of update steps (f{args.num_training_steps}). Stopping training.")
338 | print(f"Rank {global_rank} stopping training.")
339 | break
340 |
341 | batch = {k: v.to(device) for k, v in batch.items()}
342 | labels = batch["input_ids"].clone()
343 | labels[labels == pad_idx] = -100
344 | tokens_seen += (batch["input_ids"] != pad_idx).sum().item() * world_size
345 |
346 | loss = model(**batch, labels=labels).loss
347 | scaled_loss = loss / args.gradient_accumulation
348 | scaled_loss.backward()
349 |
350 | if global_step % args.gradient_accumulation != 0:
351 | continue
352 |
353 | # The below code is only executed during the update step
354 |
355 | # add grad clipping
356 | if args.grad_clipping != 0.0: torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clipping)
357 |
358 | if global_rank == 0: pbar.update(1)
359 |
360 | if not layer_wise_flag:
361 | optimizer.step()
362 | scheduler.step()
363 | optimizer.zero_grad()
364 |
365 | update_step += 1
366 | update_time = time.time() - update_time
367 |
368 | # save checkpoint by save_every
369 | if local_step > args.gradient_accumulation and update_step % args.save_every == 0 and global_rank == 0:
370 | current_model_directory = f"{args.save_dir}/model_{update_step}"
371 | logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
372 | tokenizer.save_pretrained(current_model_directory)
373 | os.makedirs(args.save_dir, exist_ok=True)
374 | model.module.save_pretrained(current_model_directory, max_shard_size='100GB')
375 | optimizer_checkpoint = {
376 | "optimizer": optimizer.state_dict(),
377 | "scheduler": scheduler.state_dict(),
378 | "update_step": update_step,
379 | "global_step": global_step,
380 | "config": run_config,
381 | "wandb": wandb.run.dir,
382 | "dtype": args.dtype,
383 | }
384 | torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
385 |
386 | training_state_checkpoint = {
387 | "global_step": global_step,
388 | "update_step": update_step,
389 | "tokens_seen": tokens_seen,
390 | "tokens_seen_before": tokens_seen_before,
391 | "update_time": update_time,
392 | }
393 | with open(f"{current_model_directory}/training_state.json", "w") as f:
394 | json.dump(training_state_checkpoint, f, indent=4)
395 |
396 | # save wandb related info
397 | wandb_info = {
398 | "wandb_id": wandb.run.id,
399 | }
400 | with open(f"{args.save_dir}/wandb.json", "w") as f:
401 | json.dump(wandb_info, f, indent=4)
402 |
403 | # evaluation
404 | if update_step % args.eval_every == 0:
405 | logger.info(f"Performing evaluation at step {update_step}")
406 | total_loss, evaluated_on_tokens = evaluate_model(
407 | model, preprocess_batched, pad_idx, global_rank, world_size, device, args.batch_size
408 | )
409 | if global_rank == 0:
410 | wandb.log({
411 | "final_eval_loss": total_loss,
412 | "final_eval_tokens": evaluated_on_tokens,
413 | },
414 | step=global_step,
415 | )
416 | logger.info(f"Eval loss at step {update_step}: {total_loss}")
417 |
418 | if not layer_wise_flag:
419 | lr = optimizer.param_groups[0]["lr"]
420 | else:
421 | pass
422 | tokens_in_update = tokens_seen - tokens_seen_before
423 | tokens_seen_before = tokens_seen
424 | batches_in_update = args.gradient_accumulation * world_size
425 |
426 | if global_rank == 0:
427 | wandb.log({
428 | "loss": loss.item(),
429 | "lr": lr,
430 | "update_step": update_step,
431 | "tokens_seen": tokens_seen,
432 | "throughput_tokens": tokens_in_update / update_time,
433 | "throughput_examples": args.total_batch_size / update_time,
434 | "throughput_batches": batches_in_update / update_time,
435 | },
436 | step=global_step,
437 | )
438 | update_time = time.time()
439 |
440 | # ##############################
441 | # END of training loop
442 | # ##############################
443 | logger.info("Training finished")
444 | if global_rank == 0: pbar.close()
445 |
446 | current_model_directory = f"{args.save_dir}/model_{update_step}"
447 | if global_rank == 0 and not os.path.exists(current_model_directory):
448 | logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
449 | os.makedirs(args.save_dir, exist_ok=True)
450 | model.module.save_pretrained(current_model_directory)
451 |
452 | optimizer_checkpoint = {
453 | "optimizer": optimizer.state_dict(),
454 | "scheduler": scheduler.state_dict(),
455 | "update_step": update_step,
456 | "global_step": global_step,
457 | "config": run_config,
458 | "wandb": wandb.run.dir,
459 | "dtype": args.dtype,
460 | }
461 | torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
462 |
463 | training_state_checkpoint = {
464 | "global_step": global_step,
465 | "update_step": update_step,
466 | "tokens_seen": tokens_seen,
467 | "tokens_seen_before": tokens_seen_before,
468 | "update_time": update_time,
469 | }
470 | with open(f"{current_model_directory}/training_state.json", "w") as f:
471 | json.dump(training_state_checkpoint, f, indent=4)
472 |
473 | # Final evaluation
474 | logger.info("Running final evaluation")
475 | model.eval()
476 | del loss, optimizer, scheduler
477 | import gc; gc.collect()
478 | torch.cuda.empty_cache()
479 |
480 | total_loss, evaluated_on_tokens = evaluate_model(
481 | model, preprocess_batched, pad_idx, global_rank, world_size, device, args.batch_size
482 | )
483 |
484 | if global_rank == 0:
485 | wandb.log({
486 | "final_eval_loss": total_loss,
487 | "final_eval_tokens": evaluated_on_tokens,
488 | },
489 | step=global_step,
490 | )
491 | logger.info(f"Final eval loss: {total_loss}")
492 |
493 | logger.info("Script finished successfully")
494 | print(f"Rank {global_rank} finished successfully")
495 |
496 |
497 | if __name__ == "__main__":
498 | print("Starting script")
499 | args = parse_args(None)
500 | main(args)
501 |
--------------------------------------------------------------------------------
/utils/angular_distance.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import matplotlib.pyplot as plt
3 | from datasets import load_dataset
4 | from torch.utils.data import DataLoader
5 | from tqdm import tqdm
6 |
7 | from short_hf import ShortHFModel
8 |
9 | def compute_angular_distance(model, data, max_seq_len=1024, stride=256, n_samples=50):
10 | angular_distances = []
11 | for i, batch in enumerate(tqdm(data)):
12 | if i >= n_samples:
13 | break
14 | prompts = batch['text']
15 |
16 | model.eval_importance(
17 | prompts=prompts,
18 | max_seq_len=max_seq_len,
19 | stride=stride,
20 | max_gen_len=0,
21 | angular=True
22 | )
23 | angular_distances.extend(model.importances)
24 | return angular_distances
25 |
26 | def plot_angular_distances(distances, output_path):
27 | plt.figure(figsize=(10, 6))
28 | plt.plot(distances, label='Angular Distances')
29 | plt.xlabel('Sample Index')
30 | plt.ylabel('Angular Distance')
31 | plt.title('Angular Distances Across Samples')
32 | plt.legend()
33 | plt.savefig(output_path)
34 | plt.show()
35 |
36 | def main(args):
37 | model = ShortHFModel(
38 | model_name=args.model_path,
39 | layers_path="model.layers",
40 | n_prune_layers=1, # this is a dummy value, don't worry about it
41 | )
42 |
43 | data = load_dataset("allenai/c4", "en", split="train", streaming=True)
44 |
45 | angular_distances = compute_angular_distance(model, data, args.max_seq_len, args.stride, args.n_samples)
46 |
47 | plot_angular_distances(angular_distances, args.output_path)
48 |
49 | if __name__ == "__main__":
50 | parser = argparse.ArgumentParser(description="Compute and plot angular distances")
51 | parser.add_argument("--model_path", type=str, required=True, help="Path to the model")
52 | parser.add_argument("--n_samples", type=int, default=50, help="Number of samples to process")
53 | parser.add_argument("--max_seq_len", type=int, default=1024, help="Maximum sequence length")
54 | parser.add_argument("--stride", type=int, default=256, help="Stride for processing sequences")
55 | parser.add_argument("--output_path", type=str, default="angular_distances.png", help="Output path for the plot")
56 |
57 | args = parser.parse_args()
58 |
59 | main(args)
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def block_influence(
5 | input_hidden_state: torch.Tensor,
6 | output_hidden_state: torch.Tensor,
7 | angular=False,
8 | ):
9 | """
10 | input_hidden_state: B, S, D
11 | output_hidden_state: B, S, D
12 | """
13 | _, _, d = input_hidden_state.shape
14 | input_hidden_state = input_hidden_state.reshape(-1, d)
15 | output_hidden_state = output_hidden_state.reshape(-1, d)
16 |
17 | norm_input = input_hidden_state.norm(dim=-1, keepdim=True)
18 | norm_output = output_hidden_state.norm(dim=-1, keepdim=True)
19 |
20 | sim = (input_hidden_state @ output_hidden_state.T) / (norm_input * norm_output)
21 | sim = sim.diagonal().nan_to_num(nan=0.5)
22 |
23 | if angular:
24 | return (torch.arccos(sim) / torch.pi)
25 |
26 | return 1 - sim
--------------------------------------------------------------------------------
/utils/short_hf.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from transformers import AutoTokenizer, AutoModelForCausalLM
7 |
8 | from metrics import *
9 |
10 |
11 | class ShortHFModel():
12 |
13 | def __init__(self, model_name: str, layers_path: str, n_prune_layers: Optional[int] = None):
14 | """
15 | HuggingFace Model Wrapper
16 |
17 | Args:
18 | model_name (str): HuggingFace model name
19 | layers_path (str): String in dot notation demonstrating how to access layers of the model. Ex: "model.layers"
20 | (Optional) n_prune_layers (int): Number of layers to prune. Defaults to None.
21 | """
22 | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23 | self.tokenizer.pad_token = self.tokenizer.eos_token
24 | self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
25 | # self.model.params = self.model.to_fp16(self.model.params)
26 | self.model.to("cuda")
27 |
28 | modules = layers_path.split(".")
29 | mod = self.model
30 | for m in modules:
31 | mod = getattr(mod, m)
32 | self.layers = mod
33 |
34 | self.n_prune_layers = n_prune_layers
35 | self.importances = [0 for _ in self.layers] # layer-wise importance scores
36 |
37 | def remove_layers(
38 | self,
39 | layers_to_remove: Optional[List[int]] = [],
40 | angular: Optional[bool] = False
41 | ):
42 | if angular:
43 | assert self.importances, "Need to compute importances with eval_importance()"
44 | assert self.n_prune_layers, "Need number of layers to prune, set `n_prune_layers`"
45 | start_layer = np.argsort(np.array(self.importances[:-self.n_prune_layers+1]))[0]
46 | layers_to_remove = list(range(start_layer, start_layer + self.n_prune_layers))
47 | elif not layers_to_remove and self.n_prune_layers:
48 | assert self.importances, "Need to compute importances with eval_importance()"
49 | layers_to_remove = np.argsort(np.array(self.importances))[:self.n_prune_layers].tolist()
50 |
51 | # remove layers in reverse to avoid indexing errors
52 | for layer_idx in sorted(layers_to_remove, reverse=True):
53 | try:
54 | del self.layers[layer_idx]
55 | except IndexError:
56 | print(f"layer {layer_idx} does not exist, function may have already been called")
57 | return []
58 |
59 | return layers_to_remove
60 |
61 | def compute_bi(self, hiddens: List[torch.Tensor], angular: bool):
62 | n = 1
63 | if angular:
64 | assert self.n_prune_layers is not None, "Set number of layers to prune to use angular importance"
65 | n = self.n_prune_layers
66 |
67 | for i in range(len(hiddens) - n):
68 | in_hidden = hiddens[i]
69 | out_hidden = hiddens[i+n]
70 | if angular:
71 | # use only last token for angular distance as described in section 3.2
72 | # https://arxiv.org/pdf/2403.17887.pdf
73 | in_hidden = in_hidden[:,-1:]
74 | out_hidden = out_hidden[:,-1:]
75 |
76 | self.importances[i] += block_influence(
77 | in_hidden,
78 | out_hidden,
79 | angular=angular
80 | ).sum().cpu().item()
81 |
82 | @torch.inference_mode()
83 | def eval_importance(
84 | self,
85 | prompts: List[str],
86 | max_seq_len: int,
87 | stride: int = 256,
88 | max_gen_len: int = 0,
89 | temperature: float = 0.6,
90 | top_p: float = 0.9,
91 | angular: Optional[bool] = False
92 | ):
93 | """
94 | Computes layer-wise importances over input texts.
95 |
96 | NOTE: ShortGPT paper performs no generation during importance computation, which suggests a `max_gen_len`= 0.
97 |
98 | Args:
99 | prompts (List[str]): List of prompts.
100 | max_seq_len (int): Maximum sequence length for model input, the sliding window size.
101 | (Optional) stride (int): Number of tokens to skip/shift between each window inference.
102 | (Optional) max_gen_len (int): Maximum length of the generated text sequence.
103 | (Optional) temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
104 | (Optional) top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
105 | (Optional) angular (bool): Whether to ues angular distance. Defaults to False.
106 |
107 | Returns:
108 | None
109 | """
110 | prompt_tokens = self.tokenizer(
111 | prompts,
112 | padding=True,
113 | return_attention_mask=True,
114 | return_tensors='pt'
115 | )
116 | input_ids = prompt_tokens.input_ids
117 | attn_mask = prompt_tokens.attention_mask
118 |
119 | max_prompt_len = max(len(t) for t in input_ids)
120 |
121 | # authors use a sliding window of size 1024 with a shift of 256
122 | for start in range(0, max_prompt_len, stride):
123 | seq_ids = (attn_mask.sum(dim=-1) > start).nonzero().squeeze()
124 | seq_ids = seq_ids.unsqueeze(0) if seq_ids.dim() == 0 else seq_ids # ensure 2d
125 | inputs = input_ids[seq_ids, start:start+max_seq_len]
126 | attn = attn_mask[seq_ids, start:start+max_seq_len]
127 |
128 | if max_gen_len == 0:
129 | outputs = self.model(
130 | input_ids=inputs.to("cuda"),
131 | attention_mask=attn.to("cuda"),
132 | output_hidden_states=True,
133 | )
134 | else:
135 | outputs = self.model.generate(
136 | input_ids=inputs.to("cuda"),
137 | attention_mask=attn.to("cuda"),
138 | max_new_tokens=max_gen_len,
139 | do_sample=True,
140 | temperature=temperature,
141 | top_p=top_p,
142 | output_hidden_states=True,
143 | return_dict_in_generate=True,
144 | )
145 |
146 | self.compute_bi(outputs.hidden_states, angular=angular)
147 |
148 | return
--------------------------------------------------------------------------------