├── LICENSE ├── README.md ├── assets └── liger_framework.png ├── checkpoints ├── liger_gla_base │ └── config.json ├── liger_gsa_base │ └── config.json ├── liger_hgrn2_base │ └── config.json ├── liger_mistral_gla_base │ └── config.json ├── liger_qwen25_gla_base │ └── config.json └── lolcats_base │ └── config.json ├── configs ├── liger_gla.yaml ├── liger_gsa.yaml ├── liger_hgrn2.yaml ├── liger_mistral_gla.yaml ├── liger_qwen2_gla.yaml ├── lolcats_ar.yaml └── lolcats_at.yaml ├── eval └── harness.py ├── liger ├── __init__.py └── models │ ├── __init__.py │ ├── liger_gla │ ├── __init__.py │ ├── configuration_liger_gla.py │ └── modeling_liger_gla.py │ ├── liger_gsa │ ├── __init__.py │ ├── configuration_liger_gsa.py │ └── modeling_liger_gsa.py │ ├── liger_hgrn2 │ ├── __init__.py │ ├── configuration_liger_hgrn2.py │ └── modeling_liger_hgrn2.py │ ├── liger_mistral_gla │ ├── __init__.py │ ├── configuration_liger_mistral_gla.py │ └── modeling_liger_mistral_gla.py │ └── liger_qwen2_gla │ ├── __init__.py │ ├── configuration_liger_qwen2_gla.py │ └── modeling_liger_qwen2_gla.py ├── lolcats ├── __init__.py └── models │ ├── __init__.py │ └── lolcats │ ├── __init__.py │ ├── configuration_lolcats.py │ └── modeling_lolcats.py ├── requirements.txt ├── run.py ├── scripts ├── train_liger.sh ├── train_lolcats_stage1.sh └── train_lolcats_stage2.sh └── training ├── __init__.py ├── dataloader.py ├── train.py ├── trainer.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Liger: Linearizing Large Language Models to Gated Recurrent Structures 2 | 3 | [![arXiv](https://img.shields.io/badge/Arxiv-2503.01496-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2503.01496) 4 | [![huggingface weights](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Weights-ffc107?color=ffc107&logoColor=white)](https://huggingface.co/collections/linear-moe-hub/liger-67d904bffd7f9b77ade7747d) 5 | 6 | ## Framework 7 | 8 |

9 | 10 |

11 |
12 | Figure 1: Liger Framework 13 |
14 | 15 | ## Environment 16 | 17 | ```bash 18 | git clone --recurse-submodules https://github.com/OpenSparseLLMs/Linearization.git 19 | conda create -n liger python=3.10 20 | conda activate liger 21 | pip install -r requirements 22 | pip install flash-attn --no-build-isolation 23 | cd third_party/flash-linear-attention 24 | pip install -e . 25 | ``` 26 | 27 | ## Linearization 28 | 29 | 1. Copy your pre-trained base model directory (e.g. Meta-Llama-3-8B) to `./checkpoints/`; 30 | 2. Modify the `config` file of the original Llama-3 base model to the `config` file of the Liger model (see `./checkpoints/liger_gla_base/config.json`); 31 | 3. Modify the linearization settings in `./configs/config.yaml ` file (e.g. liger_gla.yaml); 32 | 4. Run the linearization script: 33 | 34 | ```bash 35 | sh scripts/train_liger.sh 36 | ``` 37 | 38 | ## Evaluation 39 | 40 | You need to install [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) for evaluation: 41 | 42 | ``` 43 | cd third_party/lm-evaluation-harness 44 | pip install -e . 45 | ``` 46 | 47 | ```bash 48 | python -m eval.harness --model hf \ 49 | --model_args pretrained=/your/Liger/checkpoints/liger_base_model,peft=/your/Liger/checkpoints/lora_adapter_path \ 50 | --tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 51 | --batch_size 64 \ 52 | --device cuda \ 53 | --seed 0 54 | ``` 55 | 56 | ## Acknowledgements 57 | 58 | We use the triton-implemented linear attention kernels from [fla-org/flash-linear-attention](https://github.com/fla-org/flash-linear-attention). We refer to [HazyResearch/lolcats](https://github.com/HazyResearch/lolcats) to construct our linearization training processs. The evaluation is supported by [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). Sincerely thank their contributions! 59 | 60 | ## Citation 61 | 62 | If you find this repo useful, please cite and star our work: 63 | 64 | ```bibtex 65 | @article{lan2025liger, 66 | title={Liger: Linearizing Large Language Models to Gated Recurrent Structures}, 67 | author={Lan, Disen and Sun, Weigao and Hu, Jiaxi and Du, Jusen and Cheng, Yu}, 68 | journal={arXiv preprint arXiv:2503.01496}, 69 | year={2025} 70 | } 71 | ``` -------------------------------------------------------------------------------- /assets/liger_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSparseLLMs/Linearization/0e3cfae33a700fa5f644cf5752d8434c6afc2412/assets/liger_framework.png -------------------------------------------------------------------------------- /checkpoints/liger_gla_base/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LigerGLAForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "expand_k": 1, 10 | "expand_v": 1, 11 | "head_dim": 128, 12 | "hidden_act": "silu", 13 | "hidden_ratio": 4, 14 | "hidden_size": 4096, 15 | "initializer_range": 0.02, 16 | "intermediate_size": 14336, 17 | "max_position_embeddings": 8192, 18 | "mlp_bias": false, 19 | "model_type": "liger_gla", 20 | "num_attention_heads": 32, 21 | "num_hidden_layers": 32, 22 | "num_key_value_heads": 8, 23 | "pool_size": 128, 24 | "pretraining_tp": 1, 25 | "rms_norm_eps": 1e-05, 26 | "rope_scaling": null, 27 | "rope_theta": 500000.0, 28 | "tie_word_embeddings": false, 29 | "torch_dtype": "bfloat16", 30 | "transformers_version": "4.47.1", 31 | "use_cache": true, 32 | "vocab_size": 128256 33 | } 34 | -------------------------------------------------------------------------------- /checkpoints/liger_gsa_base/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LigerGSAForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "expand_k": 1, 10 | "expand_v": 1, 11 | "head_dim": 128, 12 | "hidden_act": "silu", 13 | "hidden_ratio": 4, 14 | "hidden_size": 4096, 15 | "initializer_range": 0.02, 16 | "intermediate_size": 14336, 17 | "max_position_embeddings": 8192, 18 | "mlp_bias": false, 19 | "model_type": "liger_gsa", 20 | "num_attention_heads": 32, 21 | "num_hidden_layers": 32, 22 | "num_key_value_heads": 8, 23 | "pool_size": 128, 24 | "pretraining_tp": 1, 25 | "rms_norm_eps": 1e-05, 26 | "rope_scaling": null, 27 | "rope_theta": 500000.0, 28 | "tie_word_embeddings": false, 29 | "torch_dtype": "bfloat16", 30 | "transformers_version": "4.47.1", 31 | "use_cache": true, 32 | "vocab_size": 128256 33 | } 34 | -------------------------------------------------------------------------------- /checkpoints/liger_hgrn2_base/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LigerHGRN2ForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "expand_k": 1, 10 | "expand_v": 1, 11 | "head_dim": 128, 12 | "hidden_act": "silu", 13 | "hidden_ratio": 4, 14 | "hidden_size": 4096, 15 | "initializer_range": 0.02, 16 | "intermediate_size": 14336, 17 | "max_position_embeddings": 8192, 18 | "mlp_bias": false, 19 | "model_type": "liger_hgrn2", 20 | "num_attention_heads": 32, 21 | "num_hidden_layers": 32, 22 | "num_key_value_heads": 8, 23 | "pool_size": 128, 24 | "pretraining_tp": 1, 25 | "rms_norm_eps": 1e-05, 26 | "rope_scaling": null, 27 | "rope_theta": 500000.0, 28 | "tie_word_embeddings": false, 29 | "torch_dtype": "bfloat16", 30 | "transformers_version": "4.47.1", 31 | "use_cache": true, 32 | "vocab_size": 128256 33 | } 34 | -------------------------------------------------------------------------------- /checkpoints/liger_mistral_gla_base/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LigerMistralGLAForCausalLM" 4 | ], 5 | "bos_token_id": 1, 6 | "eos_token_id": 2, 7 | "hidden_act": "silu", 8 | "hidden_size": 4096, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 14336, 11 | "max_position_embeddings": 32768, 12 | "model_type": "liger_mistral_gla", 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 32, 15 | "num_key_value_heads": 8, 16 | "rms_norm_eps": 1e-05, 17 | "rope_theta": 10000.0, 18 | "sliding_window": 4096, 19 | "tie_word_embeddings": false, 20 | "torch_dtype": "bfloat16", 21 | "transformers_version": "4.34.0.dev0", 22 | "use_cache": true, 23 | "vocab_size": 32000 24 | } 25 | -------------------------------------------------------------------------------- /checkpoints/liger_qwen25_gla_base/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LigerQwen2GLAForCausalLM" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 151643, 7 | "eos_token_id": 151643, 8 | "hidden_act": "silu", 9 | "hidden_size": 3584, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 18944, 12 | "max_position_embeddings": 131072, 13 | "max_window_layers": 28, 14 | "model_type": "liger_qwen2_gla", 15 | "num_attention_heads": 28, 16 | "num_hidden_layers": 28, 17 | "num_key_value_heads": 4, 18 | "rms_norm_eps": 1e-06, 19 | "rope_theta": 1000000.0, 20 | "sliding_window": 131072, 21 | "tie_word_embeddings": false, 22 | "torch_dtype": "bfloat16", 23 | "transformers_version": "4.40.1", 24 | "use_cache": true, 25 | "use_mrope": false, 26 | "use_sliding_window": false, 27 | "vocab_size": 152064 28 | } 29 | -------------------------------------------------------------------------------- /checkpoints/lolcats_base/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LolcatsModelForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "attn": null, 8 | "attn_mode": "fused_chunk", 9 | "bos_token_id": 128000, 10 | "elementwise_affine": true, 11 | "eos_token_id": 128001, 12 | "expand_k": 1, 13 | "expand_v": 1, 14 | "feature_map": "lolcats_hedgehog", 15 | "fuse_cross_entropy": true, 16 | "head_dim": 128, 17 | "hidden_act": "silu", 18 | "hidden_ratio": 4, 19 | "hidden_size": 4096, 20 | "initializer_range": 0.02, 21 | "intermediate_size": 14336, 22 | "max_position_embeddings": 8192, 23 | "mlp_bias": false, 24 | "model_type": "lolcats", 25 | "norm_eps": 1e-06, 26 | "norm_feature_map": false, 27 | "norm_k": false, 28 | "norm_q": false, 29 | "num_attention_heads": 32, 30 | "num_hidden_layers": 32, 31 | "num_key_value_heads": 8, 32 | "pretraining_tp": 1, 33 | "rms_norm_eps": 1e-05, 34 | "rope_scaling": null, 35 | "rope_theta": 500000.0, 36 | "tie_feature_map_qk": false, 37 | "tie_word_embeddings": false, 38 | "torch_dtype": "bfloat16", 39 | "transformers_version": "4.46.2", 40 | "use_cache": true, 41 | "vocab_size": 128256 42 | } 43 | -------------------------------------------------------------------------------- /configs/liger_gla.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: "alpaca_cleand" 3 | path: "yahma/alpaca-cleaned" 4 | batch_size: 8 5 | micro_batch_size: 1 6 | val_set_size: 200 7 | model: 8 | name: 'liger_gla' 9 | pretrained_model_name_or_path: '/your/Liger/checkpoints/liger_gla_base' 10 | max_length: 1024 11 | # tokenizer 12 | add_eos_token: False 13 | 14 | device_map: 'auto' 15 | train: 16 | optim: 'adamw_torch' 17 | lr: 0.001 18 | epochs: 2 19 | max_grad_norm: 1.0 20 | output_dir: 'checkpoints' 21 | train_qk: True 22 | train_qk_lora: True 23 | train_v: True 24 | train_v_lora: True -------------------------------------------------------------------------------- /configs/liger_gsa.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: "alpaca_cleand" 3 | path: "yahma/alpaca-cleaned" 4 | batch_size: 8 5 | micro_batch_size: 1 6 | val_set_size: 200 7 | model: 8 | name: 'liger_gsa' 9 | pretrained_model_name_or_path: '/your/Liger/checkpoints/liger_gsa_base' 10 | max_length: 1024 11 | # tokenizer 12 | add_eos_token: False 13 | 14 | device_map: 'auto' 15 | train: 16 | optim: 'adamw_torch' 17 | lr: 0.001 18 | epochs: 2 19 | max_grad_norm: 1.0 20 | output_dir: 'checkpoints' 21 | train_qk: True 22 | train_qk_lora: True 23 | train_v: True 24 | train_v_lora: True -------------------------------------------------------------------------------- /configs/liger_hgrn2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: "alpaca_cleand" 3 | path: "yahma/alpaca-cleaned" 4 | batch_size: 8 5 | micro_batch_size: 1 6 | val_set_size: 200 7 | model: 8 | name: 'liger_hgrn2' 9 | pretrained_model_name_or_path: '/your/Liger/checkpoints/liger_hgrn2_base' 10 | max_length: 1024 11 | # tokenizer 12 | add_eos_token: False 13 | 14 | device_map: 'auto' 15 | train: 16 | optim: 'adamw_torch' 17 | lr: 0.001 18 | epochs: 2 19 | max_grad_norm: 1.0 20 | output_dir: 'checkpoints' 21 | train_qk: True 22 | train_qk_lora: True 23 | train_v: True 24 | train_v_lora: True -------------------------------------------------------------------------------- /configs/liger_mistral_gla.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: "alpaca_cleand" 3 | path: "yahma/alpaca-cleaned" 4 | batch_size: 8 5 | micro_batch_size: 1 6 | val_set_size: 200 7 | model: 8 | name: 'liger_mistral_gla' 9 | pretrained_model_name_or_path: '/your/checkpoints/liger_mistral_gla_base' 10 | max_length: 1024 11 | # tokenizer 12 | add_eos_token: False 13 | 14 | device_map: 'auto' 15 | train: 16 | optim: 'adamw_torch' 17 | lr: 0.001 18 | epochs: 2 19 | max_grad_norm: 1.0 20 | output_dir: 'checkpoints' 21 | train_qk: True 22 | train_qk_lora: True 23 | train_v: True 24 | train_v_lora: True -------------------------------------------------------------------------------- /configs/liger_qwen2_gla.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: "alpaca_cleand" 3 | path: "yahma/alpaca-cleaned" 4 | batch_size: 8 5 | micro_batch_size: 1 6 | val_set_size: 200 7 | model: 8 | name: 'liger_qwen2_gla' 9 | pretrained_model_name_or_path: '/your/Liger/checkpoints/liger_qwen25_gla_base' 10 | max_length: 1024 11 | # tokenizer 12 | add_eos_token: False 13 | 14 | device_map: 'auto' 15 | train: 16 | optim: 'adamw_torch' 17 | lr: 0.001 18 | epochs: 2 19 | max_grad_norm: 1.0 20 | output_dir: 'checkpoints' 21 | train_qk: True 22 | train_qk_lora: True 23 | train_v: True 24 | train_v_lora: True -------------------------------------------------------------------------------- /configs/lolcats_ar.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: "alpaca_cleand" 3 | path: "yahma/alpaca-cleaned" 4 | batch_size: 8 5 | micro_batch_size: 1 6 | val_set_size: 200 7 | model: 8 | name: 'lolcats_ar' 9 | pretrained_model_name_or_path: '/your/Liger/checkpoints/lolcats_at' 10 | max_length: 1024 11 | # tokenizer 12 | add_eos_token: False 13 | 14 | device_map: 'auto' 15 | train: 16 | optim: 'adamw_torch' 17 | lr: 0.0001 18 | epochs: 2 19 | max_grad_norm: 1.0 20 | output_dir: 'checkpoints' 21 | train_qk: True 22 | train_qk_lora: True 23 | train_v: True 24 | train_v_lora: True -------------------------------------------------------------------------------- /configs/lolcats_at.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: "alpaca_cleand" 3 | path: "yahma/alpaca-cleaned" 4 | batch_size: 8 5 | micro_batch_size: 1 6 | val_set_size: 200 7 | model: 8 | name: 'lolcats_at' 9 | pretrained_model_name_or_path: '/your/Liger/checkpoints/lolcats_base' 10 | max_length: 1024 11 | # tokenizer 12 | add_eos_token: False 13 | 14 | device_map: 'auto' 15 | train: 16 | optim: 'adamw_torch' 17 | lr: 0.01 18 | epochs: 2 19 | max_grad_norm: 1.0 20 | output_dir: 'checkpoints' 21 | train_qk: False 22 | train_qk_lora: False 23 | train_v: False 24 | train_v_lora: False -------------------------------------------------------------------------------- /eval/harness.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | import fla # noqa 6 | import liger 7 | import lolcats 8 | from lm_eval.__main__ import cli_evaluate 9 | from lm_eval.api.registry import register_model 10 | from lm_eval.models.huggingface import HFLM 11 | 12 | 13 | @register_model('fla') 14 | class FlashLinearAttentionLMWrapper(HFLM): 15 | def __init__(self, **kwargs) -> FlashLinearAttentionLMWrapper: 16 | 17 | # TODO: provide options for doing inference with different kernels 18 | 19 | super().__init__(**kwargs) 20 | 21 | 22 | if __name__ == "__main__": 23 | cli_evaluate() 24 | -------------------------------------------------------------------------------- /liger/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from liger.models.liger_gla import LigerGLAConfig, LigerGLAForCausalLM, LigerGLAModel 4 | from liger.models.liger_gsa import LigerGSAConfig, LigerGSAForCausalLM, LigerGSAModel 5 | from liger.models.liger_hgrn2 import LigerHGRN2Config, LigerHGRN2ForCausalLM, LigerHGRN2Model 6 | from liger.models.liger_mistral_gla import LigerMistralGLAConfig, LigerMistralGLAForCausalLM, LigerMistralGLAModel 7 | from liger.models.liger_qwen2_gla import LigerQwen2GLAConfig, LigerQwen2GLAForCausalLM, LigerQwen2GLAModel 8 | 9 | __all__ = [ 10 | 'LigerGLAConfig', 'LigerGLAForCausalLM', 'LigerGLAModel', 11 | 'LigerGSAConfig', 'LigerGSAForCausalLM', 'LigerGSAModel', 12 | 'LigerHGRN2Config', 'LigerHGRN2ForCausalLM', 'LigerHGRN2Model', 13 | 'LigerMistralGLAConfig', 'LigerMistralGLAForCausalLM', 'LigerMistralGLAModel', 14 | 'LigerQwen2GLAConfig', 'LigerQwen2GLAForCausalLM', 'LigerQwen2GLAModel', 15 | ] -------------------------------------------------------------------------------- /liger/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from liger.models.liger_gla import LigerGLAConfig, LigerGLAForCausalLM, LigerGLAModel 4 | from liger.models.liger_gsa import LigerGSAConfig, LigerGSAForCausalLM, LigerGSAModel 5 | from liger.models.liger_hgrn2 import LigerHGRN2Config, LigerHGRN2ForCausalLM, LigerHGRN2Model 6 | from liger.models.liger_mistral_gla import LigerMistralGLAConfig, LigerMistralGLAForCausalLM, LigerMistralGLAModel 7 | from liger.models.liger_qwen2_gla import LigerQwen2GLAConfig, LigerQwen2GLAForCausalLM, LigerQwen2GLAModel 8 | 9 | __all__ = [ 10 | 'LigerGLAConfig', 'LigerGLAForCausalLM', 'LigerGLAModel', 11 | 'LigerGSAConfig', 'LigerGSAForCausalLM', 'LigerGSAModel', 12 | 'LigerHGRN2Config', 'LigerHGRN2ForCausalLM', 'LigerHGRN2Model', 13 | 'LigerMistralGLAConfig', 'LigerMistralGLAForCausalLM', 'LigerMistralGLAModel', 14 | 'LigerQwen2GLAConfig', 'LigerQwen2GLAForCausalLM', 'LigerQwen2GLAModel', 15 | ] -------------------------------------------------------------------------------- /liger/models/liger_gla/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from liger.models.liger_gla.configuration_liger_gla import LigerGLAConfig 4 | from liger.models.liger_gla.modeling_liger_gla import LigerGLAForCausalLM, LigerGLAModel 5 | 6 | AutoConfig.register(LigerGLAConfig.model_type, LigerGLAConfig) 7 | AutoModel.register(LigerGLAConfig, LigerGLAModel) 8 | AutoModelForCausalLM.register(LigerGLAConfig, LigerGLAForCausalLM) 9 | 10 | 11 | __all__ = ['LigerGLAConfig', 'LigerGLAForCausalLM', 'LigerGLAModel'] -------------------------------------------------------------------------------- /liger/models/liger_gla/configuration_liger_gla.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.models.llama.configuration_llama import LlamaConfig 7 | 8 | class LigerGLAConfig(LlamaConfig, PretrainedConfig): 9 | model_type = 'liger_gla' 10 | keys_to_ignore_at_inference = ['past_key_values'] 11 | 12 | def __init__( 13 | self, 14 | # llama config 15 | vocab_size=32000, 16 | hidden_size=4096, 17 | intermediate_size=11008, 18 | num_hidden_layers=32, 19 | num_attention_heads=32, 20 | num_key_value_heads=None, 21 | hidden_act="silu", 22 | max_position_embeddings=2048, 23 | initializer_range=0.02, 24 | rms_norm_eps=1e-6, 25 | use_cache=True, 26 | pad_token_id=None, 27 | bos_token_id=1, 28 | eos_token_id=2, 29 | pretraining_tp=1, 30 | tie_word_embeddings=False, 31 | rope_theta=10000.0, 32 | rope_scaling=None, 33 | attention_bias=False, 34 | attention_dropout=0.0, 35 | mlp_bias=False, 36 | head_dim=None, 37 | **kwargs, 38 | ): 39 | super().__init__( 40 | vocab_size=vocab_size, 41 | hidden_size=hidden_size, 42 | intermediate_size=intermediate_size, 43 | num_hidden_layers=num_hidden_layers, 44 | num_attention_heads=num_attention_heads, 45 | num_key_value_heads=num_key_value_heads, 46 | hidden_act=hidden_act, 47 | max_position_embeddings=max_position_embeddings, 48 | initializer_range=initializer_range, 49 | rms_norm_eps=rms_norm_eps, 50 | use_cache=use_cache, 51 | pad_token_id=pad_token_id, 52 | bos_token_id=bos_token_id, 53 | eos_token_id=eos_token_id, 54 | pretraining_tp=pretraining_tp, 55 | tie_word_embeddings=tie_word_embeddings, 56 | rope_theta=rope_theta, 57 | rope_scaling=rope_scaling, 58 | attention_bias=attention_bias, 59 | attention_dropout=attention_dropout, 60 | mlp_bias=mlp_bias, 61 | head_dim=head_dim, 62 | **kwargs, 63 | ) -------------------------------------------------------------------------------- /liger/models/liger_gla/modeling_liger_gla.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | import copy 4 | from typing import List, Optional, Tuple, Union 5 | from einops import rearrange, repeat 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint 11 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 12 | from transformers.generation import GenerationMixin 13 | from transformers.modeling_outputs import ( 14 | BaseModelOutputWithPast, 15 | CausalLMOutputWithPast, 16 | ) 17 | from transformers.models.llama.modeling_llama import ( 18 | LlamaRMSNorm, 19 | LlamaRotaryEmbedding, 20 | repeat_kv, 21 | apply_rotary_pos_emb, 22 | LlamaMLP, 23 | LlamaAttention, 24 | LlamaFlashAttention2, 25 | LlamaSdpaAttention, 26 | LlamaDecoderLayer, 27 | LlamaForCausalLM, 28 | LlamaModel, 29 | LlamaPreTrainedModel, 30 | LLAMA_INPUTS_DOCSTRING, 31 | ) 32 | 33 | from transformers.utils import logging, add_start_docstrings_to_model_forward 34 | from transformers.utils import is_flash_attn_2_available 35 | 36 | if is_flash_attn_2_available(): 37 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 38 | else: 39 | print("flash_attn_2 is not available") 40 | 41 | from fla.models.utils import Cache as FlaCache 42 | from fla.ops.gla import fused_chunk_gla, fused_recurrent_gla 43 | 44 | from .configuration_liger_gla import LigerGLAConfig 45 | 46 | logger = logging.get_logger(__name__) 47 | 48 | class LigerGatedLinearAttention(nn.Module): 49 | def __init__( 50 | self, 51 | config: LigerGLAConfig, 52 | layer_idx: Optional[int] = None, 53 | ): 54 | super().__init__() 55 | self.config = config 56 | self.layer_idx = layer_idx 57 | 58 | self.attention_dropout = config.attention_dropout 59 | self.hidden_size = config.hidden_size 60 | self.num_heads = config.num_attention_heads 61 | self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) 62 | self.num_key_value_heads = config.num_key_value_heads 63 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 64 | 65 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 66 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 67 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 68 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 69 | 70 | 71 | self.rotary_emb = LlamaRotaryEmbedding(config=self.config) 72 | self.pool_g = nn.AdaptiveAvgPool1d(output_size=self.head_dim * self.num_key_value_heads) 73 | 74 | def forward( 75 | self, 76 | hidden_states: torch.Tensor, 77 | attention_mask: Optional[torch.Tensor] = None, 78 | position_ids: Optional[torch.LongTensor] = None, 79 | past_key_value: Optional[FlaCache] = None, 80 | output_attentions: bool = False, 81 | use_cache: bool = False, 82 | cache_position: Optional[torch.LongTensor] = None, 83 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 84 | **kwargs, 85 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 86 | last_state = None 87 | if past_key_value is not None and len(past_key_value) > self.layer_idx: 88 | last_state = past_key_value[self.layer_idx] 89 | 90 | q = self.q_proj(hidden_states) 91 | k = self.k_proj(hidden_states) 92 | v = self.v_proj(hidden_states) 93 | g = self.pool_g(k) 94 | 95 | # dealing with left-padding 96 | if attention_mask is not None: 97 | v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) 98 | 99 | q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads) 100 | k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_key_value_heads) 101 | v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_key_value_heads) 102 | g = rearrange(g, 'b n (h m) -> b h n m', h=self.num_key_value_heads) 103 | 104 | k = repeat_kv(k, self.num_key_value_groups) 105 | v = repeat_kv(v, self.num_key_value_groups) 106 | g = repeat_kv(g, self.num_key_value_groups) 107 | 108 | sq, sk, sv = q, k, v 109 | 110 | # norm 111 | q = F.softmax(q, dim=-1) 112 | k = F.softmax(k, dim=-1) 113 | 114 | gate_logit_normalizer = 16 115 | g = F.logsigmoid(g) / gate_logit_normalizer # (b, h, n, m) 116 | 117 | recurrent_state = last_state['recurrent_state'] if last_state is not None else None 118 | offsets = kwargs.get('offsets', None) 119 | scale = 1 120 | q, k, v, g = (x.to(torch.float32).contiguous() for x in (q, k, v, g)) 121 | 122 | if self.training or q.shape[-2] > 1: 123 | o_, recurrent_state = fused_chunk_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True) 124 | else: 125 | o_, recurrent_state = fused_recurrent_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True) 126 | 127 | if past_key_value is not None: 128 | past_key_value.update( 129 | recurrent_state=recurrent_state, 130 | layer_idx=self.layer_idx, 131 | offset=q.shape[1] 132 | ) 133 | 134 | q_len = hidden_states.size(-2) 135 | 136 | if position_embeddings is None: 137 | logger.warning_once( 138 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 139 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " 140 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " 141 | "removed and `position_embeddings` will be mandatory." 142 | ) 143 | cos, sin = self.rotary_emb(sv, position_ids) 144 | else: 145 | cos, sin = position_embeddings 146 | sq, sk = apply_rotary_pos_emb(sq, sk, cos, sin) 147 | 148 | input_dtype = sq.dtype 149 | if input_dtype == torch.float32: 150 | if torch.is_autocast_enabled(): 151 | target_dtype = torch.get_autocast_gpu_dtype() 152 | # Handle the case where the model is quantized 153 | elif hasattr(self.config, "_pre_quantization_dtype"): 154 | target_dtype = self.config._pre_quantization_dtype 155 | else: 156 | target_dtype = self.q_proj.weight.dtype 157 | 158 | logger.warning_once( 159 | f"The input hidden states seems to be silently casted in float32, this might be related to" 160 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 161 | f" {target_dtype}." 162 | ) 163 | 164 | sq = sq.to(target_dtype) 165 | sk = sk.to(target_dtype) 166 | sv = sv.to(target_dtype) 167 | 168 | window_size = 64 169 | if attention_mask is not None and 0.0 in attention_mask: 170 | pass 171 | else: 172 | attention_mask = None 173 | 174 | y = _flash_attention_forward( # Reashape to the expected shape for Flash Attention 175 | sq.transpose(1, 2), 176 | sk.transpose(1, 2), 177 | sv.transpose(1, 2), 178 | attention_mask, 179 | q_len, 180 | position_ids=position_ids, 181 | dropout=0.0, 182 | sliding_window=window_size, 183 | use_top_left_mask=False, 184 | is_causal=True, 185 | target_dtype=torch.float32, 186 | ).transpose(1, 2) 187 | o_ = 0.5 * y + 0.5 * o_ 188 | o = rearrange(o_.bfloat16(), 'b h n d -> b n (h d)') 189 | o = self.o_proj(o) 190 | 191 | return o, None, past_key_value 192 | 193 | class LigerGLADecoderLayer(LlamaDecoderLayer): 194 | def __init__(self, config: LigerGLAConfig, layer_idx: int): 195 | super().__init__(config, layer_idx) 196 | self.hidden_size = config.hidden_size 197 | self.self_attn = LigerGatedLinearAttention(config=config, layer_idx=layer_idx) 198 | self.mlp = LlamaMLP(config) 199 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 200 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 201 | 202 | class LigerGLAPreTrainedModel(LlamaPreTrainedModel): 203 | 204 | config_class = LigerGLAConfig 205 | base_model_prefix = "model" 206 | supports_gradient_checkpointing = True 207 | _no_split_modules = ['LigerGLADecoderLayer'] 208 | _skip_keys_device_placement = "past_key_values" 209 | 210 | class LigerGLAModel(LlamaModel, LigerGLAPreTrainedModel): 211 | 212 | def __init__(self, config: LigerGLAConfig): 213 | super().__init__(config) 214 | self.padding_idx = config.pad_token_id 215 | self.vocab_size = config.vocab_size 216 | 217 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 218 | self.layers = nn.ModuleList( 219 | [LigerGLADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 220 | ) 221 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 222 | self.rotary_emb = LlamaRotaryEmbedding(config=config) 223 | self.gradient_checkpointing = False 224 | 225 | # Initialize weights and apply final processing 226 | self.post_init() 227 | 228 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 229 | def forward( 230 | self, 231 | input_ids: torch.LongTensor = None, 232 | attention_mask: Optional[torch.Tensor] = None, 233 | position_ids: Optional[torch.LongTensor] = None, 234 | past_key_values: Optional[Union[Tuple, FlaCache, List[torch.FloatTensor]]] = None, 235 | inputs_embeds: Optional[torch.FloatTensor] = None, 236 | use_cache: Optional[bool] = None, 237 | output_attentions: Optional[bool] = None, 238 | output_hidden_states: Optional[bool] = None, 239 | return_dict: Optional[bool] = None, 240 | cache_position: Optional[torch.LongTensor] = None, 241 | **kwargs, 242 | ) -> Union[Tuple, BaseModelOutputWithPast]: 243 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 244 | output_hidden_states = ( 245 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 246 | ) 247 | use_cache = use_cache if use_cache is not None else self.config.use_cache 248 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 249 | 250 | if (input_ids is None) ^ (inputs_embeds is not None): 251 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds") 252 | 253 | if self.gradient_checkpointing and self.training and use_cache: 254 | logger.warning_once( 255 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." 256 | ) 257 | use_cache = False 258 | 259 | if inputs_embeds is None: 260 | inputs_embeds = self.embed_tokens(input_ids) 261 | 262 | 263 | # kept for BC (non `Cache` `past_key_values` inputs) 264 | return_legacy_cache = False 265 | if use_cache and not isinstance(past_key_values, FlaCache): 266 | past_key_values = FlaCache.from_legacy_cache(past_key_values) 267 | 268 | if cache_position is None: 269 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 270 | cache_position = torch.arange( 271 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 272 | ) 273 | 274 | if position_ids is None: 275 | position_ids = cache_position.unsqueeze(0) 276 | 277 | causal_mask = attention_mask 278 | 279 | hidden_states = inputs_embeds 280 | 281 | # create position embeddings to be shared across the decoder layers 282 | position_embeddings = self.rotary_emb(hidden_states, position_ids) 283 | 284 | # decoder layers 285 | all_hidden_states = () if output_hidden_states else None 286 | all_self_attns = () if output_attentions else None 287 | next_decoder_cache = None 288 | 289 | if output_attentions: 290 | all_softmax_hidden_states = () 291 | 292 | for decoder_layer in self.layers: 293 | if output_hidden_states: 294 | all_hidden_states += (hidden_states,) 295 | if all_softmax_hidden_states is not None: 296 | all_softmax_hidden_states += (hidden_states,) 297 | 298 | if self.gradient_checkpointing and self.training: 299 | layer_outputs = self._gradient_checkpointing_func( 300 | decoder_layer.__call__, 301 | hidden_states, 302 | causal_mask, 303 | position_ids, 304 | past_key_values, 305 | output_attentions, 306 | use_cache, 307 | cache_position, 308 | position_embeddings, 309 | ) 310 | 311 | else: 312 | layer_outputs = decoder_layer( 313 | hidden_states, 314 | attention_mask=causal_mask, 315 | position_ids=position_ids, 316 | past_key_value=past_key_values, 317 | output_attentions=output_attentions, 318 | use_cache=use_cache, 319 | cache_position=cache_position, 320 | position_embeddings=position_embeddings, 321 | ) 322 | hidden_states = layer_outputs[0] 323 | 324 | if use_cache: 325 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 326 | 327 | if output_attentions: 328 | all_self_attns += (layer_outputs[1],) 329 | 330 | hidden_states = self.norm(hidden_states) 331 | 332 | # add hidden states from the last decoder layer 333 | if output_hidden_states: 334 | all_hidden_states += (hidden_states,) 335 | 336 | next_cache = next_decoder_cache if use_cache else None 337 | 338 | if not return_dict: 339 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 340 | return BaseModelOutputWithPast( 341 | last_hidden_state=hidden_states, 342 | past_key_values=next_cache, 343 | hidden_states=all_hidden_states, 344 | attentions=all_self_attns, 345 | ) 346 | 347 | class LigerGLAForCausalLM(LlamaForCausalLM, LigerGLAPreTrainedModel, GenerationMixin): 348 | _tied_weights_keys = ["lm_head.weight"] 349 | 350 | def __init__(self, config): 351 | super().__init__(config) 352 | self.model = LigerGLAModel(config) 353 | self.vocab_size = config.vocab_size 354 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 355 | 356 | # Initialize weights and apply final processing 357 | self.post_init() -------------------------------------------------------------------------------- /liger/models/liger_gsa/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from liger.models.liger_gsa.configuration_liger_gsa import LigerGSAConfig 4 | from liger.models.liger_gsa.modeling_liger_gsa import LigerGSAForCausalLM, LigerGSAModel 5 | 6 | AutoConfig.register(LigerGSAConfig.model_type, LigerGSAConfig) 7 | AutoModel.register(LigerGSAConfig, LigerGSAModel) 8 | AutoModelForCausalLM.register(LigerGSAConfig, LigerGSAForCausalLM) 9 | 10 | 11 | __all__ = ['LigerGSAConfig', 'LigerGSAForCausalLM', 'LigerGSAModel'] -------------------------------------------------------------------------------- /liger/models/liger_gsa/configuration_liger_gsa.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.models.llama.configuration_llama import LlamaConfig 7 | 8 | class LigerGSAConfig(LlamaConfig, PretrainedConfig): 9 | model_type = 'liger_gsa' 10 | keys_to_ignore_at_inference = ['past_key_values'] 11 | 12 | def __init__( 13 | self, 14 | # llama config 15 | vocab_size=32000, 16 | hidden_size=4096, 17 | intermediate_size=11008, 18 | num_hidden_layers=32, 19 | num_attention_heads=32, 20 | num_key_value_heads=None, 21 | hidden_act="silu", 22 | max_position_embeddings=2048, 23 | initializer_range=0.02, 24 | rms_norm_eps=1e-6, 25 | use_cache=True, 26 | pad_token_id=None, 27 | bos_token_id=1, 28 | eos_token_id=2, 29 | pretraining_tp=1, 30 | tie_word_embeddings=False, 31 | rope_theta=10000.0, 32 | rope_scaling=None, 33 | attention_bias=False, 34 | attention_dropout=0.0, 35 | mlp_bias=False, 36 | head_dim=None, 37 | pool_size: int = 64, # pooling 38 | **kwargs, 39 | ): 40 | self.pool_size = pool_size 41 | 42 | super().__init__( 43 | vocab_size=vocab_size, 44 | hidden_size=hidden_size, 45 | intermediate_size=intermediate_size, 46 | num_hidden_layers=num_hidden_layers, 47 | num_attention_heads=num_attention_heads, 48 | num_key_value_heads=num_key_value_heads, 49 | hidden_act=hidden_act, 50 | max_position_embeddings=max_position_embeddings, 51 | initializer_range=initializer_range, 52 | rms_norm_eps=rms_norm_eps, 53 | use_cache=use_cache, 54 | pad_token_id=pad_token_id, 55 | bos_token_id=bos_token_id, 56 | eos_token_id=eos_token_id, 57 | pretraining_tp=pretraining_tp, 58 | tie_word_embeddings=tie_word_embeddings, 59 | rope_theta=rope_theta, 60 | rope_scaling=rope_scaling, 61 | attention_bias=attention_bias, 62 | attention_dropout=attention_dropout, 63 | mlp_bias=mlp_bias, 64 | head_dim=head_dim, 65 | **kwargs, 66 | ) -------------------------------------------------------------------------------- /liger/models/liger_gsa/modeling_liger_gsa.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import warnings 5 | from typing import List, Optional, Tuple, Union 6 | from einops import rearrange, repeat 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint 12 | from transformers.activations import ACT2FN 13 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 14 | from transformers.modeling_outputs import ( 15 | BaseModelOutputWithPast, 16 | CausalLMOutputWithPast, 17 | ) 18 | from transformers.models.llama.modeling_llama import ( 19 | LlamaRMSNorm, 20 | LlamaRotaryEmbedding, 21 | apply_rotary_pos_emb, 22 | repeat_kv, 23 | LlamaMLP, 24 | LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention, 25 | LlamaDecoderLayer, 26 | LlamaForCausalLM, 27 | LlamaModel, 28 | LlamaPreTrainedModel, 29 | LLAMA_INPUTS_DOCSTRING, 30 | ) 31 | from transformers.utils import logging, add_start_docstrings_to_model_forward 32 | from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 33 | 34 | if is_flash_attn_2_available(): 35 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 36 | 37 | from fla.models.utils import Cache as FlaCache 38 | from fla.modules.activations import swish 39 | from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa 40 | 41 | from .configuration_liger_gsa import LigerGSAConfig 42 | 43 | logger = logging.get_logger(__name__) 44 | 45 | 46 | class LigerGatedSlotAttention(nn.Module): 47 | def __init__( 48 | self, 49 | config: LigerGSAConfig, 50 | layer_idx: Optional[int] = None, 51 | ): 52 | super().__init__() 53 | self.config = config 54 | self.layer_idx = layer_idx 55 | 56 | self.attention_dropout = config.attention_dropout 57 | self.hidden_size = config.hidden_size 58 | self.num_heads = config.num_attention_heads 59 | self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) 60 | self.num_key_value_heads = config.num_key_value_heads 61 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 62 | 63 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 64 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 65 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 66 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 67 | 68 | self.rotary_emb = LlamaRotaryEmbedding(config=self.config) 69 | 70 | self.pool_size = config.pool_size 71 | self.pool_g = nn.AdaptiveAvgPool1d(output_size=self.pool_size * self.num_key_value_heads) 72 | 73 | def forward( 74 | self, 75 | hidden_states: torch.Tensor, 76 | attention_mask: Optional[torch.Tensor] = None, 77 | position_ids: Optional[torch.LongTensor] = None, 78 | past_key_value: Optional[FlaCache] = None, 79 | output_attentions: bool = False, 80 | use_cache: bool = False, 81 | cache_position: Optional[torch.LongTensor] = None, 82 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 83 | **kwargs, 84 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 85 | last_state = None 86 | if past_key_value is not None and len(past_key_value) > self.layer_idx: 87 | last_state = past_key_value[self.layer_idx] 88 | 89 | q = self.q_proj(hidden_states) 90 | k = self.k_proj(hidden_states) 91 | v = self.v_proj(hidden_states) 92 | g = self.pool_g(k) 93 | 94 | q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads) 95 | k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_key_value_heads) 96 | v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_key_value_heads) 97 | g = rearrange(g, 'b n (h m) -> b h n m', h=self.num_key_value_heads) 98 | 99 | k = repeat_kv(k, self.num_key_value_groups) 100 | v = repeat_kv(v, self.num_key_value_groups) 101 | g = repeat_kv(g, self.num_key_value_groups) 102 | 103 | sq, sk, sv = q, k, v 104 | 105 | gate_logit_normalizer = 16 106 | g = F.logsigmoid(g) / gate_logit_normalizer # (b, h, n, m) 107 | s = 1 - torch.exp(g) 108 | # dealing with left-padding 109 | if attention_mask is not None: 110 | s = s.mul_(attention_mask[:, None, -s.shape[2]:, None]) 111 | v = v.mul_(attention_mask[:, None, -v.shape[2]:, None]) 112 | 113 | recurrent_state = last_state['recurrent_state'] if last_state is not None else None 114 | scale = 1 115 | 116 | q, k, v, s, g = (x.to(torch.float32).contiguous() for x in (q, k, v, s, g)) 117 | 118 | if self.training or q.shape[-2] > 1: 119 | o_, recurrent_state = chunk_gsa(q, k, v, s, g, scale=scale, initial_state=recurrent_state, output_final_state=True) 120 | else: 121 | o_, recurrent_state = fused_recurrent_gsa(q, k, v, s, g, scale=scale, initial_state=recurrent_state, output_final_state=True) 122 | 123 | if past_key_value is not None: 124 | past_key_value.update( 125 | recurrent_state=recurrent_state, 126 | layer_idx=self.layer_idx, 127 | offset=q.shape[1] 128 | ) 129 | 130 | q_len = hidden_states.size(-2) 131 | 132 | if position_embeddings is None: 133 | logger.warning_once( 134 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 135 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " 136 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " 137 | "removed and `position_embeddings` will be mandatory." 138 | ) 139 | cos, sin = self.rotary_emb(sv, position_ids) 140 | else: 141 | cos, sin = position_embeddings 142 | sq, sk = apply_rotary_pos_emb(sq, sk, cos, sin) 143 | 144 | 145 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 146 | # therefore the input hidden states gets silently casted in float32. Hence, we need 147 | # cast them back in the correct dtype just to be sure everything works as expected. 148 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 149 | # in fp32. (LlamaRMSNorm handles it correctly) 150 | 151 | input_dtype = sq.dtype 152 | if input_dtype == torch.float32: 153 | if torch.is_autocast_enabled(): 154 | target_dtype = torch.get_autocast_gpu_dtype() 155 | # Handle the case where the model is quantized 156 | elif hasattr(self.config, "_pre_quantization_dtype"): 157 | target_dtype = self.config._pre_quantization_dtype 158 | else: 159 | target_dtype = self.q_proj.weight.dtype 160 | 161 | logger.warning_once( 162 | f"The input hidden states seems to be silently casted in float32, this might be related to" 163 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 164 | f" {target_dtype}." 165 | ) 166 | 167 | sq = sq.to(target_dtype) 168 | sk = sk.to(target_dtype) 169 | sv = sv.to(target_dtype) 170 | 171 | window_size = 64 172 | y = _flash_attention_forward( # Reashape to the expected shape for Flash Attention 173 | sq.transpose(1, 2), 174 | sk.transpose(1, 2), 175 | sv.transpose(1, 2), 176 | attention_mask, 177 | q_len, 178 | position_ids=position_ids, 179 | dropout=0.0, 180 | sliding_window=window_size, 181 | use_top_left_mask=not is_flash_attn_greater_or_equal_2_10(), 182 | is_causal=True, 183 | target_dtype=torch.float32, 184 | **kwargs, 185 | ).transpose(1, 2) 186 | 187 | o_ = 0.5 * y + 0.5 * o_ # 0.5 is important 188 | o = rearrange(o_.bfloat16(), 'b h n d -> b n (h d)') 189 | o = self.o_proj(o) 190 | 191 | return o, None, past_key_value 192 | 193 | 194 | 195 | class LigerGSADecoderLayer(LlamaDecoderLayer): 196 | def __init__(self, config: LigerGSAConfig, layer_idx: int): 197 | super().__init__(config, layer_idx) # layer_idx 198 | self.hidden_size = config.hidden_size 199 | self.self_attn = LigerGatedSlotAttention(config=config, layer_idx=layer_idx) 200 | self.mlp = LlamaMLP(config) 201 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 202 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 203 | 204 | def forward( 205 | self, 206 | hidden_states: torch.Tensor, 207 | attention_mask: Optional[torch.Tensor] = None, 208 | position_ids: Optional[torch.LongTensor] = None, 209 | past_key_value: Optional[Union[FlaCache, Tuple]] = None, 210 | output_attentions: Optional[bool] = False, 211 | use_cache: Optional[bool] = False, 212 | cache_position: Optional[torch.LongTensor] = None, 213 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 214 | **kwargs, 215 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 216 | outputs = super().forward( 217 | hidden_states, 218 | attention_mask, 219 | position_ids, 220 | past_key_value, 221 | output_attentions, 222 | use_cache, 223 | cache_position, 224 | position_embeddings, 225 | **kwargs 226 | ) 227 | return outputs 228 | 229 | 230 | class LigerGSAPreTrainedModel(LlamaPreTrainedModel): 231 | 232 | config_class = LigerGSAConfig 233 | base_model_prefix = "model" 234 | supports_gradient_checkpointing = True 235 | _no_split_modules = ['LigerGSADecoderLayer'] 236 | _skip_keys_device_placement = "past_key_values" 237 | 238 | def _init_weights( 239 | self, 240 | module, 241 | ): 242 | std = self.config.initializer_range 243 | if isinstance(module, nn.Linear): 244 | module.weight.data.normal_(mean=0.0, std=std) 245 | if module.bias is not None: 246 | module.bias.data.zero_() 247 | elif isinstance(module, nn.Embedding): 248 | module.weight.data.normal_(mean=0.0, std=std) 249 | if module.padding_idx is not None: 250 | module.weight.data[module.padding_idx].zero_() 251 | 252 | class LigerGSAModel(LlamaModel, LigerGSAPreTrainedModel): 253 | 254 | def __init__(self, config: LigerGSAConfig): 255 | super().__init__(config) 256 | self.padding_idx = config.pad_token_id 257 | self.vocab_size = config.vocab_size 258 | 259 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 260 | self.layers = nn.ModuleList( 261 | [LigerGSADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 262 | ) 263 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 264 | self.rotary_emb = LlamaRotaryEmbedding(config=config) 265 | self.gradient_checkpointing = False 266 | 267 | # Initialize weights and apply final processing 268 | self.post_init() 269 | 270 | def get_input_embeddings(self): 271 | return self.embed_tokens 272 | 273 | def set_input_embeddings(self, value): 274 | self.embed_tokens = value 275 | 276 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 277 | def forward( 278 | self, 279 | input_ids: torch.LongTensor = None, 280 | attention_mask: Optional[torch.Tensor] = None, 281 | position_ids: Optional[torch.LongTensor] = None, 282 | past_key_values: Optional[Union[Tuple, FlaCache, List[torch.FloatTensor]]] = None, 283 | inputs_embeds: Optional[torch.FloatTensor] = None, 284 | use_cache: Optional[bool] = None, 285 | output_attentions: Optional[bool] = None, 286 | output_hidden_states: Optional[bool] = None, 287 | return_dict: Optional[bool] = None, 288 | cache_position: Optional[torch.LongTensor] = None, 289 | ) -> Union[Tuple, BaseModelOutputWithPast]: 290 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 291 | output_hidden_states = ( 292 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 293 | ) 294 | use_cache = use_cache if use_cache is not None else self.config.use_cache 295 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 296 | 297 | if (input_ids is None) ^ (inputs_embeds is not None): 298 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds") 299 | 300 | if self.gradient_checkpointing and self.training and use_cache: 301 | logger.warning_once( 302 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." 303 | ) 304 | use_cache = False 305 | 306 | if inputs_embeds is None: 307 | inputs_embeds = self.embed_tokens(input_ids) 308 | 309 | # kept for BC (non `Cache` `past_key_values` inputs) 310 | return_legacy_cache = False 311 | if use_cache: 312 | if output_attentions: 313 | # LigerGSA kv 314 | LigerGSA_past_key_values, softmax_past_key_values = None, None 315 | if past_key_values is None: 316 | LigerGSA_past_key_values = FlaCache.from_legacy_cache(past_key_values) 317 | softmax_past_key_values = DynamicCache() 318 | else: 319 | if not isinstance(past_key_values[0], FlaCache): 320 | LigerGSA_past_key_values = FlaCache.from_legacy_cache(past_key_values[0]) 321 | # softmax kv 322 | if not isinstance(past_key_values[1], Cache): 323 | return_legacy_cache = True 324 | if past_key_values[1] is None: 325 | softmax_past_key_values = DynamicCache() 326 | else: 327 | softmax_past_key_values = DynamicCache.from_legacy_cache(past_key_values[1]) 328 | logger.warning_once( 329 | "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " 330 | "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " 331 | "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" 332 | ) 333 | 334 | past_key_values = (LigerGSA_past_key_values, softmax_past_key_values) 335 | else: 336 | # only LigerGSA kv 337 | if not isinstance(past_key_values, FlaCache): 338 | past_key_values = FlaCache.from_legacy_cache(past_key_values) 339 | 340 | if cache_position is None: 341 | if output_attentions: 342 | past_seen_tokens = past_key_values[1].get_seq_length() if past_key_values is not None else 0 343 | else: 344 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 345 | cache_position = torch.arange( 346 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 347 | ) 348 | if position_ids is None: 349 | position_ids = cache_position.unsqueeze(0) 350 | 351 | if output_attentions: 352 | causal_mask = self._update_causal_mask( 353 | attention_mask, inputs_embeds, cache_position, past_key_values[1], output_attentions 354 | ) 355 | causal_mask = (attention_mask, causal_mask) 356 | else: 357 | causal_mask = attention_mask 358 | hidden_states = inputs_embeds 359 | 360 | # create position embeddings to be shared across the decoder layers 361 | position_embeddings = self.rotary_emb(hidden_states, position_ids) 362 | 363 | # decoder layers 364 | all_hidden_states = () if output_hidden_states else None 365 | all_self_attns = () if output_attentions else None 366 | next_decoder_cache = None 367 | 368 | if output_attentions: 369 | all_softmax_hidden_states = () 370 | 371 | for decoder_layer in self.layers: 372 | if output_hidden_states: 373 | all_hidden_states += (hidden_states,) 374 | if all_softmax_hidden_states is not None: 375 | all_softmax_hidden_states += (hidden_states,) 376 | 377 | if self.gradient_checkpointing and self.training: 378 | layer_outputs = self._gradient_checkpointing_func( 379 | decoder_layer.__call__, 380 | hidden_states, 381 | causal_mask, 382 | position_ids, 383 | past_key_values, 384 | output_attentions, 385 | use_cache, 386 | cache_position, 387 | position_embeddings, 388 | ) 389 | 390 | else: 391 | layer_outputs = decoder_layer( 392 | hidden_states, 393 | attention_mask=causal_mask, 394 | position_ids=position_ids, 395 | past_key_value=past_key_values, 396 | output_attentions=output_attentions, 397 | use_cache=use_cache, 398 | cache_position=cache_position, 399 | position_embeddings=position_embeddings, 400 | ) 401 | hidden_states = layer_outputs[0] 402 | 403 | if use_cache: 404 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 405 | 406 | if output_attentions: 407 | all_self_attns += (layer_outputs[1],) 408 | 409 | hidden_states = self.norm(hidden_states) 410 | 411 | # add hidden states from the last decoder layer 412 | if output_hidden_states: 413 | all_hidden_states += (hidden_states,) 414 | 415 | if output_attentions: 416 | next_cache = next_decoder_cache[1] if use_cache else None 417 | if return_legacy_cache: 418 | next_cache = next_cache.to_legacy_cache() 419 | 420 | next_cache = (next_decoder_cache[0], next_cache) 421 | else: 422 | next_cache = next_decoder_cache if use_cache else None 423 | 424 | if not return_dict: 425 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 426 | return BaseModelOutputWithPast( 427 | last_hidden_state=hidden_states, 428 | past_key_values=next_cache, 429 | hidden_states=all_hidden_states, 430 | attentions=all_self_attns, 431 | ) 432 | 433 | class LigerGSAForCausalLM(LlamaForCausalLM, LigerGSAPreTrainedModel): 434 | _tied_weights_keys = ["lm_head.weight"] 435 | 436 | def __init__(self, config): 437 | super().__init__(config) 438 | self.model = LigerGSAModel(config) 439 | self.vocab_size = config.vocab_size 440 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 441 | 442 | # Initialize weights and apply final processing 443 | self.post_init() -------------------------------------------------------------------------------- /liger/models/liger_hgrn2/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from liger.models.liger_hgrn2.configuration_liger_hgrn2 import LigerHGRN2Config 4 | from liger.models.liger_hgrn2.modeling_liger_hgrn2 import LigerHGRN2ForCausalLM, LigerHGRN2Model 5 | 6 | AutoConfig.register(LigerHGRN2Config.model_type, LigerHGRN2Config) 7 | AutoModel.register(LigerHGRN2Config, LigerHGRN2Model) 8 | AutoModelForCausalLM.register(LigerHGRN2Config, LigerHGRN2ForCausalLM) 9 | 10 | 11 | __all__ = ['LigerHGRN2Config', 'LigerHGRN2ForCausalLM', 'LigerHGRN2Model'] -------------------------------------------------------------------------------- /liger/models/liger_hgrn2/configuration_liger_hgrn2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.models.llama.configuration_llama import LlamaConfig 7 | 8 | class LigerHGRN2Config(LlamaConfig, PretrainedConfig): 9 | model_type = 'liger_hgrn2' 10 | keys_to_ignore_at_inference = ['past_key_values'] 11 | 12 | def __init__( 13 | self, 14 | # llama config 15 | vocab_size=32000, 16 | hidden_size=4096, 17 | intermediate_size=11008, 18 | num_hidden_layers=32, 19 | num_attention_heads=32, 20 | num_key_value_heads=None, 21 | hidden_act="silu", 22 | max_position_embeddings=2048, 23 | initializer_range=0.02, 24 | rms_norm_eps=1e-6, 25 | use_cache=True, 26 | pad_token_id=None, 27 | bos_token_id=1, 28 | eos_token_id=2, 29 | pretraining_tp=1, 30 | tie_word_embeddings=False, 31 | rope_theta=10000.0, 32 | rope_scaling=None, 33 | attention_bias=False, 34 | attention_dropout=0.0, 35 | mlp_bias=False, 36 | head_dim=None, 37 | pool_size: int = 128, # pooling 38 | **kwargs, 39 | ): 40 | self.pool_size = pool_size 41 | 42 | super().__init__( 43 | vocab_size=vocab_size, 44 | hidden_size=hidden_size, 45 | intermediate_size=intermediate_size, 46 | num_hidden_layers=num_hidden_layers, 47 | num_attention_heads=num_attention_heads, 48 | num_key_value_heads=num_key_value_heads, 49 | hidden_act=hidden_act, 50 | max_position_embeddings=max_position_embeddings, 51 | initializer_range=initializer_range, 52 | rms_norm_eps=rms_norm_eps, 53 | use_cache=use_cache, 54 | pad_token_id=pad_token_id, 55 | bos_token_id=bos_token_id, 56 | eos_token_id=eos_token_id, 57 | pretraining_tp=pretraining_tp, 58 | tie_word_embeddings=tie_word_embeddings, 59 | rope_theta=rope_theta, 60 | rope_scaling=rope_scaling, 61 | attention_bias=attention_bias, 62 | attention_dropout=attention_dropout, 63 | mlp_bias=mlp_bias, 64 | head_dim=head_dim, 65 | **kwargs, 66 | ) -------------------------------------------------------------------------------- /liger/models/liger_hgrn2/modeling_liger_hgrn2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from typing import List, Optional, Tuple, Union 4 | from einops import rearrange, repeat 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint 10 | from transformers.activations import ACT2FN 11 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 12 | from transformers.modeling_outputs import ( 13 | BaseModelOutputWithPast, 14 | CausalLMOutputWithPast, 15 | ) 16 | from transformers.models.llama.modeling_llama import ( 17 | LlamaRMSNorm, 18 | LlamaRotaryEmbedding, 19 | apply_rotary_pos_emb, 20 | repeat_kv, 21 | LlamaMLP, 22 | LlamaAttention, # LlamaFlashAttention2, LlamaSdpaAttention, 23 | LlamaDecoderLayer, 24 | LlamaForCausalLM, 25 | LlamaModel, 26 | LlamaPreTrainedModel, 27 | LLAMA_INPUTS_DOCSTRING, 28 | ) 29 | 30 | 31 | from transformers.utils import logging, add_start_docstrings_to_model_forward 32 | from transformers.utils import is_flash_attn_2_available 33 | 34 | if is_flash_attn_2_available(): 35 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 36 | 37 | from fla.modules.activations import swish 38 | from fla.models.utils import Cache as FlaCache 39 | from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla 40 | 41 | from liger.models.liger_hgrn2 import LigerHGRN2Config 42 | 43 | logger = logging.get_logger(__name__) 44 | 45 | 46 | class LigerHGRN2Attention(nn.Module): 47 | def __init__( 48 | self, 49 | config: LigerHGRN2Config, 50 | layer_idx: Optional[int] = None, 51 | ): 52 | super().__init__() 53 | self.config = config 54 | self.layer_idx = layer_idx 55 | 56 | self.attention_dropout = config.attention_dropout 57 | self.hidden_size = config.hidden_size 58 | self.num_heads = config.num_attention_heads 59 | self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) 60 | self.num_key_value_heads = config.num_key_value_heads 61 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 62 | 63 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 64 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 65 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 66 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 67 | 68 | self.rotary_emb = LlamaRotaryEmbedding(config=self.config) 69 | 70 | self.pool_g = nn.AdaptiveAvgPool1d(output_size=self.head_dim * self.num_key_value_heads) 71 | 72 | def forward( 73 | self, 74 | hidden_states: torch.Tensor, 75 | attention_mask: Optional[torch.Tensor] = None, 76 | position_ids: Optional[torch.LongTensor] = None, 77 | past_key_value: Optional[FlaCache] = None, 78 | output_attentions: bool = False, 79 | use_cache: bool = False, 80 | cache_position: Optional[torch.LongTensor] = None, 81 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 82 | **kwargs, 83 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 84 | last_state = None 85 | if past_key_value is not None and len(past_key_value) > self.layer_idx: 86 | last_state = past_key_value[self.layer_idx] 87 | 88 | q = self.q_proj(hidden_states) 89 | k = self.k_proj(hidden_states) 90 | v = self.v_proj(hidden_states) 91 | f = self.pool_g(k) 92 | 93 | q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads) 94 | k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_key_value_heads) 95 | v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_key_value_heads) 96 | f = rearrange(f, 'b n (h m) -> b h n m', h=self.num_key_value_heads) 97 | 98 | k = repeat_kv(k, self.num_key_value_groups) 99 | v = repeat_kv(v, self.num_key_value_groups) 100 | f = repeat_kv(f, self.num_key_value_groups) 101 | 102 | sq, sk, sv = q, k, v 103 | 104 | # the lower bound for the first layer is zero 105 | lower_bound = None 106 | if lower_bound is None or self.layer_idx == 0: 107 | k, g = 1 - f.sigmoid(), F.logsigmoid(f) 108 | else: 109 | g = lower_bound + (1 - lower_bound) * f.sigmoid() 110 | k, g = 1 - g, g.log() 111 | 112 | # norm 113 | q = F.softmax(q, dim=-1) 114 | # k = F.softmax(k, dim=-1) 115 | 116 | # dealing with left-padding 117 | if attention_mask is not None: 118 | v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) 119 | 120 | recurrent_state = last_state['recurrent_state'] if last_state is not None else None 121 | offsets = kwargs.get('offsets', None) 122 | scale = 1 # default 123 | q, k, v, g = (x.to(torch.float32).contiguous() for x in (q, k, v, g)) 124 | 125 | if self.training or q.shape[-2] > 1: 126 | o_, recurrent_state = fused_chunk_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True) 127 | else: 128 | o_, recurrent_state = fused_recurrent_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True, offsets=offsets) 129 | 130 | if past_key_value is not None: 131 | past_key_value.update( 132 | recurrent_state=recurrent_state, 133 | layer_idx=self.layer_idx, 134 | offset=q.shape[1] 135 | ) 136 | 137 | q_len = hidden_states.size(-2) 138 | 139 | if position_embeddings is None: 140 | logger.warning_once( 141 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 142 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " 143 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " 144 | "removed and `position_embeddings` will be mandatory." 145 | ) 146 | cos, sin = self.rotary_emb(sv, position_ids) 147 | else: 148 | cos, sin = position_embeddings 149 | sq, sk = apply_rotary_pos_emb(sq, sk, cos, sin) 150 | 151 | # if past_key_value is not None: 152 | # # sin and cos are specific to RoPE models; cache_position needed for the static cache 153 | # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 154 | # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 155 | 156 | input_dtype = sq.dtype 157 | if input_dtype == torch.float32: 158 | if torch.is_autocast_enabled(): 159 | target_dtype = torch.get_autocast_gpu_dtype() 160 | # Handle the case where the model is quantized 161 | elif hasattr(self.config, "_pre_quantization_dtype"): 162 | target_dtype = self.config._pre_quantization_dtype 163 | else: 164 | target_dtype = self.q_proj.weight.dtype 165 | 166 | logger.warning_once( 167 | f"The input hidden states seems to be silently casted in float32, this might be related to" 168 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 169 | f" {target_dtype}." 170 | ) 171 | 172 | sq = sq.to(target_dtype) 173 | sk = sk.to(target_dtype) 174 | sv = sv.to(target_dtype) 175 | 176 | window_size = 64 177 | 178 | y = _flash_attention_forward( # Reashape to the expected shape for Flash Attention 179 | sq.transpose(1, 2), 180 | sk.transpose(1, 2), 181 | sv.transpose(1, 2), 182 | attention_mask, 183 | q_len, 184 | position_ids=position_ids, 185 | dropout=0.0, 186 | sliding_window=window_size, 187 | use_top_left_mask=False, 188 | is_causal=True, 189 | target_dtype=torch.float32, 190 | ).transpose(1, 2) 191 | 192 | o_ = 0.5 * y + 0.5 * o_ 193 | o = rearrange(o_.bfloat16(), 'b h n d -> b n (h d)') 194 | o = self.o_proj(o) 195 | 196 | return o, o_, past_key_value 197 | 198 | class LigerHGRN2DecoderLayer(LlamaDecoderLayer): 199 | def __init__(self, config: LigerHGRN2Config, layer_idx: int): 200 | super().__init__(config, layer_idx) 201 | self.hidden_size = config.hidden_size 202 | self.self_attn = LigerHGRN2Attention(config=config, layer_idx=layer_idx) 203 | self.mlp = LlamaMLP(config) 204 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 205 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 206 | 207 | 208 | class LigerHGRN2PreTrainedModel(LlamaPreTrainedModel): 209 | 210 | config_class = LigerHGRN2Config 211 | base_model_prefix = "model" 212 | supports_gradient_checkpointing = True 213 | _no_split_modules = ['LigerHGRN2DecoderLayer'] 214 | _skip_keys_device_placement = "past_key_values" 215 | 216 | class LigerHGRN2Model(LlamaModel, LigerHGRN2PreTrainedModel): 217 | 218 | def __init__(self, config: LigerHGRN2Config): 219 | super().__init__(config) 220 | self.padding_idx = config.pad_token_id 221 | self.vocab_size = config.vocab_size 222 | 223 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 224 | self.layers = nn.ModuleList( 225 | [LigerHGRN2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 226 | ) 227 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 228 | self.rotary_emb = LlamaRotaryEmbedding(config=config) 229 | self.gradient_checkpointing = False 230 | 231 | # Initialize weights and apply final processing 232 | self.post_init() 233 | 234 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 235 | def forward( 236 | self, 237 | input_ids: torch.LongTensor = None, 238 | attention_mask: Optional[torch.Tensor] = None, 239 | position_ids: Optional[torch.LongTensor] = None, 240 | past_key_values: Optional[Union[Tuple, FlaCache, List[torch.FloatTensor]]] = None, 241 | inputs_embeds: Optional[torch.FloatTensor] = None, 242 | use_cache: Optional[bool] = None, 243 | output_attentions: Optional[bool] = None, 244 | output_hidden_states: Optional[bool] = None, 245 | return_dict: Optional[bool] = None, 246 | cache_position: Optional[torch.LongTensor] = None, 247 | **kwargs, 248 | ) -> Union[Tuple, BaseModelOutputWithPast]: 249 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 250 | output_hidden_states = ( 251 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 252 | ) 253 | use_cache = use_cache if use_cache is not None else self.config.use_cache 254 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 255 | 256 | if (input_ids is None) ^ (inputs_embeds is not None): 257 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds") 258 | 259 | if self.gradient_checkpointing and self.training and use_cache: 260 | logger.warning_once( 261 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." 262 | ) 263 | use_cache = False 264 | 265 | if inputs_embeds is None: 266 | inputs_embeds = self.embed_tokens(input_ids) 267 | 268 | 269 | # kept for BC (non `Cache` `past_key_values` inputs) 270 | return_legacy_cache = False 271 | if use_cache and not isinstance(past_key_values, FlaCache): 272 | past_key_values = FlaCache.from_legacy_cache(past_key_values) 273 | 274 | if cache_position is None: 275 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 276 | cache_position = torch.arange( 277 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 278 | ) 279 | 280 | if position_ids is None: 281 | position_ids = cache_position.unsqueeze(0) 282 | 283 | causal_mask = attention_mask 284 | 285 | hidden_states = inputs_embeds 286 | 287 | # create position embeddings to be shared across the decoder layers 288 | position_embeddings = self.rotary_emb(hidden_states, position_ids) 289 | 290 | # decoder layers 291 | all_hidden_states = () if output_hidden_states else None 292 | all_self_attns = () if output_attentions else None 293 | next_decoder_cache = None 294 | 295 | if output_attentions: 296 | all_softmax_hidden_states = () 297 | 298 | for decoder_layer in self.layers: 299 | if output_hidden_states: 300 | all_hidden_states += (hidden_states,) 301 | if all_softmax_hidden_states is not None: 302 | all_softmax_hidden_states += (hidden_states,) 303 | 304 | if self.gradient_checkpointing and self.training: 305 | layer_outputs = self._gradient_checkpointing_func( 306 | decoder_layer.__call__, 307 | hidden_states, 308 | causal_mask, 309 | position_ids, 310 | past_key_values, 311 | output_attentions, 312 | use_cache, 313 | cache_position, 314 | position_embeddings, 315 | ) 316 | 317 | else: 318 | layer_outputs = decoder_layer( 319 | hidden_states, 320 | attention_mask=causal_mask, 321 | position_ids=position_ids, 322 | past_key_value=past_key_values, 323 | output_attentions=output_attentions, 324 | use_cache=use_cache, 325 | cache_position=cache_position, 326 | position_embeddings=position_embeddings, 327 | ) 328 | hidden_states = layer_outputs[0] 329 | 330 | if use_cache: 331 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 332 | 333 | if output_attentions: 334 | all_self_attns += (layer_outputs[1],) 335 | 336 | hidden_states = self.norm(hidden_states) 337 | 338 | # add hidden states from the last decoder layer 339 | if output_hidden_states: 340 | all_hidden_states += (hidden_states,) 341 | 342 | next_cache = next_decoder_cache if use_cache else None 343 | 344 | if not return_dict: 345 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 346 | return BaseModelOutputWithPast( 347 | last_hidden_state=hidden_states, 348 | past_key_values=next_cache, 349 | hidden_states=all_hidden_states, 350 | attentions=all_self_attns, 351 | ) 352 | 353 | class LigerHGRN2ForCausalLM(LlamaForCausalLM, LigerHGRN2PreTrainedModel): 354 | _tied_weights_keys = ["lm_head.weight"] 355 | 356 | def __init__(self, config): 357 | super().__init__(config) 358 | self.model = LigerHGRN2Model(config) 359 | self.vocab_size = config.vocab_size 360 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 361 | 362 | # Initialize weights and apply final processing 363 | self.post_init() -------------------------------------------------------------------------------- /liger/models/liger_mistral_gla/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from liger.models.liger_mistral_gla.configuration_liger_mistral_gla import LigerMistralGLAConfig 4 | from liger.models.liger_mistral_gla.modeling_liger_mistral_gla import LigerMistralGLAForCausalLM, LigerMistralGLAModel 5 | 6 | AutoConfig.register(LigerMistralGLAConfig.model_type, LigerMistralGLAConfig) 7 | AutoModel.register(LigerMistralGLAConfig, LigerMistralGLAModel) 8 | AutoModelForCausalLM.register(LigerMistralGLAConfig, LigerMistralGLAForCausalLM) 9 | 10 | 11 | __all__ = ['LigerMistralGLAConfig', 'LigerMistralGLAForCausalLM', 'LigerMistralGLAModel'] -------------------------------------------------------------------------------- /liger/models/liger_mistral_gla/configuration_liger_mistral_gla.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.utils import logging 7 | from transformers.models.mistral.configuration_mistral import MistralConfig 8 | 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | class LigerMistralGLAConfig(MistralConfig, PretrainedConfig): 13 | model_type = "liger_mistral_gla" 14 | keys_to_ignore_at_inference = ["past_key_values"] 15 | 16 | base_model_tp_plan = { 17 | "layers.*.self_attn.q_proj": "colwise", 18 | "layers.*.self_attn.k_proj": "colwise", 19 | "layers.*.self_attn.v_proj": "colwise", 20 | "layers.*.self_attn.o_proj": "rowwise", 21 | "layers.*.mlp.gate_proj": "colwise", 22 | "layers.*.mlp.up_proj": "colwise", 23 | "layers.*.mlp.down_proj": "rowwise", 24 | } 25 | def __init__( 26 | self, 27 | vocab_size=32000, 28 | hidden_size=4096, 29 | intermediate_size=14336, 30 | num_hidden_layers=32, 31 | num_attention_heads=32, 32 | num_key_value_heads=8, 33 | head_dim=None, 34 | hidden_act="silu", 35 | max_position_embeddings=4096 * 32, 36 | initializer_range=0.02, 37 | rms_norm_eps=1e-6, 38 | use_cache=True, 39 | pad_token_id=None, 40 | bos_token_id=1, 41 | eos_token_id=2, 42 | tie_word_embeddings=False, 43 | rope_theta=10000.0, 44 | sliding_window=4096, 45 | attention_dropout=0.0, 46 | # linear attention 47 | expand_k: int = 1, 48 | expand_v: int = 1, 49 | hidden_ratio: Optional[int] = 4, 50 | **kwargs, 51 | ): 52 | self.expand_k = expand_k 53 | self.expand_v = expand_v 54 | self.hidden_ratio = hidden_ratio 55 | 56 | super().__init__( 57 | vocab_size=vocab_size, 58 | hidden_size=hidden_size, 59 | intermediate_size=intermediate_size, 60 | num_hidden_layers=num_hidden_layers, 61 | num_attention_heads=num_attention_heads, 62 | num_key_value_heads=num_key_value_heads, 63 | head_dim=head_dim, 64 | hidden_act=hidden_act, 65 | max_position_embeddings=max_position_embeddings, 66 | initializer_range=initializer_range, 67 | rms_norm_eps=rms_norm_eps, 68 | use_cache=use_cache, 69 | pad_token_id=pad_token_id, 70 | bos_token_id=bos_token_id, 71 | eos_token_id=eos_token_id, 72 | tie_word_embeddings=tie_word_embeddings, 73 | rope_theta=rope_theta, 74 | sliding_window=sliding_window, 75 | attention_dropout=attention_dropout, 76 | **kwargs, 77 | ) 78 | -------------------------------------------------------------------------------- /liger/models/liger_mistral_gla/modeling_liger_mistral_gla.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | import copy 4 | from typing import List, Optional, Tuple, Union 5 | from einops import rearrange, repeat 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint 11 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 12 | from transformers.modeling_outputs import ( 13 | BaseModelOutputWithPast, 14 | CausalLMOutputWithPast, 15 | ) 16 | 17 | from transformers.models.mistral.modeling_mistral import ( 18 | MistralRMSNorm, 19 | MistralMLP, 20 | repeat_kv, 21 | apply_rotary_pos_emb, 22 | MistralRotaryEmbedding, 23 | MistralDecoderLayer, 24 | MistralForCausalLM, 25 | MistralModel, 26 | MistralPreTrainedModel, 27 | ) 28 | from transformers.utils import logging, add_start_docstrings_to_model_forward 29 | from transformers.utils import is_flash_attn_2_available 30 | 31 | if is_flash_attn_2_available(): 32 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 33 | 34 | from fla.modules.activations import swish 35 | from fla.models.utils import Cache as FlaCache 36 | from fla.ops.gla import fused_chunk_gla, fused_recurrent_gla 37 | 38 | from liger.models.liger_mistral_gla import LigerMistralGLAConfig 39 | 40 | logger = logging.get_logger(__name__) 41 | 42 | 43 | class LigerMistralGatedLinearAttention(nn.Module): 44 | def __init__( 45 | self, 46 | config: LigerMistralGLAConfig, 47 | layer_idx: Optional[int] = None, 48 | ): 49 | super().__init__() 50 | self.config = config 51 | self.layer_idx = layer_idx 52 | if layer_idx is None: 53 | logger.warning_once( 54 | f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " 55 | "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " 56 | "when creating this class." 57 | ) 58 | 59 | self.attention_dropout = config.attention_dropout 60 | self.hidden_size = config.hidden_size 61 | self.num_heads = config.num_attention_heads 62 | self.head_dim = config.head_dim 63 | self.num_key_value_heads = config.num_key_value_heads 64 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 65 | self.max_position_embeddings = config.max_position_embeddings 66 | self.rope_theta = config.rope_theta 67 | self.is_causal = True 68 | 69 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 70 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 71 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 72 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 73 | 74 | self.rotary_emb = MistralRotaryEmbedding( 75 | self.head_dim, 76 | max_position_embeddings=self.max_position_embeddings, 77 | base=self.rope_theta, 78 | ) 79 | 80 | self.pool_g = nn.AdaptiveAvgPool1d(output_size=self.head_dim * self.num_key_value_heads) 81 | 82 | def forward( 83 | self, 84 | hidden_states: torch.Tensor, 85 | attention_mask: Optional[torch.Tensor] = None, 86 | position_ids: Optional[torch.LongTensor] = None, 87 | past_key_value: Optional[FlaCache] = None, 88 | output_attentions: bool = False, 89 | use_cache: bool = False, 90 | cache_position: Optional[torch.LongTensor] = None, 91 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 92 | **kwargs, 93 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 94 | last_state = None 95 | if past_key_value is not None and len(past_key_value) > self.layer_idx: 96 | last_state = past_key_value[self.layer_idx] 97 | 98 | bsz, q_len, _ = hidden_states.size() 99 | 100 | q = self.q_proj(hidden_states) 101 | k = self.k_proj(hidden_states) 102 | v = self.v_proj(hidden_states) 103 | g = self.pool_g(k) 104 | 105 | q = q.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 106 | k = k.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 107 | v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 108 | g = g.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 109 | 110 | sv = v 111 | cos, sin = self.rotary_emb(sv, position_ids) 112 | sq, sk = apply_rotary_pos_emb(q, k, cos, sin) 113 | 114 | k = repeat_kv(k, self.num_key_value_groups) 115 | v = repeat_kv(v, self.num_key_value_groups) 116 | g = repeat_kv(g, self.num_key_value_groups) 117 | 118 | sk = repeat_kv(sk, self.num_key_value_groups) 119 | sv = repeat_kv(sv, self.num_key_value_groups) 120 | 121 | # norm 122 | q = F.softmax(q, dim=-1) 123 | k = F.softmax(k, dim=-1) 124 | 125 | # dealing with left-padding 126 | if attention_mask is not None: 127 | v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) 128 | 129 | gate_logit_normalizer = 16 130 | g = F.logsigmoid(g) / gate_logit_normalizer # (b, h, n, m) 131 | 132 | recurrent_state = last_state['recurrent_state'] if last_state is not None else None 133 | offsets = kwargs.get('offsets', None) 134 | scale = 1 135 | q, k, v, g = (x.to(torch.float32).contiguous() for x in (q, k, v, g)) 136 | 137 | if self.training or q.shape[-2] > 1: 138 | o_, recurrent_state = fused_chunk_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True) 139 | else: 140 | o_, recurrent_state = fused_recurrent_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True, offsets=offsets) 141 | 142 | if past_key_value is not None: 143 | past_key_value.update( 144 | recurrent_state=recurrent_state, 145 | layer_idx=self.layer_idx, 146 | offset=q.shape[1] 147 | ) 148 | 149 | # if past_key_value is not None: 150 | # # sin and cos are specific to RoPE models; cache_position needed for the static cache 151 | # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 152 | # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 153 | 154 | input_dtype = sq.dtype 155 | if input_dtype == torch.float32: 156 | if torch.is_autocast_enabled(): 157 | target_dtype = torch.get_autocast_gpu_dtype() 158 | # Handle the case where the model is quantized 159 | elif hasattr(self.config, "_pre_quantization_dtype"): 160 | target_dtype = self.config._pre_quantization_dtype 161 | else: 162 | target_dtype = self.q_proj.weight.dtype 163 | 164 | logger.warning_once( 165 | f"The input hidden states seems to be silently casted in float32, this might be related to" 166 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 167 | f" {target_dtype}." 168 | ) 169 | 170 | sq = sq.to(target_dtype) 171 | sk = sk.to(target_dtype) 172 | sv = sv.to(target_dtype) 173 | 174 | window_size = 64 175 | 176 | y = _flash_attention_forward( # Reashape to the expected shape for Flash Attention 177 | sq.transpose(1, 2), 178 | sk.transpose(1, 2), 179 | sv.transpose(1, 2), 180 | attention_mask, 181 | q_len, 182 | position_ids=position_ids, 183 | dropout=0.0, 184 | sliding_window=window_size, 185 | use_top_left_mask=False, 186 | is_causal=True, 187 | target_dtype=torch.float32, 188 | ).transpose(1, 2) 189 | 190 | o_ = 0.5 * y + 0.5 * o_ 191 | o = rearrange(o_.bfloat16(), 'b h n d -> b n (h d)') 192 | o = self.o_proj(o) 193 | 194 | return o, o_, past_key_value 195 | 196 | 197 | class LigerMistralGLADecoderLayer(MistralDecoderLayer): 198 | def __init__(self, config: LigerMistralGLAConfig, layer_idx: int): 199 | super().__init__(config, layer_idx) 200 | self.hidden_size = config.hidden_size 201 | self.self_attn = LigerMistralGatedLinearAttention(config=config, layer_idx=layer_idx) 202 | self.mlp = MistralMLP(config) 203 | self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 204 | self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 205 | 206 | 207 | class LigerMistralPreTrainedModel(MistralPreTrainedModel): 208 | config_class = LigerMistralGLAConfig 209 | base_model_prefix = "model" 210 | supports_gradient_checkpointing = True 211 | _no_split_modules = ["MistralDecoderLayer"] 212 | _skip_keys_device_placement = "past_key_values" 213 | 214 | 215 | class LigerMistralGLAModel(MistralModel, LigerMistralPreTrainedModel): 216 | def __init__(self, config: LigerMistralGLAConfig): 217 | super().__init__(config) 218 | self.layers = nn.ModuleList( 219 | [LigerMistralGLADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 220 | ) 221 | 222 | def forward( 223 | self, 224 | input_ids: torch.LongTensor = None, 225 | attention_mask: Optional[torch.Tensor] = None, 226 | position_ids: Optional[torch.LongTensor] = None, 227 | past_key_values: Optional[Union[Tuple, FlaCache, List[torch.FloatTensor]]] = None, 228 | inputs_embeds: Optional[torch.FloatTensor] = None, 229 | use_cache: Optional[bool] = None, 230 | output_attentions: Optional[bool] = None, 231 | output_hidden_states: Optional[bool] = None, 232 | return_dict: Optional[bool] = None, 233 | cache_position: Optional[torch.LongTensor] = None, 234 | ) -> Union[Tuple, BaseModelOutputWithPast]: 235 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 236 | output_hidden_states = ( 237 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 238 | ) 239 | use_cache = use_cache if use_cache is not None else self.config.use_cache 240 | 241 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 242 | 243 | # retrieve input_ids and inputs_embeds 244 | if (input_ids is None) ^ (inputs_embeds is not None): 245 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds") 246 | 247 | if self.gradient_checkpointing and self.training and use_cache: 248 | logger.warning_once( 249 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 250 | ) 251 | use_cache = False 252 | 253 | if inputs_embeds is None: 254 | inputs_embeds = self.embed_tokens(input_ids) 255 | 256 | # kept for BC (non `Cache` `past_key_values` inputs) 257 | return_legacy_cache = False 258 | if use_cache and not isinstance(past_key_values, FlaCache): 259 | past_key_values = FlaCache.from_legacy_cache(past_key_values) 260 | 261 | if cache_position is None: 262 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 263 | cache_position = torch.arange( 264 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 265 | ) 266 | 267 | if position_ids is None: 268 | position_ids = cache_position.unsqueeze(0) 269 | 270 | causal_mask = self._update_causal_mask( 271 | attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions 272 | ) 273 | 274 | hidden_states = inputs_embeds 275 | 276 | # decoder layers 277 | all_hidden_states = () if output_hidden_states else None 278 | all_self_attns = () if output_attentions else None 279 | next_decoder_cache = None 280 | 281 | for decoder_layer in self.layers: 282 | if output_hidden_states: 283 | all_hidden_states += (hidden_states,) 284 | 285 | if self.gradient_checkpointing and self.training: 286 | layer_outputs = self._gradient_checkpointing_func( 287 | decoder_layer.__call__, 288 | hidden_states, 289 | causal_mask, 290 | position_ids, 291 | past_key_values, 292 | output_attentions, 293 | use_cache, 294 | cache_position, 295 | ) 296 | else: 297 | layer_outputs = decoder_layer( 298 | hidden_states, 299 | attention_mask=causal_mask, 300 | position_ids=position_ids, 301 | past_key_value=past_key_values, 302 | output_attentions=output_attentions, 303 | use_cache=use_cache, 304 | cache_position=cache_position, 305 | ) 306 | 307 | hidden_states = layer_outputs[0] 308 | 309 | if use_cache: 310 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 311 | 312 | if output_attentions: 313 | all_self_attns += (layer_outputs[1],) 314 | 315 | hidden_states = self.norm(hidden_states) 316 | 317 | # add hidden states from the last decoder layer 318 | if output_hidden_states: 319 | all_hidden_states += (hidden_states,) 320 | 321 | # next_cache = next_decoder_cache if use_cache else None 322 | # if return_legacy_cache: 323 | # next_cache = next_cache.to_legacy_cache() 324 | next_cache = next_decoder_cache if use_cache else None 325 | 326 | if not return_dict: 327 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 328 | return BaseModelOutputWithPast( 329 | last_hidden_state=hidden_states, 330 | past_key_values=next_cache, 331 | hidden_states=all_hidden_states, 332 | attentions=all_self_attns, 333 | ) 334 | 335 | class LigerMistralGLAForCausalLM(MistralForCausalLM, LigerMistralPreTrainedModel): 336 | _tied_weights_keys = ["lm_head.weight"] 337 | _tp_plan = {"lm_head": "colwise_rep"} 338 | 339 | def __init__(self, config: LigerMistralGLAConfig): 340 | super().__init__(config) 341 | self.model = LigerMistralGLAModel(config) 342 | self.vocab_size = config.vocab_size 343 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 344 | 345 | # Initialize weights and apply final processing 346 | self.post_init() -------------------------------------------------------------------------------- /liger/models/liger_qwen2_gla/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from liger.models.liger_qwen2_gla.configuration_liger_qwen2_gla import LigerQwen2GLAConfig 4 | from liger.models.liger_qwen2_gla.modeling_liger_qwen2_gla import LigerQwen2GLAForCausalLM, LigerQwen2GLAModel 5 | 6 | AutoConfig.register(LigerQwen2GLAConfig.model_type, LigerQwen2GLAConfig) 7 | AutoModel.register(LigerQwen2GLAConfig, LigerQwen2GLAModel) 8 | AutoModelForCausalLM.register(LigerQwen2GLAConfig, LigerQwen2GLAForCausalLM) 9 | 10 | 11 | __all__ = ['LigerQwen2GLAConfig', 'LigerQwen2GLAForCausalLM', 'LigerQwen2GLAModel'] -------------------------------------------------------------------------------- /liger/models/liger_qwen2_gla/configuration_liger_qwen2_gla.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.models.qwen2.configuration_qwen2 import Qwen2Config 7 | 8 | class LigerQwen2GLAConfig(Qwen2Config, PretrainedConfig): 9 | model_type = "liger_qwen2_gla" 10 | keys_to_ignore_at_inference = ["past_key_values"] 11 | 12 | def __init__( 13 | self, 14 | vocab_size=151936, 15 | hidden_size=4096, 16 | intermediate_size=22016, 17 | num_hidden_layers=32, 18 | num_attention_heads=32, 19 | num_key_value_heads=32, 20 | hidden_act="silu", 21 | max_position_embeddings=32768, 22 | initializer_range=0.02, 23 | rms_norm_eps=1e-6, 24 | use_cache=True, 25 | tie_word_embeddings=False, 26 | rope_theta=10000.0, 27 | rope_scaling=None, 28 | use_sliding_window=False, 29 | sliding_window=4096, 30 | max_window_layers=28, 31 | attention_dropout=0.0, 32 | **kwargs, 33 | ): 34 | super().__init__( 35 | vocab_size=vocab_size, 36 | hidden_size=hidden_size, 37 | intermediate_size=intermediate_size, 38 | num_hidden_layers=num_hidden_layers, 39 | num_attention_heads=num_attention_heads, 40 | num_key_value_heads=num_key_value_heads, 41 | hidden_act=hidden_act, 42 | max_position_embeddings=max_position_embeddings, 43 | initializer_range=initializer_range, 44 | rms_norm_eps=rms_norm_eps, 45 | use_cache=use_cache, 46 | tie_word_embeddings=tie_word_embeddings, 47 | rope_theta=rope_theta, 48 | rope_scaling=rope_scaling, 49 | use_sliding_window=use_sliding_window, 50 | sliding_window=sliding_window, 51 | max_window_layers=max_window_layers, 52 | attention_dropout=attention_dropout, 53 | **kwargs, 54 | ) -------------------------------------------------------------------------------- /liger/models/liger_qwen2_gla/modeling_liger_qwen2_gla.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | import copy 4 | from typing import List, Optional, Tuple, Union 5 | from einops import rearrange, repeat 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint 11 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 12 | from transformers.generation import GenerationMixin 13 | from transformers.modeling_outputs import ( 14 | BaseModelOutputWithPast, 15 | CausalLMOutputWithPast, 16 | ) 17 | from transformers.models.qwen2.modeling_qwen2 import ( 18 | repeat_kv, 19 | apply_rotary_pos_emb, 20 | Qwen2RotaryEmbedding, 21 | Qwen2RMSNorm, 22 | Qwen2MLP, 23 | Qwen2Attention, 24 | Qwen2DecoderLayer, 25 | Qwen2PreTrainedModel, 26 | Qwen2Model, 27 | Qwen2ForCausalLM, 28 | ) 29 | from transformers.utils import ( 30 | LossKwargs, 31 | add_code_sample_docstrings, 32 | add_start_docstrings, 33 | add_start_docstrings_to_model_forward, 34 | logging, 35 | replace_return_docstrings, 36 | ) 37 | 38 | from transformers.utils import is_flash_attn_2_available 39 | 40 | if is_flash_attn_2_available(): 41 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 42 | else: 43 | print("flash_attn_2 is not available") 44 | 45 | from fla.models.utils import Cache as FlaCache 46 | from fla.ops.gla import fused_chunk_gla, fused_recurrent_gla 47 | 48 | from .configuration_liger_qwen2_gla import LigerQwen2GLAConfig 49 | 50 | logger = logging.get_logger(__name__) 51 | 52 | class LigerQwen2GatedLinearAttention(nn.Module): 53 | def __init__( 54 | self, 55 | config: LigerQwen2GLAConfig, 56 | layer_idx: Optional[int] = None, 57 | ): 58 | super().__init__() 59 | self.config = config 60 | self.layer_idx = layer_idx 61 | if layer_idx is None: 62 | logger.warning_once( 63 | f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " 64 | "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " 65 | "when creating this class." 66 | ) 67 | 68 | self.hidden_size = config.hidden_size 69 | self.num_heads = config.num_attention_heads 70 | self.head_dim = self.hidden_size // self.num_heads 71 | self.num_key_value_heads = config.num_key_value_heads 72 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 73 | self.max_position_embeddings = config.max_position_embeddings 74 | self.rope_theta = config.rope_theta 75 | self.is_causal = True 76 | self.attention_dropout = config.attention_dropout 77 | 78 | if (self.head_dim * self.num_heads) != self.hidden_size: 79 | raise ValueError( 80 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 81 | f" and `num_heads`: {self.num_heads})." 82 | ) 83 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) 84 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) 85 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) 86 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 87 | 88 | self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) 89 | self.pool_g = nn.AdaptiveAvgPool1d(output_size=self.head_dim * self.num_key_value_heads) 90 | 91 | 92 | def forward( 93 | self, 94 | hidden_states: torch.Tensor, 95 | attention_mask: Optional[torch.Tensor] = None, 96 | position_ids: Optional[torch.LongTensor] = None, 97 | past_key_value: Optional[FlaCache] = None, 98 | output_attentions: bool = False, 99 | use_cache: bool = False, 100 | cache_position: Optional[torch.LongTensor] = None, 101 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 102 | **kwargs, 103 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 104 | last_state = None 105 | if past_key_value is not None and len(past_key_value) > self.layer_idx: 106 | last_state = past_key_value[self.layer_idx] 107 | 108 | bsz, q_len, _ = hidden_states.size() 109 | 110 | query_states = self.q_proj(hidden_states) 111 | key_states = self.k_proj(hidden_states) 112 | value_states = self.v_proj(hidden_states) 113 | g = self.pool_g(key_states) 114 | 115 | query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 116 | key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 117 | value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 118 | g = g.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 119 | 120 | # if position_embeddings is None: 121 | # logger.warning_once( 122 | # "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 123 | # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " 124 | # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " 125 | # "removed and `position_embeddings` will be mandatory." 126 | # ) 127 | # cos, sin = self.rotary_emb(value_states, position_ids) 128 | # else: 129 | # cos, sin = position_embeddings 130 | # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 131 | 132 | # if past_key_value is not None: 133 | # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models 134 | # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs) 135 | 136 | # repeat k/v heads if n_kv_heads < n_heads 137 | q = query_states 138 | k = repeat_kv(key_states, self.num_key_value_groups) 139 | v = repeat_kv(value_states, self.num_key_value_groups) 140 | g = repeat_kv(g, self.num_key_value_groups) 141 | 142 | sq, sk, sv = q, k, v 143 | 144 | # norm 145 | q = F.softmax(q, dim=-1) 146 | k = F.softmax(k, dim=-1) 147 | 148 | gate_logit_normalizer = 16 149 | g = F.logsigmoid(g) / gate_logit_normalizer # (b, h, n, m) 150 | 151 | recurrent_state = last_state['recurrent_state'] if last_state is not None else None 152 | offsets = kwargs.get('offsets', None) 153 | scale = 1 154 | q, k, v, g = (x.to(torch.float32).contiguous() for x in (q, k, v, g)) 155 | 156 | if self.training or q.shape[-2] > 1: 157 | o_, recurrent_state = fused_chunk_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True) 158 | else: 159 | o_, recurrent_state = fused_recurrent_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True) 160 | 161 | if past_key_value is not None: 162 | past_key_value.update( 163 | recurrent_state=recurrent_state, 164 | layer_idx=self.layer_idx, 165 | offset=q.shape[1] 166 | ) 167 | 168 | if position_embeddings is None: 169 | logger.warning_once( 170 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 171 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " 172 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " 173 | "removed and `position_embeddings` will be mandatory." 174 | ) 175 | cos, sin = self.rotary_emb(sv, position_ids) 176 | else: 177 | cos, sin = position_embeddings 178 | sq, sk = apply_rotary_pos_emb(sq, sk, cos, sin) 179 | 180 | input_dtype = sq.dtype 181 | if input_dtype == torch.float32: 182 | if torch.is_autocast_enabled(): 183 | target_dtype = torch.get_autocast_gpu_dtype() 184 | # Handle the case where the model is quantized 185 | elif hasattr(self.config, "_pre_quantization_dtype"): 186 | target_dtype = self.config._pre_quantization_dtype 187 | else: 188 | target_dtype = self.q_proj.weight.dtype 189 | 190 | logger.warning_once( 191 | f"The input hidden states seems to be silently casted in float32, this might be related to" 192 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 193 | f" {target_dtype}." 194 | ) 195 | 196 | sq = sq.to(target_dtype) 197 | sk = sk.to(target_dtype) 198 | sv = sv.to(target_dtype) 199 | 200 | window_size = 64 201 | if attention_mask is not None and 0.0 in attention_mask: 202 | pass 203 | else: 204 | attention_mask = None 205 | 206 | y = _flash_attention_forward( # Reashape to the expected shape for Flash Attention 207 | sq.transpose(1, 2), 208 | sk.transpose(1, 2), 209 | sv.transpose(1, 2), 210 | attention_mask, 211 | q_len, 212 | position_ids=position_ids, 213 | dropout=0.0, 214 | sliding_window=window_size, 215 | use_top_left_mask=False, 216 | is_causal=True, 217 | target_dtype=torch.float32, 218 | ).transpose(1, 2) 219 | o_ = 0.5 * y + 0.5 * o_ 220 | o = rearrange(o_.bfloat16(), 'b h n d -> b n (h d)') 221 | o = self.o_proj(o) 222 | 223 | return o, None, past_key_value 224 | 225 | class LigerQwen2DecoderLayer(Qwen2DecoderLayer): 226 | def __init__(self, config: LigerQwen2GLAConfig, layer_idx: int): 227 | super().__init__(config, layer_idx) 228 | self.hidden_size = config.hidden_size 229 | 230 | if config.sliding_window and config._attn_implementation != "flash_attention_2": 231 | logger.warning_once( 232 | f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " 233 | "unexpected results may be encountered." 234 | ) 235 | self.self_attn = LigerQwen2GatedLinearAttention(config, layer_idx) 236 | self.mlp = Qwen2MLP(config) 237 | self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 238 | self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 239 | 240 | class LigerQwen2PreTrainedModel(Qwen2PreTrainedModel): 241 | 242 | config_class = LigerQwen2GLAConfig 243 | base_model_prefix = "model" 244 | supports_gradient_checkpointing = True 245 | _no_split_modules = ["Qwen2DecoderLayer"] 246 | _skip_keys_device_placement = "past_key_values" 247 | 248 | class LigerQwen2GLAModel(Qwen2Model, LigerQwen2PreTrainedModel): 249 | 250 | def __init__(self, config: LigerQwen2GLAConfig): 251 | super().__init__(config) 252 | self.padding_idx = config.pad_token_id 253 | self.vocab_size = config.vocab_size 254 | 255 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 256 | self.layers = nn.ModuleList( 257 | [LigerQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 258 | ) 259 | self._attn_implementation = config._attn_implementation 260 | self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 261 | self.rotary_emb = Qwen2RotaryEmbedding(config=config) 262 | 263 | self.gradient_checkpointing = False 264 | # Initialize weights and apply final processing 265 | self.post_init() 266 | 267 | def forward( 268 | self, 269 | input_ids: torch.LongTensor = None, 270 | attention_mask: Optional[torch.Tensor] = None, 271 | position_ids: Optional[torch.LongTensor] = None, 272 | past_key_values: Optional[Union[Tuple, FlaCache, List[torch.FloatTensor]]] = None, 273 | inputs_embeds: Optional[torch.FloatTensor] = None, 274 | use_cache: Optional[bool] = None, 275 | output_attentions: Optional[bool] = None, 276 | output_hidden_states: Optional[bool] = None, 277 | return_dict: Optional[bool] = None, 278 | cache_position: Optional[torch.LongTensor] = None, 279 | **kwargs, 280 | ) -> Union[Tuple, BaseModelOutputWithPast]: 281 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 282 | output_hidden_states = ( 283 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 284 | ) 285 | use_cache = use_cache if use_cache is not None else self.config.use_cache 286 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 287 | 288 | if (input_ids is None) ^ (inputs_embeds is not None): 289 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds") 290 | 291 | if self.gradient_checkpointing and self.training and use_cache: 292 | logger.warning_once( 293 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." 294 | ) 295 | use_cache = False 296 | 297 | 298 | # kept for BC (non `Cache` `past_key_values` inputs) 299 | return_legacy_cache = False 300 | if use_cache and not isinstance(past_key_values, FlaCache): 301 | past_key_values = FlaCache.from_legacy_cache(past_key_values) 302 | 303 | if inputs_embeds is None: 304 | inputs_embeds = self.embed_tokens(input_ids) 305 | 306 | if cache_position is None: 307 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 308 | cache_position = torch.arange( 309 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 310 | ) 311 | 312 | if position_ids is None: 313 | position_ids = cache_position.unsqueeze(0) 314 | 315 | causal_mask = attention_mask 316 | 317 | hidden_states = inputs_embeds 318 | 319 | # create position embeddings to be shared across the decoder layers 320 | position_embeddings = self.rotary_emb(hidden_states, position_ids) 321 | 322 | # decoder layers 323 | all_hidden_states = () if output_hidden_states else None 324 | all_self_attns = () if output_attentions else None 325 | next_decoder_cache = None 326 | 327 | if output_attentions: 328 | all_softmax_hidden_states = () 329 | 330 | for decoder_layer in self.layers: 331 | if output_hidden_states: 332 | all_hidden_states += (hidden_states,) 333 | if all_softmax_hidden_states is not None: 334 | all_softmax_hidden_states += (hidden_states,) 335 | 336 | if self.gradient_checkpointing and self.training: 337 | layer_outputs = self._gradient_checkpointing_func( 338 | decoder_layer.__call__, 339 | hidden_states, 340 | causal_mask, 341 | position_ids, 342 | past_key_values, 343 | output_attentions, 344 | use_cache, 345 | cache_position, 346 | position_embeddings, 347 | ) 348 | 349 | else: 350 | layer_outputs = decoder_layer( 351 | hidden_states, 352 | attention_mask=causal_mask, 353 | position_ids=position_ids, 354 | past_key_value=past_key_values, 355 | output_attentions=output_attentions, 356 | use_cache=use_cache, 357 | cache_position=cache_position, 358 | position_embeddings=position_embeddings, 359 | ) 360 | hidden_states = layer_outputs[0] 361 | 362 | if use_cache: 363 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 364 | 365 | if output_attentions: 366 | all_self_attns += (layer_outputs[1],) 367 | 368 | hidden_states = self.norm(hidden_states) 369 | 370 | # add hidden states from the last decoder layer 371 | if output_hidden_states: 372 | all_hidden_states += (hidden_states,) 373 | 374 | next_cache = next_decoder_cache if use_cache else None 375 | 376 | if not return_dict: 377 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 378 | return BaseModelOutputWithPast( 379 | last_hidden_state=hidden_states, 380 | past_key_values=next_cache, 381 | hidden_states=all_hidden_states, 382 | attentions=all_self_attns, 383 | ) 384 | 385 | class LigerQwen2GLAForCausalLM(LigerQwen2PreTrainedModel, Qwen2ForCausalLM, GenerationMixin): 386 | _tied_weights_keys = ["lm_head.weight"] 387 | _tp_plan = {"lm_head": "colwise_rep"} 388 | 389 | def __init__(self, config): 390 | super().__init__(config) 391 | self.model = LigerQwen2GLAModel(config) 392 | self.vocab_size = config.vocab_size 393 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 394 | 395 | # Initialize weights and apply final processing 396 | self.post_init() -------------------------------------------------------------------------------- /lolcats/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from lolcats.models.lolcats import LolcatsConfig, LolcatsModel, LolcatsModelForCausalLM 4 | 5 | __all__ = [ 6 | 'LolcatsConfig', 'LolcatsModel', 'LolcatsModelForCausalLM' 7 | ] -------------------------------------------------------------------------------- /lolcats/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from lolcats.models.lolcats import LolcatsConfig, LolcatsModel, LolcatsModelForCausalLM 4 | 5 | __all__ = [ 6 | 'LolcatsConfig', 'LolcatsModel', 'LolcatsModelForCausalLM' 7 | ] -------------------------------------------------------------------------------- /lolcats/models/lolcats/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from lolcats.models.lolcats.configuration_lolcats import LolcatsConfig 4 | from lolcats.models.lolcats.modeling_lolcats import LolcatsModel, LolcatsModelForCausalLM 5 | 6 | AutoConfig.register(LolcatsConfig.model_type, LolcatsConfig) 7 | AutoModel.register(LolcatsConfig, LolcatsModel) 8 | AutoModelForCausalLM.register(LolcatsConfig, LolcatsModelForCausalLM) 9 | 10 | 11 | __all__ = ['LolcatsConfig', 'LolcatsModelForCausalLM', 'LolcatsModel'] -------------------------------------------------------------------------------- /lolcats/models/lolcats/configuration_lolcats.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.modeling_rope_utils import rope_config_validation 7 | from transformers.models.llama.configuration_llama import LlamaConfig 8 | 9 | class LolcatsConfig(LlamaConfig, PretrainedConfig): 10 | 11 | model_type = 'lolcats' 12 | keys_to_ignore_at_inference = ['past_key_values'] 13 | 14 | def __init__( 15 | self, 16 | # llama config 17 | vocab_size=32000, 18 | hidden_size=4096, 19 | intermediate_size=11008, 20 | num_hidden_layers=32, 21 | num_attention_heads=32, 22 | num_key_value_heads=None, 23 | hidden_act="silu", 24 | max_position_embeddings=2048, 25 | initializer_range=0.02, 26 | rms_norm_eps=1e-6, 27 | use_cache=True, 28 | pad_token_id=None, 29 | bos_token_id=1, 30 | eos_token_id=2, 31 | pretraining_tp=1, 32 | tie_word_embeddings=False, 33 | rope_theta=10000.0, 34 | rope_scaling=None, 35 | attention_bias=False, 36 | attention_dropout=0.0, 37 | mlp_bias=False, 38 | head_dim=None, 39 | # linear attention 40 | attn_mode: str = "fused_chunk", 41 | expand_k: int = 1, 42 | expand_v: int = 1, 43 | hidden_ratio: Optional[int] = 4, 44 | # num_heads: int = 4, 45 | # num_kv_heads: Optional[int] = None, 46 | feature_map: str = "lolcats_t2r", 47 | tie_feature_map_qk: bool = False, 48 | norm_q: bool = True, 49 | norm_k: bool = True, 50 | norm_feature_map: bool = False, 51 | elementwise_affine: Optional[bool] = True, 52 | norm_eps: float = 1e-6, 53 | attn: Optional[Dict] = None, 54 | fuse_cross_entropy: bool = True, 55 | **kwargs 56 | ): 57 | 58 | # linear attention settings 59 | self.attn_mode = attn_mode 60 | self.expand_k = expand_k 61 | self.expand_v = expand_v 62 | self.hidden_ratio = hidden_ratio 63 | # self.num_heads = num_heads 64 | # self.num_kv_heads = num_kv_heads 65 | self.feature_map = feature_map 66 | self.tie_feature_map_qk = tie_feature_map_qk 67 | self.norm_q = norm_q 68 | self.norm_k = norm_k 69 | self.norm_feature_map = norm_feature_map 70 | self.max_position_embeddings = max_position_embeddings 71 | self.elementwise_affine = elementwise_affine 72 | self.norm_eps = norm_eps 73 | self.attn = attn 74 | self.fuse_cross_entropy = fuse_cross_entropy 75 | 76 | if attn is not None: 77 | if not isinstance(attn, Dict): 78 | raise ValueError("attn must be a dictionary") 79 | if 'layers' not in attn: 80 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 81 | if 'num_heads' not in attn: 82 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 83 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 84 | attn['window_size'] = attn.get('window_size', None) 85 | 86 | super().__init__( 87 | vocab_size=vocab_size, 88 | hidden_size=hidden_size, 89 | intermediate_size=intermediate_size, 90 | num_hidden_layers=num_hidden_layers, 91 | num_attention_heads=num_attention_heads, 92 | num_key_value_heads=num_key_value_heads, 93 | hidden_act=hidden_act, 94 | max_position_embeddings=max_position_embeddings, 95 | initializer_range=initializer_range, 96 | rms_norm_eps=rms_norm_eps, 97 | use_cache=use_cache, 98 | pad_token_id=pad_token_id, 99 | bos_token_id=bos_token_id, 100 | eos_token_id=eos_token_id, 101 | pretraining_tp=pretraining_tp, 102 | tie_word_embeddings=tie_word_embeddings, 103 | rope_theta=rope_theta, 104 | rope_scaling=rope_scaling, 105 | attention_bias=attention_bias, 106 | attention_dropout=attention_dropout, 107 | mlp_bias=mlp_bias, 108 | head_dim=head_dim, 109 | **kwargs, 110 | ) 111 | 112 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.1 2 | triton==3.1.0 3 | transformers==4.47.1 4 | datasets 5 | peft 6 | evaluate 7 | omegaconf -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import random 5 | import numpy as np 6 | from omegaconf import OmegaConf 7 | import torch 8 | 9 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 10 | 11 | from training.train import train 12 | 13 | 14 | def set_random_seed(seed=0): 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | torch.backends.cudnn.benchmark = False # if benchmark=True, deterministic will be False 22 | torch.backends.cudnn.deterministic = True # choose a deterministic algorithm 23 | 24 | def get_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--cfg", type=str, default="configs/liger.yaml") 27 | args = parser.parse_args() 28 | return args 29 | 30 | def main(): 31 | set_random_seed(seed=0) 32 | args = get_args() 33 | config = OmegaConf.load(args.cfg) 34 | output_dir = args.cfg.split('/')[-1].split('.')[0] 35 | config.train.output_dir = os.path.join(config.train.output_dir, output_dir) # 'checkpoints/${filename}' 36 | train(config) 37 | 38 | if __name__ == "__main__": 39 | main() -------------------------------------------------------------------------------- /scripts/train_liger.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | python run.py --cfg configs/liger_gla.yaml -------------------------------------------------------------------------------- /scripts/train_lolcats_stage1.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | python run.py --cfg configs/lolcats_at.yaml -------------------------------------------------------------------------------- /scripts/train_lolcats_stage2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=7 2 | 3 | python run.py --cfg configs/lolcats_ar.yaml -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSparseLLMs/Linearization/0e3cfae33a700fa5f644cf5752d8434c6afc2412/training/__init__.py -------------------------------------------------------------------------------- /training/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import random 4 | from tqdm import tqdm 5 | from functools import partial 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from datasets import Dataset as HFDataset 9 | from datasets import load_dataset, load_from_disk 10 | import evaluate 11 | from huggingface_hub import hf_hub_download 12 | 13 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM 14 | from transformers import DataCollatorForSeq2Seq 15 | 16 | 17 | PROMPT_DICT = { 18 | "prompt_input": ( 19 | "Below is an instruction that describes a task, paired with an input that provides further context. " 20 | "Write a response that appropriately completes the request.\n\n" 21 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" 22 | ), 23 | "prompt_no_input": ( 24 | "Below is an instruction that describes a task. " 25 | "Write a response that appropriately completes the request.\n\n" 26 | "### Instruction:\n{instruction}\n\n### Response:\n" 27 | ), 28 | } 29 | 30 | def encode_response(response: str, tokenizer) -> list[int]: 31 | tokens = tokenizer.encode(response.strip(), add_special_tokens=False) 32 | # For Llama 3 Instruct: tokens.append(tokenizer.get_added_vocab()["<|eot_id|>"]) 33 | tokens.append(tokenizer.eos_token_id) 34 | try: # Llama 3 Instruct 35 | tokens.append(tokenizer.get_added_vocab()["<|end_of_text|>"]) 36 | except KeyError: 37 | pass 38 | return tokens 39 | 40 | def load_data(config): 41 | cache_dir = "/root/.cache" 42 | input_len = config.model.max_length 43 | concat_data = True 44 | 45 | tokenizer_path = config.model.pretrained_model_name_or_path 46 | tokenizer_name = tokenizer_path.split('/')[-1] 47 | 48 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 49 | if tokenizer.pad_token is None: 50 | tokenizer.pad_token = tokenizer.eos_token 51 | print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}') 52 | 53 | tokenizer.padding_side = 'left' # for decoder-only generation 54 | # Get initial data 55 | ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs'] 56 | dataset_config = { 57 | "name": "default", 58 | "path": "yahma/alpaca-cleaned", 59 | "chunk_size": input_len, 60 | "concat_data": concat_data, 61 | "cache_dir": cache_dir, 62 | } 63 | dataset = load_dataset( 64 | **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs} 65 | ) 66 | dataset = dataset['train'] 67 | train_set = convert_to_hf_dataset([dataset[ix] for ix in range(200, len(dataset))], cache_dir) 68 | val_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir) 69 | test_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir) 70 | 71 | # Convert to dicts of {input_ids, attention_mask, labels} 72 | train_set = train_set.map( 73 | partial(template_and_tokenize, tokenizer=tokenizer, include_label=True), 74 | remove_columns=list(dataset.features),) 75 | val_set = val_set.map( 76 | partial(template_and_tokenize, tokenizer=tokenizer, include_label=True), 77 | remove_columns=list(dataset.features),) 78 | test_set = test_set.map( 79 | partial(template_and_tokenize, tokenizer=tokenizer, include_label=False), 80 | remove_columns=list(dataset.features),) 81 | 82 | # Chunk together train and val sets 83 | if concat_data: 84 | train_set = ConcatDataset(train_set, chunk_size=input_len) 85 | val_set = ConcatDataset(val_set, chunk_size=input_len) 86 | 87 | loader_kwargs = { 88 | "batch_size": config.data.micro_batch_size, 89 | "num_workers": 0, 90 | "drop_last": False, 91 | "pin_memory": True, 92 | } 93 | 94 | # Get dataloaders 95 | dataloaders = { 96 | 'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs), 97 | 'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs), 98 | 'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs), 99 | } 100 | # Evaluation metric 101 | try: 102 | # metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge 103 | metric = evaluate.load(download_metric(), 'gov_report') 104 | except Exception as e: 105 | print(f'Error loading metric: {e}') 106 | metric = None 107 | 108 | # Finishing touches 109 | for k, v in dataloaders.items(): # Make tokenizer accessible 110 | dataloaders[k].dataset.tokenizer = tokenizer 111 | dataloaders[k].dataset.metric = metric 112 | return dataloaders 113 | 114 | 115 | def convert_to_hf_dataset(dataset, cache_dir: str): 116 | """ 117 | Convert iterable dataset to HuggingFace HFDataset object 118 | """ 119 | def gen(): 120 | for _, sample in enumerate(dataset): 121 | yield sample # dataset[idx] 122 | return HFDataset.from_generator(gen, cache_dir=cache_dir) 123 | 124 | def template_and_tokenize(sample, tokenizer, include_label: bool = True): 125 | """ 126 | Format dataset context and answers into single-sequence prompts 127 | """ 128 | if sample.get('input', '') == '': 129 | prompt = PROMPT_DICT["prompt_no_input"].format_map(sample) 130 | else: 131 | prompt = PROMPT_DICT["prompt_input"].format_map(sample) 132 | 133 | prompt = tokenizer.encode(prompt, add_special_tokens=True) 134 | if include_label: 135 | answer = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}', 136 | add_special_tokens=False) 137 | target = None 138 | else: 139 | answer = [] 140 | target = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}', 141 | add_special_tokens=False) 142 | input_ids = prompt + answer 143 | attn_mask = [1] * len(input_ids) 144 | 145 | sample = { 146 | "input_ids": input_ids, 147 | "attention_mask" : attn_mask, 148 | "labels": [-100] * len(prompt) + answer if include_label else target, 149 | } 150 | return sample 151 | 152 | def get_lm_loader(dataset: Dataset, tokenizer: AutoTokenizer, 153 | split: str, max_length: int = None, **loader_kwargs: any): 154 | """ 155 | Get dataloader for language modeling (training) 156 | -> Currently this ends up being the same as get_seq2seq_loader 157 | """ 158 | # collate_fn = DefaultDataCollator(return_tensors='pt') 159 | # collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, padding=True, 160 | # max_length=max_length, return_tensors='pt') 161 | collate_fn = DataCollatorForSeq2Seq( 162 | tokenizer, label_pad_token_id=-100, return_tensors='pt') 163 | return DataLoader( 164 | dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs) 165 | 166 | def get_seq2seq_loader(dataset: Dataset, tokenizer: AutoTokenizer, 167 | split: str, **loader_kwargs: any): 168 | """ 169 | Get dataloader for seq2seq tasks (evaluation) 170 | """ 171 | tokenizer.padding_side = 'right' 172 | collate_fn = DataCollatorForSeq2Seq( 173 | tokenizer, label_pad_token_id=-100, return_tensors='pt') 174 | return DataLoader( 175 | dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs) 176 | 177 | def download_metric(): 178 | """ 179 | Download ROUGE, F1, and other accuracy metrics included in the SCROLLS dataset 180 | """ 181 | scrolls_metric_path = hf_hub_download( 182 | repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset" 183 | ) 184 | updated_scrolls_metric_path = ( 185 | os.path.dirname(scrolls_metric_path) + 186 | os.path.basename(scrolls_metric_path).replace(".", "_") + ".py" 187 | ) 188 | shutil.copy(scrolls_metric_path, updated_scrolls_metric_path) 189 | return updated_scrolls_metric_path 190 | 191 | class ConcatDataset(Dataset): 192 | """ 193 | Concatenates or packs samples of a dataset into chunks of size `chunk_size` 194 | """ 195 | def __init__(self, dataset, chunk_size: int = 1024, seed: int = 42,) -> None: 196 | self.dataset = dataset 197 | self.chunk_size = chunk_size 198 | self.samples = [] 199 | buffer = { 200 | "input_ids": [], 201 | "attention_mask": [], 202 | "labels": [], 203 | } 204 | random.seed(seed) 205 | for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): 206 | buffer = {k: v + sample[k] for k,v in buffer.items()} 207 | 208 | while len(next(iter(buffer.values()))) > self.chunk_size: 209 | self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) 210 | buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} 211 | # Slow hack, but filter out any samples without valid labels (all -100) 212 | self.filtered_samples = [] 213 | for s in self.samples: 214 | if sum(s['labels']) != chunk_size * -100: 215 | self.filtered_samples.append(s) 216 | if len(self.filtered_samples) < len(self.samples): 217 | print(f'OG dataset: {len(self.samples)} samples -> Filtered dataset: {len(self.filtered_samples)}') 218 | print(f'-> Filtered out {len(self.samples) - len(self.filtered_samples)} samples') 219 | 220 | def __getitem__(self, idx): 221 | return self.filtered_samples[idx] 222 | 223 | def __len__(self): 224 | return len(self.filtered_samples) 225 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import fla 10 | import liger 11 | import lolcats 12 | import torch.utils 13 | import torch.utils.data 14 | import torch.utils.data.dataloader 15 | from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig 16 | from transformers import Trainer, TrainingArguments 17 | from peft import LoraConfig, TaskType, PeftModel, get_peft_model 18 | 19 | from training.trainer import DefaultTrainer, FinetuneTrainer 20 | from training.utils import get_optimizer_and_scheduler, count_model_params 21 | from training.dataloader import load_data 22 | 23 | 24 | def train(config): 25 | 26 | trainer = FinetuneTrainer 27 | model_config = AutoConfig.from_pretrained(config.model.pretrained_model_name_or_path) 28 | if config.model.name == "liger_gla": 29 | from liger.models.liger_gla import LigerGLAConfig 30 | liger_model_config = LigerGLAConfig() 31 | elif config.model.name == "liger_gsa": 32 | from liger.models.liger_gsa import LigerGSAConfig 33 | liger_model_config = LigerGSAConfig() 34 | elif config.model.name == "lolcats_at": 35 | # first stage: attention transfer 36 | from lolcats.models.lolcats import LolcatsConfig 37 | liger_model_config = LolcatsConfig() 38 | trainer = DefaultTrainer 39 | elif config.model.name == "lolcats_ar": 40 | # second stage 41 | from lolcats.models.lolcats import LolcatsConfig 42 | liger_model_config = LolcatsConfig() 43 | else: 44 | raise NotImplementedError(config.model.name) 45 | 46 | liger_model_config.__dict__.update(model_config.__dict__) 47 | model_config = liger_model_config 48 | model = AutoModelForCausalLM.from_pretrained( 49 | config.model.pretrained_model_name_or_path, 50 | config=model_config, 51 | device_map="cuda" 52 | ).to(torch.bfloat16) 53 | 54 | 55 | print("Model config:") 56 | print(model_config) 57 | print("Model:") 58 | print(model) 59 | 60 | tokenizer = AutoTokenizer.from_pretrained(config.model.pretrained_model_name_or_path) 61 | tokenizer.pad_token_id = tokenizer.eos_token_id 62 | tokenizer.padding_side = "left" # Allow batched inference 63 | 64 | for name, param in model.named_parameters(): 65 | param.requires_grad = False 66 | if "train_qk" in config.train and config.train.train_qk: 67 | if "self_attn.q_proj" in name: 68 | param.requires_grad = True 69 | elif "self_attn.k_proj" in name: 70 | param.requires_grad = True 71 | if "train_v" in config.train and config.train.train_v and "self_attn.v_proj" in name: 72 | param.requires_grad = True 73 | if "train_o" in config.train and config.train.train_o and "self_attn.o_proj" in name: 74 | param.requires_grad = True 75 | 76 | # LoRA finetune 77 | target_modules = [] 78 | if "train_qk" in config.train and config.train.train_qk and config.train.train_qk_lora: 79 | target_modules.append("self_attn.q_proj") 80 | target_modules.append("self_attn.k_proj") 81 | if "train_v" in config.train and config.train.train_v and config.train.train_v_lora: 82 | target_modules.append("self_attn.v_proj") 83 | if "train_o" in config.train and config.train.train_o and config.train.train_o_lora: 84 | target_modules.append("self_attn.o_proj") 85 | # lolcats attention transfer 86 | if config.model.name == "lolcats_at": 87 | for name, param in model.named_parameters(): 88 | if "feature_map" in name: 89 | param.requires_grad = True 90 | else: 91 | param.requires_grad = False 92 | 93 | 94 | if len(target_modules) != 0: 95 | lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, target_modules=target_modules) 96 | model = get_peft_model(model, peft_config=lora_config) 97 | 98 | # print trainable params count 99 | trainable_params = count_model_params(model, requires_grad=True) 100 | total_params = count_model_params(model, requires_grad=False) 101 | print(f"Model trainable params: {trainable_params}") 102 | print(f"Model total params: {total_params}") 103 | print(f"trainable%: {trainable_params / total_params}") 104 | 105 | gradient_accumulation_steps = config.data.batch_size // config.data.micro_batch_size 106 | 107 | print("Preparing data...") 108 | 109 | dataloaders = load_data(config) 110 | train_loader = dataloaders["train"] 111 | eval_loader = dataloaders["validation"] 112 | 113 | print("Building trainer...") 114 | 115 | training_args = TrainingArguments( 116 | per_device_train_batch_size=config.data.micro_batch_size, 117 | gradient_accumulation_steps=gradient_accumulation_steps, 118 | warmup_steps=0, 119 | num_train_epochs=config.train.epochs, 120 | learning_rate=config.train.lr, 121 | bf16=True, 122 | max_grad_norm=config.train.max_grad_norm, 123 | logging_steps=1, 124 | optim=config.train.optim, 125 | evaluation_strategy="steps" if config.data.val_set_size > 0 else "no", 126 | save_strategy="steps", 127 | eval_steps=200 if config.data.val_set_size > 0 else None, 128 | save_steps=1000, 129 | logging_dir=config.train.output_dir, 130 | output_dir=config.train.output_dir, 131 | save_total_limit=3, 132 | load_best_model_at_end=True if config.data.val_set_size > 0 else False, 133 | # default trainer args 134 | greater_is_better = False, 135 | metric_for_best_model = 'eval/loss', 136 | # wandb 137 | report_to="none" # wandb off "wandb" 138 | ) 139 | 140 | trainer = trainer( 141 | model=model, 142 | train_loader=train_loader, 143 | eval_loader=eval_loader, 144 | args=training_args, 145 | optimizers=get_optimizer_and_scheduler(model, config), 146 | tokenizer=tokenizer, 147 | config=config 148 | ) 149 | 150 | print("Train start") 151 | best_model = trainer.train() 152 | save_path = trainer.save_path + '/best' 153 | best_model.save_pretrained(save_path) 154 | tokenizer.save_pretrained(save_path) 155 | print(f'\n-> Saved best model checkpoint to: {save_path}!') 156 | 157 | print("Train over") 158 | -------------------------------------------------------------------------------- /training/trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pandas as pd 4 | 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | 12 | import sys 13 | import os 14 | import pandas as pd 15 | 16 | from tqdm import tqdm 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.utils.data import DataLoader 22 | 23 | from transformers import AutoModel, AutoModelForCausalLM 24 | from transformers import Trainer, TrainingArguments 25 | 26 | from peft import PeftModel 27 | 28 | class DefaultTrainer(): 29 | # code is modified from: https://github.com/HazyResearch/lolcats/blob/main/src/trainer/default_lm.py 30 | def __init__(self, model, train_loader, eval_loader, args, optimizers, tokenizer, config): 31 | super().__init__() 32 | self.model = model 33 | self.args = args 34 | self.tokenizer = tokenizer 35 | self.config = config 36 | self.type = 'default' 37 | 38 | self.step = 0 # Total steps taken 39 | self.grad_step = 0 # Total gradient updates 40 | self.compute_loss_backprop = False # Whether we backprop in self.compute_loss 41 | 42 | self.optimizer, self.scheduler = optimizers 43 | self.scheduler_step_after_epoch = True 44 | # Dataloaders 45 | self.train_loader = train_loader 46 | self.eval_loader = eval_loader 47 | 48 | self.device = model.device 49 | wandb = None 50 | self.wandb = wandb 51 | 52 | # args 53 | self.metric_for_best_model = self.args.metric_for_best_model 54 | self.num_train_epochs = self.args.num_train_epochs 55 | self.gradient_accumulation_steps = self.args.gradient_accumulation_steps 56 | self.evaluation_strategy = self.args.evaluation_strategy 57 | self.greater_is_better = self.args.greater_is_better 58 | self.is_better = (lambda x, y: x > y if self.args.greater_is_better else x < y) 59 | self.load_best_model_at_end = self.args.load_best_model_at_end 60 | self.logging_steps = self.args.logging_steps 61 | self.max_steps = self.args.max_steps 62 | self.eval_steps = self.args.eval_steps 63 | 64 | max_eval_batches = -1 65 | print_samples = False 66 | initial_eval = True 67 | self.max_eval_batches = max_eval_batches 68 | self.print_samples = print_samples 69 | self.initial_eval = initial_eval 70 | self.save_total_limit = self.args.save_total_limit 71 | self.save_steps = self.args.save_steps # num_save_ckpt_steps 72 | 73 | # Saving metrics 74 | self.train_metrics = {'train/loss': None, 75 | 'train/epoch': None, 76 | 'train/step': None} 77 | self.eval_metrics = {self.metric_for_best_model: None} 78 | self.eval_metrics_by_step = {'eval_step': []} # save all eval metrics 79 | self.criterion = nn.CrossEntropyLoss(reduction='mean') 80 | 81 | save_results = True 82 | save_checkpoints = True 83 | 84 | self.save_results = save_results 85 | self.results_path = None 86 | self.best_val_metric = 0 if self.greater_is_better else 1e10 87 | self.best_val_metric_epoch = 0 88 | self.best_val_metric_step = 0 89 | if save_checkpoints: # Also initializes best_val_metrics 90 | self.init_checkpointing(config=config) 91 | 92 | def train(self) -> nn.Module: 93 | """ 94 | Entire training run 95 | """ 96 | model = self.model 97 | pbar = tqdm(range(self.num_train_epochs), leave=False, colour='white', desc='Training') 98 | for ix, epoch in enumerate(pbar): 99 | model, early_stopping = self.train_step(model, epoch) 100 | if self.evaluation_strategy == 'epoch': 101 | _eval_metrics = self.eval_step(model, step=self.grad_step) 102 | print(f'Epoch {ix} metrics:', _eval_metrics) 103 | if early_stopping: 104 | break 105 | 106 | if self.load_best_model_at_end: # Return best checkpoint 107 | try: 108 | model.from_pretrained(self.best_val_checkpoint_path) 109 | print(f'-> Loading best checkpoint from {self.best_val_checkpoint_path}') 110 | except FileNotFoundError as e: 111 | print(e) 112 | print('-> Returning most recent model instead') 113 | return model 114 | 115 | def train_step(self, model, epoch) -> nn.Module: 116 | if self.gradient_accumulation_steps is None: 117 | accum_iter = 1 118 | else: 119 | accum_iter = self.gradient_accumulation_steps 120 | 121 | model.train() 122 | model.zero_grad() 123 | pbar = tqdm(self.train_loader, leave=False, colour='blue', desc=f'-> Training (epoch {epoch} / {self.args.num_train_epochs})') 124 | total_loss = 0 125 | eval_for_step = False 126 | 127 | # Initial eval 128 | if self.initial_eval: 129 | print('') 130 | print('-> Initial eval') 131 | self.compute_eval_metrics(model, step=self.grad_step) 132 | 133 | # model.to(self.device) 134 | for ix, data in enumerate(pbar): 135 | loss, train_metrics = self.compute_loss(model, data, return_outputs=True) 136 | loss /= accum_iter 137 | if not self.compute_loss_backprop: 138 | # loss.backward() did not occur in compute_loss 139 | try: 140 | with torch.autograd.set_detect_anomaly(True): 141 | loss.backward() 142 | except Exception as e: 143 | breakpoint() 144 | if (self.step + 1) % accum_iter == 0: # and self.step != 0: 145 | self.optimizer.step() 146 | if not self.scheduler_step_after_epoch and self.scheduler is not None: 147 | self.scheduler.step() 148 | self.optimizer.zero_grad() 149 | self.grad_step += 1 150 | if not self.compute_loss_backprop: 151 | loss = loss.detach().cpu().item() 152 | 153 | self.step += 1 154 | if not isinstance(loss, float): 155 | total_loss += loss.item() 156 | else: 157 | total_loss += loss 158 | desc = f"Training epoch {epoch} | loss: {total_loss / (ix + 1):.3f} | lr: {self.optimizer.param_groups[0]['lr']:.5f}" 159 | desc += f' | gradient step: {self.grad_step}' 160 | for k, v in train_metrics.items(): 161 | desc += f' | {k}: {v:.3f}' 162 | pbar.set_description(desc) 163 | 164 | # Logging 165 | if (self.grad_step) % (self.logging_steps): 166 | self.train_metrics['train/loss'] = loss.item() if not isinstance(loss, float) else loss 167 | self.train_metrics['train/epoch'] = epoch 168 | self.train_metrics['train/step'] = self.grad_step 169 | self.train_metrics['train/lr'] = self.optimizer.param_groups[0]['lr'] 170 | for k, v in train_metrics.items(): 171 | self.train_metrics[f'train/{k}'] = v 172 | 173 | if self.wandb is not None: 174 | self.wandb.log(self.train_metrics, step=self.grad_step) 175 | 176 | if self.evaluation_strategy == 'steps': 177 | if (self.grad_step % self.eval_steps == 0 and self.grad_step > 0 and not eval_for_step): 178 | _eval_metrics = self.eval_step(model, step=self.grad_step) 179 | print(f'Grad Step {self.grad_step} eval metrics:', _eval_metrics) 180 | eval_for_step = True 181 | model.train() # Need to set back to train mode 182 | elif self.grad_step == 0 and self.save_steps < 1000 and not eval_for_step: # hack for micros 183 | _eval_metrics = self.eval_step(model, step=self.grad_step) 184 | print(f'Grad Step {self.grad_step} eval metrics:', _eval_metrics) 185 | eval_for_step = True 186 | model.train() # Need to set back to train mode 187 | 188 | elif self.grad_step % self.eval_steps == 0 and self.grad_step > 0 and eval_for_step: 189 | pass 190 | else: 191 | if self.grad_step > 0: 192 | eval_for_step = False 193 | if self.grad_step == self.max_steps: 194 | early_stopping = True 195 | return model, early_stopping 196 | 197 | early_stopping = False 198 | return model, early_stopping 199 | 200 | 201 | def eval_step(self, model: nn.Module, step: int = None, **kwargs: any) -> dict[any]: 202 | """ 203 | Evaluation loop over one epoch 204 | """ 205 | with torch.no_grad(): 206 | self.eval_metrics = self.compute_eval_metrics(model, step=step, **kwargs) 207 | val_metric = self.eval_metrics[self.metric_for_best_model] 208 | 209 | # Save results 210 | if self.wandb is not None: # log to WandB 211 | self.wandb.log(self.eval_metrics, step=self.grad_step) 212 | 213 | if self.results_path is not None: # log to local file 214 | self.eval_metrics_by_step['eval_step'].append(step) 215 | for k, v in self.eval_metrics.items(): 216 | if k not in self.eval_metrics_by_step: 217 | self.eval_metrics_by_step[k] = [v] 218 | else: 219 | self.eval_metrics_by_step[k].append(v) 220 | # Inefficient, but log for experiments results 221 | pd.DataFrame(self.eval_metrics_by_step).to_csv(self.results_path) 222 | 223 | # Save best metric and checkpoint 224 | if self.grad_step % self.eval_steps == 0 and step > 0: 225 | if self.is_better(val_metric, self.best_val_metric): 226 | self.best_val_metric = val_metric 227 | self.best_val_metric_step = self.grad_step 228 | 229 | save_path = self.save_path + '/iter' + '_' + str(step) 230 | self.best_val_checkpoint_path = save_path 231 | model.save_pretrained(save_path) 232 | self.tokenizer.save_pretrained(save_path) 233 | print(f'\n-> Saved best model checkpoint to: {save_path}!') 234 | 235 | if self.grad_step % self.save_steps == 0 and step > 0: 236 | 237 | save_path = self.save_path + '/' + self.type + '_' + str(step) 238 | self.best_val_checkpoint_path = save_path 239 | model.save_pretrained(save_path) 240 | self.tokenizer.save_pretrained(save_path) 241 | print(f'\n-> Saved model checkpoint to: {save_path}!') 242 | 243 | if self.scheduler_step_after_epoch and self.scheduler is not None: 244 | self.scheduler.step(val_metric) 245 | return self.eval_metrics 246 | 247 | def compute_eval_metrics(self, 248 | model: nn.Module, step: int, 249 | max_batches: int = None, 250 | dataloader: DataLoader = None, 251 | **kwargs: any) -> dict[any]: 252 | """ 253 | One evaluation loop over a validation dataset 254 | """ 255 | max_batches = (self.max_eval_batches if max_batches is None else max_batches) 256 | dataloader = self.eval_loader if dataloader is None else dataloader 257 | pbar = tqdm(dataloader, leave=False, colour='green', desc=f'Evaluating at step {step}') 258 | 259 | model.eval() 260 | step_loss = 0 261 | step_eval_metrics = {} 262 | with torch.no_grad(): 263 | for ix, data in enumerate(pbar): 264 | loss, eval_metrics = self.compute_loss(model, data, return_outputs=True) 265 | if not self.compute_loss_backprop: 266 | loss = loss.item() # otherwise already float 267 | if ix == 0: 268 | step_eval_metrics[self.metric_for_best_model] = [loss] 269 | for k, v in eval_metrics.items(): 270 | step_eval_metrics[f'eval/{k}'] = [v] 271 | else: 272 | step_eval_metrics[self.metric_for_best_model].append(loss) 273 | for k, v in eval_metrics.items(): 274 | step_eval_metrics[f'eval/{k}'].append(v) 275 | 276 | step_loss += loss 277 | desc = f"Evaluating at step {step} | loss: {step_loss / (ix + 1):.3f}" 278 | if self.optimizer is not None: 279 | desc += f" | lr: {self.optimizer.param_groups[0]['lr']:.5f}" 280 | pbar.set_description(desc) 281 | if ix == max_batches: 282 | break 283 | 284 | # Average over batches 285 | for k, v in step_eval_metrics.items(): 286 | step_eval_metrics[k] = sum(v) / len(v) 287 | print(f'Eval step {step}:', step_eval_metrics) 288 | del loss 289 | torch.cuda.empty_cache() 290 | return step_eval_metrics 291 | 292 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 293 | inputs = {k: v.to(model.device) for k, v in inputs.items() if k != 'labels'} 294 | outputs = model(**inputs, output_attentions=True) 295 | 296 | outputs = outputs.attentions # tuple [num_decoder_layers, 2, B, H, L, L] 297 | loss_mse = 0 298 | self.mse_factor = 1000 299 | self.criterion_mse = nn.MSELoss(reduction='mean') 300 | n_layers = 0 # Number of layers to distill 301 | 302 | for layer_idx, attns in enumerate(outputs): 303 | if attns is not None: 304 | loss_mse += self.criterion_mse(attns[0], attns[1]) 305 | n_layers += 1 306 | 307 | if n_layers > 0: 308 | loss_mse = loss_mse / n_layers * self.mse_factor 309 | loss = loss_mse 310 | outputs = {'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, 311 | 'mse_factor': self.mse_factor} 312 | 313 | return (loss, outputs) if return_outputs else loss 314 | 315 | def init_checkpointing(self, config) -> None: 316 | self.save_path = config.train.output_dir 317 | self.best_val_checkpoint_path = config.train.output_dir 318 | 319 | # Best metric setup 320 | self.best_val_metric = 0 if self.greater_is_better else 1e10 321 | self.best_val_metric_epoch = 0 322 | self.best_val_metric_step = 0 323 | self.best_train_metric = 0 if self.greater_is_better else 1e10 324 | self.best_train_metric_epoch = 0 325 | self.best_train_metric_step = 0 326 | self.metric_for_best_model = self.metric_for_best_model 327 | if self.metric_for_best_model is not None: 328 | if 'eval' not in self.metric_for_best_model: 329 | self.metric_for_best_model = f'eval/{self.metric_for_best_model}' 330 | 331 | class FinetuneTrainer(DefaultTrainer): 332 | def __init__(self, model, train_loader, eval_loader, args, optimizers, tokenizer, config): 333 | super().__init__(model, train_loader, eval_loader, args, optimizers, tokenizer, config) 334 | self.type = 'finetune' 335 | 336 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 337 | input_keys = {'input_ids', 'attention_mask'} 338 | data = {k: v.to(model.device) for k, v in inputs.items() if k in input_keys} 339 | outputs = model(**data, output_attentions=False) 340 | outputs = outputs.get('logits')[..., :-1, :].contiguous() 341 | targets = inputs.get('labels')[..., 1:].contiguous() 342 | # Flatten and compute cross-entropy loss 343 | outputs = outputs.view(-1, outputs.shape[-1]) 344 | targets = targets.view(-1).to(outputs.device) 345 | loss = self.criterion(outputs, targets) 346 | 347 | targets = targets.cpu() 348 | outputs = outputs.cpu() 349 | outputs = {'ppl': torch.exp(loss).item(), 'seq_len': targets.shape[-1] + 1} 350 | return (loss, outputs) if return_outputs else loss 351 | -------------------------------------------------------------------------------- /training/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim 4 | 5 | def get_optimizer_and_scheduler(model, config): 6 | params = [p for p in model.parameters() if p.requires_grad] 7 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.train.lr, fused=True) 8 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 9 | optimizer=optimizer, 10 | mode='min', 11 | factor=0.1, 12 | patience=10, 13 | min_lr=0.00001 14 | ) 15 | return optimizer, scheduler 16 | 17 | def count_model_params(model, requires_grad: bool = True): 18 | # code form lolcats 19 | """ 20 | Return total # of trainable parameters 21 | """ 22 | if requires_grad: 23 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 24 | else: 25 | model_parameters = model.parameters() 26 | try: 27 | return sum([np.prod(p.size()) for p in model_parameters]).item() 28 | except: 29 | return sum([np.prod(p.size()) for p in model_parameters]) --------------------------------------------------------------------------------